Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDILab开源决策智能平台
DI-engine
提交
7acdb671
D
DI-engine
项目概览
OpenDILab开源决策智能平台
/
DI-engine
上一次同步 2 年多
通知
56
Star
321
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-engine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
7acdb671
编写于
11月 16, 2021
作者:
N
niuyazhe
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish(nyz): add torch1.1.0 compatibility for nn.Flatten
上级
171dddc4
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
46 addition
and
9 deletion
+46
-9
ding/__init__.py
ding/__init__.py
+5
-0
ding/model/common/encoder.py
ding/model/common/encoder.py
+2
-2
ding/torch_utils/network/__init__.py
ding/torch_utils/network/__init__.py
+1
-1
ding/torch_utils/network/nn_module.py
ding/torch_utils/network/nn_module.py
+21
-0
ding/torch_utils/network/resnet.py
ding/torch_utils/network/resnet.py
+2
-1
ding/torch_utils/network/tests/test_nn_module.py
ding/torch_utils/network/tests/test_nn_module.py
+14
-1
ding/utils/data/collate_fn.py
ding/utils/data/collate_fn.py
+1
-4
未找到文件。
ding/__init__.py
浏览文件 @
7acdb671
import
os
import
os
import
torch
__TITLE__
=
'DI-engine'
__TITLE__
=
'DI-engine'
__VERSION__
=
'v0.2.0'
__VERSION__
=
'v0.2.0'
...
@@ -10,3 +11,7 @@ __version__ = __VERSION__
...
@@ -10,3 +11,7 @@ __version__ = __VERSION__
enable_hpc_rl
=
False
enable_hpc_rl
=
False
enable_linklink
=
os
.
environ
.
get
(
'ENABLE_LINKLINK'
,
'false'
).
lower
()
==
'true'
enable_linklink
=
os
.
environ
.
get
(
'ENABLE_LINKLINK'
,
'false'
).
lower
()
==
'true'
enable_numba
=
True
enable_numba
=
True
def
torch_gt_131
():
return
int
(
""
.
join
(
list
(
filter
(
str
.
isdigit
,
torch
.
__version__
))))
>=
131
ding/model/common/encoder.py
浏览文件 @
7acdb671
...
@@ -2,7 +2,7 @@ from typing import Optional
...
@@ -2,7 +2,7 @@ from typing import Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
ding.torch_utils
import
ResFCBlock
,
ResBlock
from
ding.torch_utils
import
ResFCBlock
,
ResBlock
,
Flatten
from
ding.utils
import
SequenceType
from
ding.utils
import
SequenceType
...
@@ -49,7 +49,7 @@ class ConvEncoder(nn.Module):
...
@@ -49,7 +49,7 @@ class ConvEncoder(nn.Module):
assert
len
(
set
(
hidden_size_list
[
3
:
-
1
]))
<=
1
,
"Please indicate the same hidden size for res block parts"
assert
len
(
set
(
hidden_size_list
[
3
:
-
1
]))
<=
1
,
"Please indicate the same hidden size for res block parts"
for
i
in
range
(
3
,
len
(
self
.
hidden_size_list
)
-
1
):
for
i
in
range
(
3
,
len
(
self
.
hidden_size_list
)
-
1
):
layers
.
append
(
ResBlock
(
self
.
hidden_size_list
[
i
],
activation
=
self
.
act
,
norm_type
=
norm_type
))
layers
.
append
(
ResBlock
(
self
.
hidden_size_list
[
i
],
activation
=
self
.
act
,
norm_type
=
norm_type
))
layers
.
append
(
nn
.
Flatten
())
layers
.
append
(
Flatten
())
self
.
main
=
nn
.
Sequential
(
*
layers
)
self
.
main
=
nn
.
Sequential
(
*
layers
)
flatten_size
=
self
.
_get_flatten_size
()
flatten_size
=
self
.
_get_flatten_size
()
...
...
ding/torch_utils/network/__init__.py
浏览文件 @
7acdb671
from
.activation
import
build_activation
,
Swish
from
.activation
import
build_activation
,
Swish
from
.res_block
import
ResBlock
,
ResFCBlock
from
.res_block
import
ResBlock
,
ResFCBlock
from
.nn_module
import
fc_block
,
conv2d_block
,
one_hot
,
deconv2d_block
,
BilinearUpsample
,
NearestUpsample
,
\
from
.nn_module
import
fc_block
,
conv2d_block
,
one_hot
,
deconv2d_block
,
BilinearUpsample
,
NearestUpsample
,
\
binary_encode
,
NoiseLinearLayer
,
noise_block
,
MLP
binary_encode
,
NoiseLinearLayer
,
noise_block
,
MLP
,
Flatten
from
.normalization
import
build_normalization
from
.normalization
import
build_normalization
from
.rnn
import
get_lstm
,
sequence_mask
from
.rnn
import
get_lstm
,
sequence_mask
from
.soft_argmax
import
SoftArgmax
from
.soft_argmax
import
SoftArgmax
...
...
ding/torch_utils/network/nn_module.py
浏览文件 @
7acdb671
...
@@ -4,6 +4,7 @@ import torch.nn as nn
...
@@ -4,6 +4,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn.init
import
xavier_normal_
,
kaiming_normal_
,
orthogonal_
from
torch.nn.init
import
xavier_normal_
,
kaiming_normal_
,
orthogonal_
from
typing
import
Union
,
Tuple
,
List
,
Callable
from
typing
import
Union
,
Tuple
,
List
,
Callable
from
ding
import
torch_gt_131
from
.normalization
import
build_normalization
from
.normalization
import
build_normalization
...
@@ -577,3 +578,23 @@ def noise_block(
...
@@ -577,3 +578,23 @@ def noise_block(
if
use_dropout
:
if
use_dropout
:
block
.
append
(
nn
.
Dropout
(
dropout_probability
))
block
.
append
(
nn
.
Dropout
(
dropout_probability
))
return
sequential_pack
(
block
)
return
sequential_pack
(
block
)
class
NaiveFlatten
(
nn
.
Module
):
def
__init__
(
self
,
start_dim
:
int
=
1
,
end_dim
:
int
=
-
1
)
->
None
:
super
(
NaiveFlatten
,
self
).
__init__
()
self
.
start_dim
=
start_dim
self
.
end_dim
=
end_dim
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
end_dim
!=
-
1
:
return
x
.
view
(
*
x
.
shape
[:
self
.
start_dim
],
-
1
,
*
x
.
shape
[
self
.
end_dim
+
1
:])
else
:
return
x
.
view
(
*
x
.
shape
[:
self
.
start_dim
],
-
1
)
if
torch_gt_131
():
Flatten
=
nn
.
Flatten
else
:
Flatten
=
NaiveFlatten
ding/torch_utils/network/resnet.py
浏览文件 @
7acdb671
...
@@ -6,6 +6,7 @@ import math
...
@@ -6,6 +6,7 @@ import math
import
numpy
as
np
import
numpy
as
np
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
.nn_module
import
Flatten
def
to_2tuple
(
item
):
def
to_2tuple
(
item
):
...
@@ -94,7 +95,7 @@ class ClassifierHead(nn.Module):
...
@@ -94,7 +95,7 @@ class ClassifierHead(nn.Module):
self
.
drop_rate
=
drop_rate
self
.
drop_rate
=
drop_rate
self
.
global_pool
,
num_pooled_features
=
_create_pool
(
in_chs
,
num_classes
,
pool_type
,
use_conv
=
use_conv
)
self
.
global_pool
,
num_pooled_features
=
_create_pool
(
in_chs
,
num_classes
,
pool_type
,
use_conv
=
use_conv
)
self
.
fc
=
_create_fc
(
num_pooled_features
,
num_classes
,
use_conv
=
use_conv
)
self
.
fc
=
_create_fc
(
num_pooled_features
,
num_classes
,
use_conv
=
use_conv
)
self
.
flatten
=
nn
.
Flatten
(
1
)
if
use_conv
and
pool_type
else
nn
.
Identity
()
self
.
flatten
=
Flatten
(
1
)
if
use_conv
and
pool_type
else
nn
.
Identity
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
global_pool
(
x
)
x
=
self
.
global_pool
(
x
)
...
...
ding/torch_utils/network/tests/test_nn_module.py
浏览文件 @
7acdb671
...
@@ -2,7 +2,7 @@ import torch
...
@@ -2,7 +2,7 @@ import torch
import
pytest
import
pytest
from
ding.torch_utils
import
build_activation
,
build_normalization
from
ding.torch_utils
import
build_activation
,
build_normalization
from
ding.torch_utils.network.nn_module
import
conv1d_block
,
conv2d_block
,
fc_block
,
deconv2d_block
,
ChannelShuffle
,
\
from
ding.torch_utils.network.nn_module
import
conv1d_block
,
conv2d_block
,
fc_block
,
deconv2d_block
,
ChannelShuffle
,
\
one_hot
,
NearestUpsample
,
BilinearUpsample
,
binary_encode
,
weight_init_
one_hot
,
NearestUpsample
,
BilinearUpsample
,
binary_encode
,
weight_init_
,
NaiveFlatten
batch_size
=
2
batch_size
=
2
in_channels
=
2
in_channels
=
2
...
@@ -148,3 +148,16 @@ class TestNnModule:
...
@@ -148,3 +148,16 @@ class TestNnModule:
max_val
=
torch
.
tensor
(
8
)
max_val
=
torch
.
tensor
(
8
)
output
=
binary_encode
(
input
,
max_val
)
output
=
binary_encode
(
input
,
max_val
)
assert
torch
.
equal
(
output
,
torch
.
tensor
([[
0
,
1
,
0
,
0
]]))
assert
torch
.
equal
(
output
,
torch
.
tensor
([[
0
,
1
,
0
,
0
]]))
@
pytest
.
mark
.
tmp
def
test_flatten
(
self
):
inputs
=
torch
.
randn
(
4
,
3
,
8
,
8
)
model1
=
NaiveFlatten
()
output1
=
model1
(
inputs
)
assert
output1
.
shape
==
(
4
,
3
*
8
*
8
)
model2
=
NaiveFlatten
(
1
,
2
)
output2
=
model2
(
inputs
)
assert
output2
.
shape
==
(
4
,
3
*
8
,
8
)
model3
=
NaiveFlatten
(
1
,
3
)
output3
=
model2
(
inputs
)
assert
output1
.
shape
==
(
4
,
3
*
8
*
8
)
ding/utils/data/collate_fn.py
浏览文件 @
7acdb671
...
@@ -5,6 +5,7 @@ import torch
...
@@ -5,6 +5,7 @@ import torch
import
re
import
re
from
torch._six
import
string_classes
from
torch._six
import
string_classes
import
collections.abc
as
container_abcs
import
collections.abc
as
container_abcs
from
ding
import
torch_gt_131
int_classes
=
int
int_classes
=
int
np_str_obj_array_pattern
=
re
.
compile
(
r
'[SaUO]'
)
np_str_obj_array_pattern
=
re
.
compile
(
r
'[SaUO]'
)
...
@@ -15,10 +16,6 @@ default_collate_err_msg_format = (
...
@@ -15,10 +16,6 @@ default_collate_err_msg_format = (
)
)
def
torch_gt_131
():
return
int
(
""
.
join
(
list
(
filter
(
str
.
isdigit
,
torch
.
__version__
))))
>=
131
def
default_collate
(
batch
:
Sequence
,
def
default_collate
(
batch
:
Sequence
,
cat_1dim
:
bool
=
True
,
cat_1dim
:
bool
=
True
,
ignore_prefix
:
list
=
[
'collate_ignore'
])
->
Union
[
torch
.
Tensor
,
Mapping
,
Sequence
]:
ignore_prefix
:
list
=
[
'collate_ignore'
])
->
Union
[
torch
.
Tensor
,
Mapping
,
Sequence
]:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录