from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gymnasium as gym
import numpy as np
import torch
from tianshou.data import Batch, ReplayBuffer, to_torch_as
from tianshou.exploration import BaseNoise
from torch import nn
from torch.distributions import Independent, Normal
from fsrl.policy.lagrangian_base import LagrangianPolicy
from fsrl.utils import BaseLogger
[docs]class SACLagrangian(LagrangianPolicy):
"""Implementation of the Soft Actor-Critic (SAC) with PID Lagrangian.
More details, please refer to https://arxiv.org/abs/1801.01290 (SAC) and
https://arxiv.org/abs/2007.03964 (PID Lagrangian).
:param torch.nn.Module actor: the actor network following the rules in
:class:`~fsrl.policy.BasePolicy`. (s -> logits)
:param Union[nn.Module, List[nn.Module]] critics: the critic network(s). (s -> V(s))
:param Optional[torch.optim.Optimizer] actor_optim: the optimizer for the actor
network.
:param Optional[torch.optim.Optimizer] critic_optim: the optimizer for the critic
network(s).
:param BaseLogger logger: the logger instance for logging training information.
(default: DummyLogger)
:param Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] alpha: initial
temperature for entropy regularization. If a tuple (target_entropy, log_alpha,
alpha_optim) is provided, then alpha is automatically tuned.(default: 0.005)
:param float tau: target smoothing coefficient for soft update of target networks.
(default: 0.05)
:param Optional[BaseNoise] exploration_noise: the exploration noise. (default: None)
:param int n_step: number of steps for multi-step learning. (default: 2)
:param bool use_lagrangian: whether to use the Lagrangian constraint optimization.
(default: True)
:param List lagrangian_pid: the PID coefficients for the Lagrangian constraint
optimization. (default: [0.05, 0.0005, 0.1])
:param Union[List, float] cost_limit: the constraint limit(s) for the Lagrangian
optimization. (default: np.inf)
:param bool rescaling: whether use the rescaling trick for Lagrangian multiplier, see
Alg. 1 in http://proceedings.mlr.press/v119/stooke20a/stooke20a.pdf
:param float gamma: the discount factor for future rewards. (default: 0.99)
:param bool reward_normalization: normalize rewards if True. (default: False)
:param bool deterministic_eval: whether to use deterministic action selection during
evaluation. (default: True)
:param bool action_scaling: whether to scale the actions according to the action
space bounds. (default: True)
:param str action_bound_method: the method for handling actions that exceed the
action space bounds ("clip" or other custom methods). (default: "clip")
:param Optional[gym.Space] observation_space: the observation space of the
environment. (default: None)
:param Optional[gym.Space] action_space: the action space of the environment.
(default: None)
:param Optional[torch.optim.lr_scheduler.LambdaLR] lr_scheduler: learning rate
scheduler for the optimizer. (default: None)
.. seealso::
Please refer to :class:`~fsrl.policy.BasePolicy` and
:class:`~fsrl.policy.LagrangianPolicy` for more detailed \
hyperparameter explanations and usage.
"""
def __init__(
self,
actor: nn.Module,
critics: Union[nn.Module, List[nn.Module]],
actor_optim: Optional[torch.optim.Optimizer],
critic_optim: Optional[torch.optim.Optimizer],
logger: BaseLogger = BaseLogger(),
# SAC specific arguments
alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.005,
tau: float = 0.05,
exploration_noise: Optional[BaseNoise] = None,
n_step: int = 2,
# Lagrangian specific arguments
use_lagrangian: bool = True,
lagrangian_pid: Tuple = (0.05, 0.0005, 0.1),
cost_limit: Union[List, float] = np.inf,
rescaling: bool = True,
# Base policy common arguments
gamma: float = 0.99,
reward_normalization: bool = False,
deterministic_eval: bool = True,
action_scaling: bool = True,
action_bound_method: str = "clip",
observation_space: Optional[gym.Space] = None,
action_space: Optional[gym.Space] = None,
lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None
) -> None:
super().__init__(
actor, critics, None, logger, use_lagrangian, lagrangian_pid, cost_limit,
rescaling, gamma, 10000, reward_normalization, deterministic_eval,
action_scaling, action_bound_method, observation_space, action_space,
lr_scheduler
)
self.actor_optim = actor_optim
self.critics_old = deepcopy(self.critics)
self.critics_old.eval()
self.critics_optim = critic_optim
assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]"
self.tau = tau
self._is_auto_alpha = False
self._alpha: Union[float, torch.Tensor]
if isinstance(alpha, tuple):
self._is_auto_alpha = True
self._target_entropy, self._log_alpha, self._alpha_optim = alpha
assert alpha[1].shape == torch.Size([1]) and alpha[1].requires_grad
self._alpha = self._log_alpha.detach().exp()
else:
self._alpha = alpha
self._noise = exploration_noise
self._n_step = n_step
self.__eps = np.finfo(np.float32).eps.item()
[docs] def set_exp_noise(self, noise: Optional[BaseNoise]) -> None:
"""Set the exploration noise."""
self._noise = noise
def train(self, mode: bool = True):
"""Set the module in training mode, except for the target network."""
self.training = mode
self.actor.train(mode)
self.critics.train(mode)
return self
[docs] def sync_weight(self) -> None:
"""Soft-update the weight for the target network."""
self.soft_update(self.critics_old, self.critics, self.tau)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> List[torch.Tensor]:
batch = buffer[indices] # batch.obs_next: s_{t+n}
obs_next_result = self(batch, input='obs_next')
act = obs_next_result.act
log_prob = obs_next_result.log_prob
target_q_list = []
for i in range(self.critics_num):
target_q, _ = self.critics_old[i].predict(batch.obs_next, act)
target_q_list.append(target_q - self._alpha * log_prob)
return target_q_list
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
) -> Batch:
batch = self.compute_nstep_returns(
batch, buffer, indices, self._target_q, self._n_step
)
return batch
def forward( # type: ignore
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
input: str = "obs",
**kwargs: Any,
) -> Batch:
obs = batch[input]
logits, hidden = self.actor(obs, state=state, info=batch.info)
assert isinstance(logits, tuple)
dist = Independent(Normal(*logits), 1)
if self._deterministic_eval and not self.training:
act = logits[0]
else:
act = dist.rsample()
log_prob = dist.log_prob(act).unsqueeze(-1)
# apply correction for Tanh squashing when computing logprob from Gaussian You
# can check out the original SAC paper (arXiv 1801.01290): Eq 21. in appendix C
# to get some understanding of this equation.
squashed_action = torch.tanh(act)
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) +
self.__eps).sum(-1, keepdim=True)
return Batch(
logits=logits,
act=squashed_action,
state=hidden,
dist=dist,
log_prob=log_prob
)
def critics_loss(
self, batch: Batch, critics: torch.nn.Module, optimizer: torch.optim.Optimizer
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A simple wrapper script for updating critic network."""
weight = getattr(batch, "weight", 1.0)
loss_critic = 0
td_average = 0
stats_critic = {}
for i in range(self.critics_num):
target_q = batch.rets[..., i].flatten()
# double q network
current_q_list = critics[i](batch.obs, batch.act)
loss_i = 0
for j in range(2):
td = current_q_list[j].flatten() - target_q
td_average += td
loss_i += (td.pow(2) * weight).mean()
loss_critic += loss_i
stats_critic["loss/q" + str(i)] = loss_i.item()
optimizer.zero_grad()
loss_critic.backward()
optimizer.step()
td_average /= self.critics_num * 2
stats_critic["loss/q_total"] = loss_critic.item()
return td_average, stats_critic
def policy_loss(self, batch: Batch, **kwarg):
obs_result = self(batch)
act = obs_result.act
# normal loss
current_q_list = self.critics[0](batch.obs, act)
current_q = torch.min(current_q_list[0], current_q_list[1]).flatten()
loss_actor_rew = (self._alpha * obs_result.log_prob.flatten() - current_q).mean()
# compute safety loss
values = []
if self.use_lagrangian:
for i in range(1, self.critics_num):
safety_q_list = self.critics[i](batch.obs, act)
safety_q = torch.min(safety_q_list[0], safety_q_list[1]).flatten()
values.append(safety_q)
loss_actor_safety, stats_actor = self.safety_loss(values)
rescaling = stats_actor["loss/rescaling"]
loss_actor_total = rescaling * (loss_actor_rew + loss_actor_safety)
self.actor_optim.zero_grad()
loss_actor_total.backward()
self.actor_optim.step()
if self._is_auto_alpha:
log_prob = obs_result.log_prob.detach() + self._target_entropy
# please take a look at issue #258 if you'd like to change this line
alpha_loss = -(self._log_alpha * log_prob).mean()
self._alpha_optim.zero_grad()
alpha_loss.backward()
self._alpha_optim.step()
self._alpha = self._log_alpha.detach().exp()
stats_actor.update(
{
"loss/alpha_loss": alpha_loss.item(),
"loss/alpha_value": self._alpha.item()
}
)
stats_actor.update(
{
"loss/actor_rew": loss_actor_rew.item(),
"loss/actor_total": loss_actor_total.item()
}
)
return loss_actor_total, stats_actor
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
# critic
td, stats_critic = self.critics_loss(batch, self.critics, self.critics_optim)
batch.weight = td # prio-buffer
# actor
loss_actor, stats_actor = self.policy_loss(batch)
self.sync_weight()
self.logger.store(**stats_actor)
self.logger.store(**stats_critic)
[docs] def exploration_noise(self, act: Union[np.ndarray, Batch],
batch: Batch) -> Union[np.ndarray, Batch]:
if self._noise is None:
return act
if isinstance(act, np.ndarray):
return act + self._noise(act.shape)
return act