Source code for fsrl.data.basic_collector

import time
from typing import Any, Callable, Dict, List, Optional, Union

import gymnasium as gym
import numpy as np
import torch
from tianshou.data import Batch, ReplayBuffer, to_numpy

from fsrl.data.traj_buf import TrajectoryBuffer
from fsrl.policy import BasePolicy


[docs]class BasicCollector: """A basic collector for a single environment. This collector doesn't support vector env and is served as experimental purpose. It supports to store collected data in the :class:`~fsrl.data.TrajectoryBuffer` with a grid filter, which can be used to memory-efficiently collect trajectory-wise interaction dataset. Example of data saving: :: traj_buffer = TrajectoryBuffer(max_traj_num) collector = BasicCollector(policy, env, traj_buffer=traj_buffer) collector.collect(n_episodes) traj_buffer.save(logdir) :param policy: an instance of the :class:`~fsrl.policy.BasePolicy` class. :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. If set to None, it will not store the data. Default to None. :param bool exploration_noise: determine whether the action needs to be modified with corresponding policy's exploration noise. If so, "policy. exploration_noise(act, batch)" will be called automatically to add the exploration noise into action. Default to False. :param TrajectoryBuffer traj_buffer: the buffer used to store trajectories .. note:: Please make sure the given environment has a time limitation (can be done), \ because we only support the `n_episode` collect option. """ def __init__( self, policy: BasePolicy, env: gym.Env, buffer: Optional[ReplayBuffer] = None, exploration_noise: Optional[bool] = False, traj_buffer: Optional[TrajectoryBuffer] = None, ): self.env = env self.policy = policy if buffer is None: buffer = ReplayBuffer(1) self.buffer = buffer self.exploration_noise = exploration_noise self._action_space = self.env.action_space self.traj_buffer = traj_buffer self.reset(False)
[docs] def reset( self, reset_buffer: bool = True, gym_reset_kwargs: Optional[Dict[str, Any]] = None, ) -> None: """Reset the environment, statistics, current data and possibly replay memory. :param bool reset_buffer: if true, reset the replay buffer that is attached to the collector. :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Defaults to None (extra keyword arguments) """ # use empty Batch for "state" so that self.data supports slicing convert empty # Batch to None when passing data to policy self.data = Batch( obs={}, act={}, rew={}, cost={}, terminated={}, truncated={}, done={}, obs_next={}, info={} ) self.reset_env(gym_reset_kwargs) if reset_buffer: self.reset_buffer() self.reset_stat()
[docs] def reset_buffer(self, keep_statistics: bool = False) -> None: """Reset the data buffer.""" self.buffer.reset(keep_statistics=keep_statistics)
[docs] def reset_stat(self) -> None: """Reset the statistic variables.""" self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0
[docs] def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None: """Reset all of the environments.""" gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {} rval = self.env.reset(**gym_reset_kwargs) returns_info = isinstance(rval, (tuple, list )) and len(rval) == 2 and isinstance(rval[1], dict) if returns_info: obs, info = rval self.data.info = [info] else: obs = rval self.data.obs = [obs]
[docs] def collect( self, n_episode: int = 0, random: bool = False, render: Optional[float] = None, no_grad: bool = True, gym_reset_kwargs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Collect a specified number of step or episode. To ensure unbiased sampling result with n_episode option, this function will first collect ``n_episode - env_num`` episodes, then for the last ``env_num`` episodes, they will be collected evenly from each env. :param int n_episode: how many episodes you want to collect. :param bool random: whether to use random policy for collecting data. Default to False. :param float render: the sleep time between rendering consecutive frames. Default to None (no rendering). :param bool no_grad: whether to retain gradient in policy.forward(). Default to True (no gradient retaining). :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Defaults to None (extra keyword arguments) .. note:: We don not support the `n_step` collection method in Tianshou, because using `n_episode` only can facilitate the episodic cost computation and better evaluate the agent. :return: A dict including the following keys * ``n/ep`` collected number of episodes. * ``n/st`` collected number of steps. * ``rew`` mean of episodic rewards. * ``len`` mean of episodic lengths. * ``total_cost`` cumulative costs in this collect. * ``cost`` mean of episodic costs. * ``truncated`` mean of episodic truncation. * ``terminated`` mean of episodic termination. """ start_time = time.time() step_count = 0 total_cost = 0 termination_count = 0 truncation_count = 0 episode_count = 0 episode_rews = [] episode_lens = [] while True: # get the next action if random: act_sample = self._action_space.sample() act_sample = self.policy.map_action_inverse(act_sample) self.data.update(act=[act_sample]) else: if no_grad: with torch.no_grad(): # faster than retain_grad version result = self.policy( Batch(obs=self.data.obs, info=self.data.info) ) else: result = self.policy(Batch(obs=self.data.obs, info=self.data.info)) act = to_numpy(result.act)[0] # print(act) if self.exploration_noise: act = self.policy.exploration_noise(act, self.data) self.data.update(act=[act]) # get bounded and remapped actions first (not saved into buffer) action_remap = np.squeeze(self.policy.map_action(self.data.act)) # print(action_remap.shape) print("Env action space shape: ", # np.shape(self.env.action_space.sample())) step in env result = self.env.step(action_remap) if len(result) == 5: obs_next, rew, terminated, truncated, info = result done = np.logical_or(terminated, truncated) elif len(result) == 4: obs_next, rew, done, info = result if isinstance(info, dict): truncated = info["TimeLimit.truncated"] else: truncated = np.array( [ info_item.get("TimeLimit.truncated", False) for info_item in info ] ) terminated = np.logical_and(done, ~truncated) else: raise ValueError() cost = info.get("cost", 0) self.data.update( obs_next=[obs_next], rew=[rew], terminated=[terminated], truncated=[truncated], done=[done], cost=[cost], info=[info] ) termination_count += terminated truncation_count += truncated total_cost += cost if render: self.env.render() if render > 0 and not np.isclose(render, 0): time.sleep(render) # add data into the buffer ptr, ep_rew, ep_len, ep_idx = self.buffer.add(self.data, 1) if self.traj_buffer is not None: traj_data = Batch( observations=self.data.obs, next_observations=self.data.obs_next, actions=[action_remap], rewards=self.data.rew, costs=self.data.cost, terminals=self.data.terminated, timeouts=self.data.truncated ) self.traj_buffer.store(traj_data) step_count += 1 if done: episode_count += 1 episode_lens.append(ep_len) episode_rews.append(ep_rew) # now we copy obs_next to obs, but since there might be finished # episodes, we have to reset finished envs first. self.reset_env(gym_reset_kwargs) self.data.obs = self.data.obs_next if episode_count >= n_episode: break # generate statistics self.collect_step += step_count self.collect_episode += episode_count self.collect_time += max(time.time() - start_time, 1e-9) self.reset_env() done_count = truncation_count + termination_count return { "n/ep": episode_count, "n/st": step_count, "rew": np.mean(episode_rews), "len": np.mean(episode_lens), "total_cost": total_cost, "cost": total_cost / episode_count, "truncated": truncation_count / done_count, "terminated": termination_count / done_count, }