Source code for fsrl.utils.net.common
from typing import List, Union
import torch.nn as nn
[docs]class ActorCritic(nn.Module):
"""An actor-critic network for parsing parameters.
:param nn.Module actor: the actor network.
:param nn.Module critic: the critic network.
"""
def __init__(self, actor: nn.Module, critics: Union[List, nn.Module]):
super().__init__()
self.actor = actor
if isinstance(critics, List):
critics = nn.ModuleList(critics)
self.critics = critics