# Copyright (c) Yuta Saito, Yusuke Narita, and ZOZO Technologies, Inc. All rights reserved.
# Licensed under the Apache 2.0 License.
"""Base Interfaces for Bandit Algorithms."""
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Optional
import numpy as np
from sklearn.utils import check_random_state
[docs]@dataclass
class BaseContextFreePolicy(metaclass=ABCMeta):
"""Base class for context-free bandit policies.
Parameters
----------
n_actions: int
Number of actions.
len_list: int, default=1
Length of a list of actions recommended in each impression.
When Open Bandit Dataset is used, 3 should be set.
batch_size: int, default=1
Number of samples used in a batch parameter update.
random_state: int, default=None
Controls the random seed in sampling actions.
"""
n_actions: int
len_list: int = 1
batch_size: int = 1
random_state: Optional[int] = None
def __post_init__(self) -> None:
"""Initialize Class."""
assert self.n_actions > 1 and isinstance(
self.n_actions, int
), f"n_actions must be an integer larger than 1, but {self.n_actions} is given"
assert self.len_list > 0 and isinstance(
self.len_list, int
), f"len_list must be a positive integer, but {self.len_list} is given"
assert self.batch_size > 0 and isinstance(
self.batch_size, int
), f"batch_size must be a positive integer, but {self.batch_size} is given"
self.n_trial = 0
self.random_ = check_random_state(self.random_state)
self.action_counts = np.zeros(self.n_actions, dtype=int)
self.action_counts_temp = np.zeros(self.n_actions, dtype=int)
self.reward_counts_temp = np.zeros(self.n_actions)
self.reward_counts = np.zeros(self.n_actions)
@property
def policy_type(self) -> str:
"""Type of the bandit policy."""
return "contextfree"
[docs] def initialize(self) -> None:
"""Initialize Parameters."""
self.n_trial = 0
self.random_ = check_random_state(self.random_state)
self.action_counts = np.zeros(self.n_actions, dtype=int)
self.action_counts_temp = np.zeros(self.n_actions, dtype=int)
self.reward_counts_temp = np.zeros(self.n_actions)
self.reward_counts = np.zeros(self.n_actions)
[docs] @abstractmethod
def select_action(self) -> np.ndarray:
"""Select a list of actions."""
raise NotImplementedError
[docs] @abstractmethod
def update_params(self, action: int, reward: float) -> None:
"""Update policy parameters."""
raise NotImplementedError
[docs]@dataclass
class BaseContextualPolicy(metaclass=ABCMeta):
"""Base class for contextual bandit policies.
Parameters
----------
dim: int
Number of dimensions of context vectors.
n_actions: int
Number of actions.
len_list: int, default=1
Length of a list of actions recommended in each impression.
When Open Bandit Dataset is used, 3 should be set.
batch_size: int, default=1
Number of samples used in a batch parameter update.
alpha_: float, default=1.
Prior parameter for the online logistic regression.
lambda_: float, default=1.
Regularization hyperparameter for the online logistic regression.
random_state: int, default=None
Controls the random seed in sampling actions.
"""
dim: int
n_actions: int
len_list: int = 1
batch_size: int = 1
alpha_: float = 1.0
lambda_: float = 1.0
random_state: Optional[int] = None
def __post_init__(self) -> None:
"""Initialize class."""
assert self.dim > 0 and isinstance(
self.dim, int
), f"dim must be a positive integer, but {self.dim} is given"
assert self.n_actions > 1 and isinstance(
self.n_actions, int
), f"n_actions must be an integer larger than 1, but {self.n_actions} is given"
assert self.len_list > 0 and isinstance(
self.len_list, int
), f"len_list must be a positive integer, but {self.len_list} is given"
assert self.batch_size > 0 and isinstance(
self.batch_size, int
), f"batch_size must be a positive integer, but {self.batch_size} is given"
self.n_trial = 0
self.random_ = check_random_state(self.random_state)
self.alpha_list = self.alpha_ * np.ones(self.n_actions)
self.lambda_list = self.lambda_ * np.ones(self.n_actions)
self.action_counts = np.zeros(self.n_actions, dtype=int)
self.reward_lists = [[] for _ in np.arange(self.n_actions)]
self.context_lists = [[] for _ in np.arange(self.n_actions)]
@property
def policy_type(self) -> str:
"""Type of the bandit policy."""
return "contextual"
[docs] def initialize(self) -> None:
"""Initialize policy parameters."""
self.n_trial = 0
self.random_ = check_random_state(self.random_state)
self.action_counts = np.zeros(self.n_actions, dtype=int)
self.reward_lists = [[] for _ in np.arange(self.n_actions)]
self.context_lists = [[] for _ in np.arange(self.n_actions)]
[docs] @abstractmethod
def select_action(self, context: np.ndarray) -> np.ndarray:
"""Select a list of actions."""
raise NotImplementedError
[docs] @abstractmethod
def update_params(self, action: float, reward: float, context: np.ndarray) -> None:
"""Update policy parameters."""
raise NotImplementedError
[docs]@dataclass
class BaseOfflinePolicyLearner(metaclass=ABCMeta):
"""Base class for off-policy learners.
Parameters
-----------
n_actions: int
Number of actions.
len_list: int, default=1
Length of a list of actions recommended in each impression.
When Open Bandit Dataset is used, 3 should be set.
"""
n_actions: int
len_list: int = 1
def __post_init__(self) -> None:
"""Initialize class."""
assert self.n_actions > 1 and isinstance(
self.n_actions, int
), f"n_actions must be an integer larger than 1, but {self.n_actions} is given"
assert self.len_list > 0 and isinstance(
self.len_list, int
), f"len_list must be a positive integer, but {self.len_list} is given"
@property
def policy_type(self) -> str:
"""Type of the bandit policy."""
return "offline"
[docs] @abstractmethod
def fit(
self,
) -> None:
"""Fits an offline bandit policy using the given logged bandit feedback data."""
raise NotImplementedError
[docs] @abstractmethod
def predict(self, context: np.ndarray) -> np.ndarray:
"""Predict best action for new data.
Parameters
-----------
context: array-like, shape (n_rounds_of_new_data, dim_context)
Context vectors for new data.
Returns
-----------
action: array-like, shape (n_rounds_of_new_data, n_actions, len_list)
Action choices by a policy trained by calling the `fit` method.
"""
raise NotImplementedError