Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
fc633ce4
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
396
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
fc633ce4
编写于
3月 21, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(imperative/amp): fix custom grad in Subgraph
GitOrigin-RevId: 1c728d6ab97e8a49f84bf7e309a288938111d7be
上级
673b295d
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
280 addition
and
104 deletion
+280
-104
imperative/python/megengine/amp/autocast.py
imperative/python/megengine/amp/autocast.py
+19
-19
imperative/python/megengine/amp/convert_format.py
imperative/python/megengine/amp/convert_format.py
+7
-4
imperative/python/megengine/amp/grad_scaler.py
imperative/python/megengine/amp/grad_scaler.py
+3
-1
imperative/python/megengine/core/_config.py
imperative/python/megengine/core/_config.py
+5
-6
imperative/python/megengine/core/autodiff/grad.py
imperative/python/megengine/core/autodiff/grad.py
+2
-0
imperative/python/megengine/core/tensor/utils.py
imperative/python/megengine/core/tensor/utils.py
+1
-0
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+6
-2
imperative/python/megengine/module/batchnorm.py
imperative/python/megengine/module/batchnorm.py
+0
-3
imperative/python/megengine/optimizer/optimizer.py
imperative/python/megengine/optimizer/optimizer.py
+5
-3
imperative/python/megengine/optimizer/sgd.py
imperative/python/megengine/optimizer/sgd.py
+1
-0
imperative/python/test/unit/amp/test_convert_format.py
imperative/python/test/unit/amp/test_convert_format.py
+17
-1
imperative/python/test/unit/core/test_formatted_tensor.py
imperative/python/test/unit/core/test_formatted_tensor.py
+88
-26
imperative/src/impl/transformations/format.cpp
imperative/src/impl/transformations/format.cpp
+120
-38
imperative/src/include/megbrain/imperative/transformations/format.h
.../src/include/megbrain/imperative/transformations/format.h
+4
-1
imperative/src/include/megbrain/imperative/transformations/grad.h
...ve/src/include/megbrain/imperative/transformations/grad.h
+2
-0
未找到文件。
imperative/python/megengine/amp/autocast.py
浏览文件 @
fc633ce4
...
...
@@ -50,36 +50,36 @@ class autocast:
self
.
_origin_enabled
=
None
self
.
_origin_high
=
None
self
.
_origin_low
=
None
self
.
_origin_compute_mode
=
None
self
.
_origin_configs
=
None
def
__enter__
(
self
):
self
.
_origin_enabled
=
amp
.
_enabled
amp
.
_enabled
=
self
.
enabled
amp
.
_set_amp_dtype_autocast
(
self
.
enabled
)
if
not
self
.
enabled
:
return
if
self
.
enabled
:
self
.
_origin_enabled
=
amp
.
_enabled
self
.
_origin_high
=
amp
.
_get_amp_high_prec_dtype
()
self
.
_origin_low
=
amp
.
_get_amp_low_prec_dtype
()
amp
.
_enabled
=
self
.
enabled
amp
.
_set_amp_dtype_autocast
(
self
.
enabled
)
amp
.
_set_amp_high_prec_dtype
(
self
.
high_prec_dtype
)
amp
.
_set_amp_low_prec_dtype
(
self
.
low_prec_dtype
)
self
.
_origin_high
=
amp
.
_get_amp_high_prec_dtype
()
self
.
_origin_low
=
amp
.
_get_amp_low_prec_dtype
()
amp
.
_set_amp_high_prec_dtype
(
self
.
high_prec_dtype
)
amp
.
_set_amp_low_prec_dtype
(
self
.
low_prec_dtype
)
self
.
_origin_configs
=
_config
.
_reset_execution_config
(
compute_mode
=
"float32"
)
self
.
_origin_configs
=
_config
.
_reset_execution_config
(
compute_mode
=
"float32"
)
def
__exit__
(
self
,
*
args
):
amp
.
_enabled
=
self
.
_origin_enabled
amp
.
_set_amp_dtype_autocast
(
self
.
_origin_enabled
)
if
not
self
.
enabled
:
return
amp
.
_set_amp_high_prec_dtype
(
self
.
_origin_high
)
amp
.
_set_amp_low_prec_dtype
(
self
.
_origin_low
)
if
self
.
enabled
:
amp
.
_enabled
=
self
.
_origin_enabled
amp
.
_set_amp_dtype_autocast
(
self
.
_origin_enabled
)
amp
.
_set_amp_high_prec_dtype
(
self
.
_origin_high
)
amp
.
_set_amp_low_prec_dtype
(
self
.
_origin_low
)
_config
.
_reset_execution_config
(
*
self
.
_origin_compute_mode
)
def
__call__
(
self
,
func
):
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
if
not
self
.
enabled
:
return
func
(
*
args
,
**
kwargs
)
with
self
:
return
func
(
*
args
,
**
kwargs
)
...
...
imperative/python/megengine/amp/convert_format.py
浏览文件 @
fc633ce4
...
...
@@ -10,6 +10,7 @@ from copy import deepcopy
from
..
import
functional
as
F
from
..module
import
Module
from
..tensor
import
Tensor
from
..core
import
_config
def
_is_nchw_format
(
param
:
Tensor
):
...
...
@@ -26,10 +27,12 @@ def convert_tensor_format(x: Tensor, inplace: bool = True):
else
:
raise
ValueError
(
"Unsupport tensor ndim {}"
.
format
(
x
.
ndim
))
# TODO: use initialization from tensor after fixing format setting
if
inplace
:
x
[...]
=
Tensor
(
x
.
numpy
().
transpose
(
*
pattern
),
format
=
"nhwc"
)
else
:
x
=
Tensor
(
x
.
numpy
().
transpose
(
*
pattern
),
format
=
"nhwc"
)
if
x
.
format
!=
"nhwc"
:
if
inplace
:
data
=
x
.
numpy
().
transpose
(
*
pattern
)
x
[...]
=
Tensor
(
data
,
format
=
"nhwc"
)
else
:
x
=
Tensor
(
x
.
numpy
().
transpose
(
*
pattern
),
format
=
"nhwc"
)
return
x
...
...
imperative/python/megengine/amp/grad_scaler.py
浏览文件 @
fc633ce4
...
...
@@ -144,7 +144,9 @@ class GradScaler:
def
_check_gradients
(
self
,
grads
,
scale
):
if
len
(
grads
)
==
0
:
return
False
return
_check_non_finite
(
grads
,
scale
)
rst
=
_check_non_finite
(
grads
,
scale
)
rst
=
rst
.
numpy
()
return
rst
def
update
(
self
,
new_scale
:
float
=
None
):
r
"""Update the scale factor according to whether encountered overflow grad.
...
...
imperative/python/megengine/core/_config.py
浏览文件 @
fc633ce4
...
...
@@ -182,7 +182,6 @@ def _reset_execution_config(
deterministic_kernel
=
None
,
async_level
=
None
,
compute_mode
=
None
,
bn_format
=
None
,
auto_format_convert
=
None
,
):
global
_benchmark_kernel
,
_deterministic_kernel
,
__compute_mode
...
...
@@ -234,11 +233,11 @@ def _override(
def train():
"""
orig_flags
=
_reset_execution_config
(
benchmark_kernel
,
deterministic_kernel
,
async_level
,
compute_mode
,
auto_format_convert
,
benchmark_kernel
=
benchmark_kernel
,
deterministic_kernel
=
deterministic_kernel
,
async_level
=
async_level
,
compute_mode
=
compute_mode
,
auto_format_convert
=
auto_format_convert
,
)
try
:
yield
...
...
imperative/python/megengine/core/autodiff/grad.py
浏览文件 @
fc633ce4
...
...
@@ -64,7 +64,9 @@ class Grad:
continue
grad
.
suppress
()
print
(
"before backward"
)
self
.
_impl
.
backward
(
ys
,
dys
)
print
(
"after backward"
)
for
grad
in
group
:
if
grad
is
self
:
...
...
imperative/python/megengine/core/tensor/utils.py
浏览文件 @
fc633ce4
...
...
@@ -24,6 +24,7 @@ from .._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from
.._imperative_rt.ops
import
jit_supported
from
.._wrap
import
as_device
from
..autodiff.grad
import
Function
from
..
import
_config
from
..ops
import
builtin
from
.amp
import
_get_amp_high_prec_dtype
,
_get_amp_low_prec_dtype
from
.dtype
import
is_dtype_equal
,
is_quantize
...
...
imperative/python/megengine/functional/nn.py
浏览文件 @
fc633ce4
...
...
@@ -1226,12 +1226,16 @@ def batch_norm(
bias
=
make_full_if_none
(
bias
,
0
)
if
not
training
:
op
=
builtin
.
BatchNorm
(
fwd_mode
=
BatchNorm
.
FwdMode
.
INFERENCE
,
epsilon
=
eps
)
op
=
builtin
.
BatchNorm
(
fwd_mode
=
BatchNorm
.
FwdMode
.
INFERENCE
,
param_dim
=
"dim_1c11"
,
epsilon
=
eps
)
ret
=
apply
(
op
,
inp
,
weight
,
bias
,
running_mean
,
running_var
)[
-
1
]
return
ret
else
:
op
=
builtin
.
BatchNorm
(
avg_factor
=
1
-
momentum
,
epsilon
=
eps
)
op
=
builtin
.
BatchNorm
(
avg_factor
=
1
-
momentum
,
param_dim
=
"dim_1c11"
,
epsilon
=
eps
)
if
has_mean
or
has_var
:
running_mean
=
make_full_if_none
(
running_mean
,
0
)
running_var
=
make_full_if_none
(
running_var
,
1
)
...
...
imperative/python/megengine/module/batchnorm.py
浏览文件 @
fc633ce4
...
...
@@ -19,7 +19,6 @@ class _BatchNorm(Module):
affine
=
True
,
track_running_stats
=
True
,
freeze
=
False
,
param_dim
=
"dim_1c11"
,
**
kwargs
):
super
(
_BatchNorm
,
self
).
__init__
(
**
kwargs
)
...
...
@@ -30,7 +29,6 @@ class _BatchNorm(Module):
self
.
track_running_stats
=
track_running_stats
self
.
_track_running_stats_saved
=
track_running_stats
self
.
freeze
=
freeze
self
.
param_dim
=
param_dim
if
self
.
freeze
:
assert
(
self
.
_track_running_stats_saved
...
...
@@ -104,7 +102,6 @@ class _BatchNorm(Module):
or
((
self
.
running_mean
is
None
)
and
(
self
.
running_var
is
None
)),
momentum
=
exponential_average_factor
,
eps
=
self
.
eps
,
param_dim
=
self
.
param_dim
,
)
return
output
...
...
imperative/python/megengine/optimizer/optimizer.py
浏览文件 @
fc633ce4
...
...
@@ -8,6 +8,7 @@ from typing import Union
import
numpy
as
np
from
..core
import
_config
from
..core._imperative_rt.core2
import
(
get_auto_format_convert
,
pop_scope
,
...
...
@@ -96,7 +97,7 @@ class Optimizer(metaclass=ABCMeta):
"optimizer can only optimize Parameters, but one of the params is "
+
str
(
type
(
param
))
)
param
.
_reset
(
Tensor
(
param
.
numpy
(),
no_cache
=
True
,
format
=
param
.
format
))
param
.
_reset
(
Tensor
(
param
,
no_cache
=
True
))
for
name
,
default
in
self
.
_defaults
.
items
():
if
default
is
required
and
name
not
in
param_group
:
...
...
@@ -119,10 +120,11 @@ class Optimizer(metaclass=ABCMeta):
def
_add_state
(
self
,
param
,
state_name
,
initializer
=
None
):
if
initializer
is
None
:
initializer
=
np
.
zeros
(
param
.
shape
,
dtype
=
np
.
float32
)
with
_config
.
_override
(
auto_format_convert
=
False
):
initializer
=
np
.
zeros
(
param
.
shape
,
dtype
=
np
.
float32
)
state_dict
=
self
.
_state
.
setdefault
(
param
,
{})
assert
state_name
not
in
state_dict
state
=
Tensor
(
initializer
,
no_cache
=
True
)
state
=
Tensor
(
initializer
,
no_cache
=
True
,
format
=
param
.
format
)
state_dict
[
state_name
]
=
state
@
abstractmethod
...
...
imperative/python/megengine/optimizer/sgd.py
浏览文件 @
fc633ce4
...
...
@@ -5,6 +5,7 @@ from typing import Iterable, Union
from
..functional.inplace
import
_inplace_add_
from
..tensor
import
Parameter
,
tensor
from
.optimizer
import
Optimizer
from
..core
import
_config
class
SGD
(
Optimizer
):
...
...
imperative/python/test/unit/amp/test_convert_format.py
浏览文件 @
fc633ce4
...
...
@@ -10,7 +10,7 @@ import pytest
import
megengine.functional
as
F
import
megengine.module
as
M
from
megengine
import
Parameter
,
Tensor
,
amp
,
tensor
from
megengine
import
Parameter
,
Tensor
,
amp
,
config
class
MyModule
(
M
.
Module
):
...
...
@@ -39,6 +39,22 @@ class MyModule(M.Module):
@
pytest
.
mark
.
parametrize
(
"is_inplace"
,
[
False
,
True
])
def
test_convert_module
(
is_inplace
):
m
=
MyModule
()
expected_shape
=
{
"i.bn.weight"
:
(
1
,
1
,
1
,
4
),
"i.bn.bias"
:
(
1
,
1
,
1
,
4
),
"i.bn.running_mean"
:
(
1
,
1
,
1
,
4
),
"i.bn.running_var"
:
(
1
,
1
,
1
,
4
),
"conv.weight"
:
(
2
,
2
,
4
,
4
,
2
),
"conv.bias"
:
(
1
,
1
,
1
,
4
),
"bn.weight"
:
(
1
,
1
,
1
,
4
),
"bn.bias"
:
(
1
,
1
,
1
,
4
),
"bn.running_mean"
:
(
1
,
1
,
1
,
4
),
"bn.running_var"
:
(
1
,
1
,
1
,
4
),
"param"
:
(
1
,
1
,
1
,
3
),
"buff"
:
(
1
,
1
,
1
,
3
),
}
m
=
amp
.
convert_module_format
(
m
,
is_inplace
)
for
name
,
param
in
m
.
named_tensors
():
assert
param
.
format
==
"nhwc"
with
config
.
_override
(
auto_format_convert
=
False
):
assert
param
.
shape
==
expected_shape
[
name
],
name
imperative/python/test/unit/core/test_formatted_tensor.py
浏览文件 @
fc633ce4
...
...
@@ -3,6 +3,7 @@ import pytest
import
megengine
as
mge
import
megengine.functional
as
F
import
megengine.module
as
M
from
megengine
import
tensor
from
megengine.autodiff
import
GradManager
from
megengine.jit
import
trace
...
...
@@ -36,9 +37,9 @@ def _compare_nchw_nhwc(data, func, is_symbolic=None):
x2
=
tensor
(
data
.
transpose
(
0
,
2
,
3
,
1
),
format
=
"nhwc"
)
if
is_symbolic
is
not
None
:
func
=
trace
(
func
,
symbolic
=
is_symbolic
)
#
out1 = func(x1)
out1
=
func
(
x1
)
out2
=
func
(
x2
)
#
np.testing.assert_almost_equal(out1, out2, decimal=5)
np
.
testing
.
assert_almost_equal
(
out1
,
out2
,
decimal
=
5
)
@
pytest
.
mark
.
parametrize
(
"is_symbolic"
,
[
None
])
...
...
@@ -322,30 +323,91 @@ def test_pooling2d(pooling, is_symbolic):
_compare_nchw_nhwc
(
data
,
func
,
is_symbolic
)
@
pytest
.
mark
.
parametrize
(
"is_symbolic"
,
[
None
])
def
test_backward
(
is_symbolic
):
data
=
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
))
x
=
tensor
(
data
.
transpose
(
0
,
2
,
3
,
1
),
format
=
"nhwc"
)
w
=
mge
.
tensor
(
np
.
ones
((
3
,
1
,
1
,
2
)),
format
=
"nhwc"
)
b
=
mge
.
tensor
(
np
.
ones
((
1
,
1
,
1
,
3
)),
format
=
"nhwc"
)
gm
=
GradManager
().
attach
([
w
,
b
])
def
_compare_backward
(
inps
,
model
,
is_symbolic
=
None
):
def
func
(
*
inps
):
return
model
(
*
inps
)
def
func
(
x
,
w
,
b
)
:
return
F
.
conv2d
(
x
,
w
,
b
)
if
is_symbolic
is
not
None
:
func
=
trace
(
func
,
symbolic
=
is_symbolic
)
gm
=
GradManager
().
attach
(
model
.
parameters
())
with
gm
:
if
is_symbolic
is
not
None
:
func
=
trace
(
func
,
symbolic
=
is_symbolic
)
x
=
func
(
x
,
w
,
b
)
assert
x
.
format
==
"nhwc"
# test manually convert to NHWC, usually used in detection head
x
=
x
.
transpose
(
0
,
2
,
3
,
1
).
reshape
(
1
,
18
,
2
)
gm
.
backward
(
x
)
print
(
"finish backward"
,
x
.
format
)
# backward grad has no format
np
.
testing
.
assert_equal
(
w
.
grad
.
numpy
(),
np
.
array
([
66
,
210
,
66
,
210
,
66
,
210
]).
reshape
((
3
,
1
,
1
,
2
)),
)
np
.
testing
.
assert_equal
(
b
.
grad
.
numpy
(),
np
.
array
([
12
,
12
,
12
]).
reshape
((
1
,
1
,
1
,
3
))
)
rst
=
func
(
*
inps
)
gm
.
backward
(
rst
)
expected_grads
=
[
param
.
grad
for
param
in
model
.
parameters
()]
inps
=
[
mge
.
amp
.
convert_tensor_format
(
inp
)
for
inp
in
inps
]
model
=
mge
.
amp
.
convert_module_format
(
model
)
gm
=
GradManager
().
attach
(
model
.
parameters
())
with
gm
:
rst
=
func
(
*
inps
)
gm
.
backward
(
rst
)
actual_grads
=
[
param
.
grad
for
param
in
model
.
parameters
()]
for
expected
,
actual
in
zip
(
expected_grads
,
actual_grads
):
# print(param.grad)
np
.
testing
.
assert_equal
(
expected
.
numpy
(),
actual
.
numpy
())
@
pytest
.
mark
.
parametrize
(
"is_symbolic"
,
[
None
])
def
test_backward_conv2d_dimshuffle
(
is_symbolic
):
class
Net
(
M
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
M
.
Conv2d
(
2
,
3
,
1
)
def
forward
(
self
,
inp
):
# test manually convert to NHWC, usually used in detection head
return
F
.
transpose
(
self
.
conv
(
inp
),
(
0
,
2
,
3
,
1
)).
reshape
(
1
,
18
,
2
)
inp
=
mge
.
tensor
(
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
)))
# x = tensor(data.transpose(0, 2, 3, 1), format="nhwc")
# w = mge.tensor(np.ones((3, 1, 1, 2)), format="nhwc")
# b = mge.tensor(np.ones((1, 1, 1, 3)), format="nhwc")
# grads = [
# np.array([66, 210, 66, 210, 66, 210]).reshape((3, 1, 1, 2)),
# np.array([12, 12, 12]).reshape((1, 1, 1, 3)),
# ]
_compare_backward
([
inp
],
Net
(),
is_symbolic
)
@
pytest
.
mark
.
parametrize
(
"is_symbolic"
,
[
None
])
def
test_backward_groupconv2d_bn
(
is_symbolic
):
class
Net
(
M
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
M
.
Conv2d
(
2
,
2
,
1
,
groups
=
2
)
self
.
bn
=
M
.
BatchNorm2d
(
2
)
def
forward
(
self
,
inp
):
# test manually convert to NHWC, usually used in detection head
return
self
.
bn
(
self
.
conv
(
inp
))
inp
=
mge
.
tensor
(
np
.
arange
(
0
,
24
).
reshape
((
1
,
2
,
3
,
4
)))
_compare_backward
([
inp
],
Net
(),
is_symbolic
)
# def func(x, w, b, bn_w, bn_b):
# x = F.conv2d(x, w, b, groups=2)
# x = F.batch_norm(
# x,
# running_mean=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"),
# running_var=mge.tensor(np.ones((1, 1, 1, 2)), format="nhwc"),
# weight=bn_w,
# bias=bn_b,
# training=True,
# inplace=True,
# )
# return x
# data = np.arange(0, 24).reshape((1, 2, 3, 4))
# x = tensor(data.transpose(0, 2, 3, 1), format="nhwc")
# w = tensor(np.ones((2, 1, 1, 1, 1)), format="nhwc")
# b = tensor(np.ones((1, 1, 1, 2)), format="nhwc")
# bn_w = tensor(np.ones((1, 1, 1, 2)), format="nhwc")
# bn_b = tensor(np.ones((1, 1, 1, 2)), format="nhwc")
# grads = [
# np.array([66, 210]).reshape((2, 1, 1, 1, 1)),
# np.array([12, 12]).reshape((1, 1, 1, 2)),
# np.array([12, 12]).reshape((1, 1, 1, 2)),
# np.array([12, 12]).reshape((1, 1, 1, 2)),
# ]
# _compare_backward(x, func, [w, b, bn_w, bn_b], grads, is_symbolic)
imperative/src/impl/transformations/format.cpp
浏览文件 @
fc633ce4
#include "megbrain/imperative/transformations/format.h"
#include "megbrain/imperative/transformations/grad.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h"
namespace
mgb
{
namespace
imperative
{
...
...
@@ -17,7 +19,12 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to(
const
std
::
string
&
scope
)
const
{
std
::
vector
<
int32_t
>
pattern
;
if
(
tensor
.
format
()
==
FT
::
NHWC
&&
target
==
FT
::
NCHW
)
{
pattern
=
{
0
,
3
,
1
,
2
};
// FIXME(czh): temporary fast path for group conv 5D weight.
if
(
tensor
.
value
().
shape
().
cast
<
ShapeValue
>
().
ndim
==
5
)
{
pattern
=
{
0
,
1
,
4
,
2
,
3
};
}
else
{
pattern
=
{
0
,
3
,
1
,
2
};
}
}
else
if
(
tensor
.
format
()
==
FT
::
NCHW
&&
target
==
FT
::
NHWC
)
{
pattern
=
{
0
,
2
,
3
,
1
};
}
else
{
...
...
@@ -65,12 +72,22 @@ inline ValueRefList FormatTransformation::wrap_outputs(
namespace
{
ValueShape
convert_nhwc2nchw_shape
(
const
ValueShape
&
shape
)
{
mgb_assert
(
shape
.
ndim
==
4
);
auto
out
=
ValueShape
(
shape
);
out
[
3
]
=
shape
[
2
];
out
[
2
]
=
shape
[
1
];
out
[
1
]
=
shape
[
3
];
return
out
;
if
(
shape
.
ndim
==
4
)
{
out
[
1
]
=
shape
[
3
];
out
[
2
]
=
shape
[
1
];
out
[
3
]
=
shape
[
2
];
return
out
;
}
else
if
(
shape
.
ndim
==
5
)
{
out
[
2
]
=
shape
[
4
];
out
[
3
]
=
shape
[
2
];
out
[
4
]
=
shape
[
3
];
return
out
;
}
else
{
mgb_throw
(
MegBrainError
,
"Unsupported shape ndim %u in GetAttr(Shape)."
,
shape
.
ndim
);
}
}
using
FormatRule
=
std
::
function
<
ValueRefList
(
...
...
@@ -278,10 +295,10 @@ ValueRefList setsubtensor_rule(
inline
FT
get_inputs_format
(
Span
<
ValueRef
>&
inputs
,
const
FormatTransformation
&
t
)
{
FT
format
(
FT
::
DEFAULT
);
for
(
auto
&
inp
:
inputs
)
{
auto
&&
inp_
ref
=
inp
.
as_ref
(
t
.
value_type
()
);
if
(
inp_
ref
&&
inp_ref
->
format
()
!=
FT
::
DEFAULT
)
{
mgb_assert
(
format
==
FT
::
DEFAULT
||
inp_
ref
->
format
()
==
format
);
format
=
inp_
ref
->
format
()
.
type
();
auto
&&
inp_
format
=
inp
.
cast
(
t
.
value_type
()).
format
(
);
if
(
inp_
format
!=
FT
::
DEFAULT
)
{
mgb_assert
(
format
==
FT
::
DEFAULT
||
inp_
format
==
format
);
format
=
inp_
format
.
type
();
}
}
return
format
;
...
...
@@ -308,13 +325,6 @@ ValueRefList concat_rule(
format
);
}
ValueRefList
elemwise_rule
(
const
Elemwise
&
op
,
Span
<
ValueRef
>&
inputs
,
const
bool
&
auto_convert
,
const
FormatTransformation
&
t
)
{
FT
format
=
get_inputs_format
(
inputs
,
t
);
return
t
.
wrap_outputs
(
imperative
::
apply
(
op
,
t
.
unwrap_inputs
(
inputs
)),
format
);
}
ValueRefList
identity_rule_helper
(
const
OpDef
&
op
,
const
Span
<
ValueRef
>&
inputs
,
const
FormatTransformation
&
t
)
{
// mgb_assert(inputs.size() == 1);
...
...
@@ -336,24 +346,49 @@ ValueRefList batchnorm_rule(
return
identity_rule_helper
(
op
,
inputs
,
t
);
}
ValueRefList
checknonfinite_rule
(
const
CheckNonFinite
&
op
,
Span
<
ValueRef
>&
inputs
,
const
bool
&
auto_convert
,
const
FormatTransformation
&
t
)
{
auto
&&
inputs_
=
t
.
unwrap_inputs
(
inputs
);
auto
&&
outputs_
=
imperative
::
apply
(
op
,
inputs_
);
return
t
.
wrap_outputs
(
outputs_
);
}
// clang-format off
#define FOREACH_IDENTITY_OP(cb) \
cb(Copy) \
cb(FastpathCopy) \
cb(TypeCvt) \
cb(Dropout) \
#define FOREACH_MULTI_INPS_NO_PARAM_OP(cb) \
cb(Elemwise) \
cb(CompiledOp) \
cb(SubgraphOp)
#define FOREACH_IDENTITY_OP(cb) \
cb(Copy) \
cb(FastpathCopy) \
cb(TypeCvt) \
cb(Dropout) \
cb(Identity)
#define FOREACH_FORMAT_OP(cb) \
cb(AdaptivePooling) \
cb(WarpAffine) \
#define FOREACH_FORMAT_OP(cb)
\
cb(AdaptivePooling)
\
cb(WarpAffine)
\
cb(Resize)
#define FOREACH_FORMAT_POLICY_OP(cb)\
cb(Pooling) \
#define FOREACH_FORMAT_POLICY_OP(cb)
\
cb(Pooling)
\
cb(Convolution)
// clang-format on
// multi inputs op without params
#define CREATE_MULTI_INPS_NO_PARAM_OP_RULE(Op) \
ValueRefList Op##_rule( \
const Op& _op, Span<ValueRef>& inputs, const bool& auto_convert, \
const FormatTransformation& t) { \
FT format = get_inputs_format(inputs, t); \
return t.wrap_outputs( \
imperative::apply(_op, t.unwrap_inputs(inputs)), format); \
}
FOREACH_MULTI_INPS_NO_PARAM_OP
(
CREATE_MULTI_INPS_NO_PARAM_OP_RULE
)
#undef CREATE_MULTI_INPS_NO_PARAM_OP_RULE
// identity op
#define CREATE_IDENTITY_OP_RULE(Op) \
ValueRefList Op##_rule( \
...
...
@@ -409,8 +444,9 @@ struct FormatRuleRegistry {
register_format_rule
(
setsubtensor_rule
<
SetSubtensor
>
);
register_format_rule
(
setsubtensor_rule
<
IndexingSetMultiAxisVec
>
);
register_format_rule
(
concat_rule
);
register_format_rule
(
elemwise_rule
);
register_format_rule
(
batchnorm_rule
);
register_format_rule
(
checknonfinite_rule
);
FOREACH_MULTI_INPS_NO_PARAM_OP
(
REGISTER_OP_RULE
)
FOREACH_IDENTITY_OP
(
REGISTER_OP_RULE
)
FOREACH_FORMAT_OP
(
REGISTER_OP_RULE
)
FOREACH_FORMAT_POLICY_OP
(
REGISTER_OP_RULE
)
...
...
@@ -455,27 +491,73 @@ ValueRefList FormatTransformation::apply_transformation(
return
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
));
}
}
else
if
(
op
.
is
<
GetFormat
>
())
{
bool
is_formatted_tensor
=
inputs
.
item
().
is
(
m_value_type
);
if
(
i
s_formatted_tensor
)
{
return
{
FormatValue
::
make
(
inp
uts
[
0
].
cast
(
m_value_type
).
format
())};
auto
&&
inp_ref
=
inputs
[
0
].
as_ref
(
m_value_type
);
if
(
i
np_ref
)
{
return
{
FormatValue
::
make
(
inp
_ref
->
format
())};
}
else
{
mgb_log_warn
(
"Not FormattedTensorValue input for GetFormat op: %s"
,
inputs
[
0
].
to_string
().
c_str
());
"Not FormattedTensorValue input for GetFormat op: %s
, %s
"
,
op
.
to_string
().
c_str
(),
inputs
[
0
].
to_string
().
c_str
());
return
{
FormatValue
::
make
(
FT
::
DEFAULT
)};
}
}
else
if
(
op
.
is
<
Operator
::
IdentityLike
>
())
{
bool
is_formatted_tensor
=
inputs
.
item
().
is
(
m_value_type
);
if
(
i
s_formatted_tensor
)
{
auto
&&
format
=
inp
uts
[
0
].
cast
(
m_value_type
).
format
();
auto
&&
inp_ref
=
inputs
[
0
].
as_ref
(
m_value_type
);
if
(
i
np_ref
)
{
auto
&&
format
=
inp
_ref
->
format
();
return
wrap_outputs
(
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
)),
format
.
type
());
}
else
{
mgb_log_warn
(
"Not FormattedTensorValue input for IdentityLike op: %s"
,
inputs
[
0
].
to_string
().
c_str
());
"Not FormattedTensorValue input for IdentityLike op: %s
, %s
"
,
op
.
to_string
().
c_str
(),
inputs
[
0
].
to_string
().
c_str
());
return
imperative
::
apply
(
op
,
inputs
);
}
}
else
if
(
op
.
is
<
AttachGrad
>
())
{
auto
&&
inp_ref
=
inputs
[
0
].
as_ref
(
m_value_type
);
if
(
inp_ref
)
{
auto
format
=
inp_ref
->
format
();
GenericFunction
callback
=
(
GenericFunction
&
)
inputs
[
1
].
cast
<
FunctionValue
>
();
GenericFunction
new_callback
=
[
this
,
callback
,
format
](
Span
<
ValueRef
>
inputs_
)
->
ValueRefList
{
auto
wrapped_inputs
=
SmallVector
<
ValueRef
>
{
this
->
value_type
().
make
(
inputs_
.
item
(),
format
.
type
())};
auto
ret
=
callback
(
wrapped_inputs
);
return
ret
;
};
auto
&&
outputs
=
imperative
::
apply
(
op
,
inp_ref
->
value
(),
FunctionValue
::
make
(
new_callback
));
return
wrap_outputs
(
outputs
,
format
.
type
());
}
else
{
mgb_log_warn
(
"Not FormattedTensorValue input for AttachGrad op: %s, %s"
,
op
.
to_string
().
c_str
(),
inputs
[
0
].
to_string
().
c_str
());
return
imperative
::
apply
(
op
,
inputs
);
}
}
else
if
(
auto
*
set_grad
=
op
.
as
<
SetGrad
>
())
{
size_t
nr_inputs
=
set_grad
->
nr_inputs
();
size_t
nr_outputs
=
inputs
.
size
()
-
nr_inputs
;
Span
<
ValueRef
>
inputs_
=
{
inputs
.
data
(),
nr_inputs
};
Span
<
ValueRef
>
outputs_
=
{
inputs
.
data
()
+
nr_inputs
,
nr_outputs
};
// run original apply.
// grads needn't to unwrap and wrap, which will be unwrapped in GradTrans
auto
&&
outputs
=
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
));
// handle output's formats
auto
wrapped_outputs
=
ValueRefList
(
nr_outputs
);
for
(
size_t
i
=
0
;
i
<
nr_outputs
;
++
i
)
{
if
(
auto
output_ref
=
outputs_
[
i
].
as_ref
(
m_value_type
))
{
wrapped_outputs
[
i
]
=
m_value_type
.
make
(
outputs
[
i
],
output_ref
->
format
().
type
());
}
else
{
mgb_log_warn
(
"Not FormattedTensorValue outputs for SetGrad op: %s, %s"
,
op
.
to_string
().
c_str
(),
inputs_
[
i
].
to_string
().
c_str
());
wrapped_outputs
[
i
]
=
m_value_type
.
make
(
outputs
[
i
],
FT
::
DEFAULT
);
}
}
return
wrapped_outputs
;
}
else
{
return
imperative
::
apply
(
op
,
unwrap_inputs
(
inputs
));
}
...
...
imperative/src/include/megbrain/imperative/transformations/format.h
浏览文件 @
fc633ce4
...
...
@@ -47,7 +47,10 @@ public:
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
ValueRef
unwrap
(
ValueRef
value
)
override
{
mgb_assert
(
!
value
.
is
(
m_value_type
));
//mgb_assert(!value.is(m_value_type));
if
(
auto
format_val
=
value
.
as_ref
(
m_value_type
))
{
return
format_val
->
value
();
}
return
value
;
}
...
...
imperative/src/include/megbrain/imperative/transformations/grad.h
浏览文件 @
fc633ce4
...
...
@@ -377,6 +377,8 @@ public:
SetGrad
(
GenericFunction
grad_fn
,
size_t
nr_inputs
)
:
m_grad_fn
(
grad_fn
),
m_nr_inputs
(
nr_inputs
)
{}
std
::
shared_ptr
<
GradKey
>
key
()
const
{
return
m_key
;
}
GenericFunction
grad_fn
()
const
{
return
m_grad_fn
;
}
size_t
nr_inputs
()
const
{
return
m_nr_inputs
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录