Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
fd46fc1a
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 接近 3 年
通知
65
Star
322
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
fd46fc1a
编写于
11月 16, 2021
作者:
N
niuyazhe
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'dev-treetensor' of
https://github.com/opendilab/DI-engine
into dev-treetensor
上级
86a86b22
6e22e735
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
13 addition
and
13 deletion
+13
-13
ding/model/template/q_learning.py
ding/model/template/q_learning.py
+2
-1
ding/policy/dqn.py
ding/policy/dqn.py
+8
-9
ding/utils/type_helper.py
ding/utils/type_helper.py
+3
-3
未找到文件。
ding/model/template/q_learning.py
浏览文件 @
fd46fc1a
from
typing
import
Union
,
Optional
,
Dict
,
Callable
,
List
from
typing
import
Union
,
Optional
,
Dict
,
Callable
,
List
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
treetensor.torch
as
ttorch
import
treetensor.torch
as
ttorch
...
@@ -95,7 +96,7 @@ class DQN(nn.Module):
...
@@ -95,7 +96,7 @@ class DQN(nn.Module):
"""
"""
x
=
self
.
encoder
(
x
)
x
=
self
.
encoder
(
x
)
x
=
self
.
head
(
x
)
x
=
self
.
head
(
x
)
return
ttorch
.
as_t
ensor
(
x
)
return
ttorch
.
T
ensor
(
x
)
@
MODEL_REGISTRY
.
register
(
'c51dqn'
)
@
MODEL_REGISTRY
.
register
(
'c51dqn'
)
...
...
ding/policy/dqn.py
浏览文件 @
fd46fc1a
from
typing
import
List
,
Dict
,
Any
,
Tuple
from
collections
import
namedtuple
import
copy
import
copy
from
collections
import
namedtuple
from
typing
import
List
,
Dict
,
Any
,
Tuple
import
torch
import
torch
import
treetensor.torch
as
ttorch
import
treetensor.torch
as
ttorch
from
ding.torch_utils
import
Adam
,
to_device
from
ding.rl_utils
import
q_nstep_td_data
,
q_nstep_td_error
,
get_nstep_return_data
,
get_train_sample
from
ding.model
import
model_wrap
from
ding.model
import
model_wrap
from
ding.rl_utils
import
q_nstep_td_data
,
q_nstep_td_error
,
get_nstep_return_data
,
get_train_sample
from
ding.torch_utils
import
Adam
from
ding.utils
import
POLICY_REGISTRY
from
ding.utils
import
POLICY_REGISTRY
from
ding.utils.data
import
default_collate
,
default_decollate
from
.base_policy
import
Policy
from
.base_policy
import
Policy
from
.common_utils
import
default_preprocess_learn
@
POLICY_REGISTRY
.
register
(
'dqn'
)
@
POLICY_REGISTRY
.
register
(
'dqn'
)
...
@@ -150,7 +149,7 @@ class DQNPolicy(Policy):
...
@@ -150,7 +149,7 @@ class DQNPolicy(Policy):
"""
"""
for
d
in
data
:
for
d
in
data
:
d
[
'replay_unique_id'
]
=
0
# TODO
d
[
'replay_unique_id'
]
=
0
# TODO
data
=
[
ttorch
.
as_t
ensor
(
d
)
for
d
in
data
]
data
=
[
ttorch
.
T
ensor
(
d
)
for
d
in
data
]
data
=
ttorch
.
stack
(
data
)
data
=
ttorch
.
stack
(
data
)
data
.
action
.
squeeze_
(
1
)
data
.
action
.
squeeze_
(
1
)
if
self
.
_cfg
.
learn
.
ignore_done
:
if
self
.
_cfg
.
learn
.
ignore_done
:
...
@@ -268,7 +267,7 @@ class DQNPolicy(Policy):
...
@@ -268,7 +267,7 @@ class DQNPolicy(Policy):
- necessary: ``logit``, ``action``
- necessary: ``logit``, ``action``
"""
"""
data_id
=
list
(
data
.
keys
())
data_id
=
list
(
data
.
keys
())
data
=
[
ttorch
.
as_t
ensor
(
item
)
for
item
in
data
.
values
()]
data
=
[
ttorch
.
T
ensor
(
item
)
for
item
in
data
.
values
()]
data
=
ttorch
.
stack
(
data
)
data
=
ttorch
.
stack
(
data
)
if
self
.
_cuda
:
if
self
.
_cuda
:
data
=
data
.
cuda
(
self
.
_device
)
data
=
data
.
cuda
(
self
.
_device
)
...
@@ -346,7 +345,7 @@ class DQNPolicy(Policy):
...
@@ -346,7 +345,7 @@ class DQNPolicy(Policy):
- necessary: ``action``
- necessary: ``action``
"""
"""
data_id
=
list
(
data
.
keys
())
data_id
=
list
(
data
.
keys
())
data
=
[
ttorch
.
as_t
ensor
(
item
)
for
item
in
data
.
values
()]
data
=
[
ttorch
.
T
ensor
(
item
)
for
item
in
data
.
values
()]
data
=
ttorch
.
stack
(
data
)
data
=
ttorch
.
stack
(
data
)
if
self
.
_cuda
:
if
self
.
_cuda
:
data
=
data
.
cuda
(
self
.
_device
)
data
=
data
.
cuda
(
self
.
_device
)
...
...
ding/utils/type_helper.py
浏览文件 @
fd46fc1a
import
typing
import
treetensor
from
collections
import
namedtuple
from
collections
import
namedtuple
from
typing
import
List
,
Dict
,
Tuple
,
TypeVar
,
Type
from
typing
import
List
,
Dict
,
Tuple
,
TypeVar
import
treetensor
SequenceType
=
TypeVar
(
'SequenceType'
,
List
,
Tuple
,
namedtuple
)
SequenceType
=
TypeVar
(
'SequenceType'
,
List
,
Tuple
,
namedtuple
)
NestedType
=
TypeVar
(
'NestedType'
,
Dict
,
treetensor
.
torch
.
Tensor
)
NestedType
=
TypeVar
(
'NestedType'
,
Dict
,
treetensor
.
torch
.
Tensor
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录