提交 fd46fc1a 编写于 作者: N niuyazhe

Merge branch 'dev-treetensor' of https://github.com/opendilab/DI-engine into dev-treetensor

from typing import Union, Optional, Dict, Callable, List
import torch
import torch.nn as nn
import treetensor.torch as ttorch
......@@ -95,7 +96,7 @@ class DQN(nn.Module):
"""
x = self.encoder(x)
x = self.head(x)
return ttorch.as_tensor(x)
return ttorch.Tensor(x)
@MODEL_REGISTRY.register('c51dqn')
......
from typing import List, Dict, Any, Tuple
from collections import namedtuple
import copy
from collections import namedtuple
from typing import List, Dict, Any, Tuple
import torch
import treetensor.torch as ttorch
from ding.torch_utils import Adam, to_device
from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, get_nstep_return_data, get_train_sample
from ding.model import model_wrap
from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, get_nstep_return_data, get_train_sample
from ding.torch_utils import Adam
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate
from .base_policy import Policy
from .common_utils import default_preprocess_learn
@POLICY_REGISTRY.register('dqn')
......@@ -150,7 +149,7 @@ class DQNPolicy(Policy):
"""
for d in data:
d['replay_unique_id'] = 0 # TODO
data = [ttorch.as_tensor(d) for d in data]
data = [ttorch.Tensor(d) for d in data]
data = ttorch.stack(data)
data.action.squeeze_(1)
if self._cfg.learn.ignore_done:
......@@ -268,7 +267,7 @@ class DQNPolicy(Policy):
- necessary: ``logit``, ``action``
"""
data_id = list(data.keys())
data = [ttorch.as_tensor(item) for item in data.values()]
data = [ttorch.Tensor(item) for item in data.values()]
data = ttorch.stack(data)
if self._cuda:
data = data.cuda(self._device)
......@@ -346,7 +345,7 @@ class DQNPolicy(Policy):
- necessary: ``action``
"""
data_id = list(data.keys())
data = [ttorch.as_tensor(item) for item in data.values()]
data = [ttorch.Tensor(item) for item in data.values()]
data = ttorch.stack(data)
if self._cuda:
data = data.cuda(self._device)
......
import typing
import treetensor
from collections import namedtuple
from typing import List, Dict, Tuple, TypeVar, Type
from typing import List, Dict, Tuple, TypeVar
import treetensor
SequenceType = TypeVar('SequenceType', List, Tuple, namedtuple)
NestedType = TypeVar('NestedType', Dict, treetensor.torch.Tensor)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册