Source code for fsrl.policy.focops

import time
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import gymnasium as gym
import numpy as np
import torch
from tianshou.data import Batch, ReplayBuffer, to_torch_as
from torch import nn
from torch.distributions import kl_divergence

from fsrl.policy import BasePolicy
from fsrl.utils import BaseLogger
from fsrl.utils.net.common import ActorCritic


[docs]class FOCOPS(BasePolicy): """Implementation of the First Order Constrained Optimization in Policy Space. More details, please refer to https://arxiv.org/pdf/2002.06506.pdf :param torch.nn.Module actor: the actor network following the rules in :class:`~fsrl.policy.BasePolicy`. (s -> logits) :param Union[nn.Module, List[nn.Module]] critics: the critic network(s). (s -> V(s)) :param torch.optim.Optimizer actor_optim: the optimizer for the actor network. :param torch.optim.Optimizer critic_optim: the optimizer for the critic network(s). :param Type[torch.distributions.Distribution] dist_fn: the probability distribution function for sampling actions. :param BaseLogger logger: the logger instance for logging training information. :param float cost_limit: the constraint limit for the optimization. Default value is 10. :param Union[float, Tuple[float, float, torch.Tensor]] nu: cost coefficient. Default value is 0.01. :param float l2_reg: L2 regularization rate. Default value is 1e-3. :param float delta: early stop KL bound. Default value is 0.02. :param float eta: KL bound for indicator function. Default value is 0.02. :param float tem_lambda: inverse temperature lambda. Default value is 0.95. :param float gae_lambda: GAE (Generalized Advantage Estimation) lambda for advantage computation. Default value is 0.95. :param Optional[float] max_grad_norm: maximum gradient norm for gradient clipping, if specified. Default value is 0.5. :param bool advantage_normalization: normalize advantage if True. Default value is True. :param bool recompute_advantage: recompute advantage using the updated value function. Default value is False. :param float gamma: the discount factor for future rewards. Default value is 0.99. :param int max_batchsize: maximum batch size for the optimization. Default value is 99999. :param bool reward_normalization: normalize the rewards if True. Default value is False. :param bool deterministic_eval: whether to use deterministic action selection during evaluation. Default value is True. :param bool action_scaling: whether to scale the actions according to the action space bounds. Default value is True. :param str action_bound_method: the method for handling actions that exceed the action space bounds ("clip" or other custom methods). Default value is "clip". :param Optional[gym.Space] observation_space: the observation space of the environment. Default value is None. :param Optional[gym.Space] action_space: the action space of the environment. Default value is None. :param Optional[torch.optim.lr_scheduler.LambdaLR] lr_scheduler: learning rate scheduler for the optimizer. Default value is None. .. seealso:: Please refer to :class:`~fsrl.policy.BasePolicy` for more detailed hyperparameter explanations and usage. """ def __init__( self, actor: nn.Module, critics: Union[nn.Module, List[nn.Module]], actor_optim: torch.optim.Optimizer, critic_optim: torch.optim.Optimizer, dist_fn: Type[torch.distributions.Distribution], logger: BaseLogger = BaseLogger(), cost_limit: float = 10, nu: Union[float, Tuple[float, float, torch.Tensor]] = 0.01, l2_reg: float = 1e-3, delta: float = 0.02, eta: float = 0.02, tem_lambda: float = 0.95, gae_lambda: float = 0.95, max_grad_norm: Optional[float] = 0.5, advantage_normalization: bool = True, recompute_advantage: bool = False, gamma: float = 0.99, max_batchsize: int = 99999, reward_normalization: bool = False, deterministic_eval: bool = True, action_scaling: bool = True, action_bound_method: str = "clip", observation_space: Optional[gym.Space] = None, action_space: Optional[gym.Space] = None, lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None ) -> None: super().__init__( actor, critics, dist_fn, logger, gamma, max_batchsize, reward_normalization, deterministic_eval, action_scaling, action_bound_method, observation_space, action_space, lr_scheduler ) self.actor_optim = actor_optim self.critics_optim = critic_optim self.cost_limit = cost_limit self._gae_lambda = gae_lambda self._tem_lambda = tem_lambda self._grad_norm = max_grad_norm self._is_auto_nu = False if isinstance(nu, tuple): self._is_auto_nu = True self._nu_max, self._nu_lr, self._nu = nu else: self._nu = nu self._l2_reg = l2_reg self._delta = delta self._eta = eta self._norm_adv = advantage_normalization self._recompute_adv = recompute_advantage self._actor_critic: ActorCritic
[docs] def pre_update_fn(self, stats_train: Dict, **kwarg) -> Any: self._ave_cost_return = stats_train["cost"]
[docs] def update_cost_limit(self, cost_limit: float) -> None: """Update the cost limit threshold. :param float cost_limit: new cost threshold """ self.cost_limit = [cost_limit] * (self.critics_num - 1) if np.isscalar(cost_limit) else cost_limit
def process_fn( self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray ) -> Batch: if self._recompute_adv: # buffer input `buffer` and `indices` to be used in `learn()`. self._buffer, self._indices = buffer, indices # batch get 3 new keys: values, rets, advs batch = self.compute_gae_returns(batch, buffer, indices, self._gae_lambda) batch.act = to_torch_as(batch.act, batch.values[..., 0]) old_log_prob, old_mean, old_std = [], [], [] with torch.no_grad(): for minibatch in batch.split( self._max_batchsize, shuffle=False, merge_last=True ): res = self.forward(minibatch) old_log_prob.append(res.dist.log_prob(minibatch.act)) old_mean.append(res.logits[0, ...]) old_std.append(res.logits[1, ...]) batch.logp_old = torch.cat(old_log_prob, dim=0) batch.mean_old = torch.cat(old_mean, dim=0) batch.std_old = torch.cat(old_std, dim=0) return batch
[docs] def nu_loss(self, batch: Batch): loss_nu = self.cost_limit - self._ave_cost_return self._nu += -self._nu_lr * loss_nu self._nu = torch.clamp(self._nu, 0, self._nu_max) stats_nu = {"loss/nu_loss": loss_nu, "loss/nu_value": self._nu.detach().item()} return loss_nu, stats_nu
def critics_loss(self, minibatch: Batch): critic_losses = 0 stats = {} for i, critic in enumerate(self.critics): value = critic(minibatch.obs).flatten() ret = minibatch.rets[..., i] vf_loss = (ret - value).pow(2).mean() for param in critic.parameters(): vf_loss += param.pow(2).sum() * self._l2_reg critic_losses += vf_loss stats["loss/vf" + str(i)] = vf_loss.item() self.critics_optim.zero_grad() critic_losses.backward() self.critics_optim.step() stats["loss/vf_total"] = critic_losses.item() return critic_losses, stats def policy_loss(self, minibatch: Batch): # obtain the action distribution dist = self.forward(minibatch).dist ent = dist.entropy().mean() log_p = dist.log_prob(minibatch.act) ratio = (log_p - minibatch.logp_old).exp() dist_old = self.dist_fn(*(minibatch.mean_old, minibatch.std_old)) kl_new_old = kl_divergence(dist, dist_old) if self._norm_adv: for i in range(self.critics_num): adv = minibatch.advs[..., i] mean, std = adv.mean(), adv.std() minibatch.advs[..., i] = (adv - mean) / std # per-batch norm rew_adv = minibatch.advs[..., 0] cost_adv = minibatch.advs[..., 1] loss_actor = ( ( kl_new_old - 1 / self._tem_lambda * ratio * (rew_adv - self._nu * cost_adv) ) * (kl_new_old.detach() <= self._eta) ).mean() self.actor_optim.zero_grad() loss_actor.backward() if self._grad_norm: # clip large gradient nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=self._grad_norm) self.actor_optim.step() stats_actor = { "loss/actor_loss": loss_actor.item(), "loss/kl": kl_new_old.mean().item(), "loss/entropy": ent.item() } return loss_actor, stats_actor def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any) -> Dict[str, List[float]]: # update nu loss_nu, stats_nu = self.nu_loss(batch) for step in range(repeat): if self._recompute_adv and step > 0: batch = self.compute_gae_returns( batch, self._buffer, self._indices, self._gae_lambda ) iter_counts, approx_kl = 0, 0.0 for minibatch in batch.split(batch_size, merge_last=True): # update critic loss_vf, stats_critic = self.critics_loss(minibatch) # update actor loss_actor, stats_actor = self.policy_loss(minibatch) approx_kl += stats_actor["loss/kl"] iter_counts += 1 self.gradient_steps += 1 self.logger.store(**stats_nu) self.logger.store(**stats_actor) self.logger.store(**stats_critic) # trick in # https://github.com/liuzuxin/robust-safe-rl/blob/main/rsrl/policy/robust_ppo.py # noqa: E501 approx_kl /= iter_counts + 1e-7 if approx_kl > self._delta: # early stop self.logger.print("Early stop at step %d due to reaching max kl." % step) break self.logger.store(gradient_steps=self.gradient_steps, tab="update")