from abc import ABC, abstractmethod
from typing import Any, Generic, TypeVar
import torch
from tensordict import TensorDict
from torchrl.data import (
CompositeSpec,
DiscreteTensorSpec,
TensorSpec,
UnboundedContinuousTensorSpec,
)
from ._utils import assert_1d_spec
_ActionSpec = TypeVar("_ActionSpec", bound=TensorSpec)
_FeatureSpec = TypeVar("_FeatureSpec", bound=TensorSpec)
_TorchDistribution = TypeVar(
"_TorchDistribution", bound=torch.distributions.Distribution
)
[docs]class Distribution(ABC):
"""Policy component that defines a probability distribution over a
feature set from a model.
This definition is largely inspired by RLlib's `action distribution`_.
Most commonly, the feature set is a single vector of logits or log
probabilities used for defining and sampling from the probability
distribution. Custom probabiltiy distributions, however, are not
constrained to just a single vector.
Args:
features: Features from ``model``'s forward pass.
model: Model for parameterizing the probability distribution.
.. _`action distribution`: https://github.com/ray-project/ray/blob/master/rllib/models/action_dist.py
"""
#: Features from :attr:`Distribution.model` forward pass. Simple action
#: distributions expect one field and corresponding tensor in the
#: tensordict, but custom action distributions can return any kind of
#: tensordict from :attr:`Distribution.model`.
features: TensorDict
#: Model from the parent policy also passed to the action distribution.
#: This is necessary in case the model has components that're only
#: used for sampling or probability distribution characteristics
#: computations.
model: Any
def __init__(self, features: TensorDict, model: Any, /) -> None:
super().__init__()
self.features = features
self.model = model
[docs] @staticmethod
def default_dist_cls(action_spec: TensorSpec, /) -> type["Distribution"]:
"""Return a default distribution given an action spec.
Args:
action_spec: Spec defining required environment inputs.
Returns:
A distribution for simple, supported action specs.
"""
assert_1d_spec(action_spec)
match action_spec:
case DiscreteTensorSpec():
return Categorical
case UnboundedContinuousTensorSpec():
return Normal
case _:
raise TypeError(
f"Action spec {action_spec} has no default distribution support."
)
[docs] @abstractmethod
def deterministic_sample(self) -> torch.Tensor | TensorDict:
"""Draw a deterministic sample from the probability distribution."""
[docs] @abstractmethod
def entropy(self) -> torch.Tensor:
"""Compute the probability distribution's entropy (a measurement
of randomness).
"""
[docs] @abstractmethod
def logp(self, samples: torch.Tensor | TensorDict) -> torch.Tensor:
"""Compute the log probability of sampling `samples` from the probability
distribution.
"""
[docs] @abstractmethod
def sample(self) -> torch.Tensor | TensorDict:
"""Draw a stochastic sample from the probability distribution."""
[docs]class TorchDistributionWrapper(
Distribution, Generic[_FeatureSpec, _TorchDistribution, _ActionSpec]
):
"""Wrapper class for PyTorch distributions.
This is inspired by `RLlib`_.
.. _`RLlib`: https://github.com/ray-project/ray/blob/master/rllib/models/torch/torch_action_dist.py
"""
#: Underlying PyTorch distribution.
dist: _TorchDistribution
[docs] def deterministic_sample(self) -> torch.Tensor:
return self.dist.mode
[docs] def entropy(self) -> torch.Tensor:
return self.dist.entropy().sum(-1, keepdim=True)
[docs] def logp(self, samples: torch.Tensor) -> torch.Tensor:
return self.dist.log_prob(samples).sum(-1, keepdim=True)
[docs] def sample(self) -> torch.Tensor:
return self.dist.sample()
[docs]class Categorical(
TorchDistributionWrapper[
CompositeSpec, torch.distributions.Categorical, DiscreteTensorSpec
]
):
"""Wrapper around the PyTorch categorical (i.e., discrete) distribution."""
def __init__(self, features: TensorDict, model: Any, /) -> None:
super().__init__(features, model)
self.dist = torch.distributions.Categorical(logits=features["logits"])
[docs]class Normal(
TorchDistributionWrapper[
CompositeSpec, torch.distributions.Normal, UnboundedContinuousTensorSpec
]
):
"""Wrapper around the PyTorch normal (i.e., gaussian) distribution."""
def __init__(self, features: TensorDict, model: Any) -> None:
super().__init__(features, model)
self.dist = torch.distributions.Normal(
loc=features["mean"], scale=torch.exp(features["log_std"])
)
[docs]class SquashedNormal(Normal):
"""Squashed normal distribution such that samples are always within [-1, 1]."""
[docs] def deterministic_sample(self) -> torch.Tensor:
return super().deterministic_sample().tanh()
[docs] def entropy(self) -> torch.Tensor:
raise NotImplementedError(
f"Entropy isn't defined for {self.__class__.__name__}. Set the"
" entropy coefficient to `0` to avoid this error during training."
)
[docs] def logp(self, samples: torch.Tensor) -> torch.Tensor:
eps = torch.finfo(samples.dtype).eps
clipped_samples = samples.clamp(min=-1 + eps, max=1 - eps)
inverted_samples = 0.5 * (clipped_samples.log1p() - (-clipped_samples).log1p())
logp = torch.clamp(self.dist.log_prob(inverted_samples), min=-100, max=100).sum(
-1, keepdim=True
)
logp -= torch.sum(torch.log(1 - samples**2 + eps), dim=-1, keepdim=True)
return logp
[docs] def sample(self) -> torch.Tensor:
return super().sample().tanh()