提交 097ee4db 编写于 作者: N niuyazhe

hotfix(nyz): fix qacd model style

上级 6e8a746d
......@@ -2,11 +2,11 @@ from typing import Union, Dict, Optional
import torch
import torch.nn as nn
from ding.utils import SequenceType, squeeze, MODEL_REGISTRY
from ding.utils import SequenceType, squeeze, MODEL_REGISTRY
from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, MultiHead, \
FCEncoder, ConvEncoder
@MODEL_REGISTRY.register('qacd')
class QACD(nn.Module):
r"""
......@@ -62,27 +62,22 @@ class QACD(nn.Module):
)
self.actor_encoder = encoder_cls(
obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type
)
obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type
)
self.critic_encoder = encoder_cls(
obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type
)
obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type
)
self.critic_head = RegressionHead(
critic_head_hidden_size,action_shape,critic_head_layer_num,activation=activation,norm_type=norm_type
critic_head_hidden_size, action_shape, critic_head_layer_num, activation=activation, norm_type=norm_type
)
self.actor_head = DiscreteHead(
actor_head_hidden_size,
action_shape,
actor_head_layer_num,
activation=activation,
norm_type=norm_type
actor_head_hidden_size, action_shape, actor_head_layer_num, activation=activation, norm_type=norm_type
)
self.actor = [self.actor_encoder, self.actor_head]
self.critic = [self.critic_encoder, self.critic_head]
self.actor = nn.ModuleList(self.actor)
self.critic = nn.ModuleList(self.critic)
def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict:
r"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册