Source code for rl8.algorithms._feedforward

from dataclasses import dataclass
from typing import Any, Literal

import torch
import torch.amp as amp
import torch.nn as nn
import torch.optim as optim
from torch.amp.grad_scaler import GradScaler
from torchrl.data import Composite, Unbounded

from .._utils import Batcher, StatTracker, assert_nd_spec, profile_ms
from ..data import (
    AlgorithmHparams,
    AlgorithmState,
    CollectStats,
    DataKeys,
    Device,
    StepStats,
)
from ..distributions import Distribution
from ..env import EnvFactory
from ..models import Model, ModelFactory
from ..nn import generalized_advantage_estimate, ppo_losses
from ..policies import Policy
from ..schedulers import EntropyScheduler, LRScheduler, ScheduleKind
from ._base import GenericAlgorithmBase


[docs]@dataclass class AlgorithmConfig: """Algorith config for building a feedforward PPO algorithm.""" #: Model instance to use. Mutually exclusive with ``model_cls``. model: None | Model = None #: Optional custom policy model definition. A model class #: is provided for you based on the environment instance's specs #: if you don't provide one. Defaults to a simple feedforward #: neural network. model_cls: None | ModelFactory = None #: Optional policy model config unpacked into the model #: during instantiation. model_config: None | dict[str, Any] = None #: Custom policy action distribution class. #: If not provided, an action distribution class is inferred from the #: environment specs. Defaults to a categorical distribution for discrete #: actions and a normal distribution for continuous actions. Complex #: actions are not supported by default distributions. distribution_cls: None | type[Distribution] = None #: Number of environment transitions to collect during :meth:`Algorithm.collect`. #: The environment resets according to ``horizons_per_env_reset``. #: Buffer size is [B, T] where T = horizon. horizon: int = 32 #: Number of times :meth:`Algorithm.collect` can be called before resetting #: :attr:`Algorithm.env`. Increase this for cross-horizon learning. Default #: 1 resets after every horizon. horizons_per_env_reset: int = 1 #: Number of parallelized environment instances. #: Determines buffer size [B, T] where B = num_envs. num_envs: int = 8192 #: Custom optimizer class. Defaults to a simple, low-tuning optimizer. optimizer_cls: type[optim.Optimizer] = optim.Adam #: Configuration passed to the optimizer during instantiation. optimizer_config: None | dict[str, Any] = None #: Whether to accumulate gradients across minibatches before stepping the #: optimizer. Increases effective batch size while minimizing memory usage. accumulate_grads: bool = False #: Whether to enable Automatic Mixed Precision (AMP) for faster and more #: memory-efficient training. enable_amp: bool = False #: Optional schedule controlling the optimizer's learning rate over #: environment transitions. Keeps learning rate constant if not provided. lr_schedule: None | list[tuple[int, float]] = None #: Learning rate scheduler type if lr_schedule is provided. #: Options: ``"step"`` (jump and hold) or ``"interp"`` (interpolate between #: values). lr_schedule_kind: ScheduleKind = "step" #: Entropy coefficient weight in total loss. #: Ignored if ``entropy_coeff_schedule`` is provided. entropy_coeff: float = 0.0 #: Optional schedule overriding entropy_coeff based on number of environment #: transitions. entropy_coeff_schedule: None | list[tuple[int, float]] = None #: Entropy scheduler type. Options: #: ``"step"``: jump and hold, ``"interp"``: interpolate between values. entropy_coeff_schedule_kind: ScheduleKind = "step" #: Generalized Advantage Estimation (GAE) λ parameter for controlling the #: variance and bias tradeoff when estimating the state value function #: from collected environment transitions. A higher value allows higher #: variance while a lower value allows higher bias estimation but lower #: variance. gae_lambda: float = 0.95 #: Discount reward factor often used in the Bellman operator for controlling #: the variance and bias tradeoff in collected experienced rewards. Note, #: this does not control the bias/variance of the state value estimation #: and only controls the weight future rewards have on the total discounted #: return. gamma: float = 0.95 #: PPO hyperparameter for minibatch size during policy update. #: Larger minibatches reduce update variance and accelerate CUDA computations. #: If ``None``, the entire buffer is treated as one batch. sgd_minibatch_size: None | int = None #: PPO hyperparameter for number of SGD iterations over the collected buffer. num_sgd_iters: int = 4 #: Whether to shuffle minibatches within :meth:`Algorithm.step`. Recommended, but #: not necessary if the minibatch size is large enough (e.g., the buffer is the #: batch). shuffle_minibatches: bool = True #: PPO hyperparameter indicating the max distance the policy can update away from #: previously collected policy sample data with respect to likelihoods of taking #: actions conditioned on observations. This is the main innovation of PPO. clip_param: float = 0.2 #: PPO hyperparameter similar to ``clip_param`` but for the value function estimate. #: A measure of max distance the model's value function is allowed to update away #: from previous value function samples. vf_clip_param: float = 5.0 #: PPO hyperparameter that clips like ``clip_param`` but when advantage estimations #: are negative. Helps prevent instability for continuous action spaces when #: policies are making large updates. Leave ``None`` for this clip to not apply. #: Otherwise, typical values are around ``5``. dual_clip_param: None | float = None #: Value function loss component weight. Only needs to be tuned when the policy #: and value function share parameters. vf_coeff: float = 1.0 #: Target maximum KL divergence when updating the policy. If approximate KL #: divergence is greater than this value, then policy updates stop early for #: that algorithm step. If this is left `None then early stopping doesn't occur. #: A higher value means the policy is allowed to diverge more from the previous #: policy during updates. target_kl_div: None | float = None #: Max gradient norm allowed when updating the policy's model within #: :meth:`Algorithm.step`. max_grad_norm: float = 5.0 #: Whether to normalize advantages computed for GAE using the batch's #: mean and standard deviation. This has been shown to generally improve #: convergence speed and performance and should usually be ``True``. normalize_advantages: bool = True #: Whether to normalize rewards using reversed discounted returns as #: from https://arxiv.org/pdf/2005.12729.pdf. Reward normalization, #: although not exactly correct and optimal, typically improves #: convergence speed and performance and should usually be ``True``. normalize_rewards: bool = True #: Device :attr:`Algorithm.env`, :attr:`Algorithm.buffer`, and #: :attr:`Algorithm.policy` all reside on. device: Device | Literal["auto"] = "auto"
[docs] def build(self, env_cls: EnvFactory) -> "Algorithm": """Build and validate an :class:Algorithm` from a config.""" algo = Algorithm(env_cls, config=self) algo.validate() return algo
[docs]class Algorithm(GenericAlgorithmBase[AlgorithmHparams, AlgorithmState, Policy]): """An optimized feedforward `PPO`_ algorithm with common tricks for stabilizing and accelerating learning. Args: env_cls: Highly parallelized environment for sampling experiences. Will be stepped for ``horizon`` each :meth:`Algorithm.collect` call. config: Algorithm config for building a feedforward PPO algorithm. See :class:`AlgorithmConfig` for all parameters. Examples: Instantiate an algorithm for a dummy environment and update the underlying policy once. >>> from rl8 import AlgorithmConfig >>> from rl8.env import DiscreteDummyEnv >>> algo = AlgorithmConfig().build(DiscreteDummyEnv) >>> algo.collect() # doctest: +SKIP >>> algo.step() # doctest: +SKIP .. _`PPO`: https://arxiv.org/pdf/1707.06347.pdf """ def __init__( self, env_cls: EnvFactory, /, config: None | AlgorithmConfig = None ) -> None: config = config or AlgorithmConfig() device = str( "cuda" if torch.cuda.is_available() else "cpu" if config.device == "auto" else config.device ) max_num_envs = ( env_cls.max_num_envs if hasattr(env_cls, "max_num_envs") else config.num_envs ) num_envs = min(config.num_envs, max_num_envs) max_horizon = ( env_cls.max_horizon if hasattr(env_cls, "max_horizon") else 1_000_000 ) horizon = min(config.horizon, max_horizon) self.env = env_cls(num_envs, horizon, device=device) assert_nd_spec(self.env.observation_spec) assert_nd_spec(self.env.action_spec) self.policy = Policy( self.env.observation_spec, self.env.action_spec, model=config.model, model_cls=config.model_cls, model_config=config.model_config, distribution_cls=config.distribution_cls, device=device, ) self.buffer_spec = Composite( { DataKeys.OBS: self.env.observation_spec, DataKeys.REWARDS: Unbounded(1, device=device), DataKeys.ACTIONS: self.env.action_spec, DataKeys.LOGP: Unbounded(1, device=device), DataKeys.VALUES: Unbounded(1, device=device), DataKeys.ADVANTAGES: Unbounded(1, device=device), DataKeys.RETURNS: Unbounded(1, device=device), }, ) if config.normalize_rewards: self.buffer_spec.set( DataKeys.REVERSED_DISCOUNTED_RETURNS, Unbounded(1, device=device), ) self.buffer_spec = self.buffer_spec.to(device) self.buffer = self.buffer_spec.zero([num_envs, horizon + 1]) optimizer_config = config.optimizer_config or {"lr": 1e-3} optimizer = config.optimizer_cls( self.policy.model.parameters(), **optimizer_config ) self.lr_scheduler = LRScheduler( optimizer, schedule=config.lr_schedule, kind=config.lr_schedule_kind, ) self.entropy_scheduler = EntropyScheduler( config.entropy_coeff, schedule=config.entropy_coeff_schedule, kind=config.entropy_coeff_schedule_kind, ) sgd_minibatch_size = ( config.sgd_minibatch_size if config.sgd_minibatch_size else num_envs * horizon ) self.hparams = AlgorithmHparams( accumulate_grads=config.accumulate_grads, clip_param=config.clip_param, device=device, dual_clip_param=config.dual_clip_param, enable_amp=config.enable_amp, gae_lambda=config.gae_lambda, gamma=config.gamma, horizon=horizon, horizons_per_env_reset=config.horizons_per_env_reset, max_grad_norm=config.max_grad_norm, normalize_advantages=config.normalize_advantages, normalize_rewards=config.normalize_rewards, num_envs=num_envs, num_sgd_iters=config.num_sgd_iters, sgd_minibatch_size=sgd_minibatch_size, shuffle_minibatches=config.shuffle_minibatches, target_kl_div=config.target_kl_div, vf_clip_param=config.vf_clip_param, vf_coeff=config.vf_coeff, ).validate() self.state = AlgorithmState() self.optimizer = optimizer self.grad_scaler = GradScaler(device=device, enabled=config.enable_amp)
[docs] def collect( self, *, env_config: None | dict[str, Any] = None, deterministic: bool = False, ) -> CollectStats: """Collect environment transitions and policy samples in a buffer. This is one of the main :class:`Algorithm` methods. This is usually called immediately prior to :meth:`Algorithm.step` to collect experiences used for learning. The environment is reset immediately prior to collecting transitions according to ``horizons_per_env_reset``. If the environment isn't reset, then the last observation is used as the initial observation. This method sets the ``buffered`` flag to enable calling of :meth:`Algorithm.step` so it isn't called with dummy data. Args: env_config: Optional config to pass to the environment's reset method. This isn't used if the environment isn't scheduled to be reset according to ``horizons_per_env_reset``. deterministic: Whether to sample from the policy deterministically. This is usally ``False`` during learning and ``True`` during evaluation. Returns: Summary statistics related to the collected experiences and policy samples. """ with profile_ms() as collect_timer: # Gather initial observation. env_was_reset = False if self.state.horizons and self.hparams.horizons_per_env_reset < 0: self.buffer[DataKeys.OBS][:, 0, ...] = self.buffer[DataKeys.OBS][ :, -1, ... ] if self.hparams.normalize_rewards: self.buffer[DataKeys.REVERSED_DISCOUNTED_RETURNS][ :, 0, ... ] = self.buffer[DataKeys.REVERSED_DISCOUNTED_RETURNS][:, -1, ...] elif not (self.state.horizons % self.hparams.horizons_per_env_reset): self.buffer[DataKeys.OBS][:, 0, ...] = self.env.reset(config=env_config) env_was_reset = True if self.hparams.normalize_rewards: self.buffer[DataKeys.REVERSED_DISCOUNTED_RETURNS][:, 0, ...] = 0.0 else: self.buffer[DataKeys.OBS][:, 0, ...] = self.buffer[DataKeys.OBS][ :, -1, ... ] if self.hparams.normalize_rewards: self.buffer[DataKeys.REVERSED_DISCOUNTED_RETURNS][ :, 0, ... ] = self.buffer[DataKeys.REVERSED_DISCOUNTED_RETURNS][:, -1, ...] for t in range(self.hparams.horizon): # Sample the policy and step the environment. in_batch = self.buffer[:, : (t + 1), ...] sample_batch = self.policy.sample( in_batch, kind="last", deterministic=deterministic, inplace=False, requires_grad=False, return_actions=True, return_logp=True, return_values=True, return_views=False, ) out_batch = self.env.step(sample_batch[DataKeys.ACTIONS]) # Getting reversed discounted returns for normalizing reward # scale during GAE. This isn't exactly correct according to # theory but works in practice. if self.hparams.normalize_rewards: self.buffer[DataKeys.REVERSED_DISCOUNTED_RETURNS][:, t + 1, ...] = ( self.hparams.gamma * self.buffer[DataKeys.REVERSED_DISCOUNTED_RETURNS][:, t, ...] + out_batch[DataKeys.REWARDS] ) # Update the buffer using sampled policy data and environment # transition data. self.buffer[DataKeys.ACTIONS][:, t, ...] = sample_batch[ DataKeys.ACTIONS ] self.buffer[DataKeys.LOGP][:, t, ...] = sample_batch[DataKeys.LOGP] self.buffer[DataKeys.VALUES][:, t, ...] = sample_batch[DataKeys.VALUES] self.buffer[DataKeys.REWARDS][:, t, ...] = out_batch[DataKeys.REWARDS] self.buffer[DataKeys.OBS][:, t + 1, ...] = out_batch[DataKeys.OBS] # Sample features and value function at last observation. in_batch = self.buffer[:, :, ...] sample_batch = self.policy.sample( in_batch, kind="last", deterministic=deterministic, inplace=False, requires_grad=False, return_actions=False, return_logp=False, return_values=True, return_views=False, ) self.buffer[DataKeys.VALUES][:, -1, ...] = sample_batch[DataKeys.VALUES] # Aggregate some metrics. rewards = self.buffer[DataKeys.REWARDS][:, :-1, ...] returns = torch.sum(rewards, dim=1) returns_std, returns_mean = torch.std_mean(returns) rewards_std, rewards_mean = torch.std_mean(rewards) collect_stats: CollectStats = { "returns/min": float(torch.min(returns)), "returns/max": float(torch.max(returns)), "returns/mean": float(returns_mean), "returns/std": float(returns_std), "rewards/min": float(torch.min(rewards)), "rewards/max": float(torch.max(rewards)), "rewards/mean": float(rewards_mean), "rewards/std": float(rewards_std), } self.state.horizons += 1 self.state.buffered = True self.state.reward_scale = ( float( torch.std( self.buffer[DataKeys.REVERSED_DISCOUNTED_RETURNS][:, 1:, ...] ) ) if self.hparams.normalize_rewards else 1.0 ) collect_stats["env/resets"] = self.hparams.num_envs * int(env_was_reset) collect_stats["env/steps"] = self.hparams.num_envs * self.hparams.horizon collect_stats["profiling/collect_ms"] = collect_timer() return collect_stats
[docs] def step(self) -> StepStats: """Take a step with the algorithm, using collected environment experiences to update the policy. Returns: Data associated with the step (losses, loss coefficients, etc.). """ if not self.state.buffered: raise RuntimeError( f"{self.__class__.__name__} is not buffered. " "Call `collect` once prior to `step`." ) with profile_ms() as step_timer: # Generalized Advantage Estimation (GAE) and returns bootstrapping. self.buffer = generalized_advantage_estimate( self.buffer, gae_lambda=self.hparams.gae_lambda, gamma=self.hparams.gamma, inplace=True, normalize_advantages=self.hparams.normalize_advantages, return_returns=True, reward_scale=self.state.reward_scale, ) # Batchify the buffer. Save the last sample for adding it back to the # buffer. Remove the last sample afterwards since it contains dummy # data. final_obs = self.buffer[DataKeys.OBS][:, -1, ...] self.buffer = self.buffer[:, :-1, ...] views = self.policy.model.apply_view_requirements(self.buffer, kind="all") # Free buffer elements that aren't used for the rest of the step. del self.buffer[DataKeys.OBS] del self.buffer[DataKeys.REWARDS] del self.buffer[DataKeys.VALUES] self.buffer = self.buffer.reshape(-1) self.buffer[DataKeys.VIEWS] = views # Main PPO loop. grad_accumulation_steps = ( self.hparams.num_minibatches if self.hparams.accumulate_grads else 1 ) stat_tracker = StatTracker( [ "coefficients/entropy", "coefficients/vf", "losses/entropy", "losses/policy", "losses/vf", "losses/total", "monitors/kl_div", ], sum_keys=[ "losses/entropy", "losses/policy", "losses/vf", "losses/total", "monitors/kl_div", ], ) batcher = Batcher( self.buffer, batch_size=self.hparams.sgd_minibatch_size, shuffle=self.hparams.shuffle_minibatches, ) stop_early = False for _ in range(self.hparams.num_sgd_iters): for i, buffer_batch in enumerate(batcher): step_this_batch = (i + 1) % grad_accumulation_steps == 0 with amp.autocast( self.hparams.device_type, enabled=self.hparams.enable_amp, ): sample_batch = self.policy.sample( buffer_batch, kind="all", deterministic=False, inplace=False, requires_grad=True, return_actions=False, return_logp=False, return_values=True, return_views=False, ) # Get action distributions and their log probability ratios. curr_action_dist = self.policy.distribution_cls( sample_batch[DataKeys.FEATURES], self.policy.model ) losses = ppo_losses( buffer_batch, sample_batch, curr_action_dist, clip_param=self.hparams.clip_param, dual_clip_param=self.hparams.dual_clip_param, entropy_coeff=self.entropy_scheduler.coeff, vf_clip_param=self.hparams.vf_clip_param, vf_coeff=self.hparams.vf_coeff, ) losses = losses.apply(lambda x: x / grad_accumulation_steps) # Calculate approximate KL divergence for early-stopping and # debugging. Early-stopping is per-batch and can't be done with # gradient accumulation (hence approximate KL isn't compared to # target KL with the number of gradient accumulation steps # factor). with torch.no_grad(): logp_ratio = ( curr_action_dist.logp(buffer_batch[DataKeys.ACTIONS]) - buffer_batch[DataKeys.LOGP] ) approximate_kl_div = float( torch.mean((torch.exp(logp_ratio) - 1) - logp_ratio) ) # Update step data. stat_tracker.update( { "coefficients/entropy": self.entropy_scheduler.coeff, "coefficients/vf": self.hparams.vf_coeff, "losses/entropy": float(losses["entropy"].detach()), "losses/policy": float(losses["policy"].detach()), "losses/vf": float(losses["vf"].detach()), "losses/total": float(losses["total"].detach()), "monitors/kl_div": approximate_kl_div / grad_accumulation_steps, }, reduce=step_this_batch, ) # Early stopping using approximate KL divergence. if ( self.hparams.target_kl_div is not None and approximate_kl_div > 1.5 * self.hparams.target_kl_div ): stop_early = True break # Optimize. self.grad_scaler.scale(losses["total"]).backward() if step_this_batch: self.grad_scaler.unscale_(self.optimizer) nn.utils.clip_grad_norm_( self.policy.model.parameters(), self.hparams.max_grad_norm ) self.grad_scaler.step(self.optimizer) self.grad_scaler.update() self.optimizer.zero_grad() if stop_early: break # Update schedulers. self.lr_scheduler.step(self.hparams.num_envs * self.state.horizons) self.entropy_scheduler.step(self.hparams.num_envs * self.state.horizons) # Reset the buffer and buffered flag. self.buffer = self.buffer_spec.zero( [ self.hparams.num_envs, self.hparams.horizon + 1, ] ) self.buffer[DataKeys.OBS][:, -1, ...] = final_obs self.state.buffered = False # Update algo stats. step_stats = stat_tracker.items() step_stats["profiling/step_ms"] = step_timer() return step_stats # type: ignore[return-value]
[docs] def validate(self) -> None: """Do some validation on all the tensor/tensordict shapes within the algorithm. Helpful when the algorithm is throwing an error on mismatched tensor/tensordict sizes. Call this at least once before running the algorithm for peace of mind. """ # Check initial observation. obs = self.env.reset() self.env.observation_spec.assert_is_in(obs) try: self.buffer[DataKeys.OBS][:, 0, ...] = obs except RuntimeError as e: raise AssertionError( f"The observation from {self.env.reset.__qualname__} doesn't match the" " observation spec shape." ) from e # Sample the policy and check all outputs. in_batch = self.buffer[:, :1, ...] sample_batch = self.policy.sample( in_batch, kind="last", deterministic=False, inplace=False, requires_grad=False, return_actions=True, return_logp=True, return_values=True, ) actions = sample_batch[DataKeys.ACTIONS] assert actions.ndim >= 2, ( "Actions must be at least 2D and have shape ``[N, ...]`` (where ``N`` is" " the number of independent elements or environment instances, and ``...``" " is any number of additional dimensions)." ) self.env.action_spec.assert_is_in(actions) try: self.buffer[DataKeys.ACTIONS][:, 0, ...] = sample_batch[DataKeys.ACTIONS] except RuntimeError as e: raise AssertionError( "The action sampled from the policy doesn't match the action spec." ) from e assert sample_batch[DataKeys.LOGP].shape == torch.Size( [self.hparams.num_envs, 1] ), ( "Action log probabilities must be 2D and have shape ``[N, 1]`` (where ``N``" " is the number of independent elements or environment instances)." ) assert sample_batch[DataKeys.VALUES].shape == torch.Size( [self.hparams.num_envs, 1] ), ( "Expected value estimates must be 2D and have shape ``[N, 1]`` (where ``N``" " is the number of independent elements or environment instances)." ) # Step the environment and check everything once more. out_batch = self.env.step(actions) obs = out_batch[DataKeys.OBS] self.env.observation_spec.assert_is_in(obs) try: self.buffer[DataKeys.OBS][:, 1, ...] = obs except RuntimeError as e: raise AssertionError( f"The observation from {self.env.step.__qualname__} doesn't match the" " observation spec shape." ) from e assert out_batch[DataKeys.REWARDS].shape == torch.Size( [self.hparams.num_envs, 1] ), ( "Rewards must be 2D and have shape ``[N, 1]`` (where ``N`` is the number of" " independent elements or environment instances)." )