Skip to content

Policies

Policy interfaces and reference implementations.

BehaviorPolicy dataclass

Bases: Policy

Wrap a policy with metadata about how it was obtained.

CallablePolicy dataclass

Bases: Policy

Wrap a callable that returns actions or action probabilities.

MLPConfig dataclass

Configuration for MLP construction.

Estimand

Not applicable.

Assumptions: None. Inputs: input_dim: Observation dimension. action_dim: Number of actions. hidden_sizes: Hidden layer sizes. activation: Activation name (relu or tanh). Outputs: Configuration object. Failure modes: None.

Policy

Abstract policy interface for discrete or continuous action spaces.

Estimand

Not applicable.

Assumptions: None. Inputs: observations: Array with shape (n, d) or (n,) representing states. Outputs: action_probs: Array with shape (n, a) for discrete policies. action_density: Array with shape (n,) for continuous policies. Failure modes: Implementations should raise ValueError if probabilities are invalid.

action_density(observations, actions)

Return action densities for selected actions (continuous).

action_prob(observations, actions)

Return probabilities for selected actions (discrete).

action_probs(observations)

Return action probabilities for each observation (discrete).

from_sklearn(model, action_space_n, *, deterministic=None, name=None) classmethod

Wrap a scikit-learn model as a policy.

If deterministic is True, uses model.predict and treats outputs as actions. Otherwise uses model.predict_proba for action probabilities.

from_torch(model, action_space_n, *, device='cpu', name=None) classmethod

Wrap a torch model that outputs action logits.

log_prob(observations, actions)

Return log-probability or log-density for selected actions.

sample_action(observations, rng)

Sample actions for observations (optional).

to_dict()

Return a dictionary representation.

StochasticPolicy dataclass

Bases: Policy

Wrap a callable that returns action probabilities.

TabularPolicy

Bases: Policy

Tabular policy defined by a probability table.

Estimand

Not applicable.

Assumptions: Actions are discrete and state indices are valid. Inputs: table: Array with shape (num_states, num_actions) of action probabilities. Outputs: action_probs: Array with shape (n, num_actions). Failure modes: Raises ValueError if rows do not sum to 1 or contain negative values.

num_actions property

Return the number of actions.

num_states property

Return the number of states.

action_probs(observations)

Return action probabilities for each observation.

sample_action(observations, rng)

Sample actions for observations.

to_dict()

Return a dictionary representation.

TorchMLPPolicy

Bases: Policy

MLP policy implemented in PyTorch.

Estimand

Not applicable.

Assumptions: Observations are numeric features and actions are discrete. Inputs: model: torch.nn.Module mapping observations to logits. action_dim: Number of discrete actions. device: Torch device string. Outputs: action_probs: Array with shape (n, action_dim). Failure modes: Raises ValueError if logits have the wrong shape.

action_probs(observations)

Return action probabilities for each observation.

from_config(config, device='cpu') classmethod

Construct a policy from an MLPConfig.

sample_action(observations, rng)

Sample actions for observations.

UniformPolicy

Bases: Policy

Uniform random policy over discrete actions.