rl8.trainers package
Submodules
rl8.trainers.config module
Configuration for the high-level training interfaces.
- class rl8.trainers.config.TrainConfig(env_cls: EnvFactory, env_config: None | dict[str, Any] = None, model_cls: None | RecurrentModelFactory | ModelFactory = None, model_config: None | dict[str, Any] = None, distribution_cls: None | type[rl8.distributions.Distribution] = None, horizon: None | int = None, horizons_per_env_reset: None | int = None, num_envs: None | int = None, seq_len: None | int = None, seqs_per_state_reset: None | int = None, optimizer_cls: None | type[torch.optim.optimizer.Optimizer] = None, optimizer_config: None | dict[str, Any] = None, accumulate_grads: None | bool = None, enable_amp: None | bool = None, entropy_coeff: None | float = None, gae_lambda: None | float = None, gamma: None | float = None, sgd_minibatch_size: None | int = None, num_sgd_iters: None | int = None, shuffle_minibatches: None | bool = None, clip_param: None | float = None, vf_clip_param: None | float = None, dual_clip_param: None | float = None, vf_coeff: None | float = None, target_kl_div: None | float = None, max_grad_norm: None | float = None, normalize_advantages: None | bool = None, normalize_rewards: None | bool = None, device: None | str | device | Literal['auto'] = None, recurrent: bool = False)[source]
Bases:
object
A helper for instantiating a trainer based on a config.
It’s common to run training experiments based on some config. This class is introduced to reduce the need for creating some custom trainer config parser. This class doesn’t support all trainer/algorithm options/fields, but supports enough to cover the majority of use cases.
It’s intended for this class to be instantiated from a JSON or YAML file, and then for the instance to build a trainer directly afterwards.
Examples
Assume there’s some YAML config at
./config.yaml
with the following contents:env_cls: rl8.env.DiscreteDummyEnv horizon: 8 gamma: 1
The following will instantiate a
TrainConfig
from the file, instantiate a trainer, and then train indefinitely.>>> from rl8 import TrainConfig >>> TrainConfig.from_file("./config.yaml").build().run()
- env_cls: EnvFactory
- model_cls: None | RecurrentModelFactory | ModelFactory = None
- distribution_cls: None | type[rl8.distributions.Distribution] = None
- optimizer_cls: None | type[torch.optim.optimizer.Optimizer] = None
- build() Trainer | RecurrentTrainer [source]
Instantiate a trainer from the train config.
Null fields are removed from the train config before being unpacked into the trainer’s constructor (so the default values on the trainer are used to instantiate the trainer). The trainer type (i.e., recurrent or feedforward) is specified by the ``recurrent` attribute.
- Returns:
A trainer based on the train config values.
Examples
>>> from rl8 import DiscreteDummyEnv, TrainConfig >>> trainer = TrainConfig(DiscreteDummyEnv).build()
- classmethod from_file(path: str | Path) TrainConfig [source]
Instantiate a
TrainConfig
from a JSON or YAML file.The JSON or YAML file should have fields with the same type as the dataclass fields except for:
“env_cls”
“model_cls”
“distribution_cls”
“optimizer_cls”
These fields should be fully qualified paths to their definitions. As an example, if one were to use a custom package
my_package
with submoduleenvs
and environment classMyEnv
, they would set"env_cls"
to"my_package.envs.MyEnv"
.Definitions specified in these fields will be dynamically imported from their respective packages and modules. A current limitation is these field specifications must point to an installed package and can’t be from relative file locations (e.g., something like
"..my_package.envs.MyEnv"
will not work).- Parameters:
path – Pathlike to the JSON or YAML file to read.
- Returns:
A train config based on the given file.
Module contents
Definitions related to PPO trainers (abstractions over algorithms and interfaces between those algorithms and other tools).
- class rl8.trainers.GenericTrainerBase(algorithm: _Algorithm, /)[source]
Bases:
Generic
[_Algorithm
]The base trainer interface.
All trainers (the higher-level training interfaces) inherit from this class and are bound to a particular algorithm (i.e., one trainer per algorithm).
- algorithm: _Algorithm
Underlying PPO algorithm, including the environment, model, action distribution, and hyperparameters.
- state: TrainerState
Trainer state used for tracking a handful of running totals necessary for logging metrics, determining when a policy can be evaluated, etc..
- eval(*, env_config: None | dict[str, Any] = None, deterministic: bool = True) EvalCollectStats [source]
Run a single evaluation step, collecting environment transitions for several horizons with potentially different environment configs.
- Parameters:
env_config – Environment config override. Useful for evaluating a policy’s generalizability by setting the environment config to something different from the environment config during training.
deterministic – Whether to sample from the policy deterministically. This is usally
False
during learning andTrue
during evaluation.
- Returns:
Eval stats from the collection buffer.
- Raises:
RuntimeError – If this method is called outside of the underlying algorithm’s
horizons_per_env_reset
interval.ValueError – If the an eval environment config is provided but the environment isn’t expected to use that eval environment config.
- run(*, env_config: None | dict[str, Any] = None, eval_env_config: None | dict[str, Any] = None, steps_per_eval: None | int = None, stop_conditions: None | list[rl8.conditions.Condition] = None) TrainStats [source]
Run the trainer and underlying algorithm until at least of of the
stop_conditions
is satisfied.This method runs indefinitely unless at least one stop condition is provided.
- Parameters:
env_config – Environment config override. Useful for scheduling domain randomization.
eval_env_config – Environment config override during evaluations. Defaults to the config provided by
env_config
if not provided. Useful for evaluating a policy’s generalizability.steps_per_eval – Number of
Trainer.step()
calls before callingTrainer.eval()
.stop_conditions – Conditions evaluated each iteration that determines whether to stop training. Only one condition needs to evaluate as
True
for training to stop. Training will continue indefinitely unless a stop condition returnsTrue
.
- Returns:
The most recent train stats when the training is stopped due to a stop condition being satisfied.
- Raises:
If the an eval environment config is provided but the environment isn’t expected to use that eval environment config, and 2) if
steps_per_eval
is not a factor of the algorithm’shorizons_per_env_reset
.
- step(*, env_config: None | dict[str, Any] = None) TrainStats [source]
Run a single training step, collecting environment transitions and updating the policy with those transitions.
- Parameters:
env_config – Environment config override. Useful for scheduling domain randomization.
- Returns:
Train stats from the policy update.
- class rl8.trainers.RecurrentTrainer(env_cls: EnvFactory, /, **algorithm_config: Unpack)[source]
Bases:
GenericTrainerBase
[RecurrentAlgorithm
]Higher-level training interface that interops with other tools for tracking and saving experiments (i.e., MLflow).
This is the preferred training interface when training recurrent policies in most cases.
- Parameters:
env_cls – Highly parallelized environment for sampling experiences. Instantiated with
env_config
. Will be stepped forhorizon
eachRecurrentAlgorithm.collect()
call.**algorithm_config – See
RecurrentAlgorithm
.
- class rl8.trainers.TrainConfig(env_cls: EnvFactory, env_config: None | dict[str, Any] = None, model_cls: None | RecurrentModelFactory | ModelFactory = None, model_config: None | dict[str, Any] = None, distribution_cls: None | type[rl8.distributions.Distribution] = None, horizon: None | int = None, horizons_per_env_reset: None | int = None, num_envs: None | int = None, seq_len: None | int = None, seqs_per_state_reset: None | int = None, optimizer_cls: None | type[torch.optim.optimizer.Optimizer] = None, optimizer_config: None | dict[str, Any] = None, accumulate_grads: None | bool = None, enable_amp: None | bool = None, entropy_coeff: None | float = None, gae_lambda: None | float = None, gamma: None | float = None, sgd_minibatch_size: None | int = None, num_sgd_iters: None | int = None, shuffle_minibatches: None | bool = None, clip_param: None | float = None, vf_clip_param: None | float = None, dual_clip_param: None | float = None, vf_coeff: None | float = None, target_kl_div: None | float = None, max_grad_norm: None | float = None, normalize_advantages: None | bool = None, normalize_rewards: None | bool = None, device: None | str | device | Literal['auto'] = None, recurrent: bool = False)[source]
Bases:
object
A helper for instantiating a trainer based on a config.
It’s common to run training experiments based on some config. This class is introduced to reduce the need for creating some custom trainer config parser. This class doesn’t support all trainer/algorithm options/fields, but supports enough to cover the majority of use cases.
It’s intended for this class to be instantiated from a JSON or YAML file, and then for the instance to build a trainer directly afterwards.
Examples
Assume there’s some YAML config at
./config.yaml
with the following contents:env_cls: rl8.env.DiscreteDummyEnv horizon: 8 gamma: 1
The following will instantiate a
TrainConfig
from the file, instantiate a trainer, and then train indefinitely.>>> from rl8 import TrainConfig >>> TrainConfig.from_file("./config.yaml").build().run()
- env_cls: EnvFactory
- model_cls: None | RecurrentModelFactory | ModelFactory = None
- distribution_cls: None | type[rl8.distributions.Distribution] = None
- optimizer_cls: None | type[torch.optim.optimizer.Optimizer] = None
- build() Trainer | RecurrentTrainer [source]
Instantiate a trainer from the train config.
Null fields are removed from the train config before being unpacked into the trainer’s constructor (so the default values on the trainer are used to instantiate the trainer). The trainer type (i.e., recurrent or feedforward) is specified by the ``recurrent` attribute.
- Returns:
A trainer based on the train config values.
Examples
>>> from rl8 import DiscreteDummyEnv, TrainConfig >>> trainer = TrainConfig(DiscreteDummyEnv).build()
- classmethod from_file(path: str | Path) TrainConfig [source]
Instantiate a
TrainConfig
from a JSON or YAML file.The JSON or YAML file should have fields with the same type as the dataclass fields except for:
“env_cls”
“model_cls”
“distribution_cls”
“optimizer_cls”
These fields should be fully qualified paths to their definitions. As an example, if one were to use a custom package
my_package
with submoduleenvs
and environment classMyEnv
, they would set"env_cls"
to"my_package.envs.MyEnv"
.Definitions specified in these fields will be dynamically imported from their respective packages and modules. A current limitation is these field specifications must point to an installed package and can’t be from relative file locations (e.g., something like
"..my_package.envs.MyEnv"
will not work).- Parameters:
path – Pathlike to the JSON or YAML file to read.
- Returns:
A train config based on the given file.
- class rl8.trainers.Trainer(env_cls: EnvFactory, /, **algorithm_config: Unpack)[source]
Bases:
GenericTrainerBase
[Algorithm
]Higher-level training interface that interops with other tools for tracking and saving experiments (i.e., MLflow).
This is the preferred training interface when training feedforward (i.e., non-recurrent) policies in most cases.
- Parameters:
env_cls – Highly parallelized environment for sampling experiences. Instantiated with
env_config
. Will be stepped forhorizon
eachAlgorithm.collect()
call.**algorithm_config – See
Algorithm
.