from abc import ABCMeta, abstractmethod
from dataclasses import asdict
from typing import Any, Generic, TypeVar
import torch.optim as optim
from tensordict import TensorDict
from torch.cuda.amp.grad_scaler import GradScaler
from torchrl.data import CompositeSpec
from .._utils import memory_stats
from ..data import (
AlgorithmHparams,
AlgorithmState,
CollectStats,
MemoryStats,
StepStats,
)
from ..env import Env
from ..policies import GenericPolicyBase
from ..schedulers import EntropyScheduler, LRScheduler
_AlgorithmHparams = TypeVar("_AlgorithmHparams", bound=AlgorithmHparams)
_AlgorithmState = TypeVar("_AlgorithmState", bound=AlgorithmState)
_Policy = TypeVar("_Policy", bound=GenericPolicyBase[Any])
[docs]class GenericAlgorithmBase(
Generic[_AlgorithmHparams, _AlgorithmState, _Policy], metaclass=ABCMeta
):
"""The base class for PPO algorithm flavors."""
#: Environment experience buffer used for aggregating environment
#: transition data and policy sample data. The same buffer object
#: is shared whenever using :meth:`GenericAlgorithmBase.collect`. Buffer
#: dimensions are determined by ``num_envs`` and ``horizon`` args.
buffer: TensorDict
#: Tensor spec defining the environment experience buffer components
#: and dimensions. Used for instantiating :attr:`GenericAlgorithmBase.buffer`
#: at :class:`GenericAlgorithmBase` instantiation and each
#: :meth:`GenericAlgorithmBase.step` call.
buffer_spec: CompositeSpec
#: Entropy scheduler for updating the ``entropy_coeff`` after each
#: :meth:`GenericAlgorithmBase.step` call based on the number environment
#: transitions collected and learned on. By default, the entropy scheduler
#: does not actually update the entropy coefficient. The entropy scheduler
#: only updates the entropy coefficient if an ``entropy_coeff_schedule`` is
#: provided.
entropy_scheduler: EntropyScheduler
#: Environment used for experience collection within the
#: :meth:`GenericAlgorithmBase.collect` method. It's ultimately up to the
#: environment to make learning efficient by parallelizing simulations.
env: Env
#: Used for enabling Automatic Mixed Precision (AMP). Handles gradient
#: scaling for the optimizer. Not all optimizers and hyperparameters are
#: compatible with gradient scaling.
grad_scaler: GradScaler
#: PPO hyperparameters that're constant throughout training
#: and can drastically affect training performance.
hparams: _AlgorithmHparams
#: Learning rate scheduler for updating ``optimizer`` learning rate after
#: each ``step`` call based on the number of environment transitions
#: collected and learned on. By default, the learning scheduler does not
#: actually alter the ``optimizer`` learning rate (it actually leaves it
#: constant). The learning rate scheduler only alters the learning rate
#: if a ``learning_rate_schedule`` is provided.
lr_scheduler: LRScheduler
#: Underlying optimizer for updating the policy's model parameters.
#: Instantiated from an ``optimizer_cls`` and ``optimizer_config``.
#: Defaults to the Adam optimizer with generally well-performing parameters.
optimizer: optim.Optimizer
#: Policy constructed from the ``model_cls``, ``model_config``, and
#: ``distribution_cls`` kwargs. A default policy is constructed according to
#: the environment's observation and action specs if these policy args
#: aren't provided. The policy is what does all the action sampling
#: within :meth:`GenericAlgorithmBase.collect` and is what is updated within
#: :meth:`GenericAlgorithmBase.step`.
policy: _Policy
#: Algorithm state for determining when to reset the environment, when
#: the policy can be updated, etc..
state: _AlgorithmState
[docs] @abstractmethod
def collect(
self,
*,
env_config: None | dict[str, Any] = None,
deterministic: bool = False,
) -> CollectStats:
"""Collect environment transitions and policy samples in a buffer.
This is one of the main :class:`GenericAlgorithmBase` methods. This is
usually called immediately prior to :meth:`GenericAlgorithmBase.step`
to collect experiences used for learning.
The environment is reset immediately prior to collecting
transitions according to ``horizons_per_env_reset``. If
the environment isn't reset, then the last observation is used as
the initial observation.
This method sets the ``buffered`` flag to enable calling
of :meth:`GenericAlgorithmBase.step` so it isn't called with dummy data.
Args:
env_config: Optional config to pass to the environment's reset
method. This isn't used if the environment isn't scheduled
to be reset according to ``horizons_per_env_reset``.
deterministic: Whether to sample from the policy deterministically.
This is usally ``False`` during learning and ``True`` during
evaluation.
Returns:
Summary statistics related to the collected experiences and
policy samples.
"""
@property
def horizons_per_env_reset(self) -> int:
"""Number of times :meth:`GenericAlgorithmBase.collect` can be
called before resetting :attr:`GenericAlgorithmBase.env`.
"""
return self.hparams.horizons_per_env_reset
[docs] def memory_stats(self) -> MemoryStats:
"""Return current algorithm memory usage."""
return memory_stats(self.hparams.device_type)
@property
def params(self) -> dict[str, Any]:
"""Return algorithm parameters."""
return {
"env_cls": self.env.__class__.__name__,
"model_cls": self.policy.model.__class__.__name__,
"distribution_cls": self.policy.distribution_cls.__name__,
"optimizer_cls": self.optimizer.__class__.__name__,
"entropy_coeff": self.entropy_scheduler.coeff,
**asdict(self.hparams),
}
[docs] @abstractmethod
def step(self) -> StepStats:
"""Take a step with the algorithm, using collected environment
experiences to update the policy.
Returns:
Data associated with the step (losses, loss coefficients, etc.).
"""