fishyrl.models module#

Utility models and layers for construction of reinforcement learning agents.

class fishyrl.models.Actor(input_dim: int, actions: list[Action], hidden_dim: int = 512)#

Bases: Module

Actor network.

Pulls action heads from a latent representation and samples.

__init__(input_dim: int, actions: list[Action], hidden_dim: int = 512) None#

Initialize the actor network.

Parameters:
  • input_dim (int) – The dimension of the input vector.

  • actions (list[fishyrl.actions.Action]) – A list of action definitions, can be continuous or discrete.

  • hidden_dim (int) – The dimension of the hidden layers. (Default: 512)

forward(x: Tensor) tuple[Tensor, Tensor]#

Perform a forward pass through the actor network.

Parameters:

x (torch.Tensor) – The input tensor of shape (batch_dim, input_dim).

Returns:

A tuple containing the sampled actions and their distributions.

Return type:

tuple[torch.Tensor, torch.Tensor]

class fishyrl.models.CNNDecoder(output_channels: int, latent_dim: int)#

Bases: Module

CNN decoder for reconstructing image observations from latent representations.

Uses transposed convolutions to upsample from 4x4 to 64x64, with interleaved layer normalization and SiLU activations.

__init__(output_channels: int, latent_dim: int) None#

Initialize the CNN decoder.

Parameters:
  • output_channels (int) – The number of output channels in the reconstructed image.

  • latent_dim (int) – The size of the latent representation.

forward(x: Tensor) Tensor#

Perform a forward pass through the CNN decoder.

Parameters:

x (torch.Tensor) – The input tensor of shape (batch_dim, latent_dim).

Returns:

The output tensor of shape (batch_dim, channels, height, width).

Return type:

torch.Tensor

class fishyrl.models.CNNEncoder(input_channels: int, image_dim: tuple[int, int] = (64, 64))#

Bases: Module

CNN encoder for processing image observations.

Processes a (generally) 64x64 image to 4x4 using [2, 4, 8, 16] channels and interleaved layer normalization and SiLU activations.

__init__(input_channels: int, image_dim: tuple[int, int] = (64, 64)) None#

Initialize the CNN encoder.

Parameters:
  • input_channels (int) – The number of input channels in the image.

  • image_dim (tuple[int, int]) – The size of the input image. (Default: (64, 64))

forward(x: Tensor) Tensor#

Perform a forward pass through the CNN encoder.

Parameters:

x (torch.Tensor) – The input tensor of shape (batch_dim, channels, height, width).

Returns:

The output tensor of shape (batch_dim, output_dim, height, width).

Return type:

torch.Tensor

property output_dim: int#

Output dimension of the CNN encoder.

Type:

int

class fishyrl.models.ChannelNorm(*args: list[Any], **kwargs: dict[str, Any])#

Bases: LayerNorm

Layer normalization across the channel dimension.

__init__(*args: list[Any], **kwargs: dict[str, Any]) None#

Initialize the ChannelNorm layer.

Parameters:
  • args (list[Any]) – Positional arguments for the LayerNorm.

  • kwargs (dict[str, Any]) – Keyword arguments for the LayerNorm.

forward(x: Tensor) Tensor#

Apply layer normalization across the channel dimension.

Swaps the channel dimension to the end, applies layer normalization, then swaps it back.

Parameters:

x (torch.Tensor) – The input tensor of shape (batch_dim, channels, height, width).

Returns:

The normalized tensor of the same shape.

Return type:

torch.Tensor

class fishyrl.models.LayerNormGRU(input_dim: int, hidden_dim: int)#

Bases: Module

GRU layer with internal layer normalization.

__init__(input_dim: int, hidden_dim: int) None#

Initialize the LayerNormGRU.

Parameters:
  • input_dim (int) – The size of the input tensor.

  • hidden_dim (int) – The size of the hidden state.

forward(x: Tensor, h: Tensor | None = None) Tensor#

Perform a forward pass through the LayerNormGRU.

Applies layer normalization to the hidden state after the GRU computation.

Parameters:
  • x (torch.Tensor) – The input tensor of shape (batch_dim, input_dim).

  • h (torch.Tensor | None) – The initial hidden state of shape (batch_dim, hidden_dim), or None to use zeros. (Default: None)

Returns:

The final hidden state of shape (batch_dim, hidden_dim).

Return type:

torch.Tensor

class fishyrl.models.MLP(input_dim: int, output_dim: int | None = None, hidden_dims: list[int] = [])#

Bases: Module

MLP for processing vector inputs.

__init__(input_dim: int, output_dim: int | None = None, hidden_dims: list[int] = []) None#

Initialize the MLP.

Parameters:
  • input_dim (int) – The dimension of the input vector.

  • output_dim (int | None) – The dimension of the output vector. If provided, a final hidden layer will be added with no normalization or activation

  • hidden_dims (list[int]) – The dimensions of the hidden layers.

forward(x: Tensor) Tensor#

Perform a forward pass through the MLP.

Parameters:

x (torch.Tensor) – The input tensor of shape (batch_dim, input_dim).

Returns:

The output tensor of shape (batch_dim, hidden_dims[-1]).

Return type:

torch.Tensor

class fishyrl.models.MLPDecoder(input_dim: int, output_dim: int, num_blocks: int = 4, hidden_dim: int = 512)#

Bases: Module

MLP decoder for reconstructing vector observations from latent representations.

__init__(input_dim: int, output_dim: int, num_blocks: int = 4, hidden_dim: int = 512) None#

Initialize the MLP decoder.

Parameters:
  • input_dim (int) – The dimension of the input tensor.

  • output_dim (int) – The dimension of the output vector observations.

  • num_blocks (int) – The number of blocks in the MLP. (Default: 4)

  • hidden_dim (int) – The dimension of the hidden layers. (Default: 512)

forward(x: Tensor) Tensor#

Perform a forward pass through the MLP decoder.

Parameters:

x (torch.Tensor) – The input tensor of shape (batch_dim, hidden_dim).

Returns:

The output tensor of shape (batch_dim, output_dim).

Return type:

torch.Tensor

class fishyrl.models.MLPEncoder(input_dim: int, output_dim: int, num_blocks: int = 4, hidden_dim: int = 512, use_symlog: bool = True)#

Bases: Module

MLP encoder for processing vector observations.

__init__(input_dim: int, output_dim: int, num_blocks: int = 4, hidden_dim: int = 512, use_symlog: bool = True) None#

Initialize the MLP encoder.

Parameters:
  • input_dim (int) – The dimension of the input vector observation.

  • output_dim (int) – The dimension of the output vector.

  • num_blocks (int) – The number of blocks in the MLP. (Default: 4)

  • hidden_dim (int) – The dimension of the hidden layers. (Default: 512)

  • use_symlog (bool) – Whether to apply symlog transformation to the input. (Default: True)

forward(x: Tensor) Tensor#

Perform a forward pass through the MLP encoder.

Parameters:

x (torch.Tensor) – The input tensor of shape (batch_dim, input_dim).

Returns:

The output tensor of shape (batch_dim, hidden_dim).

Return type:

torch.Tensor

class fishyrl.models.RSSM(recurrent_model: Module, representation_model: Module, transition_model: Module, bins: int = 32, learnable_initial_state: bool = True)#

Bases: Module

Recurrent state-space model for modeling environment dynamics.

The RSSM consists of a recurrent model for tracking the hidden state, a representation model for inferring the stochastic state from the hidden state and observation, and a transition model for predicting the stochastic state from the hidden state.

__init__(recurrent_model: Module, representation_model: Module, transition_model: Module, bins: int = 32, learnable_initial_state: bool = True) None#

Initialize the RSSM.

Parameters:
  • recurrent_model (nn.Module) – The recurrent model for keeping track of the environment dynamics.

  • representation_model (nn.Module) – The model for inferring the stochastic state from the hidden state and observation.

  • transition_model (nn.Module) – The model for predicting the stochastic state from the hidden state.

  • bins (int) – The number of bins for the stochastic state. (Default: 32)

forward(action: Tensor | None = None, posterior: Tensor | None = None, hidden_state: Tensor | None = None, embedded_obs: Tensor | None = None, initialize: Tensor | None = None, batch_dim: int | None = None) dict[str, Tensor]#

Perform one step of the RSSM.

Will compute the hidden state using action, posterior, and previous hidden state. If not available, will instead use the initial hidden state. If embedded_obs is not provided, imagines the next hidden state.

Parameters:
  • action (torch.Tensor | None) – The action taken, of shape (batch_dim, action_dim).

  • posterior (torch.Tensor | None) – The posterior stochastic state from the previous step, of shape (batch_dim, latent_dim).

  • hidden_state (torch.Tensor | None) – The initial hidden state of the recurrent model, of shape (batch_dim, hidden_dim).

  • embedded_obs (torch.Tensor | None) – The embedded observation, of shape (batch_dim, obs_dim), or None if imagining.

  • initialize (torch.Tensor | None) – Boolean tensor of hidden states to initialize of shape (batch_dim). (Default: None)

  • batch_dim (int | None) – The batch dimension, inferred if not provided. (Default: None)

Returns:

A dictionary containing the prior and posterior logits and samples, and the updated hidden state.

Return type:

dict[str, torch.Tensor]

infer_stochastic(hidden_state: Tensor, embedded_obs: Tensor | None = None) tuple[Tensor, Tensor]#

Infer the stochastic state.

Infer the stochastic state from the hidden state (using the transition model) or from the hidden state and the embedded observation (using the representation model).

Parameters:
  • hidden_state (torch.Tensor) – The hidden state from the recurrent model, of shape (batch_dim, hidden_dim).

  • embedded_obs (torch.Tensor | None) – The embedded observation, of shape (batch_dim, obs_dim). (Default: None)

Returns:

The logits and sampled stochastic state.

Return type:

tuple[torch.Tensor, torch.Tensor]

property initial_hidden_state: Tensor#

Trainable initial hidden state of the recurrent model.

Type:

torch.Tensor

class fishyrl.models.RecurrentModel(input_dim: int, hidden_dim: int)#

Bases: Module

Recurrent model for processing sequences of latent representations.

Uses a GRU to process sequences of latent representations, with layer normalization and SiLU activations on the output.

__init__(input_dim: int, hidden_dim: int) None#

Initialize the recurrent model.

Parameters:
  • input_dim (int) – The size of the input tensor.

  • hidden_dim (int) – The size of the hidden state in the GRU.

forward(x: Tensor, h: Tensor | None = None) Tensor#

Perform a forward pass through the recurrent model.

Parameters:
  • x (torch.Tensor) – The input tensor of shape (batch_dim, seq_len, latent_dim).

  • h (torch.Tensor | None) – The initial hidden state of shape (1, batch_dim, hidden_dim), or None to use zeros. (Default: None)

Returns:

The final hidden state of shape (batch_dim, hidden_dim).

Return type:

torch.Tensor

property hidden_dim: int#

Hidden size of the recurrent model.

Type:

int