fishyrl.losses module#

Utility loss functions for reinforcement learning agents.

fishyrl.losses.mse_loss(prior: Tensor, posterior: Tensor, dims: int = 1) Tensor#

Compute the mean squared error between the prior and posterior distributions on the final dimension.

Parameters:
  • prior (torch.Tensor) – The prior distribution.

  • posterior (torch.Tensor) – The posterior distribution.

  • dims (int) – The number of final dimensions to compute the loss over. (Default: 1)

Returns:

The mean squared error loss.

Return type:

torch.Tensor