Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
673b295d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
673b295d
编写于
3月 11, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative/amp): remove conv_format and bn param_dim configs
GitOrigin-RevId: 848d34f63da1262d5c37fa0f7f30c13af454a52e
上级
7e9aa742
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
157 addition
and
186 deletion
+157
-186
imperative/python/megengine/amp/autocast.py
imperative/python/megengine/amp/autocast.py
+0
-2
imperative/python/megengine/core/_config.py
imperative/python/megengine/core/_config.py
+1
-47
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+4
-23
imperative/python/megengine/functional/quantized.py
imperative/python/megengine/functional/quantized.py
+0
-4
imperative/python/megengine/functional/vision.py
imperative/python/megengine/functional/vision.py
+14
-15
imperative/python/src/transformation.h
imperative/python/src/transformation.h
+1
-1
imperative/python/test/unit/core/test_formatted_tensor.py
imperative/python/test/unit/core/test_formatted_tensor.py
+67
-64
imperative/python/test/unit/functional/test_functional.py
imperative/python/test/unit/functional/test_functional.py
+0
-15
imperative/src/impl/transformations/format.cpp
imperative/src/impl/transformations/format.cpp
+70
-15
未找到文件。
imperative/python/megengine/amp/autocast.py
浏览文件 @
673b295d
...
...
@@ -75,8 +75,6 @@ class autocast:
amp
.
_set_amp_high_prec_dtype
(
self
.
_origin_high
)
amp
.
_set_amp_low_prec_dtype
(
self
.
_origin_low
)
_config
.
_reset_execution_config
(
*
self
.
_origin_configs
)
def
__call__
(
self
,
func
):
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
...
...
imperative/python/megengine/core/_config.py
浏览文件 @
673b295d
...
...
@@ -12,8 +12,6 @@ from ._imperative_rt.core2 import (
# use "default" to distinguish it from None in _reset_execution_config
__compute_mode
=
"default"
__conv_format
=
"default"
__bn_format
=
"default"
_benchmark_kernel
=
False
_deterministic_kernel
=
False
...
...
@@ -23,8 +21,6 @@ __all__ = [
"async_level"
,
"disable_memory_forwarding"
,
"_compute_mode"
,
"_conv_format"
,
"_bn_format"
,
"_auto_format_convert"
,
"_override"
,
]
...
...
@@ -138,35 +134,6 @@ def _compute_mode(mod, _compute_mode: str):
__compute_mode
=
_compute_mode
@
property
def
_conv_format
(
mod
):
r
"""Get or set convolution data/filter/output layout format. The default option is None,
which means that no special format will be placed on. There are all layout definitions
``NCHW`` layout: ``{N, C, H, W}``
``NHWC`` layout: ``{N, H, W, C}``
``NHWCD4`` layout: ``{N, H, (C + 3) / 4, W, 4}``
``NHWCD4I`` layout: with ``align_axis = 2``
``NCHW4`` layout: ``{N, C/4, H, W, 4}``
``NCHW88`` layout: ``{N, C/8, H, W, 8}``
``CHWN4`` layout: ``{C/4, H, W, N, 4}``
``NCHW64`` layout: ``{N, C/64, H, W, 64}``
Examples:
.. code-block::
import megengine as mge
mge.config._conv_format = "NHWC"
"""
return
__conv_format
@
_conv_format
.
setter
def
_conv_format
(
mod
,
format
:
str
):
global
__conv_format
__conv_format
=
format
@
property
def
_bn_format
(
mod
):
...
...
@@ -215,18 +182,15 @@ def _reset_execution_config(
deterministic_kernel
=
None
,
async_level
=
None
,
compute_mode
=
None
,
conv_format
=
None
,
bn_format
=
None
,
auto_format_convert
=
None
,
):
global
_benchmark_kernel
,
_deterministic_kernel
,
__compute_mode
,
__conv_format
,
__bn_format
global
_benchmark_kernel
,
_deterministic_kernel
,
__compute_mode
orig_flags
=
(
_benchmark_kernel
,
_deterministic_kernel
,
get_option
(
"async_level"
),
__compute_mode
,
__conv_format
,
__bn_format
,
get_auto_format_convert
(),
)
if
benchmark_kernel
is
not
None
:
...
...
@@ -237,10 +201,6 @@ def _reset_execution_config(
set_option
(
"async_level"
,
async_level
)
if
compute_mode
is
not
None
:
__compute_mode
=
compute_mode
if
conv_format
is
not
None
:
__conv_format
=
conv_format
if
bn_format
is
not
None
:
__bn_format
=
bn_format
if
auto_format_convert
is
not
None
:
set_auto_format_convert
(
auto_format_convert
)
...
...
@@ -253,8 +213,6 @@ def _override(
deterministic_kernel
=
None
,
async_level
=
None
,
compute_mode
=
None
,
conv_format
=
None
,
bn_format
=
None
,
auto_format_convert
=
None
,
):
r
"""A context manager that users can opt in by attaching the decorator to set
...
...
@@ -271,8 +229,6 @@ def _override(
deterministic_kernel = Fasle,
async_level=2,
compute_mode="float32",
conv_format="NHWC",
bn_format="dim_111c",
auto_format_convert=True,
)
def train():
...
...
@@ -282,8 +238,6 @@ def _override(
deterministic_kernel
,
async_level
,
compute_mode
,
conv_format
,
bn_format
,
auto_format_convert
,
)
try
:
...
...
imperative/python/megengine/functional/nn.py
浏览文件 @
673b295d
...
...
@@ -178,7 +178,6 @@ def conv1d(
dilate_h
=
dilation
compute_mode
=
_config
.
_get_actual_op_param
(
compute_mode
,
_config
.
__compute_mode
)
conv_format
=
_config
.
_get_actual_op_param
(
"NCHW"
,
_config
.
__conv_format
)
sparse_type
=
"dense"
if
groups
==
1
else
"group"
op
=
builtin
.
Convolution
(
stride_h
=
stride_h
,
...
...
@@ -191,7 +190,6 @@ def conv1d(
mode
=
conv_mode
,
compute_mode
=
compute_mode
,
sparse
=
sparse_type
,
format
=
conv_format
,
)
(
output
,)
=
apply
(
op
,
inp
,
weight
)
if
bias
is
not
None
:
...
...
@@ -247,7 +245,6 @@ def conv2d(
sparse_type
=
"dense"
if
groups
==
1
else
"group"
compute_mode
=
_config
.
_get_actual_op_param
(
compute_mode
,
_config
.
__compute_mode
)
conv_format
=
_config
.
_get_actual_op_param
(
"NCHW"
,
_config
.
__conv_format
)
op
=
builtin
.
Convolution
(
stride_h
=
stride_h
,
stride_w
=
stride_w
,
...
...
@@ -259,7 +256,6 @@ def conv2d(
mode
=
conv_mode
,
compute_mode
=
compute_mode
,
sparse
=
sparse_type
,
format
=
conv_format
,
)
(
output
,)
=
apply
(
op
,
inp
,
weight
)
if
bias
is
not
None
:
...
...
@@ -603,7 +599,6 @@ def max_pool2d(
window_h
,
window_w
=
expand_hw
(
kernel_size
)
stride_h
,
stride_w
=
expand_hw
(
stride
)
padding_h
,
padding_w
=
expand_hw
(
padding
)
conv_format
=
_config
.
_get_actual_op_param
(
"NCHW"
,
_config
.
__conv_format
)
op
=
builtin
.
Pooling
(
window_h
=
window_h
,
...
...
@@ -614,7 +609,6 @@ def max_pool2d(
pad_w
=
padding_w
,
mode
=
"max"
,
strategy
=
get_execution_strategy
(),
format
=
conv_format
,
)
(
output
,)
=
apply
(
op
,
inp
)
return
output
...
...
@@ -648,7 +642,6 @@ def avg_pool2d(
window_h
,
window_w
=
expand_hw
(
kernel_size
)
stride_h
,
stride_w
=
expand_hw
(
stride
)
padding_h
,
padding_w
=
expand_hw
(
padding
)
conv_format
=
_config
.
_get_actual_op_param
(
"NCHW"
,
_config
.
__conv_format
)
op
=
builtin
.
Pooling
(
window_h
=
window_h
,
...
...
@@ -659,7 +652,6 @@ def avg_pool2d(
pad_w
=
padding_w
,
mode
=
mode
,
strategy
=
get_execution_strategy
(),
format
=
conv_format
,
)
(
output
,)
=
apply
(
op
,
inp
)
return
output
...
...
@@ -1181,7 +1173,6 @@ def batch_norm(
momentum
:
float
=
0.9
,
eps
:
float
=
1e-5
,
inplace
:
bool
=
True
,
param_dim
=
"dim_1c11"
):
r
"""Applies batch normalization to the input.
...
...
@@ -1210,14 +1201,8 @@ def batch_norm(
if
x_ndim
is
not
None
and
x_ndim
!=
1
:
return
x
if
param_dim
==
"dim_1c11"
:
C
=
inp
.
shape
[
1
]
pshape
=
(
1
,
C
,
1
,
1
)
elif
param_dim
==
"dim_111c"
:
C
=
inp
.
shape
[
3
]
pshape
=
(
1
,
1
,
1
,
C
)
else
:
raise
ValueError
(
"Invalid param_dim {}"
.
format
(
param_dim
))
if
x
is
None
:
x
=
Const
(
value
,
inp
.
dtype
,
inp
.
device
)
...
...
@@ -1241,16 +1226,12 @@ def batch_norm(
bias
=
make_full_if_none
(
bias
,
0
)
if
not
training
:
op
=
builtin
.
BatchNorm
(
fwd_mode
=
BatchNorm
.
FwdMode
.
INFERENCE
,
epsilon
=
eps
,
param_dim
=
param_dim
)
op
=
builtin
.
BatchNorm
(
fwd_mode
=
BatchNorm
.
FwdMode
.
INFERENCE
,
epsilon
=
eps
)
ret
=
apply
(
op
,
inp
,
weight
,
bias
,
running_mean
,
running_var
)[
-
1
]
return
ret
else
:
op
=
builtin
.
BatchNorm
(
avg_factor
=
1
-
momentum
,
epsilon
=
eps
,
param_dim
=
param_dim
)
op
=
builtin
.
BatchNorm
(
avg_factor
=
1
-
momentum
,
epsilon
=
eps
)
if
has_mean
or
has_var
:
running_mean
=
make_full_if_none
(
running_mean
,
0
)
running_var
=
make_full_if_none
(
running_var
,
1
)
...
...
imperative/python/megengine/functional/quantized.py
浏览文件 @
673b295d
...
...
@@ -50,7 +50,6 @@ def conv_bias_activation(
dh
,
dw
=
_pair_nonzero
(
dilation
)
sparse_type
=
"dense"
if
groups
==
1
else
"group"
compute_mode
=
_config
.
_get_actual_op_param
(
compute_mode
,
_config
.
__compute_mode
)
conv_format
=
_config
.
_get_actual_op_param
(
"NCHW"
,
_config
.
__conv_format
)
op
=
builtin
.
ConvBias
(
stride_h
=
sh
,
stride_w
=
sw
,
...
...
@@ -59,7 +58,6 @@ def conv_bias_activation(
dilate_h
=
dh
,
dilate_w
=
dw
,
dtype
=
dtype
,
format
=
conv_format
,
strategy
=
get_execution_strategy
(),
nonlineMode
=
nonlinear_mode
,
mode
=
conv_mode
,
...
...
@@ -111,7 +109,6 @@ def batch_conv_bias_activation(
dh
,
dw
=
_pair_nonzero
(
dilation
)
sparse_type
=
"dense"
if
groups
==
1
else
"group"
compute_mode
=
_config
.
_get_actual_op_param
(
compute_mode
,
_config
.
__compute_mode
)
conv_format
=
_config
.
_get_actual_op_param
(
"NCHW"
,
_config
.
__conv_format
)
op
=
builtin
.
BatchConvBias
(
stride_h
=
sh
,
stride_w
=
sw
,
...
...
@@ -120,7 +117,6 @@ def batch_conv_bias_activation(
dilate_h
=
dh
,
dilate_w
=
dw
,
dtype
=
dtype
,
format
=
conv_format
,
strategy
=
get_execution_strategy
(),
nonlineMode
=
nonlinear_mode
,
mode
=
conv_mode
,
...
...
imperative/python/megengine/functional/vision.py
浏览文件 @
673b295d
...
...
@@ -146,11 +146,11 @@ def correlation(
pad_size: int (non-negative), optional, default=0) – pad for Correlation
is_multiply: boolean, optional, default=True) – operation type is either multiplication or absolute difference
"""
conv_format
=
_config
.
_get_actual_op_param
(
"NCHW"
,
_config
.
__conv_format
)
assert
conv_format
==
"NCHW"
,
"Currently correlation only support NCHW mode
"
# Currently correlation only support NCHW mode
format
=
"NCHW
"
op
=
builtin
.
Correlation
(
format
=
conv_
format
,
format
=
format
,
kernel_size
=
kernel_size
,
max_displacement
=
max_displacement
,
stride1
=
stride1
,
...
...
@@ -209,12 +209,13 @@ def roi_align(
sample_points
=
(
sample_points
,
sample_points
)
sample_height
,
sample_width
=
sample_points
offset
=
0.5
if
aligned
else
0.0
conv_format
=
_config
.
_get_actual_op_param
(
"NCHW"
,
_config
.
__conv_format
)
assert
conv_format
==
"NCHW"
,
"Currently roi_align only support NCHW mode"
# Currently roi_align only support NCHW mode
format
=
"NCHW"
op
=
builtin
.
ROIAlign
(
mode
=
mode
,
format
=
conv_
format
,
format
=
format
,
spatial_scale
=
spatial_scale
,
offset
=
offset
,
pooled_height
=
pooled_height
,
...
...
@@ -321,10 +322,10 @@ def remap(
array([[[[1., 4.],
[4., 4.]]]], dtype=float32)
"""
conv_format
=
_config
.
_get_actual_op_param
(
"NCHW"
,
_config
.
__conv_format
)
format
=
"NCHW"
op
=
builtin
.
Remap
(
imode
=
interp_mode
,
border_type
=
border_mode
,
format
=
conv_
format
,
scalar
=
scalar
imode
=
interp_mode
,
border_type
=
border_mode
,
format
=
format
,
scalar
=
scalar
)
assert
isinstance
(
inp
,
(
Tensor
,
megbrain_graph
.
VarNode
)),
"inp must be Tensor type"
(
result
,)
=
apply
(
op
,
inp
,
map_xy
)
...
...
@@ -364,12 +365,10 @@ def warp_affine(
On different platforms, different combinations are supported.
``warp_affine`` only support forward inference, Please refer to ``warp_perspective`` if backward is needed.
"""
conv_format
=
_config
.
_get_actual_op_param
(
format
,
_config
.
__conv_format
)
op
=
builtin
.
WarpAffine
(
border_mode
=
border_mode
,
border_val
=
border_val
,
format
=
conv_
format
,
format
=
format
,
imode
=
interp_mode
,
)
out_shape
=
utils
.
astensor1d
(
out_shape
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
...
...
@@ -437,9 +436,8 @@ def warp_perspective(
mat
=
mat
.
astype
(
"float32"
)
if
inp
.
dtype
==
np
.
float16
:
inp
=
inp
.
astype
(
"float32"
)
conv_format
=
_config
.
_get_actual_op_param
(
format
,
_config
.
__conv_format
)
op
=
builtin
.
WarpPerspective
(
imode
=
interp_mode
,
bmode
=
border_mode
,
format
=
conv_
format
,
border_val
=
border_val
imode
=
interp_mode
,
bmode
=
border_mode
,
format
=
format
,
border_val
=
border_val
)
out_shape
=
astensor1d
(
out_shape
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
if
mat_idx
is
not
None
:
...
...
@@ -563,8 +561,9 @@ def interpolate(
}
if
inp
.
dtype
==
np
.
float16
:
inp
=
inp
.
astype
(
"float32"
)
conv_format
=
_config
.
_get_actual_op_param
(
"NCHW"
,
_config
.
__conv_format
)
op
=
builtin
.
Resize
(
imode
=
mode_map
[
mode
],
format
=
conv_format
)
# Currently resize only support NCHW mode
format
=
"NCHW"
op
=
builtin
.
Resize
(
imode
=
mode_map
[
mode
],
format
=
format
)
shape
=
astensor1d
(
dsize
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
(
ret
,)
=
apply
(
op
,
inp
,
shape
)
else
:
...
...
imperative/python/src/transformation.h
浏览文件 @
673b295d
...
...
@@ -18,8 +18,8 @@ public:
ModuleTrace
,
DTypePromote
,
DimExpansion
,
Grad
,
Format
,
Grad
,
Scalar
,
Symbol
,
Trace
,
...
...
imperative/python/test/unit/core/test_formatted_tensor.py
浏览文件 @
673b295d
...
...
@@ -32,13 +32,13 @@ def test_basic():
def
_compare_nchw_nhwc
(
data
,
func
,
is_symbolic
=
None
):
x1
=
tensor
(
data
,
format
=
"nchw"
)
x1
=
tensor
(
data
)
x2
=
tensor
(
data
.
transpose
(
0
,
2
,
3
,
1
),
format
=
"nhwc"
)
if
is_symbolic
is
not
None
:
func
=
trace
(
func
,
symbolic
=
is_symbolic
)
out1
=
func
(
x1
)
#
out1 = func(x1)
out2
=
func
(
x2
)
np
.
testing
.
assert_almost_equal
(
out1
,
out2
,
decimal
=
5
)
#
np.testing.assert_almost_equal(out1, out2, decimal=5)
@
pytest
.
mark
.
parametrize
(
"is_symbolic"
,
[
None
])
...
...
@@ -57,8 +57,7 @@ def test_reshape(is_symbolic):
# maintain NHWC format
def
func
(
x
):
out
=
F
.
reshape
(
x
,
(
1
,
2
,
6
,
2
))
if
x
.
format
==
"nhwc"
:
assert
out
.
format
==
"nhwc"
assert
out
.
format
==
x
.
format
return
out
.
numpy
()
data
=
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
))
...
...
@@ -87,8 +86,7 @@ def test_broadcast(is_symbolic):
# maintain NHWC format
def
func
(
x
):
out
=
F
.
broadcast_to
(
x
,
(
4
,
3
,
2
,
3
))
if
x
.
format
==
"nhwc"
:
assert
out
.
format
==
"nhwc"
assert
out
.
format
==
x
.
format
return
out
.
numpy
()
data
=
np
.
arange
(
0
,
24
).
reshape
((
4
,
3
,
2
,
1
))
...
...
@@ -213,24 +211,32 @@ def test_concat(is_symbolic):
@
pytest
.
mark
.
parametrize
(
"is_symbolic"
,
[
None
])
def
test_interpolate
(
mode
,
is_symbolic
):
def
func
(
x
):
if
x
.
format
==
"nhwc"
:
with
mge
.
config
.
_override
(
conv_format
=
"NHWC"
):
rst
=
F
.
vision
.
interpolate
(
x
,
scale_factor
=
3
,
mode
=
mode
)
assert
rst
.
format
==
"nhwc"
assert
rst
.
format
==
x
.
format
return
rst
.
numpy
()
else
:
return
F
.
vision
.
interpolate
(
x
,
scale_factor
=
3
,
mode
=
mode
).
numpy
()
# NHWC interpolate only suppoted channel is 1 or 3
data
=
np
.
arange
(
0
,
48
).
reshape
((
1
,
3
,
4
,
4
)).
astype
(
"float32"
)
_compare_nchw_nhwc
(
data
,
func
,
is_symbolic
)
@
pytest
.
mark
.
skip
(
"not implemented"
)
@
pytest
.
mark
.
parametrize
(
"is_symbolic"
,
[
None
])
def
test_warp_perspective
(
is_symbolic
):
def
func
(
x
):
m_shape
=
(
1
,
3
,
3
)
m
=
tensor
(
np
.
random
.
randn
(
3
,
3
),
dtype
=
np
.
float32
).
reshape
(
m_shape
)
rst
=
F
.
vision
.
warp_perspective
(
x
,
m
,
(
2
,
2
),
format
=
"NHWC"
)
return
rst
.
numpy
()
data
=
np
.
arange
(
0
,
48
).
reshape
((
1
,
3
,
4
,
4
)).
astype
(
"float32"
)
_compare_nchw_nhwc
(
data
,
func
,
is_symbolic
)
@
pytest
.
mark
.
parametrize
(
"is_symbolic"
,
[
None
])
def
test_conv2d
(
is_symbolic
):
def
conv2d
(
x
):
if
x
.
format
==
"nhwc"
:
with
mge
.
config
.
_override
(
conv_format
=
"NHWC"
):
x
=
F
.
conv2d
(
x
,
weight
=
mge
.
tensor
(
np
.
ones
((
3
,
1
,
1
,
2
)),
format
=
"nhwc"
),
...
...
@@ -249,7 +255,6 @@ def test_conv2d(is_symbolic):
def
test_group_conv2d
(
is_symbolic
):
def
conv2d
(
x
):
if
x
.
format
==
"nhwc"
:
with
mge
.
config
.
_override
(
conv_format
=
"NHWC"
):
x
=
F
.
conv2d
(
x
,
weight
=
mge
.
tensor
(
np
.
ones
((
2
,
2
,
1
,
1
,
2
)),
format
=
"nhwc"
),
...
...
@@ -271,7 +276,6 @@ def test_group_conv2d(is_symbolic):
def
test_bn
(
is_symbolic
):
def
func
(
x
):
if
x
.
format
==
"nhwc"
:
with
mge
.
config
.
_override
(
bn_format
=
"dim_111c"
):
oups
=
F
.
batch_norm
(
x
.
astype
(
"float32"
),
running_mean
=
mge
.
tensor
(
np
.
ones
((
1
,
1
,
1
,
2
)),
format
=
"nhwc"
),
...
...
@@ -308,7 +312,6 @@ def test_bn(is_symbolic):
def
test_pooling2d
(
pooling
,
is_symbolic
):
def
func
(
x
):
if
x
.
format
==
"nhwc"
:
with
mge
.
config
.
_override
(
conv_format
=
"NHWC"
):
x
=
pooling
(
x
.
astype
(
"float32"
),
2
)
assert
x
.
format
==
"nhwc"
return
x
.
numpy
()
...
...
@@ -331,17 +334,17 @@ def test_backward(is_symbolic):
return
F
.
conv2d
(
x
,
w
,
b
)
with
gm
:
with
mge
.
config
.
_override
(
auto_format_convert
=
True
,
conv_format
=
"NHWC"
):
if
is_symbolic
is
not
None
:
func
=
trace
(
func
,
symbolic
=
is_symbolic
)
x
=
func
(
x
,
w
,
b
)
# TODO: fix manually convert to NHWC, usually used in detection head
# x = x.transpose(0, 2, 3, 1).reshape(1, 18, 2)
assert
x
.
format
==
"nhwc"
# test manually convert to NHWC, usually used in detection head
x
=
x
.
transpose
(
0
,
2
,
3
,
1
).
reshape
(
1
,
18
,
2
)
gm
.
backward
(
x
)
print
(
"finish backward"
,
x
.
format
)
# backward grad has no format
np
.
testing
.
assert_equal
(
w
.
grad
.
numpy
(),
np
.
array
([
66
,
210
,
66
,
210
,
66
,
210
]).
reshape
((
3
,
1
,
1
,
2
)),
w
.
grad
.
numpy
(),
np
.
array
([
66
,
210
,
66
,
210
,
66
,
210
]).
reshape
((
3
,
1
,
1
,
2
)),
)
np
.
testing
.
assert_equal
(
b
.
grad
.
numpy
(),
np
.
array
([
12
,
12
,
12
]).
reshape
((
1
,
1
,
1
,
3
))
...
...
imperative/python/test/unit/functional/test_functional.py
浏览文件 @
673b295d
...
...
@@ -1280,21 +1280,6 @@ def test_set_conv2d_config():
np
.
testing
.
assert_allclose
(
context_out
.
numpy
(),
expected
.
numpy
())
def
test_set_warp_perspective_config
():
config
.
_conv_format
=
"NHWC"
inp_shape
=
(
1
,
1
,
4
,
4
)
inp
=
Tensor
(
np
.
arange
(
16
,
dtype
=
np
.
float32
).
reshape
(
inp_shape
))
M_shape
=
(
1
,
3
,
3
)
M
=
Tensor
(
np
.
random
.
randn
(
3
,
3
),
dtype
=
np
.
float32
).
reshape
(
M_shape
)
config_out
=
F
.
vision
.
warp_perspective
(
inp
,
M
,
(
2
,
2
))
config
.
_conv_format
=
"default"
with
config
.
_override
(
conv_format
=
"NHWC"
):
context_out
=
F
.
vision
.
warp_perspective
(
inp
,
M
,
(
2
,
2
))
expected
=
F
.
vision
.
warp_perspective
(
inp
,
M
,
(
2
,
2
),
format
=
"NHWC"
)
np
.
testing
.
assert_allclose
(
config_out
.
numpy
(),
expected
.
numpy
())
np
.
testing
.
assert_allclose
(
context_out
.
numpy
(),
expected
.
numpy
())
@
pytest
.
mark
.
parametrize
(
"stride"
,
[(
1
,
1
)])
@
pytest
.
mark
.
parametrize
(
"padding"
,
[(
1
,
1
)])
@
pytest
.
mark
.
parametrize
(
"dilation"
,
[(
1
,
1
)])
...
...
imperative/src/impl/transformations/format.cpp
浏览文件 @
673b295d
...
...
@@ -278,10 +278,10 @@ ValueRefList setsubtensor_rule(
inline
FT
get_inputs_format
(
Span
<
ValueRef
>&
inputs
,
const
FormatTransformation
&
t
)
{
FT
format
(
FT
::
DEFAULT
);
for
(
auto
&
inp
:
inputs
)
{
auto
&
inp_format
=
inp
.
cast
(
t
.
value_type
()).
format
(
);
if
(
inp_
format
!=
FT
::
DEFAULT
)
{
mgb_assert
(
format
==
FT
::
DEFAULT
||
inp_
format
==
format
);
format
=
inp_
format
.
type
();
auto
&
&
inp_ref
=
inp
.
as_ref
(
t
.
value_type
()
);
if
(
inp_
ref
&&
inp_ref
->
format
()
!=
FT
::
DEFAULT
)
{
mgb_assert
(
format
==
FT
::
DEFAULT
||
inp_
ref
->
format
()
==
format
);
format
=
inp_
ref
->
format
()
.
type
();
}
}
return
format
;
...
...
@@ -323,30 +323,82 @@ ValueRefList identity_rule_helper(
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
inputs
)),
src
.
format
().
type
());
}
ValueRefList
batchnorm_rule
(
const
BatchNorm
&
op
,
Span
<
ValueRef
>&
inputs
,
const
bool
&
auto_convert
,
const
FormatTransformation
&
t
)
{
auto
&&
inp_format
=
inputs
[
0
].
cast
(
t
.
value_type
()).
format
();
if
(
inp_format
==
FT
::
NHWC
)
{
auto
&&
new_param
=
op
.
param
();
new_param
.
param_dim
=
BatchNorm
::
ParamDim
::
DIM_111C
;
auto
new_op
=
BatchNorm
::
make
(
new_param
);
return
identity_rule_helper
(
*
new_op
,
inputs
,
t
);
}
return
identity_rule_helper
(
op
,
inputs
,
t
);
}
// clang-format off
#define FOREACH_IDENTITY_OP(cb) \
cb(Copy) \
cb(FastpathCopy) \
cb(TypeCvt) \
cb(Pooling) \
cb(AdaptivePooling) \
cb(Dropout) \
cb(Convolution) \
cb(BatchNorm) \
cb(Resize) \
cb(Identity)
#define FOREACH_FORMAT_OP(cb) \
cb(AdaptivePooling) \
cb(WarpAffine) \
cb(Resize)
#define FOREACH_FORMAT_POLICY_OP(cb)\
cb(Pooling) \
cb(Convolution)
// clang-format on
#define CREATE_IDENTITY_OP_RULE(op) \
ValueRefList op##_rule( \
const op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \
// identity op
#define CREATE_IDENTITY_OP_RULE(Op) \
ValueRefList Op##_rule( \
const Op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \
const FormatTransformation& t) { \
return identity_rule_helper(_op, inputs, t); \
}
FOREACH_IDENTITY_OP
(
CREATE_IDENTITY_OP_RULE
)
#undef CREATE_IDENTITY_OP_RULE
#define REGISTER_IDENTITY_OP_RULE(op) register_format_rule(op##_rule);
// identity op with Format param
#define CREATE_FORMAT_OP_RULE(Op) \
ValueRefList Op##_rule( \
const Op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \
const FormatTransformation& t) { \
auto&& inp_format = inputs[0].cast(t.value_type()).format(); \
if (inp_format == FT::NHWC) { \
auto&& new_param = _op.param(); \
new_param.format = Op::Format::NHWC; \
auto new_op = Op::make(new_param); \
return identity_rule_helper(*new_op, inputs, t); \
} \
return identity_rule_helper(_op, inputs, t); \
}
FOREACH_FORMAT_OP
(
CREATE_FORMAT_OP_RULE
)
#undef CREATE_FORMAT_OP_RULE
// identity op with Format and policy param
#define CREATE_FORMAT_POLICY_OP_RULE(Op) \
ValueRefList Op##_rule( \
const Op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \
const FormatTransformation& t) { \
auto&& inp_format = inputs[0].cast(t.value_type()).format(); \
if (inp_format == FT::NHWC) { \
auto&& new_param = _op.param(); \
new_param.format = Op::Format::NHWC; \
auto new_op = Op::make(new_param, _op.policy()); \
return identity_rule_helper(*new_op, inputs, t); \
} \
return identity_rule_helper(_op, inputs, t); \
}
FOREACH_FORMAT_POLICY_OP
(
CREATE_FORMAT_POLICY_OP_RULE
)
#undef CREATE_FORMAT_OP_RULE
#define REGISTER_OP_RULE(op) register_format_rule(op##_rule);
struct
FormatRuleRegistry
{
FormatRuleRegistry
()
{
register_format_rule
(
dimshuffle_rule
);
...
...
@@ -358,10 +410,13 @@ struct FormatRuleRegistry {
register_format_rule
(
setsubtensor_rule
<
IndexingSetMultiAxisVec
>
);
register_format_rule
(
concat_rule
);
register_format_rule
(
elemwise_rule
);
FOREACH_IDENTITY_OP
(
REGISTER_IDENTITY_OP_RULE
)
register_format_rule
(
batchnorm_rule
);
FOREACH_IDENTITY_OP
(
REGISTER_OP_RULE
)
FOREACH_FORMAT_OP
(
REGISTER_OP_RULE
)
FOREACH_FORMAT_POLICY_OP
(
REGISTER_OP_RULE
)
}
}
_
;
#undef REGISTER_
IDENTITY_
OP_RULE
#undef REGISTER_OP_RULE
}
// namespace
ValueRefList
FormatTransformation
::
apply_transformation
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录