提交 ff0e572e 编写于 作者: N niuyazhe

optim(nyz): optim ddpg with treetensor

上级 fd46fc1a
......@@ -249,7 +249,7 @@ class QAC(nn.Module):
"""
if self.actor_head_type == 'regression':
x = self.actor(inputs)
return ttorch.as_tensor({'action': x['pred']})
return ttorch.Tensor({'action': x['pred']})
elif self.actor_head_type == 'reparameterization':
x = self.actor(inputs)
return {'logit': [x['mu'], x['sigma']]}
......@@ -306,4 +306,4 @@ class QAC(nn.Module):
x = [m(x)['pred'] for m in self.critic]
else:
x = self.critic(x)['pred']
return ttorch.as_tensor({'q_value': x})
return ttorch.Tensor({'q_value': x})
......@@ -214,7 +214,7 @@ class DDPGPolicy(Policy):
"""
for d in data:
d['replay_unique_id'] = 0 # TODO
data = [ttorch.as_tensor(d) for d in data]
data = [ttorch.Tensor(d) for d in data]
data = ttorch.stack(data)
data.action = data.action.float() # TODO
data.reward = data.reward.squeeze(1)
......@@ -249,7 +249,7 @@ class DDPGPolicy(Policy):
# target q value.
with torch.no_grad():
next_actor_action = self._target_model.forward(next_obs, mode='compute_actor').action
next_actor_data = ttorch.as_tensor({'obs': next_obs, 'action': next_actor_action})
next_actor_data = ttorch.Tensor({'obs': next_obs, 'action': next_actor_action})
target_q_value = self._target_model.forward(next_actor_data, mode='compute_critic').q_value
if self._twin_critic:
# TD3: two critic networks
......@@ -364,7 +364,7 @@ class DDPGPolicy(Policy):
- optional: ``logit``
"""
data_id = list(data.keys())
data = [ttorch.as_tensor(item) for item in data.values()]
data = [ttorch.Tensor(item) for item in data.values()]
data = ttorch.stack(data)
if self._cuda:
data = data.cuda(self._device)
......@@ -429,7 +429,7 @@ class DDPGPolicy(Policy):
- optional: ``logit``
"""
data_id = list(data.keys())
data = [ttorch.as_tensor(item) for item in data.values()]
data = [ttorch.Tensor(item) for item in data.values()]
data = ttorch.stack(data)
if self._cuda:
data = data.cuda(self._device)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册