from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Tuple, Union
import gymnasium as gym
from tianshou.data import ReplayBuffer, VectorReplayBuffer
from tianshou.env import BaseVectorEnv, DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv
from fsrl.data import FastCollector
from fsrl.policy import BasePolicy
from fsrl.trainer import OffpolicyTrainer, OnpolicyTrainer
from fsrl.utils import BaseLogger
[docs]class BaseAgent(ABC):
"""The base class for a default agent.
A agent class should have the following parts:
* :meth:`~fsrl.agent.BaseAgent.__init__`: initialize the agent, including the policy,
networks, optimizers, and so on;
* :meth:`~fsrl.agent.BaseAgent.learn`: start training given the learning parameters;
* :meth:`~fsrl.agent.BaseAgent.evaluate`: evaluate the agent multiple episodes;
* :attr:`~fsrl.agent.BaseAgent.state_dict`: the agent state dictionary that can be
saved as checkpoints;
Example of usage: ::
# initialize the CVPO agent
agent = CVPOAgent(env, other_algo_params) # train multiple epochs
agent.learn(training_envs, other_training_params)
# test after the training is finished agent.eval(testing_envs)
# test with agent's state_dict agent.eval(testing_envs, agent.state_dict)
All of the agent classes must inherit :class:`~fsrl.agent.BaseAgent`.
"""
name = "BaseAgent"
@abstractmethod
def __init__(self, *args, **kwargs) -> None:
self.policy: BasePolicy
self.task = None
self.logger = BaseLogger()
self.cost_limit = 0
[docs] @abstractmethod
def learn(self, *args, **kwargs) -> None:
"""Train the policy on a set of training environments."""
raise NotImplementedError
[docs] def evaluate(
self,
test_envs: Union[gym.Env, BaseVectorEnv],
state_dict: Optional[dict] = None,
eval_episodes: int = 10,
render: bool = False,
train_mode: bool = False
) -> Tuple[float, float, float]:
"""Evaluate the policy on a set of test environments.
:param Union[gym.Env, BaseVectorEnv] test_envs: A single environment or a
vectorized environment to evaluate the policy on.
:param Optional[dict] state_dict: An optional dictionary containing the state
params of the agent to be evaluated., defaults to None
:param int eval_episodes: The number of episodes to evaluate, defaults to 10
:param bool render: Whether to render the environment during evaluation, defaults
to False
:param bool train_mode: Whether to set the policy to training mode during
evaluation, defaults to False
:return Tuple: rewards, episode lengths, and constraint costs obtained during
evaluation.
"""
if state_dict is not None:
self.policy.load_state_dict(state_dict)
if train_mode:
self.policy.train()
else:
self.policy.eval()
eval_collector = FastCollector(self.policy, test_envs)
result = eval_collector.collect(n_episode=eval_episodes, render=render)
rews, lens, cost = result["rew"], result["len"], result["cost"]
# term, trun = result["terminated"], result["truncated"] print(f"Termination:
# {term}, truncation: {trun}") print(f"Eval reward: {rews.mean()}, cost: {cost},
# length: {lens.mean()}")
return rews, lens, cost
@property
def state_dict(self):
"""Return the policy's state_dict."""
return self.policy.state_dict()
[docs]class OffpolicyAgent(BaseAgent):
"""The base class for an off-policy agent.
The :meth:`~srl.agent.OffpolicyAgent.learn`: function is customized to work with the
off-policy trainer. See :class:`~fsrl.agent.BaseAgent` for more details.
"""
name = "OffpolicyAgent"
def __init__(self) -> None:
super().__init__()
[docs] def learn(
self,
train_envs: Union[gym.Env, BaseVectorEnv],
test_envs: Union[gym.Env, BaseVectorEnv] = None,
epoch: int = 300,
episode_per_collect: int = 5,
step_per_epoch: int = 3000,
update_per_step: float = 0.1,
buffer_size: int = 100000,
testing_num: int = 2,
batch_size: int = 256,
reward_threshold: float = 450,
save_interval: int = 4,
resume: bool = False, # TODO
save_ckpt: bool = True,
verbose: bool = True,
show_progress: bool = True
) -> None:
"""Train the policy on a set of training environments.
:param Union[gym.Env, BaseVectorEnv] train_envs: A single environment or a
vectorized environment to train the policy on.
:param Union[gym.Env, BaseVectorEnv] test_envs: A single environment or a
vectorized environment to evaluate the policy on, default to None.
:param int epoch: The number of training epochs, defaults to 300.
:param int episode_per_collect: The number of episodes to collect before each
policy update, defaults to 5.
:param int step_per_epoch: The number of environment steps per epoch, defaults to
3000.
:param float update_per_step: The ratio of policy updates to environment steps, \
defaults to 0.1.
:param int buffer_size: The maximum size of the replay buffer, defaults to
100000.
:param int testing_num: The number of episodes to use for evaluation, defaults to
2.
:param int batch_size: The batch size for each policy update, defaults to 256.
:param float reward_threshold: The reward threshold for early stopping, \
defaults to 450.
:param int save_interval: The interval (in epochs) for saving the policy model,
defaults to 4.
:param bool resume: Whether to resume training from the last checkpoint, defaults
to False.
:param bool save_ckpt: Whether to save the policy model, defaults to True.
:param bool verbose: Whether to print progress information during training,
defaults to True.
:param bool show_progress: Whether to show the tqdm training progress bar,
defaults to True
"""
assert self.policy is not None, "The policy is not initialized"
# set policy to train mode
self.policy.train()
# collector
if isinstance(train_envs, gym.Env):
buffer = ReplayBuffer(buffer_size)
else:
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
train_collector = FastCollector(
self.policy,
train_envs,
buffer,
exploration_noise=True,
)
test_collector = FastCollector(
self.policy, test_envs
) if test_envs is not None else None
def stop_fn(reward, cost):
return reward > reward_threshold and cost < self.cost_limit
def checkpoint_fn():
return {"model": self.state_dict}
if save_ckpt:
self.logger.setup_checkpoint_fn(checkpoint_fn)
# trainer
trainer = OffpolicyTrainer(
policy=self.policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=epoch,
batch_size=batch_size,
cost_limit=self.cost_limit,
step_per_epoch=step_per_epoch,
update_per_step=update_per_step,
episode_per_test=testing_num,
episode_per_collect=episode_per_collect,
stop_fn=stop_fn,
logger=self.logger,
resume_from_log=resume,
save_model_interval=save_interval,
verbose=verbose,
show_progress=show_progress
)
for epoch, _epoch_stat, info in trainer:
self.logger.store(tab="train", cost_limit=self.cost_limit)
if verbose:
print(f"Epoch: {epoch}", info)
return epoch, _epoch_stat, info
[docs]class OnpolicyAgent(BaseAgent):
"""The base class for an on-policy agent.
The :meth:`~srl.agent.OnpolicyAgent.learn`: function is customized to work with the \
on-policy trainer.
See :class:`~fsrl.agent.BaseAgent` for more details.
"""
name = "OnpolicyAgent"
def __init__(self) -> None:
super().__init__()
[docs] def learn(
self,
train_envs: Union[gym.Env, BaseVectorEnv],
test_envs: Union[gym.Env, BaseVectorEnv] = None,
epoch: int = 300,
episode_per_collect: int = 20,
step_per_epoch: int = 10000,
repeat_per_collect: int = 4,
buffer_size: int = 100000,
testing_num: int = 2,
batch_size: int = 512,
reward_threshold: float = 450,
save_interval: int = 4,
resume: bool = False,
save_ckpt: bool = True,
verbose: bool = True,
show_progress: bool = True
) -> None:
"""Train the policy on a set of training environments.
:param Union[gym.Env, BaseVectorEnv] train_envs: A single environment or a
vectorized environment to train the policy on.
:param Union[gym.Env, BaseVectorEnv] test_envs: A single environment or a
vectorized environment to evaluate the policy on, defaults to None.
:param int epoch: The number of training epochs, defaults to 300
:param int episode_per_collect: The number of episodes collected per data
collection, defaults to 20
:param int step_per_epoch: The number of steps per training epoch, defaults to
10000
:param int repeat_per_collect: The number of repeats of policy update for one
episode collection, defaults to 4
:param int buffer_size: The size of the replay buffer, defaults to 100000
:param int testing_num: The number of episodes to evaluate during testing,
defaults to 2
:param int batch_size: The batch size for training, default is 99999 for
:class:`~fsrl.agent.TRPOLagAgent` :class:`~fsrl.agent.CPOLagAgent`, and is
512 for others
:param float reward_threshold: The threshold for stopping training when the mean
reward exceeds it, defaults to 450
:param int save_interval: The number of epochs to save the policy, defaults to 4
:param bool resume: Whether to resume training from the saved checkpoint,
defaults to False
:param bool save_ckpt: Whether to save the policy model, defaults to True
:param bool verbose: Whether to print the training information, defaults to True
:param bool show_progress: Whether to show the tqdm training progress bar,
defaults to True
"""
assert self.policy is not None, "The policy is not initialized"
# set policy to train mode
self.policy.train()
# collector
if isinstance(train_envs, gym.Env):
buffer = ReplayBuffer(buffer_size)
else:
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
train_collector = FastCollector(
self.policy,
train_envs,
buffer,
exploration_noise=True,
)
test_collector = FastCollector(
self.policy, test_envs
) if test_envs is not None else None
def stop_fn(reward, cost):
return reward > reward_threshold and cost < self.cost_limit
def checkpoint_fn():
return {"model": self.state_dict}
if save_ckpt:
self.logger.setup_checkpoint_fn(checkpoint_fn)
# trainer
trainer = OnpolicyTrainer(
policy=self.policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=epoch,
batch_size=batch_size,
cost_limit=self.cost_limit,
step_per_epoch=step_per_epoch,
repeat_per_collect=repeat_per_collect,
episode_per_test=testing_num,
episode_per_collect=episode_per_collect,
stop_fn=stop_fn,
logger=self.logger,
resume_from_log=resume,
save_model_interval=save_interval,
verbose=verbose,
show_progress=show_progress
)
for epoch, _epoch_stat, info in trainer:
self.logger.store(tab="train", cost_limit=self.cost_limit)
if verbose:
print(f"Epoch: {epoch}", info)
return epoch, _epoch_stat, info