fishyrl.distributions module#
Utility distributions for reinforcement learning agents.
- class fishyrl.distributions.TwoHot(logits: Tensor, bins: int = None, low: float = -20.0, high: float = 20.0, pre_func: callable = <function symlog>, post_func: callable = <function symexp>, event_dims: int = 1, eps: float = 1e-05)#
Bases:
objectTwo hot distribution as described in Dreamer-V3.
Takes a binned input tensor and computes log probabilities as if the queried value is a linear interpolation between two bins.
- __init__(logits: Tensor, bins: int = None, low: float = -20.0, high: float = 20.0, pre_func: callable = <function symlog>, post_func: callable = <function symexp>, event_dims: int = 1, eps: float = 1e-05) None#
Create a TwoHot distribution from the input logits.
- Parameters:
logits (torch.Tensor) – The input logits to create the distribution from, of shape (…, bins).
bins (int) – The number of bins for the distribution. Defaults to the final dimension of
logits.low (float) – The lower bound of the distribution. (Default:
-20.0)high (float) – The upper bound of the distribution. (Default:
20.0)pre_func (callable) – A function to apply to the input value before computing log probabilities. (Default:
symlog)post_func (callable) – A function to apply to the mean of the distribution before returning. (Default:
symexp)eps (float) – A small value to add when computing entropy to avoid numerical issues. (Default:
1e-8)
- entropy() Tensor#
Compute the entropy of the distribution.
- Returns:
The entropy of the distribution, of the same shape as the mean without extra event dimensions.
- Return type:
torch.Tensor
- log_prob(value: Tensor) Tensor#
Compute the log probability of the given value under the distribution.
- Parameters:
value (torch.Tensor) – The value to compute the log probability for, can be of any shape.
- Returns:
The log probability of the value under the distribution, of the same shape as
value.- Return type:
torch.Tensor
- rsample() Tensor#
Sample by randomly sampling two bins and interpolating. Still WIP, and might be numerically unstable.
- Returns:
A sample from the distribution, matching the shape of
mean.- Return type:
torch.Tensor
- property mean: Tensor#
Compute the mean of the distribution as a weighted sum of bin values.
- Returns:
The mean of the distribution.
- Return type:
torch.Tensor
- fishyrl.distributions.identity(x: Tensor) Tensor#
Identity function, returns the input tensor unchanged.
- Parameters:
x (torch.Tensor) – The input tensor.
- Returns:
The same input tensor.
- Return type:
torch.Tensor
- fishyrl.distributions.symexp(x: Tensor) Tensor#
Apply the symmetric exponential transformation to the input tensor.
- Parameters:
x (torch.Tensor) – The input tensor to transform.
- Returns:
The transformed tensor.
- Return type:
torch.Tensor
- fishyrl.distributions.symlog(x: Tensor) Tensor#
Apply the symmetric logarithm transformation to the input tensor.
- Parameters:
x (torch.Tensor) – The input tensor to transform.
- Returns:
The transformed tensor.
- Return type:
torch.Tensor
- fishyrl.distributions.uniform_mix(logits: Tensor, ratio: float = 0.01) tuple[Tensor, Tensor]#
Mix the input logits with a uniform distribution on the final dimension.
- Parameters:
logits (torch.Tensor) – The input logits to mix, of shape (…, num_classes).
ratio (float) – The ratio of uniform distribution to mix with the input logits.
- Returns:
The mixed logits, of shape (…, num_classes).
- Return type:
torch.Tensor