from argparse import Action
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gymnasium as gym
import numpy as np
import torch
import torch.nn.functional as F
from numba import njit
from tianshou.data import Batch, ReplayBuffer, to_torch_as
from torch import nn
from torch.distributions import kl_divergence
from fsrl.policy import BasePolicy
from fsrl.utils import BaseLogger
from fsrl.utils.net.common import ActorCritic
[docs]class CPO(BasePolicy):
"""Implementation of the Constrained Policy Optimization (CPO).
More details, please refer to https://arxiv.org/abs/1705.10528.
:param torch.nn.Module actor: the actor network following the rules in
:class:`~fsrl.policy.BasePolicy`. (s -> logits)
:param Union[nn.Module, List[nn.Module]] critics: the critic network(s). (s -> V(s))
:param torch.optim.Optimizer optim: the optimizer for actor and critic network.
:param Type[torch.distributions.Distribution] dist_fn: the distribution function for
the policy.
:param BaseLogger logger: the logger instance for logging training information.
:param float target_kl: the target KL divergence for the line search. (default: 0.01)
:param float backtrack_coeff: the coefficient for backtracking during the line
search. (default: 0.8)
:param float damping_coeff: the damping coefficient for the Fisher matrix. (default:
0.1)
:param int max_backtracks: the maximum number of backtracks allowed during the line
search. (default: 10)
:param int optim_critic_iters: the number of optimization iterations for the critic
network. (default: 20)
:param float l2_reg: the L2 regularization coefficient for the critic network.
(default: 0.001)
:param float gae_lambda: the GAE lambda value. (default: 0.95)
:param bool advantage_normalization: normalize advantage if True. (default: True)
:param Union[List, float] cost_limit: the constraint limit(s) for the Lagrangian
optimization. (default: np.inf)
:param float gamma: the discount factor for future rewards. (default: 0.99)
:param int max_batchsize: the maximum batch size for updating the policy. (default:
99999)
:param bool reward_normalization: normalize rewards if True. (default: False)
:param bool deterministic_eval: whether to use deterministic action selection during
evaluation. (default: True)
:param bool action_scaling: whether to scale the actions according to the action
space bounds. (default: True)
:param str action_bound_method: the method for handling actions that exceed the
action space bounds ("clip" or other custom methods). (default: "clip")
:param Optional[gym.Space] observation_space: the observation space of the
environment. (default: None)
:param Optional[gym.Space] action_space: the action space of the environment.
(default: None)
:param Optional[torch.optim.lr_scheduler.LambdaLR] lr_scheduler: learning rate
scheduler for the optimizer. (default: None)
.. seealso::
Please refer to :class:`~fsrl.policy.BasePolicy` for more detailed hyperparameter
explanations and usage.
"""
def __init__(
self,
actor: nn.Module,
critics: Union[nn.Module, List[nn.Module]],
optim: torch.optim.Optimizer,
dist_fn: Type[torch.distributions.Distribution],
logger: BaseLogger = BaseLogger(),
# CPO specific arguments
target_kl: float = 0.01,
backtrack_coeff: float = 0.8,
damping_coeff: float = 0.1,
max_backtracks: int = 10,
optim_critic_iters: int = 20,
l2_reg: float = 0.001,
gae_lambda: float = 0.95,
advantage_normalization: bool = True,
cost_limit: Union[List, float] = np.inf,
# Base policy common arguments
gamma: float = 0.99,
max_batchsize: int = 99999,
reward_normalization: bool = False,
deterministic_eval: bool = True,
action_scaling: bool = True,
action_bound_method: str = "clip",
observation_space: Optional[gym.Space] = None,
action_space: Optional[gym.Space] = None,
lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None
) -> None:
super().__init__(
actor, critics, dist_fn, logger, gamma, max_batchsize, reward_normalization,
deterministic_eval, action_scaling, action_bound_method, observation_space,
action_space, lr_scheduler
)
self.optim = optim
self._cost_limit = cost_limit
self._lambda = gae_lambda
self._norm_adv = advantage_normalization
self._max_backtracks = max_backtracks
self._optim_critic_iters = optim_critic_iters
self._l2_reg = l2_reg
self._delta = target_kl
self._backtrack_coeff = backtrack_coeff
self._damping_coeff = damping_coeff
[docs] def pre_update_fn(self, stats_train: Dict, **kwarg) -> Any:
self._ave_cost_return = stats_train["cost"]
[docs] def update_cost_limit(self, cost_limit: float) -> None:
"""Update the cost limit threshold.
:param float cost_limit: new cost threshold
"""
self.cost_limit = [cost_limit] * (self.critics_num -
1) if np.isscalar(cost_limit) else cost_limit
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
) -> Batch:
batch = self.compute_gae_returns(batch, buffer, indices, self._lambda)
if self._norm_adv:
for i in range(self.critics_num):
adv = batch.advs[..., i]
mean, std = adv.mean(), adv.std()
batch.advs[..., i] = (adv - mean) / std
batch.act = to_torch_as(batch.act, batch.values[..., 0])
old_log_prob, old_mean, old_std = [], [], []
with torch.no_grad():
for minibatch in batch.split(
self._max_batchsize, shuffle=False, merge_last=True
):
res = self.forward(minibatch)
old_log_prob.append(res.dist.log_prob(minibatch.act))
old_mean.append(res.logits[0, ...])
old_std.append(res.logits[1, ...])
batch.logp_old = torch.cat(old_log_prob, dim=0)
batch.mean_old = torch.cat(old_mean, dim=0)
batch.std_old = torch.cat(old_std, dim=0)
return batch
def critics_loss(self, minibatch: Batch) -> Tuple[torch.Tensor, dict]:
critic_losses = torch.zeros(1)
stats = {}
for i, critic in enumerate(self.critics):
value = critic(minibatch.obs).flatten()
ret = minibatch.rets[..., i]
vf_loss = (ret - value).pow(2).mean()
for param in critic.parameters():
vf_loss += param.pow(2).sum() * self._l2_reg
critic_losses += vf_loss
stats["loss/vf" + str(i)] = vf_loss.item()
self.optim.zero_grad()
critic_losses.backward()
self.optim.step()
stats["loss/vf_total"] = critic_losses.item()
return critic_losses, stats
def _get_objective(
self, logp: torch.Tensor, logp_old: torch.Tensor, adv: torch.Tensor
) -> torch.Tensor:
return torch.mean(torch.exp(logp - logp_old) * adv)
def _get_cost_surrogate(
self, logp: torch.Tensor, logp_old: torch.Tensor, cadv: torch.Tensor
) -> torch.Tensor:
cost_surrogate = self._ave_cost_return + torch.mean(
torch.exp(logp - logp_old) * cadv
) - torch.mean(cadv)
return cost_surrogate
def _MVP(self, v: torch.Tensor, flat_kl_grad: torch.Tensor) -> torch.Tensor:
"""Matrix vector product."""
# caculate second order gradient of kl with respect to theta
kl_v = torch.dot(flat_kl_grad, v)
flat_kl_grad_grad = self._get_flat_grad(kl_v, self.actor, retain_graph=True)
return flat_kl_grad_grad + v * self._damping_coeff
def _conjugate_gradients(
self,
g: torch.Tensor,
flat_kl_grad: torch.Tensor,
nsteps: int = 10,
residual_tol: float = 1e-8
) -> torch.Tensor:
x = torch.zeros_like(g)
r, p = g.clone(), g.clone()
rs_old = torch.sum(r * r)
for _ in range(nsteps):
z = self._MVP(p, flat_kl_grad)
alpha = rs_old / torch.sum(p * z)
x += alpha * p
r -= alpha * z
rs_new = torch.sum(r * r)
if rs_new < residual_tol:
break
p = r + (rs_new / rs_old) * p
rs_old = rs_new
return x
def _get_flat_grad(
self,
y: torch.Tensor,
model: nn.Module,
retain_graph: bool = False,
create_graph: bool = False
) -> torch.Tensor:
retain_graph = True if create_graph else retain_graph
grads = torch.autograd.grad(
y,
model.parameters(), # type: ignore
retain_graph=retain_graph,
create_graph=create_graph
)
return torch.cat([grad.view(-1) for grad in grads])
def _get_flat_params(self, model: nn.Module) -> torch.Tensor:
flat_params = torch.cat([p.view(-1) for p in model.parameters()])
return flat_params
def _set_from_flat_params(self, model: nn.Module, new_params: torch.Tensor) -> None:
n = 0
for param in model.parameters():
numel = param.numel()
new_param = new_params[n:n + numel].view(param.size())
param.data = new_param
n += numel
def policy_loss(self, minibatch: Batch) -> Tuple[torch.Tensor, dict]:
self.actor.train()
# get objective & KL & cost surrogate
dist = self.forward(minibatch).dist
ent = dist.entropy().mean()
logp = dist.log_prob(minibatch.act)
dist_old = self.dist_fn(*(minibatch.mean_old, minibatch.std_old)) # type: ignore
kl = kl_divergence(dist_old, dist).mean()
objective = self._get_objective(logp, minibatch.logp_old, minibatch.advs[..., 0])
cost_surrogate = self._get_cost_surrogate(
logp, minibatch.logp_old, minibatch.advs[..., 1]
)
loss_actor_total = objective + cost_surrogate
# get gradient
grad_g = self._get_flat_grad(objective, self.actor, retain_graph=True)
grad_b = self._get_flat_grad(-cost_surrogate, self.actor, retain_graph=True)
flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True)
H_inv_g = self._conjugate_gradients(grad_g, flat_kl_grad)
approx_g = self._MVP(H_inv_g, flat_kl_grad)
c_value = cost_surrogate - self._cost_limit
# solve Lagrangian problem
EPS = 1e-8
if torch.dot(grad_b, grad_b) <= EPS and c_value < 0:
H_inv_b, scalar_r, scalar_s, A_value, B_value = [
torch.zeros(1) for _ in range(5)
]
scalar_q = torch.dot(approx_g, H_inv_g)
optim_case = 4
else:
H_inv_b = self._conjugate_gradients(grad_b, flat_kl_grad)
approx_b = self._MVP(H_inv_b, flat_kl_grad)
scalar_q = torch.dot(approx_g, H_inv_g)
scalar_r = torch.dot(approx_g, H_inv_b)
scalar_s = torch.dot(approx_b, H_inv_b)
# should be always positive (Cauchy-Shwarz)
A_value = scalar_q - scalar_r**2 / scalar_s
# does safety boundary intersect trust region? (positive = yes)
B_value = 2 * self._delta - c_value**2 / scalar_s
if c_value < 0 and B_value < 0:
optim_case = 3
elif c_value < 0 and B_value >= 0:
optim_case = 2
elif c_value >= 0 and B_value >= 0:
optim_case = 1
else:
optim_case = 0
if optim_case in [3, 4]:
lam = torch.sqrt(scalar_q / (2 * self._delta))
nu = torch.zeros_like(lam)
elif optim_case in [1, 2]:
LA, LB = [0, scalar_r / c_value], [scalar_r / c_value, np.inf]
LA, LB = (LA, LB) if c_value < 0 else (LB, LA)
proj = lambda x, L: max(L[0], min(L[1], x))
lam_a = proj(torch.sqrt(A_value / B_value), LA)
lam_b = proj(torch.sqrt(scalar_q / (2 * self._delta)), LB)
f_a = lambda lam: -0.5 * (A_value / (lam + EPS) + B_value * lam
) - scalar_r * c_value / (scalar_s + EPS)
f_b = lambda lam: -0.5 * (scalar_q / (lam + EPS) + 2 * self._delta * lam)
lam = lam_a if f_a(lam_a) >= f_b(lam_b) else lam_b
lam = torch.tensor(lam)
nu = max(0, (lam * c_value - scalar_r).item()) / (scalar_s + EPS)
else:
nu = torch.sqrt(2 * self._delta / (scalar_s + EPS))
lam = torch.zeros_like(nu)
# line search
with torch.no_grad():
delta_theta = (1. / (lam + EPS)) * (
H_inv_g + nu * H_inv_b
) if optim_case > 0 else nu * H_inv_b
delta_theta /= torch.norm(delta_theta)
beta = 1.0
# sometimes the scalar_q can be negative causing lam to be nan
if not torch.isnan(lam):
init_theta = self._get_flat_params(self.actor).clone().detach()
init_objective = objective.clone().detach()
init_cost_surrogate = cost_surrogate.clone().detach()
for _ in range(self._max_backtracks):
theta = beta * delta_theta + init_theta
self._set_from_flat_params(self.actor, theta)
dist = self.forward(minibatch).dist
logp = dist.log_prob(minibatch.act)
new_kl = kl_divergence(dist_old, dist).mean().item()
new_objective = self._get_objective(
logp, minibatch.logp_old, minibatch.advs[..., 0]
)
new_cost_surrogate = self._get_cost_surrogate(
logp, minibatch.logp_old, minibatch.advs[..., 1]
)
if new_kl <= self._delta and \
(new_objective > init_objective if optim_case > 1 else True) and \
new_cost_surrogate - init_cost_surrogate <= max(-c_value.item(), 0): # noqa
break
beta *= self._backtrack_coeff
stats_actor = {
"loss/kl": kl.item(),
"loss/entropy": ent.item(),
"loss/rew_loss": objective.item(),
"loss/cost_loss": cost_surrogate.item(),
"loss/optim_A": A_value.item(),
"loss/optim_B": B_value.item(),
"loss/optim_C": c_value.item(),
"loss/optim_Q": scalar_q.item(),
"loss/optim_R": scalar_r.item(),
"loss/optim_S": scalar_s.item(),
"loss/optim_lam": lam.item(),
"loss/optim_nu": nu.item(),
"loss/optim_case": optim_case,
"loss/step_size": beta
}
return loss_actor_total, stats_actor
def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int,
**kwargs: Any) -> Dict[str, List[float]]:
for _ in range(repeat):
for minibatch in batch.split(batch_size, merge_last=True):
for _ in range(self._optim_critic_iters):
loss_vf, stats_critic = self.critics_loss(minibatch)
# calculate policy loss
loss_actor, stats_actor = self.policy_loss(minibatch)
self.gradient_steps += 1
self.logger.store(**stats_actor)
self.logger.store(**stats_critic)
self.logger.store(gradient_steps=self.gradient_steps, tab="update")