Source code for fsrl.policy.ppo_lag

from argparse import Action
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import gymnasium as gym
import numpy as np
import torch
from tianshou.data import Batch, ReplayBuffer, to_torch_as
from torch import nn

from fsrl.policy.lagrangian_base import LagrangianPolicy
from fsrl.utils import BaseLogger
from fsrl.utils.net.common import ActorCritic


[docs]class PPOLagrangian(LagrangianPolicy): """Implementation of the Proximal Policy Optimization (PPO) with PID Lagrangian. More details, please refer to https://arxiv.org/abs/1707.06347 (PPO) and https://arxiv.org/abs/2007.03964 (PID Lagrangian). :param torch.nn.Module actor: the actor network following the rules in :class:`~fsrl.policy.BasePolicy`. (s -> logits) :param Union[nn.Module, List[nn.Module]] critics: the critic network(s). (s -> V(s)) :param torch.optim.Optimizer optim: the optimizer for the actor and critic networks. :param Type[torch.distributions.Distribution] dist_fn: the distribution class for action sampling. :param BaseLogger logger: the logger instance for logging training information. Default is BaseLogger(). :param float target_kl: the target KL divergence for the PPO update. Default is 0.02. :param float vf_coef: the value function coefficient for the loss function. Default is 0.25. :param Optional[float] max_grad_norm: the maximum gradient norm for gradient clipping (None for no clipping). Default is None. :param float gae_lambda: the Generalized Advantage Estimation (GAE) parameter. Default is 0.95. :param float eps_clip: the PPO clipping parameter for the policy update. Default is 0.2. :param Optional[float] dual_clip: the PPO dual clipping parameter (None for no dual clipping). Default is None. :param bool value_clip: whether to clip the value function update. Default is False. :param bool advantage_normalization: whether to normalize the advantages. Default is True. :param bool recompute_advantage: whether to recompute the advantages during the optimization process. Default is False. :param bool use_lagrangian: whether to use the Lagrangian constraint optimization. Default is True. :param List lagrangian_pid: the PID coefficients for the Lagrangian constraint optimization. Default is [0.05, 0.0005, 0.1]. :param Union[List, float] cost_limit: the constraint limit(s) for the Lagrangian optimization. Default is np.inf. :param bool rescaling: whether use the rescaling trick for Lagrangian multiplier, see Alg. 1 in http://proceedings.mlr.press/v119/stooke20a/stooke20a.pdf :param float gamma: the discount factor for future rewards. Default is 0.99. :param int max_batchsize: the maximum size of the batch when computing GAE, depends on the size of available memory and the memory cost of the model; should be as large as possible within the memory constraint. Default to 99999. :param bool reward_normalization: whether to normalize the rewards. Default is False. :param bool deterministic_eval: whether to use deterministic actions during evaluation. Default is True. :param bool action_scaling: whether to scale actions based on the action space. Default is True. :param str action_bound_method: the method used to handle out-of-bound actions. Default is "clip". :param Optional[gym.Space] observation_space: the observation space of the environment. Default is None. :param Optional[gym.Space] action_space: the action space of the environment. Default is None. :param Optional[torch.optim.lr_scheduler.LambdaLR] lr_scheduler: the learning rate scheduler. Default is None. .. seealso:: Please refer to :class:`~fsrl.policy.BasePolicy` and :class:`~fsrl.policy.LagrangianPolicy` for more detailed \ hyperparameter explanations and usage. """ def __init__( self, actor: nn.Module, critics: Union[nn.Module, List[nn.Module]], optim: torch.optim.Optimizer, dist_fn: Type[torch.distributions.Distribution], logger: BaseLogger = BaseLogger(), # PPO specific argumentsß target_kl: float = 0.02, vf_coef: float = 0.25, max_grad_norm: Optional[float] = None, gae_lambda: float = 0.95, eps_clip: float = 0.2, dual_clip: Optional[float] = None, value_clip: bool = False, advantage_normalization: bool = True, recompute_advantage: bool = False, # Lagrangian specific arguments use_lagrangian: bool = True, lagrangian_pid: Tuple = (0.05, 0.0005, 0.1), cost_limit: Union[List, float] = np.inf, rescaling: bool = True, # Base policy common arguments gamma: float = 0.99, max_batchsize: int = 99999, reward_normalization: bool = False, deterministic_eval: bool = True, action_scaling: bool = True, action_bound_method: str = "clip", observation_space: Optional[gym.Space] = None, action_space: Optional[gym.Space] = None, lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None ) -> None: super().__init__( actor, critics, dist_fn, logger, use_lagrangian, lagrangian_pid, cost_limit, rescaling, gamma, max_batchsize, reward_normalization, deterministic_eval, action_scaling, action_bound_method, observation_space, action_space, lr_scheduler ) self.optim = optim self._lambda = gae_lambda self._weight_vf = vf_coef self._grad_norm = max_grad_norm self._target_kl = target_kl self._eps_clip = eps_clip assert dual_clip is None or dual_clip > 1.0, \ "Dual-clip PPO parameter should greater than 1.0." self._dual_clip = dual_clip self._value_clip = value_clip if not self._rew_norm: assert not self._value_clip, \ "value clip is available only when `reward_normalization` is True" self._norm_adv = advantage_normalization self._recompute_adv = recompute_advantage self._actor_critic: ActorCritic def process_fn( self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray ) -> Batch: if self._recompute_adv: # buffer input `buffer` and `indices` to be used in `learn()`. self._buffer, self._indices = buffer, indices # batch get 3 new keys: values, rets, advs batch = self.compute_gae_returns(batch, buffer, indices, self._lambda) batch.act = to_torch_as(batch.act, batch.values[..., 0]) old_log_prob = [] with torch.no_grad(): for minibatch in batch.split( self._max_batchsize, shuffle=False, merge_last=True ): old_log_prob.append(self(minibatch).dist.log_prob(minibatch.act)) batch.logp_old = torch.cat(old_log_prob, dim=0) return batch def critics_loss(self, minibatch): critic_losses = 0 stats = {} for i, critic in enumerate(self.critics): value = critic(minibatch.obs).flatten() ret = minibatch.rets[..., i] if self._value_clip: value_target = minibatch.values[..., i] v_clip = value_target + \ (value - value_target).clamp(-self._eps_clip, self._eps_clip) vf1 = (ret - value).pow(2) vf2 = (ret - v_clip).pow(2) vf_loss = torch.max(vf1, vf2).mean() else: vf_loss = (ret - value).pow(2).mean() critic_losses += vf_loss stats["loss/vf" + str(i)] = vf_loss.item() stats["loss/vf_total"] = critic_losses.item() return critic_losses, stats def policy_loss(self, batch: Batch, dist: Type[torch.distributions.Distribution]): log_p = dist.log_prob(batch.act) ratio = (log_p - batch.logp_old).exp().float() ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) if self._norm_adv: for i in range(self.critics_num): adv = batch.advs[..., i] mean, std = adv.mean(), adv.std() batch.advs[..., i] = (adv - mean) / std # per-batch norm # compute normal ppo loss rew_adv = batch.advs[..., 0] surr1 = ratio * rew_adv surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * rew_adv if self._dual_clip: clip1 = torch.min(surr1, surr2) clip2 = torch.max(clip1, self._dual_clip * rew_adv) loss_actor_rew = -torch.where(rew_adv < 0, clip2, clip1).mean() else: loss_actor_rew = -torch.min(surr1, surr2).mean() # compute safety loss values = [ratio * batch.advs[..., i] for i in range(1, self.critics_num)] if self.use_lagrangian else [] loss_actor_safety, stats_actor = self.safety_loss(values) rescaling = stats_actor["loss/rescaling"] loss_actor_total = rescaling * (loss_actor_rew + loss_actor_safety) # compute approx KL for early stop approx_kl = (batch.logp_old - log_p).mean().item() stats_actor.update( { "loss/actor_rew": loss_actor_rew.item(), "loss/actor_total": loss_actor_total.item(), "loss/kl": approx_kl } ) return loss_actor_total, stats_actor def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any) -> Dict[str, List[float]]: for step in range(repeat): if self._recompute_adv and step > 0: batch = self.compute_gae_returns( batch, self._buffer, self._indices, self._lambda ) iter_counts, approx_kl = 0, 0.0 for minibatch in batch.split(batch_size, merge_last=True): # obtain the action distribution dist = self.forward(minibatch).dist # calculate policy loss loss_actor, stats_actor = self.policy_loss(minibatch, dist) approx_kl += stats_actor["loss/kl"] iter_counts += 1 # calculate loss for critic loss_vf, stats_critic = self.critics_loss(minibatch) loss = loss_actor + self._weight_vf * loss_vf self.optim.zero_grad() loss.backward() if self._grad_norm: # clip large gradient nn.utils.clip_grad_norm_( self._actor_critic.parameters(), max_norm=self._grad_norm ) self.optim.step() self.gradient_steps += 1 ent = dist.entropy().mean() self.logger.store(**stats_actor) self.logger.store(**stats_critic) self.logger.store(total=loss.item(), entropy=ent.item(), tab="loss") # trick in # https://github.com/liuzuxin/robust-safe-rl/blob/main/rsrl/policy/robust_ppo.py # noqa approx_kl /= iter_counts + 1e-7 if approx_kl > 1.5 * self._target_kl: # early stop self.logger.print("Early stop at step %d due to reaching max kl." % step) break self.logger.store(gradient_steps=self.gradient_steps, tab="update")