fsrl.trainer

class fsrl.trainer.BaseTrainer(learning_type: str, policy: ~fsrl.policy.base_policy.BasePolicy, train_collector: ~fsrl.data.fast_collector.FastCollector, test_collector: ~fsrl.data.fast_collector.FastCollector | None = None, max_epoch: int = 100, batch_size: int = 512, cost_limit: float = inf, step_per_epoch: int | None = None, repeat_per_collect: int | None = None, update_per_step: int | float = 1, save_model_interval: int = 1, episode_per_test: int | None = None, episode_per_collect: int = 1, stop_fn: ~typing.Callable[[float, float], bool] | None = None, resume_from_log: bool = False, logger: ~fsrl.utils.logger.base_logger.BaseLogger = <fsrl.utils.logger.base_logger.BaseLogger object>, verbose: bool = True, show_progress: bool = True)[source]

Bases: ABC

An iterator base class for trainers procedure.

Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch. The usage of the trainer is almost identical with tianshou’s trainer, but with some modifications of the parameters.

Parameters:
  • str (learning_type) – type of learning iterator, available choices are “offpolicy” and “onpolicy”, we don’t support “offline” yet.

  • policy – an instance of the BasePolicy class.

  • train_collector – the collector used for training.

  • test_collector – the collector used for testing. If it’s None, then no testing will be performed.

  • max_epoch (int) – the maximum number of epochs for training. The training process might be finished before reaching max_epoch if stop_fn is set.

  • batch_size (int) – the batch size of sample data, which is going to feed in the policy network.

  • cost_limit (int) – the constraint violation threshold.

  • step_per_epoch (int) – the number of transitions collected per epoch.

  • repeat_per_collect (int) – the number of repeat time for policy learning, for example, set it to 2 means the policy needs to learn each given batch data twice. (on-policy method)

  • update_per_step (float) – the number of gradient steps per env_step (off-policy).

  • save_model_interval (float) – how many epochs to save one checkpoint.

  • episode_per_test (int) – the number of episodes for one policy evaluation.

  • episode_per_collect (int) – the number of episodes the collector would collect before the network update, i.e., trainer will collect “episode_per_collect” episodes and do some policy network update repeatedly in each epoch.

  • stop_fn (function) – a function with signature f(reward, cost) -> bool, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal.

  • resume_from_log (bool) – resume env_step and other metadata from existing tensorboard log. Default to False.

  • logger (BaseLogger) – A logger that logs statistics during training/testing/updating. Default to a logger that doesn’t log anything.

  • verbose (bool) – whether to print tabular information. Default to True.

  • show_progress (bool) – whether to display a progress bar when training. Default to True.

static gen_doc(learning_type: str) str[source]

Document string for subclass trainer.

reset() None[source]

Initialize or reset the instance to yield a new iterator from zero.

perf_is_better(test: bool = True) bool[source]
test_step() Dict[str, Any][source]

Perform one testing step.

train_step() Dict[str, Any][source]

Perform one training step.

abstract policy_update_fn(result: Dict[str, Any]) None[source]

Policy update function for different trainer implementation.

Parameters:

result – collector’s return value.

run() Dict[str, float | str][source]

Consume iterator.

See itertools - recipes. Use functions that consume iterators at C speed (feed the entire iterator into a zero-length deque).

gather_update_info() Dict[str, Any][source]

A simple wrapper of gathering information from collectors.

Returns:

A dictionary with the following keys:

  • train_collector_time the time (s) for collecting transitions in the training collector;

  • train_model_time the time (s) for training models;

  • train_speed the speed of training (env_step per second);

  • test_time the time (s) for testing;

  • test_speed the speed of testing (env_step per second);

  • duration the total elapsed time (s).

class fsrl.trainer.OnpolicyTrainer(policy: ~fsrl.policy.base_policy.BasePolicy, train_collector: ~fsrl.data.fast_collector.FastCollector, test_collector: ~fsrl.data.fast_collector.FastCollector | None = None, max_epoch: int = 10000, batch_size: int = 512, cost_limit: float = inf, step_per_epoch: int = 10000, repeat_per_collect: int = 4, episode_per_collect: int = 10, save_model_interval: int = 1, episode_per_test: int | None = None, stop_fn: ~typing.Callable[[float, float], bool] | None = None, resume_from_log: bool = False, logger: ~fsrl.utils.logger.base_logger.BaseLogger = <fsrl.utils.logger.base_logger.BaseLogger object>, verbose: bool = True, show_progress: bool = True)[source]

Bases: BaseTrainer

An iterator class for onpolicy trainer procedure.

Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.

The “step” in onpolicy trainer means an environment step (a.k.a. transition).

Example usage:

trainer = OnpolicyTrainer(...) for epoch, epoch_stat, info in trainer:
    print("Epoch:", epoch) print(epoch_stat) print(info)
    do_something_with_policy() query_something_about_policy()
    make_a_plot_with(epoch_stat) display(info)
  • epoch int: the epoch number

  • epoch_stat dict: a large collection of metrics of the current epoch

  • info dict: result returned from gather_update_info()

You can even iterate on several trainers at the same time:

trainer1 = OnpolicyTrainer(...) trainer2 = OnpolicyTrainer(...) for result1,
result2, ... in zip(trainer1, trainer2, ...):
    compare_results(result1, result2, ...)
Parameters:
  • policy – an instance of the BasePolicy class.

  • train_collector (Collector) – the collector used for training.

  • test_collector (Collector) – the collector used for testing. If it’s None, then no testing will be performed.

  • max_epoch (int) – the maximum number of epochs for training. The training process might be finished before reaching max_epoch if stop_fn is set.

  • cost_limit (int) – the constraint violation threshold.

  • step_per_epoch (int) – the number of transitions collected per epoch.

  • repeat_per_collect (int) – the number of repeat time for policy learning, for example, set it to 2 means the policy needs to learn each given batch data twice.

  • episode_per_test (int) – the number of episodes for one policy evaluation.

  • save_model_interval (int) – how many epochs to save one checkpoint.

  • batch_size (int) – the batch size of sample data, which is going to feed in the policy network.

  • step_per_collect (int) – the number of transitions the collector would collect before the network update, i.e., trainer will collect “step_per_collect” transitions and do some policy network update repeatedly in each epoch.

  • episode_per_collect (int) – the number of episodes the collector would collect before the network update, i.e., trainer will collect “episode_per_collect” episodes and do some policy network update repeatedly in each epoch.

  • resume_from_log (bool) – resume env_step/gradient_step and other metadata from existing tensorboard log. Default to False.

  • stop_fn (function) – a function with signature f(mean_rewards: float) -> bool, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal.

  • logger (BaseLogger) – A logger that logs statistics during training/testing/updating. Default to a logger that doesn’t log anything.

  • verbose (bool) – whether to print the information. Default to True.

  • show_progress (bool) – whether to display a progress bar when training. Default to True.

Note

We do not support the step_per_collect option as in Tianshou, but only the episode_per_collect option for collecting data.

policy_update_fn(stats_train: Dict[str, Any]) None[source]

Perform one on-policy update.

class fsrl.trainer.OffpolicyTrainer(policy: ~fsrl.policy.base_policy.BasePolicy, train_collector: ~fsrl.data.fast_collector.FastCollector, test_collector: ~fsrl.data.fast_collector.FastCollector | None = None, max_epoch: int = 1000, batch_size: int = 512, cost_limit: float = inf, step_per_epoch: int = 10000, update_per_step: float = 0.1, episode_per_collect: int = 1, save_model_interval: int = 1, episode_per_test: int | None = None, stop_fn: ~typing.Callable[[float, float], bool] | None = None, resume_from_log: bool = False, logger: ~fsrl.utils.logger.base_logger.BaseLogger = <fsrl.utils.logger.base_logger.BaseLogger object>, verbose: bool = True, show_progress: bool = True)[source]

Bases: BaseTrainer

An iterator class for offpolicy trainer procedure.

Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch.

The “step” in offpolicy trainer means an environment step (a.k.a. transition).

Example usage:

trainer = OffpolicyTrainer(...) for epoch, epoch_stat, info in trainer:
    print("Epoch:", epoch) print(epoch_stat) print(info)
    do_something_with_policy() query_something_about_policy()
    make_a_plot_with(epoch_stat) display(info)
  • epoch int: the epoch number

  • epoch_stat dict: a large collection of metrics of the current epoch

  • info dict: result returned from gather_update_info()

You can even iterate on several trainers at the same time:

trainer1 = OffpolicyTrainer(...) trainer2 = OffpolicyTrainer(...) for result1,
result2, ... in zip(trainer1, trainer2, ...):
    compare_results(result1, result2, ...)
Parameters:
  • policy – an instance of the BasePolicy class.

  • train_collector (Collector) – the collector used for training.

  • test_collector (Collector) – the collector used for testing. If it’s None, then no testing will be performed.

  • max_epoch (int) – the maximum number of epochs for training. The training process might be finished before reaching max_epoch if stop_fn is set.

  • batch_size (int) – the batch size of sample data, which is going to feed in the policy network.

  • cost_limit (int) – the constraint violation threshold.

  • step_per_epoch (int) – the number of transitions collected per epoch.

  • update_per_step (float) – the number of times the policy network would be updated per transition after (step_per_collect) transitions are collected, e.g., if update_per_step set to 0.3, and step_per_collect is 256 , policy will be updated round(256 * 0.3 = 76.8) = 77 times after 256 transitions are collected by the collector.

  • episode_per_collect (int) – the number of episodes the collector would collect before the network update, i.e., trainer will collect “episode_per_collect” episodes and do some policy network update repeatedly in each epoch.

  • save_model_interval (int) – how many epochs to save one checkpoint.

  • episode_per_test (int) – the number of episodes for one policy evaluation.

  • stop_fn (function) – a function with signature f(mean_rewards: float) -> bool, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal.

  • resume_from_log (bool) – resume env_step/gradient_step and other metadata from existing tensorboard log. Default to False.

  • logger (BaseLogger) – A logger that logs statistics during training/testing/updating. Default to a logger that doesn’t log anything.

  • verbose (bool) – whether to print the information. Default to True.

  • show_progress (bool) – whether to display a progress bar when training. Default to True.

Note

We do not support the step_per_collect option as in Tianshou, but only the episode_per_collect option for collecting data.

policy_update_fn(stats_train: Dict[str, Any]) None[source]

Perform off-policy updates.

fsrl.trainer.onpolicy_trainer(*args, **kwargs) Dict[str, float | str][source]

Wrapper for OnpolicyTrainer run method.

It is identical to OnpolicyTrainer(...).run().

Returns:

See gather_info().

fsrl.trainer.offpolicy_trainer(*args, **kwargs) Dict[str, float | str][source]

Wrapper for OnpolicyTrainer run method.

It is identical to OnpolicyTrainer(...).run().

Returns:

See gather_info().