rl8.trainers package

Submodules

rl8.trainers.config module

Configuration for the high-level training interfaces.

class rl8.trainers.config.TrainConfig(env_cls: EnvFactory, env_config: None | dict[str, Any] = None, model_cls: None | RecurrentModelFactory | ModelFactory = None, model_config: None | dict[str, Any] = None, distribution_cls: None | type[rl8.distributions.Distribution] = None, horizon: None | int = None, horizons_per_env_reset: None | int = None, num_envs: None | int = None, seq_len: None | int = None, seqs_per_state_reset: None | int = None, optimizer_cls: None | type[torch.optim.optimizer.Optimizer] = None, optimizer_config: None | dict[str, Any] = None, accumulate_grads: None | bool = None, enable_amp: None | bool = None, entropy_coeff: None | float = None, gae_lambda: None | float = None, gamma: None | float = None, sgd_minibatch_size: None | int = None, num_sgd_iters: None | int = None, shuffle_minibatches: None | bool = None, clip_param: None | float = None, vf_clip_param: None | float = None, dual_clip_param: None | float = None, vf_coeff: None | float = None, target_kl_div: None | float = None, max_grad_norm: None | float = None, normalize_advantages: None | bool = None, normalize_rewards: None | bool = None, device: None | str | device | Literal['auto'] = None, recurrent: bool = False)[source]

Bases: object

A helper for instantiating a trainer based on a config.

It’s common to run training experiments based on some config. This class is introduced to reduce the need for creating some custom trainer config parser. This class doesn’t support all trainer/algorithm options/fields, but supports enough to cover the majority of use cases.

It’s intended for this class to be instantiated from a JSON or YAML file, and then for the instance to build a trainer directly afterwards.

Examples

Assume there’s some YAML config at ./config.yaml with the following contents:

env_cls: rl8.env.DiscreteDummyEnv
horizon: 8
gamma: 1

The following will instantiate a TrainConfig from the file, instantiate a trainer, and then train indefinitely.

>>> from rl8 import TrainConfig
>>> TrainConfig.from_file("./config.yaml").build().run()
env_cls: EnvFactory
env_config: None | dict[str, Any] = None
model_cls: None | RecurrentModelFactory | ModelFactory = None
model_config: None | dict[str, Any] = None
distribution_cls: None | type[rl8.distributions.Distribution] = None
horizon: None | int = None
horizons_per_env_reset: None | int = None
num_envs: None | int = None
seq_len: None | int = None
seqs_per_state_reset: None | int = None
optimizer_cls: None | type[torch.optim.optimizer.Optimizer] = None
optimizer_config: None | dict[str, Any] = None
accumulate_grads: None | bool = None
enable_amp: None | bool = None
entropy_coeff: None | float = None
gae_lambda: None | float = None
gamma: None | float = None
sgd_minibatch_size: None | int = None
num_sgd_iters: None | int = None
shuffle_minibatches: None | bool = None
clip_param: None | float = None
vf_clip_param: None | float = None
dual_clip_param: None | float = None
vf_coeff: None | float = None
target_kl_div: None | float = None
max_grad_norm: None | float = None
normalize_advantages: None | bool = None
normalize_rewards: None | bool = None
device: None | str | device | Literal['auto'] = None
recurrent: bool = False
build() Trainer | RecurrentTrainer[source]

Instantiate a trainer from the train config.

Null fields are removed from the train config before being unpacked into the trainer’s constructor (so the default values on the trainer are used to instantiate the trainer). The trainer type (i.e., recurrent or feedforward) is specified by the ``recurrent` attribute.

Returns:

A trainer based on the train config values.

Examples

>>> from rl8 import DiscreteDummyEnv, TrainConfig
>>> trainer = TrainConfig(DiscreteDummyEnv).build()
classmethod from_file(path: str | Path) TrainConfig[source]

Instantiate a TrainConfig from a JSON or YAML file.

The JSON or YAML file should have fields with the same type as the dataclass fields except for:

  • “env_cls”

  • “model_cls”

  • “distribution_cls”

  • “optimizer_cls”

These fields should be fully qualified paths to their definitions. As an example, if one were to use a custom package my_package with submodule envs and environment class MyEnv, they would set "env_cls" to "my_package.envs.MyEnv".

Definitions specified in these fields will be dynamically imported from their respective packages and modules. A current limitation is these field specifications must point to an installed package and can’t be from relative file locations (e.g., something like "..my_package.envs.MyEnv" will not work).

Parameters:

path – Pathlike to the JSON or YAML file to read.

Returns:

A train config based on the given file.

Module contents

Definitions related to PPO trainers (abstractions over algorithms and interfaces between those algorithms and other tools).

class rl8.trainers.GenericTrainerBase(algorithm: _Algorithm, /)[source]

Bases: Generic[_Algorithm]

The base trainer interface.

All trainers (the higher-level training interfaces) inherit from this class and are bound to a particular algorithm (i.e., one trainer per algorithm).

algorithm: _Algorithm

Underlying PPO algorithm, including the environment, model, action distribution, and hyperparameters.

state: TrainerState

Trainer state used for tracking a handful of running totals necessary for logging metrics, determining when a policy can be evaluated, etc..

eval(*, env_config: None | dict[str, Any] = None, deterministic: bool = True) EvalCollectStats[source]

Run a single evaluation step, collecting environment transitions for several horizons with potentially different environment configs.

Parameters:
  • env_config – Environment config override. Useful for evaluating a policy’s generalizability by setting the environment config to something different from the environment config during training.

  • deterministic – Whether to sample from the policy deterministically. This is usally False during learning and True during evaluation.

Returns:

Eval stats from the collection buffer.

Raises:
  • RuntimeError – If this method is called outside of the underlying algorithm’s horizons_per_env_reset interval.

  • ValueError – If the an eval environment config is provided but the environment isn’t expected to use that eval environment config.

run(*, env_config: None | dict[str, Any] = None, eval_env_config: None | dict[str, Any] = None, steps_per_eval: None | int = None, stop_conditions: None | list[rl8.conditions.Condition] = None) TrainStats[source]

Run the trainer and underlying algorithm until at least of of the stop_conditions is satisfied.

This method runs indefinitely unless at least one stop condition is provided.

Parameters:
  • env_config – Environment config override. Useful for scheduling domain randomization.

  • eval_env_config – Environment config override during evaluations. Defaults to the config provided by env_config if not provided. Useful for evaluating a policy’s generalizability.

  • steps_per_eval – Number of Trainer.step() calls before calling Trainer.eval().

  • stop_conditions – Conditions evaluated each iteration that determines whether to stop training. Only one condition needs to evaluate as True for training to stop. Training will continue indefinitely unless a stop condition returns True.

Returns:

The most recent train stats when the training is stopped due to a stop condition being satisfied.

Raises:

ValueError

  1. If the an eval environment config is provided but the environment isn’t expected to use that eval environment config, and 2) if steps_per_eval is not a factor of the algorithm’s horizons_per_env_reset.

step(*, env_config: None | dict[str, Any] = None) TrainStats[source]

Run a single training step, collecting environment transitions and updating the policy with those transitions.

Parameters:

env_config – Environment config override. Useful for scheduling domain randomization.

Returns:

Train stats from the policy update.

class rl8.trainers.RecurrentTrainer(env_cls: EnvFactory, /, **algorithm_config: Unpack)[source]

Bases: GenericTrainerBase[RecurrentAlgorithm]

Higher-level training interface that interops with other tools for tracking and saving experiments (i.e., MLflow).

This is the preferred training interface when training recurrent policies in most cases.

Parameters:
  • env_cls – Highly parallelized environment for sampling experiences. Instantiated with env_config. Will be stepped for horizon each RecurrentAlgorithm.collect() call.

  • **algorithm_config – See RecurrentAlgorithm.

class rl8.trainers.TrainConfig(env_cls: EnvFactory, env_config: None | dict[str, Any] = None, model_cls: None | RecurrentModelFactory | ModelFactory = None, model_config: None | dict[str, Any] = None, distribution_cls: None | type[rl8.distributions.Distribution] = None, horizon: None | int = None, horizons_per_env_reset: None | int = None, num_envs: None | int = None, seq_len: None | int = None, seqs_per_state_reset: None | int = None, optimizer_cls: None | type[torch.optim.optimizer.Optimizer] = None, optimizer_config: None | dict[str, Any] = None, accumulate_grads: None | bool = None, enable_amp: None | bool = None, entropy_coeff: None | float = None, gae_lambda: None | float = None, gamma: None | float = None, sgd_minibatch_size: None | int = None, num_sgd_iters: None | int = None, shuffle_minibatches: None | bool = None, clip_param: None | float = None, vf_clip_param: None | float = None, dual_clip_param: None | float = None, vf_coeff: None | float = None, target_kl_div: None | float = None, max_grad_norm: None | float = None, normalize_advantages: None | bool = None, normalize_rewards: None | bool = None, device: None | str | device | Literal['auto'] = None, recurrent: bool = False)[source]

Bases: object

A helper for instantiating a trainer based on a config.

It’s common to run training experiments based on some config. This class is introduced to reduce the need for creating some custom trainer config parser. This class doesn’t support all trainer/algorithm options/fields, but supports enough to cover the majority of use cases.

It’s intended for this class to be instantiated from a JSON or YAML file, and then for the instance to build a trainer directly afterwards.

Examples

Assume there’s some YAML config at ./config.yaml with the following contents:

env_cls: rl8.env.DiscreteDummyEnv
horizon: 8
gamma: 1

The following will instantiate a TrainConfig from the file, instantiate a trainer, and then train indefinitely.

>>> from rl8 import TrainConfig
>>> TrainConfig.from_file("./config.yaml").build().run()
env_cls: EnvFactory
env_config: None | dict[str, Any] = None
model_cls: None | RecurrentModelFactory | ModelFactory = None
model_config: None | dict[str, Any] = None
distribution_cls: None | type[rl8.distributions.Distribution] = None
horizon: None | int = None
horizons_per_env_reset: None | int = None
num_envs: None | int = None
seq_len: None | int = None
seqs_per_state_reset: None | int = None
optimizer_cls: None | type[torch.optim.optimizer.Optimizer] = None
optimizer_config: None | dict[str, Any] = None
accumulate_grads: None | bool = None
enable_amp: None | bool = None
entropy_coeff: None | float = None
gae_lambda: None | float = None
gamma: None | float = None
sgd_minibatch_size: None | int = None
num_sgd_iters: None | int = None
shuffle_minibatches: None | bool = None
clip_param: None | float = None
vf_clip_param: None | float = None
dual_clip_param: None | float = None
vf_coeff: None | float = None
target_kl_div: None | float = None
max_grad_norm: None | float = None
normalize_advantages: None | bool = None
normalize_rewards: None | bool = None
device: None | str | device | Literal['auto'] = None
recurrent: bool = False
build() Trainer | RecurrentTrainer[source]

Instantiate a trainer from the train config.

Null fields are removed from the train config before being unpacked into the trainer’s constructor (so the default values on the trainer are used to instantiate the trainer). The trainer type (i.e., recurrent or feedforward) is specified by the ``recurrent` attribute.

Returns:

A trainer based on the train config values.

Examples

>>> from rl8 import DiscreteDummyEnv, TrainConfig
>>> trainer = TrainConfig(DiscreteDummyEnv).build()
classmethod from_file(path: str | Path) TrainConfig[source]

Instantiate a TrainConfig from a JSON or YAML file.

The JSON or YAML file should have fields with the same type as the dataclass fields except for:

  • “env_cls”

  • “model_cls”

  • “distribution_cls”

  • “optimizer_cls”

These fields should be fully qualified paths to their definitions. As an example, if one were to use a custom package my_package with submodule envs and environment class MyEnv, they would set "env_cls" to "my_package.envs.MyEnv".

Definitions specified in these fields will be dynamically imported from their respective packages and modules. A current limitation is these field specifications must point to an installed package and can’t be from relative file locations (e.g., something like "..my_package.envs.MyEnv" will not work).

Parameters:

path – Pathlike to the JSON or YAML file to read.

Returns:

A train config based on the given file.

class rl8.trainers.Trainer(env_cls: EnvFactory, /, **algorithm_config: Unpack)[source]

Bases: GenericTrainerBase[Algorithm]

Higher-level training interface that interops with other tools for tracking and saving experiments (i.e., MLflow).

This is the preferred training interface when training feedforward (i.e., non-recurrent) policies in most cases.

Parameters:
  • env_cls – Highly parallelized environment for sampling experiences. Instantiated with env_config. Will be stepped for horizon each Algorithm.collect() call.

  • **algorithm_config – See Algorithm.