fishyrl.models module#
Utility models and layers for construction of reinforcement learning agents.
- class fishyrl.models.Actor(input_dim: int, actions: list[Action], hidden_dim: int = 512)#
Bases:
ModuleActor network.
Pulls action heads from a latent representation and samples.
- __init__(input_dim: int, actions: list[Action], hidden_dim: int = 512) None#
Initialize the actor network.
- Parameters:
input_dim (int) – The dimension of the input vector.
actions (list[fishyrl.actions.Action]) – A list of action definitions, can be continuous or discrete.
hidden_dim (int) – The dimension of the hidden layers. (Default:
512)
- forward(x: Tensor) tuple[Tensor, Tensor]#
Perform a forward pass through the actor network.
- Parameters:
x (torch.Tensor) – The input tensor of shape (batch_dim, input_dim).
- Returns:
A tuple containing the sampled actions and their distributions.
- Return type:
tuple[torch.Tensor, torch.Tensor]
- class fishyrl.models.CNNDecoder(output_channels: int, latent_dim: int)#
Bases:
ModuleCNN decoder for reconstructing image observations from latent representations.
Uses transposed convolutions to upsample from 4x4 to 64x64, with interleaved layer normalization and SiLU activations.
- __init__(output_channels: int, latent_dim: int) None#
Initialize the CNN decoder.
- Parameters:
output_channels (int) – The number of output channels in the reconstructed image.
latent_dim (int) – The size of the latent representation.
- forward(x: Tensor) Tensor#
Perform a forward pass through the CNN decoder.
- Parameters:
x (torch.Tensor) – The input tensor of shape (batch_dim, latent_dim).
- Returns:
The output tensor of shape (batch_dim, channels, height, width).
- Return type:
torch.Tensor
- class fishyrl.models.CNNEncoder(input_channels: int, image_dim: tuple[int, int] = (64, 64))#
Bases:
ModuleCNN encoder for processing image observations.
Processes a (generally) 64x64 image to 4x4 using [2, 4, 8, 16] channels and interleaved layer normalization and SiLU activations.
- __init__(input_channels: int, image_dim: tuple[int, int] = (64, 64)) None#
Initialize the CNN encoder.
- Parameters:
input_channels (int) – The number of input channels in the image.
image_dim (tuple[int, int]) – The size of the input image. (Default:
(64, 64))
- forward(x: Tensor) Tensor#
Perform a forward pass through the CNN encoder.
- Parameters:
x (torch.Tensor) – The input tensor of shape (batch_dim, channels, height, width).
- Returns:
The output tensor of shape (batch_dim, output_dim, height, width).
- Return type:
torch.Tensor
- property output_dim: int#
Output dimension of the CNN encoder.
- Type:
int
- class fishyrl.models.ChannelNorm(*args: list[Any], **kwargs: dict[str, Any])#
Bases:
LayerNormLayer normalization across the channel dimension.
- __init__(*args: list[Any], **kwargs: dict[str, Any]) None#
Initialize the ChannelNorm layer.
- Parameters:
args (list[Any]) – Positional arguments for the LayerNorm.
kwargs (dict[str, Any]) – Keyword arguments for the LayerNorm.
- forward(x: Tensor) Tensor#
Apply layer normalization across the channel dimension.
Swaps the channel dimension to the end, applies layer normalization, then swaps it back.
- Parameters:
x (torch.Tensor) – The input tensor of shape (batch_dim, channels, height, width).
- Returns:
The normalized tensor of the same shape.
- Return type:
torch.Tensor
- class fishyrl.models.LayerNormGRU(input_dim: int, hidden_dim: int)#
Bases:
ModuleGRU layer with internal layer normalization.
- __init__(input_dim: int, hidden_dim: int) None#
Initialize the LayerNormGRU.
- Parameters:
input_dim (int) – The size of the input tensor.
hidden_dim (int) – The size of the hidden state.
- forward(x: Tensor, h: Tensor | None = None) Tensor#
Perform a forward pass through the LayerNormGRU.
Applies layer normalization to the hidden state after the GRU computation.
- Parameters:
x (torch.Tensor) – The input tensor of shape (batch_dim, input_dim).
h (torch.Tensor | None) – The initial hidden state of shape (batch_dim, hidden_dim), or None to use zeros. (Default:
None)
- Returns:
The final hidden state of shape (batch_dim, hidden_dim).
- Return type:
torch.Tensor
- class fishyrl.models.MLP(input_dim: int, output_dim: int | None = None, hidden_dims: list[int] = [])#
Bases:
ModuleMLP for processing vector inputs.
- __init__(input_dim: int, output_dim: int | None = None, hidden_dims: list[int] = []) None#
Initialize the MLP.
- Parameters:
input_dim (int) – The dimension of the input vector.
output_dim (int | None) – The dimension of the output vector. If provided, a final hidden layer will be added with no normalization or activation
hidden_dims (list[int]) – The dimensions of the hidden layers.
- forward(x: Tensor) Tensor#
Perform a forward pass through the MLP.
- Parameters:
x (torch.Tensor) – The input tensor of shape (batch_dim, input_dim).
- Returns:
The output tensor of shape (batch_dim, hidden_dims[-1]).
- Return type:
torch.Tensor
- class fishyrl.models.MLPDecoder(input_dim: int, output_dim: int, num_blocks: int = 4, hidden_dim: int = 512)#
Bases:
ModuleMLP decoder for reconstructing vector observations from latent representations.
- __init__(input_dim: int, output_dim: int, num_blocks: int = 4, hidden_dim: int = 512) None#
Initialize the MLP decoder.
- Parameters:
input_dim (int) – The dimension of the input tensor.
output_dim (int) – The dimension of the output vector observations.
num_blocks (int) – The number of blocks in the MLP. (Default:
4)hidden_dim (int) – The dimension of the hidden layers. (Default:
512)
- forward(x: Tensor) Tensor#
Perform a forward pass through the MLP decoder.
- Parameters:
x (torch.Tensor) – The input tensor of shape (batch_dim, hidden_dim).
- Returns:
The output tensor of shape (batch_dim, output_dim).
- Return type:
torch.Tensor
- class fishyrl.models.MLPEncoder(input_dim: int, output_dim: int, num_blocks: int = 4, hidden_dim: int = 512, use_symlog: bool = True)#
Bases:
ModuleMLP encoder for processing vector observations.
- __init__(input_dim: int, output_dim: int, num_blocks: int = 4, hidden_dim: int = 512, use_symlog: bool = True) None#
Initialize the MLP encoder.
- Parameters:
input_dim (int) – The dimension of the input vector observation.
output_dim (int) – The dimension of the output vector.
num_blocks (int) – The number of blocks in the MLP. (Default:
4)hidden_dim (int) – The dimension of the hidden layers. (Default:
512)use_symlog (bool) – Whether to apply symlog transformation to the input. (Default:
True)
- forward(x: Tensor) Tensor#
Perform a forward pass through the MLP encoder.
- Parameters:
x (torch.Tensor) – The input tensor of shape (batch_dim, input_dim).
- Returns:
The output tensor of shape (batch_dim, hidden_dim).
- Return type:
torch.Tensor
- class fishyrl.models.RSSM(recurrent_model: Module, representation_model: Module, transition_model: Module, bins: int = 32, learnable_initial_state: bool = True)#
Bases:
ModuleRecurrent state-space model for modeling environment dynamics.
The RSSM consists of a recurrent model for tracking the hidden state, a representation model for inferring the stochastic state from the hidden state and observation, and a transition model for predicting the stochastic state from the hidden state.
- __init__(recurrent_model: Module, representation_model: Module, transition_model: Module, bins: int = 32, learnable_initial_state: bool = True) None#
Initialize the RSSM.
- Parameters:
recurrent_model (nn.Module) – The recurrent model for keeping track of the environment dynamics.
representation_model (nn.Module) – The model for inferring the stochastic state from the hidden state and observation.
transition_model (nn.Module) – The model for predicting the stochastic state from the hidden state.
bins (int) – The number of bins for the stochastic state. (Default:
32)
- forward(action: Tensor | None = None, posterior: Tensor | None = None, hidden_state: Tensor | None = None, embedded_obs: Tensor | None = None, initialize: Tensor | None = None, batch_dim: int | None = None) dict[str, Tensor]#
Perform one step of the RSSM.
Will compute the hidden state using action, posterior, and previous hidden state. If not available, will instead use the initial hidden state. If
embedded_obsis not provided, imagines the next hidden state.- Parameters:
action (torch.Tensor | None) – The action taken, of shape (batch_dim, action_dim).
posterior (torch.Tensor | None) – The posterior stochastic state from the previous step, of shape (batch_dim, latent_dim).
hidden_state (torch.Tensor | None) – The initial hidden state of the recurrent model, of shape (batch_dim, hidden_dim).
embedded_obs (torch.Tensor | None) – The embedded observation, of shape (batch_dim, obs_dim), or None if imagining.
initialize (torch.Tensor | None) – Boolean tensor of hidden states to initialize of shape (batch_dim). (Default:
None)batch_dim (int | None) – The batch dimension, inferred if not provided. (Default:
None)
- Returns:
A dictionary containing the prior and posterior logits and samples, and the updated hidden state.
- Return type:
dict[str, torch.Tensor]
- infer_stochastic(hidden_state: Tensor, embedded_obs: Tensor | None = None) tuple[Tensor, Tensor]#
Infer the stochastic state.
Infer the stochastic state from the hidden state (using the transition model) or from the hidden state and the embedded observation (using the representation model).
- Parameters:
hidden_state (torch.Tensor) – The hidden state from the recurrent model, of shape (batch_dim, hidden_dim).
embedded_obs (torch.Tensor | None) – The embedded observation, of shape (batch_dim, obs_dim). (Default:
None)
- Returns:
The logits and sampled stochastic state.
- Return type:
tuple[torch.Tensor, torch.Tensor]
Trainable initial hidden state of the recurrent model.
- Type:
torch.Tensor
- class fishyrl.models.RecurrentModel(input_dim: int, hidden_dim: int)#
Bases:
ModuleRecurrent model for processing sequences of latent representations.
Uses a GRU to process sequences of latent representations, with layer normalization and SiLU activations on the output.
- __init__(input_dim: int, hidden_dim: int) None#
Initialize the recurrent model.
- Parameters:
input_dim (int) – The size of the input tensor.
hidden_dim (int) – The size of the hidden state in the GRU.
- forward(x: Tensor, h: Tensor | None = None) Tensor#
Perform a forward pass through the recurrent model.
- Parameters:
x (torch.Tensor) – The input tensor of shape (batch_dim, seq_len, latent_dim).
h (torch.Tensor | None) – The initial hidden state of shape (1, batch_dim, hidden_dim), or None to use zeros. (Default:
None)
- Returns:
The final hidden state of shape (batch_dim, hidden_dim).
- Return type:
torch.Tensor
Hidden size of the recurrent model.
- Type:
int