diff --git a/ding/model/template/q_learning.py b/ding/model/template/q_learning.py index 07fa112e0c2f5c92b78ee60e878e12555e09bba7..0bf51c25908f8ffece79c7b53c9424947ded360b 100644 --- a/ding/model/template/q_learning.py +++ b/ding/model/template/q_learning.py @@ -1,4 +1,5 @@ from typing import Union, Optional, Dict, Callable, List + import torch import torch.nn as nn import treetensor.torch as ttorch @@ -95,7 +96,7 @@ class DQN(nn.Module): """ x = self.encoder(x) x = self.head(x) - return ttorch.as_tensor(x) + return ttorch.Tensor(x) @MODEL_REGISTRY.register('c51dqn') diff --git a/ding/policy/dqn.py b/ding/policy/dqn.py index c3e860f0805ce55d9683e52dca6d5d8cf9869bcc..2dc4618b3d6bb1f1ced0bf2af09b99c5467e79b0 100644 --- a/ding/policy/dqn.py +++ b/ding/policy/dqn.py @@ -1,16 +1,15 @@ -from typing import List, Dict, Any, Tuple -from collections import namedtuple import copy +from collections import namedtuple +from typing import List, Dict, Any, Tuple + import torch import treetensor.torch as ttorch -from ding.torch_utils import Adam, to_device -from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, get_nstep_return_data, get_train_sample from ding.model import model_wrap +from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, get_nstep_return_data, get_train_sample +from ding.torch_utils import Adam from ding.utils import POLICY_REGISTRY -from ding.utils.data import default_collate, default_decollate from .base_policy import Policy -from .common_utils import default_preprocess_learn @POLICY_REGISTRY.register('dqn') @@ -150,7 +149,7 @@ class DQNPolicy(Policy): """ for d in data: d['replay_unique_id'] = 0 # TODO - data = [ttorch.as_tensor(d) for d in data] + data = [ttorch.Tensor(d) for d in data] data = ttorch.stack(data) data.action.squeeze_(1) if self._cfg.learn.ignore_done: @@ -268,7 +267,7 @@ class DQNPolicy(Policy): - necessary: ``logit``, ``action`` """ data_id = list(data.keys()) - data = [ttorch.as_tensor(item) for item in data.values()] + data = [ttorch.Tensor(item) for item in data.values()] data = ttorch.stack(data) if self._cuda: data = data.cuda(self._device) @@ -346,7 +345,7 @@ class DQNPolicy(Policy): - necessary: ``action`` """ data_id = list(data.keys()) - data = [ttorch.as_tensor(item) for item in data.values()] + data = [ttorch.Tensor(item) for item in data.values()] data = ttorch.stack(data) if self._cuda: data = data.cuda(self._device) diff --git a/ding/utils/type_helper.py b/ding/utils/type_helper.py index b959f35201b7df0053a0438132753d3156e9be8f..ae2d9800a0fd3963c91d73e97ab1372b49439ce6 100644 --- a/ding/utils/type_helper.py +++ b/ding/utils/type_helper.py @@ -1,7 +1,7 @@ -import typing -import treetensor from collections import namedtuple -from typing import List, Dict, Tuple, TypeVar, Type +from typing import List, Dict, Tuple, TypeVar + +import treetensor SequenceType = TypeVar('SequenceType', List, Tuple, namedtuple) NestedType = TypeVar('NestedType', Dict, treetensor.torch.Tensor)