fishyrl.utilities module#

Utilities for state management and common operations.

class fishyrl.utilities.CaseInsensitiveEnumMeta(cls, bases, classdict, *, boundary=None, _simple=False, **kwds)#

Bases: EnumType

Enum meta class for case-insensitive lookup.

class fishyrl.utilities.Container(**modules: Any)#

Bases: object

Container for containing multiple submodules and utilities, without torch integration.

__init__(**modules: Any) None#

Initialize the Container.

Parameters:

modules (dict[str, Any]) – The submodules and utilities to contain.

load_state_dict(state_dict: dict[str, Any]) None#

Load the state of the container from a dictionary.

Parameters:

state_dict (dict[str, Any]) – The state dictionary.

state_dict() dict[str, Any]#

Return the state of the container as a dictionary.

Returns:

A dictionary containing the state of the container.

Return type:

dict[str, Any]

class fishyrl.utilities.ContainerModule(**modules: Module)#

Bases: Module

Module for containing multiple submodules.

__init__(**modules: Module) None#

Initialize the ContainerModule.

Parameters:

modules (dict[str, nn.Module]) – The submodules to contain.

class fishyrl.utilities.DotDict(*args: list, **kwargs: dict[str, Any])#

Bases: dict

Allow accessing dictionary keys as attributes.

Attribute access code from https://stackoverflow.com/a/23689767 and SheepRL

__init__(*args: list, **kwargs: dict[str, Any]) None#

Initialize the DotDict.

Parameters:
  • args (list) – Positional arguments to initialize the dictionary.

  • kwargs (dict[str, Any]) – Keyword arguments to initialize the dictionary.

_crawl(lookup: dict | list) None#

Recursively convert nested dictionaries to DotDict.

Crawls through dictionaries and lists.

Parameters:

lookup (dict | list) – The dictionary or list to crawl through.

class fishyrl.utilities.MovingMinMaxScaler(beta: float = 0.99, frac_low: float = 0.05, frac_high: float = 0.95, eps: float = 1e-08)#

Bases: Module

Moving percentile-based min-max scaler for normalizing inputs.

__init__(beta: float = 0.99, frac_low: float = 0.05, frac_high: float = 0.95, eps: float = 1e-08) None#

Initialize the MovingMinMaxScaler.

Parameters:
  • beta (float) – The decay rate for the moving min and max. (Default: 0.99)

  • eps (float) – Minimal value for the computed high-low range. (Default: 1e-8)

  • frac_low (float) – The lower percentile for scaling. (Default: 0.05)

  • frac_high (float) – The upper percentile for scaling. (Default: 0.95)

forward(x: Tensor) tuple[Tensor, Tensor]#

Update and return the low and range estimates.

Parameters:

x (torch.Tensor) – The input tensor to use while updating the estimates.

Returns:

A tuple containing the low estimate and the range estimate.

Return type:

Tuple[torch.Tensor, torch.Tensor]

class fishyrl.utilities.Ratio(ratio: float = 1.0)#

Bases: Module

Module for computing the number of gradient update steps.

__init__(ratio: float = 1.0) None#

Initialize Ratio.

Parameters:

ratio (float) – The ratio of gradient update steps to environment steps. (Default: 1.0)

load_state_dict(state_dict: dict[str, Any]) None#

Load the state of the module from a dictionary.

Parameters:

state_dict (dict[str, Any]) – The state dictionary to load from.

state_dict() dict[str, Any]#

Return the state of the module as a dictionary.

Returns:

A dictionary containing the state of the module.

Return type:

dict[str, Any]

fishyrl.utilities._flatten_dict(base: DotDict | dict, exceptions: list[str] = [], exclusions: list[str] = [], sep: str = '_', _key: str = '', _result: dict = {}) dict#

Recursively flatten a nested DotDict into a single-level DotDict with concatenated keys.

Parameters:
  • base (DotDict) – The base DotDict to flatten.

  • exceptions (list[str]) – A list of keys to exclude from flattening. (Default: []) See the documentation for the optional_flatten_cfg decorator for more details.

  • exclusions (list[str]) – A list of keys to exclude from being passed to the function. (Default: [])

  • sep (str) – The separator to use when concatenating keys. (Default: '_')

  • _key (str) – The current key prefix during recursion. (Default: '')

  • _result (dict) – The current result dictionary during recursion. (Default: {})

Returns:

A flattened dictionary with concatenated keys.

Return type:

dict

fishyrl.utilities._merge_dotdicts(base: DotDict, new: DotDict, list_behavior: str = 'replace') None#

Merge two DotDict objects, with priority given to the new one.

Parameters:
  • base (DotDict) – The base DotDict to merge into.

  • new (DotDict) – The new DotDict to merge from.

  • list_behavior (str) – The behavior for merging lists. Can be either ‘replace’ or ‘merge’. (Default: 'replace')

fishyrl.utilities.init_weights(m)#
fishyrl.utilities.load_config(*paths: list[str], list_behavior: str = 'replace') DotDict#

Load and merge YAML configuration files into a single DotDict, with priority given to earlier files.

Parameters:
  • paths (list[str]) – The paths to the YAML configuration files to load.

  • list_behavior (str) – The behavior for merging lists. Can be either ‘replace’ or ‘merge’. (Default: 'replace')

Returns:

A DotDict containing the merged configuration.

Return type:

DotDict

fishyrl.utilities.optional_flatten_cfg(func: callable = None, exceptions: list[str] = [], exclusions: list[str] = []) callable#

Decorator to optionallly flatten a config DotDict before passing it to the function.

Parameters:
  • func (Callable) – The function to decorate.

  • exceptions (list[str]) – A list of keys to exclude from flattening. (Default: []) As an example, the key 'model_default' will tell the decorator to not flatten cfg.model.default.* and instead pass it to the function as a DotDict using the argument 'model_default', where it would have otherwise passed 'model_default_embedded', 'model_default_blocks', etc.

  • exclusions (list[str]) – A list of keys to exclude from being passed to the function. (Default: [])

fishyrl.utilities.uniform_init_weights(given_scale)#