Policies¶
Policy interfaces and reference implementations.
CallablePolicy
dataclass
¶
MLPConfig
dataclass
¶
Configuration for MLP construction.
Estimand
Not applicable.
Assumptions: None. Inputs: input_dim: Observation dimension. action_dim: Number of actions. hidden_sizes: Hidden layer sizes. activation: Activation name (relu or tanh). Outputs: Configuration object. Failure modes: None.
Policy
¶
Abstract policy interface for discrete or continuous action spaces.
Estimand
Not applicable.
Assumptions: None. Inputs: observations: Array with shape (n, d) or (n,) representing states. Outputs: action_probs: Array with shape (n, a) for discrete policies. action_density: Array with shape (n,) for continuous policies. Failure modes: Implementations should raise ValueError if probabilities are invalid.
action_density(observations, actions)
¶
Return action densities for selected actions (continuous).
action_prob(observations, actions)
¶
Return probabilities for selected actions (discrete).
action_probs(observations)
¶
Return action probabilities for each observation (discrete).
from_sklearn(model, action_space_n, *, deterministic=None, name=None)
classmethod
¶
Wrap a scikit-learn model as a policy.
If deterministic is True, uses model.predict and treats outputs as actions. Otherwise uses model.predict_proba for action probabilities.
from_torch(model, action_space_n, *, device='cpu', name=None)
classmethod
¶
Wrap a torch model that outputs action logits.
log_prob(observations, actions)
¶
Return log-probability or log-density for selected actions.
sample_action(observations, rng)
¶
Sample actions for observations (optional).
to_dict()
¶
Return a dictionary representation.
TabularPolicy
¶
Bases: Policy
Tabular policy defined by a probability table.
Estimand
Not applicable.
Assumptions: Actions are discrete and state indices are valid. Inputs: table: Array with shape (num_states, num_actions) of action probabilities. Outputs: action_probs: Array with shape (n, num_actions). Failure modes: Raises ValueError if rows do not sum to 1 or contain negative values.
TorchMLPPolicy
¶
Bases: Policy
MLP policy implemented in PyTorch.
Estimand
Not applicable.
Assumptions: Observations are numeric features and actions are discrete. Inputs: model: torch.nn.Module mapping observations to logits. action_dim: Number of discrete actions. device: Torch device string. Outputs: action_probs: Array with shape (n, action_dim). Failure modes: Raises ValueError if logits have the wrong shape.