Source code for fsrl.utils.net.continuous

# flake8: noqa

from copy import deepcopy
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union

import numpy as np
import torch
from tianshou.utils.net.common import MLP
from tianshou.utils.net.continuous import Critic
from torch import nn


[docs]class DoubleCritic(nn.Module): """Double critic network. Will create an actor operated in continuous \ action space with structure of preprocess_net ---> 1(q value). :param preprocess_net1: a self-defined preprocess_net which output a flattened hidden state. :param preprocess_net2: a self-defined preprocess_net which output a flattened hidden state. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. Default to empty sequence (where the MLP now contains only a single linear layer). :param int preprocess_net_output_dim: the output dimension of preprocess_net. :param linear_layer: use this module as linear layer. Default to nn.Linear. :param bool flatten_input: whether to flatten input data for the last layer. Default to True. For advanced usage (how to customize the network), please refer to tianshou's \ `build_the_network tutorial <https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network>`_. .. seealso:: Please refer to tianshou's `Net <https://tianshou.readthedocs.io/en/master/api/tianshou.utils.html#tianshou.utils.net.common.Net>`_ class as an instance of how preprocess_net is suggested to be defined. """ def __init__( self, preprocess_net1: nn.Module, preprocess_net2: nn.Module, hidden_sizes: Sequence[int] = (), device: Union[str, int, torch.device] = "cpu", preprocess_net_output_dim: Optional[int] = None, linear_layer: Type[nn.Linear] = nn.Linear, flatten_input: bool = True, ) -> None: super().__init__() self.device = device self.preprocess1 = preprocess_net1 self.preprocess2 = preprocess_net2 self.output_dim = 1 input_dim = getattr(preprocess_net1, "output_dim", preprocess_net_output_dim) self.last1 = MLP( input_dim, # type: ignore 1, hidden_sizes, device=self.device, linear_layer=linear_layer, flatten_input=flatten_input, ) self.last2 = deepcopy(self.last1)
[docs] def forward( self, obs: Union[np.ndarray, torch.Tensor], act: Optional[Union[np.ndarray, torch.Tensor]] = None, info: Dict[str, Any] = {}, ) -> list: """Mapping: (s, a) -> logits -> Q(s, a).""" obs = torch.as_tensor( obs, device=self.device, # type: ignore dtype=torch.float32, ).flatten(1) if act is not None: act = torch.as_tensor( act, device=self.device, # type: ignore dtype=torch.float32, ).flatten(1) obs = torch.cat([obs, act], dim=1) logits1, hidden = self.preprocess1(obs) logits1 = self.last1(logits1) logits2, hidden = self.preprocess2(obs) logits2 = self.last2(logits2) return [logits1, logits2]
[docs] def predict( self, obs: Union[np.ndarray, torch.Tensor], act: Optional[Union[np.ndarray, torch.Tensor]] = None, info: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, list]: """Mapping: (s, a) -> logits -> Q(s, a). :return: q value, and a list of two q values (used for Bellman backup)""" q_list = self(obs, act, info) q = torch.min(q_list[0], q_list[1]) return q, q_list
[docs]class SingleCritic(Critic): """Simple critic network. Will create an actor operated in continuous \ action space with structure of preprocess_net ---> 1(q value). It differs from tianshou's original Critic in that the output will be a list to make the API consistent with :class:`~fsrl.utils.net.continuous.DoubleCritic`. :param preprocess_net: a self-defined preprocess_net which output a flattened hidden state. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. Default to empty sequence (where the MLP now contains only a single linear layer). :param int preprocess_net_output_dim: the output dimension of preprocess_net. :param linear_layer: use this module as linear layer. Default to nn.Linear. :param bool flatten_input: whether to flatten input data for the last layer. Default to True. """ def __init__( self, preprocess_net: nn.Module, hidden_sizes: Sequence[int] = (), device: Union[str, int, torch.device] = "cpu", preprocess_net_output_dim: Optional[int] = None, linear_layer: Type[nn.Linear] = nn.Linear, flatten_input: bool = True ) -> None: super().__init__( preprocess_net, hidden_sizes, device, preprocess_net_output_dim, linear_layer, flatten_input )
[docs] def forward( self, obs: Union[np.ndarray, torch.Tensor], act: Optional[Union[np.ndarray, torch.Tensor]] = None, info: Dict[str, Any] = {}, ) -> torch.Tensor: """Mapping: (s, a) -> logits -> Q(s, a).""" logits = super().forward(obs, act, info) return [logits]
[docs] def predict( self, obs: Union[np.ndarray, torch.Tensor], act: Optional[Union[np.ndarray, torch.Tensor]] = None, info: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, list]: """Mapping: (s, a) -> logits -> Q(s, a). :return: q value, and a list of two q values (used for Bellman backup) """ q = self(obs, act, info)[0] return q, [q]