Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3f3a256e
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3f3a256e
编写于
7月 26, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/functional): fix conv* dtype promotion
GitOrigin-RevId: 3f03790cfc2ecf2f2c05e1ea5a68be0bc0e84bb2
上级
536506c3
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
75 addition
and
8 deletion
+75
-8
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+41
-5
imperative/python/megengine/module/conv.py
imperative/python/megengine/module/conv.py
+3
-3
imperative/python/test/unit/module/test_conv.py
imperative/python/test/unit/module/test_conv.py
+31
-0
未找到文件。
imperative/python/megengine/functional/nn.py
浏览文件 @
3f3a256e
...
...
@@ -9,7 +9,7 @@
# pylint: disable=too-many-lines
from
typing
import
Optional
,
Sequence
,
Tuple
,
Union
from
..core._imperative_rt.core2
import
apply
from
..core._imperative_rt.core2
import
apply
,
dtype_promotion
from
..core.ops
import
builtin
from
..core.ops.builtin
import
BatchNorm
,
Elemwise
from
..core.ops.special
import
Const
...
...
@@ -157,6 +157,12 @@ def conv1d(
if
amp
.
_enabled
:
compute_mode
=
"float32"
inp
,
weight
,
bias
=
cast_tensors
(
inp
,
weight
,
bias
)
else
:
dtype
=
dtype_promotion
(
inp
,
weight
)
if
inp
.
dtype
!=
dtype
:
inp
=
inp
.
astype
(
dtype
)
if
weight
.
dtype
!=
dtype
:
weight
=
weight
.
astype
(
dtype
)
inp
=
expand_dims
(
inp
,
3
)
weight
=
expand_dims
(
weight
,
3
)
...
...
@@ -234,6 +240,12 @@ def conv2d(
if
amp
.
_enabled
:
compute_mode
=
"float32"
inp
,
weight
,
bias
=
cast_tensors
(
inp
,
weight
,
bias
)
else
:
dtype
=
dtype_promotion
(
inp
,
weight
)
if
inp
.
dtype
!=
dtype
:
inp
=
inp
.
astype
(
dtype
)
if
weight
.
dtype
!=
dtype
:
weight
=
weight
.
astype
(
dtype
)
stride_h
,
stride_w
=
expand_hw
(
stride
)
pad_h
,
pad_w
=
expand_hw
(
padding
)
...
...
@@ -297,6 +309,12 @@ def conv3d(
stride
=
_triple_nonzero
(
stride
)
dilate
=
_triple_nonzero
(
dilation
)
dtype
=
dtype_promotion
(
inp
,
weight
)
if
inp
.
dtype
!=
dtype
:
inp
=
inp
.
astype
(
dtype
)
if
weight
.
dtype
!=
dtype
:
weight
=
weight
.
astype
(
dtype
)
sparse_type
=
"dense"
if
groups
==
1
else
"group"
op
=
builtin
.
Convolution3D
(
pad_d
=
pad
[
D
],
...
...
@@ -364,6 +382,12 @@ def conv_transpose2d(
if
amp
.
_enabled
:
compute_mode
=
"float32"
inp
,
weight
,
bias
=
cast_tensors
(
inp
,
weight
,
bias
)
else
:
dtype
=
dtype_promotion
(
inp
,
weight
)
if
inp
.
dtype
!=
dtype
:
inp
=
inp
.
astype
(
dtype
)
if
weight
.
dtype
!=
dtype
:
weight
=
weight
.
astype
(
dtype
)
if
groups
!=
1
:
raise
NotImplementedError
(
"group transposed conv2d is not supported yet."
)
...
...
@@ -482,6 +506,12 @@ def local_conv2d(
pad_h
,
pad_w
=
expand_hw
(
padding
)
dilate_h
,
dilate_w
=
expand_hw
(
dilation
)
dtype
=
dtype_promotion
(
inp
,
weight
)
if
inp
.
dtype
!=
dtype
:
inp
=
inp
.
astype
(
dtype
)
if
weight
.
dtype
!=
dtype
:
weight
=
weight
.
astype
(
dtype
)
op
=
builtin
.
GroupLocal
(
stride_h
=
stride_h
,
stride_w
=
stride_w
,
...
...
@@ -527,6 +557,12 @@ def conv_transpose3d(
stride
=
_triple_nonzero
(
stride
)
dilate
=
_triple_nonzero
(
dilation
)
dtype
=
dtype_promotion
(
inp
,
weight
)
if
inp
.
dtype
!=
dtype
:
inp
=
inp
.
astype
(
dtype
)
if
weight
.
dtype
!=
dtype
:
weight
=
weight
.
astype
(
dtype
)
op
=
builtin
.
Convolution3DBackwardData
(
pad_d
=
pad
[
D
],
pad_h
=
pad
[
H
],
...
...
imperative/python/megengine/module/conv.py
浏览文件 @
3f3a256e
...
...
@@ -939,7 +939,7 @@ class ConvTranspose3d(_ConvNd):
ichl
=
self
.
in_channels
ochl
=
self
.
out_channels
kt
,
kh
,
kw
=
self
.
kernel_size
return
(
ochl
,
i
chl
,
kt
,
kh
,
kw
)
return
(
ichl
,
o
chl
,
kt
,
kh
,
kw
)
def
_infer_bias_shape
(
self
):
# Assume format is NCTHW
...
...
imperative/python/test/unit/module/test_conv.py
浏览文件 @
3f3a256e
...
...
@@ -9,11 +9,41 @@
import
itertools
import
numpy
as
np
import
pytest
import
megengine.module
as
M
from
megengine
import
Parameter
,
tensor
from
megengine.functional.debug_param
import
(
get_execution_strategy
,
set_execution_strategy
,
)
from
megengine.module
import
ConvTranspose2d
,
ConvTranspose3d
,
LocalConv2d
@
pytest
.
fixture
def
reproducible
():
old
=
get_execution_strategy
()
set_execution_strategy
(
"HEURISTIC_REPRODUCIBLE"
)
yield
set_execution_strategy
(
old
)
# NOTE: test in module for convenience. should really test in functional
@
pytest
.
mark
.
parametrize
(
"name"
,
[
"Conv1d"
,
"Conv2d"
,
"Conv3d"
,
"ConvTranspose2d"
,
"ConvTranspose3d"
,
"LocalConv2d"
],
)
def
test_conv_dtype_promotion
(
name
,
reproducible
):
N
,
Ci
,
Co
,
K
=
2
,
16
,
32
,
3
S
=
(
7
,)
*
int
(
name
[
-
2
])
if
"Local"
in
name
:
m
=
getattr
(
M
,
name
)(
Ci
,
Co
,
*
S
,
K
)
else
:
m
=
getattr
(
M
,
name
)(
Ci
,
Co
,
K
)
x
=
tensor
(
np
.
random
.
random
(
size
=
(
N
,
Ci
)
+
S
).
astype
(
"float16"
))
np
.
testing
.
assert_equal
(
m
(
x
).
numpy
(),
m
(
x
.
astype
(
"float32"
)).
numpy
())
def
test_conv_transpose2d
():
SH
,
SW
=
3
,
1
PH
,
PW
=
2
,
0
...
...
@@ -163,6 +193,7 @@ def test_conv_transpose3d():
)
out_np
=
out_np
[:,
:,
PD
:
OD
-
PD
,
PH
:
OH
-
PH
,
PW
:
OW
-
PW
]
assert
conv_transpose3d
.
weight
.
numpy
().
shape
==
weight
.
shape
conv_transpose3d
.
weight
=
Parameter
(
weight
)
out_meg
=
conv_transpose3d
.
forward
(
tensor
(
inp
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录