Dataset Format and Validation¶
This page is the single source of truth for how to build datasets that CausalRL accepts. Every estimator and diagnostic assumes these contracts are respected.
Bandit data: LoggedBanditDataset¶
Use LoggedBanditDataset for contextual bandits (one action per context).
Required fields and shapes:
contexts: shape(n, d)or(n,).actions: shape(n,), integer action indices.rewards: shape(n,).behavior_action_probs: optional shape(n,), values in[0, 1].action_space_n: positive integer.
Minimal example:
import numpy as np
from crl.data.datasets import LoggedBanditDataset
contexts = np.random.normal(size=(100, 5))
actions = np.random.randint(0, 3, size=100)
rewards = np.random.normal(size=100)
behavior_probs = np.full(100, 1.0 / 3.0)
dataset = LoggedBanditDataset(
contexts=contexts,
actions=actions,
rewards=rewards,
behavior_action_probs=behavior_probs,
action_space_n=3,
)
Convenience constructors:
dataset = LoggedBanditDataset.from_numpy(
contexts=contexts,
actions=actions,
rewards=rewards,
behavior_action_probs=behavior_probs,
)
Validation highlights:
actionsmust be integers in[0, action_space_n).behavior_action_probsmust be in[0, 1]and finite.- All arrays must have the same length.
Trajectory data: TrajectoryDataset¶
Use TrajectoryDataset for finite-horizon episodes.
Required fields and shapes:
observations: shape(n, t, d)or(n, t).actions: shape(n, t), integer action indices.rewards: shape(n, t).next_observations: same shape asobservations.mask: shape(n, t), boolean. Valid steps must be a prefix.behavior_action_probs: optional shape(n, t), values in[0, 1]on mask.discount: float in[0, 1].action_space_n: positive integer.state_space_n: optional integer (only for discrete states).
Minimal example:
import numpy as np
from crl.data.datasets import TrajectoryDataset
num_traj = 10
horizon = 5
obs = np.random.normal(size=(num_traj, horizon, 3))
next_obs = np.random.normal(size=(num_traj, horizon, 3))
actions = np.random.randint(0, 4, size=(num_traj, horizon))
rewards = np.random.normal(size=(num_traj, horizon))
mask = np.ones((num_traj, horizon), dtype=bool)
behavior_probs = np.full((num_traj, horizon), 1.0 / 4.0)
dataset = TrajectoryDataset(
observations=obs,
actions=actions,
rewards=rewards,
next_observations=next_obs,
behavior_action_probs=behavior_probs,
mask=mask,
discount=0.99,
action_space_n=4,
)
Convenience constructors:
dataset = TrajectoryDataset.from_dataframe(
df,
observation_columns=["obs"],
next_observation_columns=["next_obs"],
behavior_prob_column="propensity",
discount=0.99,
action_space_n=4,
)
Validation highlights:
maskmust be contiguous: valid steps form a prefix for each trajectory.actions,rewards,mask, andbehavior_action_probs(if present) must share shape(n, t).actionsmust be integer indices in[0, action_space_n)on valid steps.
Transition data: TransitionDataset¶
Use TransitionDataset for flat (s, a, r, s', done) logs.
Required fields and shapes:
states: shape(n, d)or(n,).actions: shape(n,)(or(n, d)for continuous actions).rewards: shape(n,).next_states: same shape asstates.dones: shape(n,).behavior_action_probs: optional shape(n,).discount: float in[0, 1].action_space_n: optional; required for discrete actions.episode_idsandtimesteps: optional; required forto_trajectory().
Minimal example:
import numpy as np
from crl.data.transition import TransitionDataset
states = np.random.normal(size=(100, 4))
next_states = np.random.normal(size=(100, 4))
actions = np.random.randint(0, 2, size=100)
rewards = np.random.normal(size=100)
dones = np.zeros(100, dtype=bool)
behavior_probs = np.full(100, 0.5)
dataset = TransitionDataset(
states=states,
actions=actions,
rewards=rewards,
next_states=next_states,
dones=dones,
behavior_action_probs=behavior_probs,
discount=0.99,
action_space_n=2,
)
To convert transitions to trajectories:
This requires episode_ids and timesteps to be present and discrete actions.
Common validation errors and fixes¶
- Actions out of range: check
action_space_nand action indexing. - Mask not contiguous: ensure each trajectory uses a prefix of valid steps.
- Missing propensities: IS/PDIS/DR require them; FQE can run without but diagnostics will be limited.
- Shape mismatch: verify
(n, t)or(n, t, d)consistency across arrays.
Researcher checklist¶
- Confirm the dataset contract for your estimator.
- Log behavior propensities whenever possible.
- Inspect diagnostics before trusting estimates.
- Use synthetic benchmarks to sanity check workflows.