fishyrl.utilities module#
Utilities for state management and common operations.
- class fishyrl.utilities.CaseInsensitiveEnumMeta(cls, bases, classdict, *, boundary=None, _simple=False, **kwds)#
Bases:
EnumTypeEnum meta class for case-insensitive lookup.
- class fishyrl.utilities.Container(**modules: Any)#
Bases:
objectContainer for containing multiple submodules and utilities, without torch integration.
- __init__(**modules: Any) None#
Initialize the
Container.- Parameters:
modules (dict[str, Any]) – The submodules and utilities to contain.
- load_state_dict(state_dict: dict[str, Any]) None#
Load the state of the container from a dictionary.
- Parameters:
state_dict (dict[str, Any]) – The state dictionary.
- state_dict() dict[str, Any]#
Return the state of the container as a dictionary.
- Returns:
A dictionary containing the state of the container.
- Return type:
dict[str, Any]
- class fishyrl.utilities.ContainerModule(**modules: Module)#
Bases:
ModuleModule for containing multiple submodules.
- __init__(**modules: Module) None#
Initialize the
ContainerModule.- Parameters:
modules (dict[str, nn.Module]) – The submodules to contain.
- class fishyrl.utilities.DotDict(*args: list, **kwargs: dict[str, Any])#
Bases:
dictAllow accessing dictionary keys as attributes.
Attribute access code from https://stackoverflow.com/a/23689767 and SheepRL
- __init__(*args: list, **kwargs: dict[str, Any]) None#
Initialize the
DotDict.- Parameters:
args (list) – Positional arguments to initialize the dictionary.
kwargs (dict[str, Any]) – Keyword arguments to initialize the dictionary.
- _crawl(lookup: dict | list) None#
Recursively convert nested dictionaries to
DotDict.Crawls through dictionaries and lists.
- Parameters:
lookup (dict | list) – The dictionary or list to crawl through.
- class fishyrl.utilities.MovingMinMaxScaler(beta: float = 0.99, frac_low: float = 0.05, frac_high: float = 0.95, eps: float = 1e-08)#
Bases:
ModuleMoving percentile-based min-max scaler for normalizing inputs.
- __init__(beta: float = 0.99, frac_low: float = 0.05, frac_high: float = 0.95, eps: float = 1e-08) None#
Initialize the
MovingMinMaxScaler.- Parameters:
beta (float) – The decay rate for the moving min and max. (Default:
0.99)eps (float) – Minimal value for the computed high-low range. (Default:
1e-8)frac_low (float) – The lower percentile for scaling. (Default:
0.05)frac_high (float) – The upper percentile for scaling. (Default:
0.95)
- forward(x: Tensor) tuple[Tensor, Tensor]#
Update and return the low and range estimates.
- Parameters:
x (torch.Tensor) – The input tensor to use while updating the estimates.
- Returns:
A tuple containing the low estimate and the range estimate.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- class fishyrl.utilities.Ratio(ratio: float = 1.0)#
Bases:
ModuleModule for computing the number of gradient update steps.
- __init__(ratio: float = 1.0) None#
Initialize
Ratio.- Parameters:
ratio (float) – The ratio of gradient update steps to environment steps. (Default:
1.0)
- load_state_dict(state_dict: dict[str, Any]) None#
Load the state of the module from a dictionary.
- Parameters:
state_dict (dict[str, Any]) – The state dictionary to load from.
- state_dict() dict[str, Any]#
Return the state of the module as a dictionary.
- Returns:
A dictionary containing the state of the module.
- Return type:
dict[str, Any]
- fishyrl.utilities._flatten_dict(base: DotDict | dict, exceptions: list[str] = [], exclusions: list[str] = [], sep: str = '_', _key: str = '', _result: dict = {}) dict#
Recursively flatten a nested
DotDictinto a single-levelDotDictwith concatenated keys.- Parameters:
base (DotDict) – The base
DotDictto flatten.exceptions (list[str]) – A list of keys to exclude from flattening. (Default:
[]) See the documentation for theoptional_flatten_cfgdecorator for more details.exclusions (list[str]) – A list of keys to exclude from being passed to the function. (Default:
[])sep (str) – The separator to use when concatenating keys. (Default:
'_')_key (str) – The current key prefix during recursion. (Default:
'')_result (dict) – The current result dictionary during recursion. (Default:
{})
- Returns:
A flattened dictionary with concatenated keys.
- Return type:
dict
- fishyrl.utilities._merge_dotdicts(base: DotDict, new: DotDict, list_behavior: str = 'replace') None#
Merge two
DotDictobjects, with priority given to the new one.
- fishyrl.utilities.init_weights(m)#
- fishyrl.utilities.load_config(*paths: list[str], list_behavior: str = 'replace') DotDict#
Load and merge YAML configuration files into a single
DotDict, with priority given to earlier files.- Parameters:
paths (list[str]) – The paths to the YAML configuration files to load.
list_behavior (str) – The behavior for merging lists. Can be either ‘replace’ or ‘merge’. (Default:
'replace')
- Returns:
A
DotDictcontaining the merged configuration.- Return type:
- fishyrl.utilities.optional_flatten_cfg(func: callable = None, exceptions: list[str] = [], exclusions: list[str] = []) callable#
Decorator to optionallly flatten a config DotDict before passing it to the function.
- Parameters:
func (Callable) – The function to decorate.
exceptions (list[str]) – A list of keys to exclude from flattening. (Default:
[]) As an example, the key'model_default'will tell the decorator to not flattencfg.model.default.*and instead pass it to the function as a DotDict using the argument'model_default', where it would have otherwise passed'model_default_embedded','model_default_blocks', etc.exclusions (list[str]) – A list of keys to exclude from being passed to the function. (Default:
[])
- fishyrl.utilities.uniform_init_weights(given_scale)#