fishyrl.distributions module#

Utility distributions for reinforcement learning agents.

class fishyrl.distributions.TwoHot(logits: Tensor, bins: int = None, low: float = -20.0, high: float = 20.0, pre_func: callable = <function symlog>, post_func: callable = <function symexp>, event_dims: int = 1, eps: float = 1e-05)#

Bases: object

Two hot distribution as described in Dreamer-V3.

Takes a binned input tensor and computes log probabilities as if the queried value is a linear interpolation between two bins.

__init__(logits: Tensor, bins: int = None, low: float = -20.0, high: float = 20.0, pre_func: callable = <function symlog>, post_func: callable = <function symexp>, event_dims: int = 1, eps: float = 1e-05) None#

Create a TwoHot distribution from the input logits.

Parameters:
  • logits (torch.Tensor) – The input logits to create the distribution from, of shape (…, bins).

  • bins (int) – The number of bins for the distribution. Defaults to the final dimension of logits.

  • low (float) – The lower bound of the distribution. (Default: -20.0)

  • high (float) – The upper bound of the distribution. (Default: 20.0)

  • pre_func (callable) – A function to apply to the input value before computing log probabilities. (Default: symlog)

  • post_func (callable) – A function to apply to the mean of the distribution before returning. (Default: symexp)

  • eps (float) – A small value to add when computing entropy to avoid numerical issues. (Default: 1e-8)

entropy() Tensor#

Compute the entropy of the distribution.

Returns:

The entropy of the distribution, of the same shape as the mean without extra event dimensions.

Return type:

torch.Tensor

log_prob(value: Tensor) Tensor#

Compute the log probability of the given value under the distribution.

Parameters:

value (torch.Tensor) – The value to compute the log probability for, can be of any shape.

Returns:

The log probability of the value under the distribution, of the same shape as value.

Return type:

torch.Tensor

rsample() Tensor#

Sample by randomly sampling two bins and interpolating. Still WIP, and might be numerically unstable.

Returns:

A sample from the distribution, matching the shape of mean.

Return type:

torch.Tensor

property mean: Tensor#

Compute the mean of the distribution as a weighted sum of bin values.

Returns:

The mean of the distribution.

Return type:

torch.Tensor

fishyrl.distributions.identity(x: Tensor) Tensor#

Identity function, returns the input tensor unchanged.

Parameters:

x (torch.Tensor) – The input tensor.

Returns:

The same input tensor.

Return type:

torch.Tensor

fishyrl.distributions.symexp(x: Tensor) Tensor#

Apply the symmetric exponential transformation to the input tensor.

Parameters:

x (torch.Tensor) – The input tensor to transform.

Returns:

The transformed tensor.

Return type:

torch.Tensor

fishyrl.distributions.symlog(x: Tensor) Tensor#

Apply the symmetric logarithm transformation to the input tensor.

Parameters:

x (torch.Tensor) – The input tensor to transform.

Returns:

The transformed tensor.

Return type:

torch.Tensor

fishyrl.distributions.uniform_mix(logits: Tensor, ratio: float = 0.01) tuple[Tensor, Tensor]#

Mix the input logits with a uniform distribution on the final dimension.

Parameters:
  • logits (torch.Tensor) – The input logits to mix, of shape (…, num_classes).

  • ratio (float) – The ratio of uniform distribution to mix with the input logits.

Returns:

The mixed logits, of shape (…, num_classes).

Return type:

torch.Tensor