Source code for fsrl.utils.exp_util

# flake8: noqa

import os
import os.path as osp
import random
import uuid
from typing import Dict, List, Optional, Sequence

import numpy as np
import torch
import yaml

from fsrl.utils.logger.logger_util import colorize


[docs]def seed_all(seed=1029, others: Optional[list] = None) -> None: """Fix the seeds of `random`, `numpy`, `torch` and the input `others` object. :param int seed: defaults to 1029 :param Optional[list] others: other objects that want to be seeded, defaults to None """ random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) # torch.use_deterministic_algorithms(True) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True if others is not None: if hasattr(others, "seed"): others.seed(seed) return True try: for item in others: if hasattr(item, "seed"): item.seed(seed) except: pass
def get_cfg_value(config, key): if key in config: value = config[key] if isinstance(value, list): suffix = "" for i in value: suffix += str(i) return suffix return str(value) for k in config.keys(): if isinstance(config[k], dict): res = get_cfg_value(config[k], key) if res is not None: return res return "None"
[docs]def load_config_and_model(path: str, best: bool = False): """ Load the configuration and trained model from a specified directory. :param path: the directory path where the configuration and trained model are stored. :param best: whether to load the best-performing model or the most recent one. Defaults to False. :return: a tuple containing the configuration dictionary and the trained model. :raises ValueError: if the specified directory does not exist. """ if osp.exists(path): config_file = osp.join(path, "config.yaml") print(f"load config from {config_file}") with open(config_file) as f: config = yaml.load(f.read(), Loader=yaml.FullLoader) model_file = "model.pt" if best: model_file = "model_best.pt" model_path = osp.join(path, "checkpoint/" + model_file) print(f"load model from {model_path}") model = torch.load(model_path) return config, model else: raise ValueError(f"{path} doesn't exist!")
# naming utils
[docs]def to_string(values): """ Recursively convert a sequence or dictionary of values to a string representation. :param values: the sequence or dictionary of values to be converted to a string. :return: a string representation of the input values. """ name = "" if isinstance(values, Sequence) and not isinstance(values, str): for i, v in enumerate(values): prefix = "" if i == 0 else "_" name += prefix + to_string(v) return name elif isinstance(values, Dict): for i, k in enumerate(sorted(values.keys())): prefix = "" if i == 0 else "_" name += prefix + to_string(values[k]) return name else: return str(values)
DEFAULT_SKIP_KEY = [ "task", "reward_threshold", "logdir", "worker", "project", "group", "name", "prefix", "suffix", "save_interval", "render", "verbose", "save_ckpt", "training_num", "testing_num", "epoch", "device", "thread" ] DEFAULT_KEY_ABBRE = { "cost_limit": "cost", "mstep_iter_num": "mnum", "estep_iter_num": "enum", "estep_kl": "ekl", "mstep_kl_mu": "kl_mu", "mstep_kl_std": "kl_std", "mstep_dual_lr": "mlr", "estep_dual_lr": "elr", "update_per_step": "update" }
[docs]def auto_name( default_cfg: dict, current_cfg: dict, prefix: str = "", suffix: str = "", skip_keys: list = DEFAULT_SKIP_KEY, key_abbre: dict = DEFAULT_KEY_ABBRE ) -> str: """Automatic generate the name by comparing the current config with the default one. :param dict default_cfg: a dictionary containing the default configuration values. :param dict current_cfg: a dictionary containing the current configuration values. :param str prefix: (optional) a string to be added at the beginning of the generated name. :param str suffix: (optional) a string to be added at the end of the generated name. :param list skip_keys: (optional) a list of keys to be skipped when generating the name. :param dict key_abbre: (optional) a dictionary containing abbreviations for keys in the generated name. :return str: a string representing the generated experiment name. """ name = prefix for i, k in enumerate(sorted(default_cfg.keys())): if default_cfg[k] == current_cfg[k] or k in skip_keys: continue prefix = "_" if len(name) else "" value = to_string(current_cfg[k]) # replace the name with abbreviation if key has abbreviation in key_abbre if k in key_abbre: k = key_abbre[k] # Add the key-value pair to the name variable with the prefix name += prefix + k + value if len(suffix): name = name + "_" + suffix if len(name) else suffix name = "default" if not len(name) else name name = f"{name}-{str(uuid.uuid4())[:4]}" return name