Skip to content

Core Interfaces

Core interfaces for CRL.

BanditDataset dataclass

Logged contextual bandit dataset.

Estimand

Not applicable.

Assumptions: None. Inputs: contexts: Array with shape (n, d) or (n,). actions: Array with shape (n,) of integer action indices. rewards: Array with shape (n,) of observed rewards. behavior_action_probs: Array with shape (n,) of propensities for actions. action_space_n: Number of discrete actions. metadata: Optional dictionary for provenance. Outputs: Dataset instance with validated fields. Failure modes: Raises ValueError if shapes mismatch or probabilities are invalid.

discount property

Return discount factor for bandits (1.0).

dones property

Bandit data is terminal after each sample.

horizon property

Return horizon length for bandits (1).

next_states property

Bandit data has no next-state field.

num_samples property

Return the number of logged samples.

states property

Alias for contexts to match the core Dataset interface.

describe()

Return summary statistics for the dataset.

fingerprint()

Return a stable fingerprint for the dataset.

from_dataframe(df, *, context_columns, action_column='action', reward_column='reward', behavior_prob_column=None, action_space_n=None, metadata=None) classmethod

Create a LoggedBanditDataset from a pandas DataFrame.

from_dict(data) classmethod

Create a dataset from a serialized dictionary.

from_numpy(*, contexts, actions, rewards, behavior_action_probs=None, action_space_n=None, metadata=None) classmethod

Create a LoggedBanditDataset from numpy arrays.

from_parquet(path, *, context_columns, action_column='action', reward_column='reward', behavior_prob_column=None, action_space_n=None, metadata=None) classmethod

Create a LoggedBanditDataset from a parquet file.

summary()

Alias for describe().

to_dict()

Serialize dataset to a dictionary of arrays.

validate()

Validate shapes and value ranges.

Diagnostics dataclass

Structured diagnostics with standard keys.

to_dict()

Return a dictionary representation.

EstimationReport dataclass

Report returned by estimators.

Estimand

Policy value for the estimator's target policy.

Assumptions: Assumptions are recorded in the estimand and warnings highlight issues. Outputs: value: Estimated policy value. stderr: Estimated standard error, if available. ci: Optional confidence interval (low, high). diagnostics: Dictionary of diagnostic metrics. assumptions_checked: Assumptions required by the estimator. assumptions_flagged: Assumptions flagged by diagnostics. warnings: List of warning strings. metadata: Extra metadata (fit details, configs). Failure modes: Diagnostics may be None if disabled.

save_html(path)

Write report contents to an HTML file.

save_json(path)

Write report contents to a JSON file.

to_dataframe()

Return a one-row pandas DataFrame if pandas is available.

to_dict()

Return a pandas-friendly dict representation.

to_html()

Return a self-contained HTML report representation.

to_json()

Return a JSON string representation.

EstimatorReport dataclass

Report returned by estimators.

Estimand

Policy value for the estimator's target policy.

Assumptions: Assumptions are recorded in the estimand and warnings highlight issues. Outputs: value: Estimated policy value. stderr: Estimated standard error, if available. ci: Optional confidence interval (low, high). diagnostics: Dictionary of diagnostic metrics. assumptions_checked: Assumptions required by the estimator. assumptions_flagged: Assumptions flagged by diagnostics. warnings: List of warning strings. metadata: Extra metadata (fit details, configs). Failure modes: Diagnostics may be None if disabled.

save_html(path)

Write report contents to an HTML file.

save_json(path)

Write report contents to a JSON file.

to_dataframe()

Return a one-row pandas DataFrame if pandas is available.

to_dict()

Return a pandas-friendly dict representation.

to_html()

Return a self-contained HTML report representation.

to_json()

Return a JSON string representation.

OPEEstimator

Bases: ABC

Base class for off-policy evaluation estimators.

Estimand

PolicyValueEstimand.

Assumptions: Each estimator declares required assumptions. Inputs: Dataset-specific objects such as TrajectoryDataset or LoggedBanditDataset. Outputs: EstimatorReport with value, diagnostics, and metadata. Failure modes: Raises ValueError if required assumptions are missing.

estimate(data) abstractmethod

Estimate policy value from data.

Policy

Bases: Protocol

Protocol for policies with discrete or continuous actions.

action_density(observations, actions)

Return action densities for selected actions (continuous).

action_prob(observations, actions)

Return action probabilities for selected actions (discrete).

action_probs(observations)

Return action probabilities for each observation (discrete).

log_prob(observations, actions)

Return log-probabilities or log-densities for selected actions.

sample_action(observations, rng)

Sample actions for observations (optional).

PolicyContrastEstimand dataclass

Contrast between two policy values.

Estimand

V^{pi_treatment} - V^{pi_control}.

Assumptions: Same as PolicyValueEstimand for both policies. Inputs: treatment: Target policy value estimand. control: Control policy value estimand. Outputs: Contrast specification used by estimators or reports. Failure modes: If assumptions differ, the contrast may not be identified.

to_dict()

Return a dictionary representation.

PolicyValueEstimand dataclass

Policy value estimand under intervention.

Estimand

V^pi = E[sum_t gamma^t R_t | do(A_t ~ pi(\cdot | S_t))].

Assumptions: Sequential ignorability, positivity/overlap, and correct data contract. Inputs: policy: Target policy. discount: Discount factor. horizon: Optional horizon for finite episodes. assumptions: AssumptionSet describing identification conditions. Outputs: Estimand specification used by estimators. Failure modes: If required assumptions are missing, estimators should refuse to run.

require(names)

Require that assumptions include the specified names.

to_dict()

Return a dictionary representation of the estimand.

TrajectoryDataset dataclass

Logged finite-horizon trajectory dataset.

Estimand

Not applicable.

Assumptions: None. Inputs: observations: Array with shape (n, t, d) or (n, t). actions: Array with shape (n, t) of integer action indices. rewards: Array with shape (n, t) of rewards. next_observations: Array with shape matching observations. behavior_action_probs: Array with shape (n, t) of propensities. mask: Boolean array with shape (n, t) indicating valid steps. discount: Discount factor in [0, 1]. action_space_n: Number of discrete actions. state_space_n: Optional number of discrete states for one-hot features. metadata: Optional dictionary for provenance. Outputs: Dataset instance with validated fields. Failure modes: Raises ValueError if shapes mismatch or probabilities are invalid.

dones property

Infer terminal flags from the mask (last valid step per trajectory).

horizon property

Return the horizon length (max steps).

next_states property

Alias for next_observations to match the core Dataset interface.

num_steps property

Return the number of valid steps (mask True).

num_trajectories property

Return the number of trajectories.

states property

Alias for observations to match the core Dataset interface.

describe()

Return summary statistics for the dataset.

fingerprint()

Return a stable fingerprint for the dataset.

from_dataframe(df, *, episode_id_column='episode_id', timestep_column='timestep', observation_columns, next_observation_columns, action_column='action', reward_column='reward', behavior_prob_column=None, discount=1.0, action_space_n=None, state_space_n=None, metadata=None) classmethod

Create a TrajectoryDataset from a long-form pandas DataFrame.

from_dict(data) classmethod

Create a dataset from a serialized dictionary.

from_numpy(*, observations, actions, rewards, next_observations, mask=None, discount, action_space_n=None, behavior_action_probs=None, state_space_n=None, metadata=None) classmethod

Create a TrajectoryDataset from numpy arrays.

from_parquet(path, *, episode_id_column='episode_id', timestep_column='timestep', observation_columns, next_observation_columns, action_column='action', reward_column='reward', behavior_prob_column=None, discount=1.0, action_space_n=None, state_space_n=None, metadata=None) classmethod

Create a TrajectoryDataset from a parquet file.

summary()

Alias for describe().

to_dict()

Serialize dataset to a dictionary of arrays.

validate()

Validate shapes and value ranges.

TransitionDataset dataclass

Transition dataset with optional episode ids and timesteps.

horizon property

Return horizon inferred from timesteps if available.

num_steps property

Return the number of transitions.

describe()

Return summary statistics for the dataset.

fingerprint()

Return a stable fingerprint for the dataset.

from_dataframe(df, *, state_columns, next_state_columns, action_column='action', reward_column='reward', done_column='done', behavior_prob_column=None, episode_id_column=None, timestep_column=None, discount=1.0, action_space_n=None, metadata=None) classmethod

Create a TransitionDataset from a pandas DataFrame.

from_parquet(path, *, state_columns, next_state_columns, action_column='action', reward_column='reward', done_column='done', behavior_prob_column=None, episode_id_column=None, timestep_column=None, discount=1.0, action_space_n=None, metadata=None) classmethod

Create a TransitionDataset from a parquet file.

summary()

Alias for describe().

to_dict()

Serialize dataset to a dictionary of arrays.

to_trajectory()

Convert transitions to a TrajectoryDataset if episodes are known.

validate()

Validate shapes, types, and ranges.