Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3c3fc6f3
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3c3fc6f3
编写于
2月 23, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(imperative): move python code of elemwise/reduce/conv2d/bn to c++
GitOrigin-RevId: 01b532439243aa2e7d40f150fcaa26fded0e4f27
上级
84466261
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
500 addition
and
121 deletion
+500
-121
dnn/src/common/batch_normalization.cpp
dnn/src/common/batch_normalization.cpp
+10
-0
imperative/python/megengine/amp/autocast.py
imperative/python/megengine/amp/autocast.py
+10
-7
imperative/python/megengine/core/tensor/amp.py
imperative/python/megengine/core/tensor/amp.py
+16
-8
imperative/python/megengine/core/tensor/array_method.py
imperative/python/megengine/core/tensor/array_method.py
+0
-40
imperative/python/megengine/core/tensor/utils.py
imperative/python/megengine/core/tensor/utils.py
+8
-11
imperative/python/megengine/functional/elemwise.py
imperative/python/megengine/functional/elemwise.py
+6
-1
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+16
-32
imperative/python/megengine/functional/tensor_cache.py
imperative/python/megengine/functional/tensor_cache.py
+34
-0
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+114
-14
imperative/python/src/transformation.h
imperative/python/src/transformation.h
+2
-1
imperative/python/test/unit/amp/test_autocast.py
imperative/python/test/unit/amp/test_autocast.py
+7
-7
imperative/src/impl/transformations/dtype_promote.cpp
imperative/src/impl/transformations/dtype_promote.cpp
+251
-0
imperative/src/include/megbrain/imperative/transformations/dtype_promote.h
...clude/megbrain/imperative/transformations/dtype_promote.h
+26
-0
未找到文件。
dnn/src/common/batch_normalization.cpp
浏览文件 @
3c3fc6f3
...
@@ -28,6 +28,16 @@ void BNForward::check_exec(
...
@@ -28,6 +28,16 @@ void BNForward::check_exec(
const
TensorLayout
&
variance
,
const
TensorLayout
&
batch_mean
,
const
TensorLayout
&
variance
,
const
TensorLayout
&
batch_mean
,
const
TensorLayout
&
batch_inv_variance
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
batch_inv_variance
,
const
TensorLayout
&
dst
,
size_t
workspace_in_bytes
,
size_t
reserve_in_bytes
)
{
size_t
workspace_in_bytes
,
size_t
reserve_in_bytes
)
{
// moving some python assert to dnn to decrease the assert overhead
megdnn_assert
(
src
.
ndim
==
4
,
"ndim of the input tensor for batch_norm should be 4, but you give %zu"
,
src
.
ndim
);
megdnn_assert
(
bn_scale
.
ndim
==
4
,
"expect 4, get %zu
\n
"
,
bn_scale
.
ndim
);
megdnn_assert
(
bn_bias
.
ndim
==
4
,
"expect 4, get %zu
\n
"
,
bn_bias
.
ndim
);
megdnn_assert_eq_layout
(
bn_scale
,
bn_bias
);
megdnn_assert_eq_layout
(
batch_mean
,
batch_inv_variance
);
megdnn_assert_contiguous
(
src
);
megdnn_assert_contiguous
(
src
);
megdnn_assert_eq_layout
(
src
,
dst
);
megdnn_assert_eq_layout
(
src
,
dst
);
megdnn_assert_eq_layout
(
bn_scale
,
bn_bias
);
megdnn_assert_eq_layout
(
bn_scale
,
bn_bias
);
...
...
imperative/python/megengine/amp/autocast.py
浏览文件 @
3c3fc6f3
...
@@ -58,16 +58,19 @@ class autocast:
...
@@ -58,16 +58,19 @@ class autocast:
self
.
_origin_low
=
None
self
.
_origin_low
=
None
def
__enter__
(
self
):
def
__enter__
(
self
):
self
.
_origin_enabled
,
amp
.
_enabled
=
amp
.
_enabled
,
self
.
enabled
self
.
_origin_enabled
=
amp
.
_enabled
self
.
_origin_high
=
amp
.
_high_prec_dtype
self
.
_origin_high
=
amp
.
_get_amp_high_prec_dtype
()
amp
.
_high_prec_dtype
=
self
.
high_prec_dtype
self
.
_origin_low
=
amp
.
_get_amp_low_prec_dtype
()
self
.
_origin_low
=
amp
.
_low_prec_dtype
amp
.
_enabled
=
self
.
enabled
amp
.
_low_prec_dtype
=
self
.
low_prec_dtype
amp
.
_set_amp_dtype_autocast
(
self
.
enabled
)
amp
.
_set_amp_high_prec_dtype
(
self
.
high_prec_dtype
)
amp
.
_set_amp_low_prec_dtype
(
self
.
low_prec_dtype
)
def
__exit__
(
self
,
*
args
):
def
__exit__
(
self
,
*
args
):
amp
.
_enabled
=
self
.
_origin_enabled
amp
.
_enabled
=
self
.
_origin_enabled
amp
.
_high_prec_dtype
=
self
.
_origin_high
amp
.
_set_amp_dtype_autocast
(
self
.
_origin_enabled
)
amp
.
_low_prec_dtype
=
self
.
_origin_low
amp
.
_set_amp_high_prec_dtype
(
self
.
_origin_high
)
amp
.
_set_amp_low_prec_dtype
(
self
.
_origin_low
)
def
__call__
(
self
,
func
):
def
__call__
(
self
,
func
):
@
functools
.
wraps
(
func
)
@
functools
.
wraps
(
func
)
...
...
imperative/python/megengine/core/tensor/amp.py
浏览文件 @
3c3fc6f3
...
@@ -5,9 +5,18 @@
...
@@ -5,9 +5,18 @@
# Unless required by applicable law or agreed to in writing,
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
.._imperative_rt.core2
import
(
_get_amp_dtype_autocast
,
_get_amp_high_prec_dtype
,
_get_amp_low_prec_dtype
,
_set_amp_dtype_autocast
,
_set_amp_high_prec_dtype
,
_set_amp_low_prec_dtype
,
)
_enabled
=
False
_enabled
=
False
_high_prec_dtype
=
"float32"
_set_amp_dtype_autocast
(
_enabled
)
_low_prec_dtype
=
"float16"
@
property
@
property
...
@@ -28,6 +37,7 @@ def enabled(mod):
...
@@ -28,6 +37,7 @@ def enabled(mod):
def
enabled
(
mod
,
enabled
:
bool
):
def
enabled
(
mod
,
enabled
:
bool
):
global
_enabled
global
_enabled
_enabled
=
enabled
_enabled
=
enabled
_set_amp_dtype_autocast
(
_enabled
)
@
property
@
property
...
@@ -42,13 +52,12 @@ def high_prec_dtype(mod):
...
@@ -42,13 +52,12 @@ def high_prec_dtype(mod):
import megengine as mge
import megengine as mge
mge.amp.high_prec_dtype = "float32"
mge.amp.high_prec_dtype = "float32"
"""
"""
return
_
high_prec_dtype
return
_
get_amp_high_prec_dtype
()
@
high_prec_dtype
.
setter
@
high_prec_dtype
.
setter
def
high_prec_dtype
(
mod
,
dtype
:
str
):
def
high_prec_dtype
(
mod
,
dtype
:
str
):
global
_high_prec_dtype
_set_amp_high_prec_dtype
(
dtype
)
_high_prec_dtype
=
dtype
@
property
@
property
...
@@ -63,10 +72,9 @@ def low_prec_dtype(mod):
...
@@ -63,10 +72,9 @@ def low_prec_dtype(mod):
import megengine as mge
import megengine as mge
mge.amp.low_prec_dtype = "float16"
mge.amp.low_prec_dtype = "float16"
"""
"""
return
_
low_prec_dtype
return
_
get_amp_low_prec_dtype
()
@
low_prec_dtype
.
setter
@
low_prec_dtype
.
setter
def
low_prec_dtype
(
mod
,
dtype
:
str
):
def
low_prec_dtype
(
mod
,
dtype
:
str
):
global
_low_prec_dtype
_set_amp_low_prec_dtype
(
dtype
)
_low_prec_dtype
=
dtype
imperative/python/megengine/core/tensor/array_method.py
浏览文件 @
3c3fc6f3
...
@@ -25,7 +25,6 @@ from .utils import (
...
@@ -25,7 +25,6 @@ from .utils import (
astensor1d
,
astensor1d
,
astype
,
astype
,
cast_tensors
,
cast_tensors
,
convert_inputs
,
make_shape_tuple
,
make_shape_tuple
,
subgraph
,
subgraph
,
)
)
...
@@ -40,38 +39,6 @@ def _elwise_apply(args, mode):
...
@@ -40,38 +39,6 @@ def _elwise_apply(args, mode):
def
_elwise
(
*
args
,
mode
):
def
_elwise
(
*
args
,
mode
):
args
=
convert_inputs
(
*
args
)
if
(
mode
in
(
_ElwMod
.
TRUE_DIV
,
_ElwMod
.
EXP
,
_ElwMod
.
POW
,
_ElwMod
.
LOG
,
_ElwMod
.
EXPM1
,
_ElwMod
.
LOG1P
,
_ElwMod
.
ACOS
,
_ElwMod
.
ASIN
,
_ElwMod
.
ATAN2
,
_ElwMod
.
COS
,
_ElwMod
.
SIN
,
_ElwMod
.
LOG_SUM_EXP
,
)
and
(
amp
.
_enabled
or
np
.
all
([
np
.
issubdtype
(
arg
.
dtype
,
np
.
integer
)
for
arg
in
args
])
)
or
mode
in
(
_ElwMod
.
TANH
,)
and
np
.
all
([
np
.
issubdtype
(
arg
.
dtype
,
np
.
integer
)
for
arg
in
args
])
):
# autocast to FP32 to maintain precision
# or to avoid op's not supporting all int args
args
=
cast_tensors
(
*
args
,
promote
=
True
)
if
mode
in
(
_ElwMod
.
CEIL
,
_ElwMod
.
FLOOR
,
_ElwMod
.
ROUND
,)
and
np
.
issubdtype
(
args
[
0
].
dtype
,
np
.
integer
):
return
args
[
0
]
return
_elwise_apply
(
args
,
mode
)
return
_elwise_apply
(
args
,
mode
)
...
@@ -504,10 +471,6 @@ def _remove_axis(inp: Tensor, axis) -> Tensor:
...
@@ -504,10 +471,6 @@ def _remove_axis(inp: Tensor, axis) -> Tensor:
def
_reduce
(
mode
):
def
_reduce
(
mode
):
def
f
(
self
,
axis
=
None
,
keepdims
:
bool
=
False
):
def
f
(
self
,
axis
=
None
,
keepdims
:
bool
=
False
):
data
=
self
data
=
self
if
mode
==
"mean"
:
data
=
data
.
astype
(
"float32"
)
elif
self
.
dtype
==
np
.
bool_
:
data
=
data
.
astype
(
"int32"
)
if
axis
is
None
:
if
axis
is
None
:
assert
not
keepdims
,
"can not set axis=None and keepdims=True"
assert
not
keepdims
,
"can not set axis=None and keepdims=True"
result
=
_reduce_to_scalar
(
builtin
.
Reduce
(
mode
=
mode
),
data
)
result
=
_reduce_to_scalar
(
builtin
.
Reduce
(
mode
=
mode
),
data
)
...
@@ -526,9 +489,6 @@ def _reduce(mode):
...
@@ -526,9 +489,6 @@ def _reduce(mode):
if
not
keepdims
:
if
not
keepdims
:
result
=
_remove_axis
(
result
,
axis
)
result
=
_remove_axis
(
result
,
axis
)
if
self
.
dtype
==
np
.
bool_
:
if
mode
in
[
"min"
,
"max"
]:
result
=
result
.
astype
(
"bool"
)
return
result
return
result
return
f
return
f
...
...
imperative/python/megengine/core/tensor/utils.py
浏览文件 @
3c3fc6f3
...
@@ -16,6 +16,8 @@ from .._imperative_rt import make_const
...
@@ -16,6 +16,8 @@ from .._imperative_rt import make_const
from
.._imperative_rt.core2
import
(
from
.._imperative_rt.core2
import
(
SymbolVar
,
SymbolVar
,
Tensor
,
Tensor
,
_get_convert_inputs
,
_set_convert_inputs
,
apply
,
apply
,
dtype_promotion
,
dtype_promotion
,
get_device
,
get_device
,
...
@@ -27,15 +29,13 @@ from .._wrap import as_device
...
@@ -27,15 +29,13 @@ from .._wrap import as_device
from
..autodiff.grad
import
Function
from
..autodiff.grad
import
Function
from
..ops
import
builtin
from
..ops
import
builtin
from
..ops.special
import
Const
from
..ops.special
import
Const
from
.amp
import
_
high_prec_dtype
,
_low_prec_dtype
from
.amp
import
_
get_amp_high_prec_dtype
,
_get_amp
_low_prec_dtype
from
.dtype
import
is_dtype_equal
,
is_quantize
from
.dtype
import
is_dtype_equal
,
is_quantize
_enable_convert_inputs
=
True
def
get_convert_inputs
():
def
get_convert_inputs
():
r
"""get the curerent state of `_enable_convert_inputs`"""
r
"""get the curerent state of `_enable_convert_inputs`"""
return
_
enable_convert_inputs
return
_
get_convert_inputs
()
def
set_convert_inputs
(
flag
):
def
set_convert_inputs
(
flag
):
...
@@ -44,10 +44,7 @@ def set_convert_inputs(flag):
...
@@ -44,10 +44,7 @@ def set_convert_inputs(flag):
`_enable_convert_inputs` is set to `False`, otherwise enabled. This function is for
`_enable_convert_inputs` is set to `False`, otherwise enabled. This function is for
internal use only, and should be removed when the tensor-like system is refactored.
internal use only, and should be removed when the tensor-like system is refactored.
"""
"""
global
_enable_convert_inputs
return
_set_convert_inputs
(
flag
)
backup
=
_enable_convert_inputs
_enable_convert_inputs
=
flag
return
backup
def
concatenate
(
inputs
,
axis
=
0
,
*
,
device
=
None
):
def
concatenate
(
inputs
,
axis
=
0
,
*
,
device
=
None
):
...
@@ -75,7 +72,7 @@ def convert_single_value(v, *, dtype=None, device=None):
...
@@ -75,7 +72,7 @@ def convert_single_value(v, *, dtype=None, device=None):
def
convert_inputs
(
*
args
,
device
=
None
):
def
convert_inputs
(
*
args
,
device
=
None
):
if
not
_
enable_convert_inputs
:
if
not
_
get_convert_inputs
()
:
return
args
return
args
dtype
=
dtype_promotion
(
args
)
dtype
=
dtype_promotion
(
args
)
...
@@ -109,9 +106,9 @@ def convert_inputs(*args, device=None):
...
@@ -109,9 +106,9 @@ def convert_inputs(*args, device=None):
def
cast_tensors
(
*
args
,
promote
=
False
):
def
cast_tensors
(
*
args
,
promote
=
False
):
if
promote
:
if
promote
:
dtype
=
_
high_prec_dtype
dtype
=
_
get_amp_high_prec_dtype
()
else
:
else
:
dtype
=
_
low_prec_dtype
dtype
=
_
get_amp_low_prec_dtype
()
return
tuple
(
arg
.
astype
(
dtype
)
if
arg
is
not
None
else
None
for
arg
in
args
)
return
tuple
(
arg
.
astype
(
dtype
)
if
arg
is
not
None
else
None
for
arg
in
args
)
...
...
imperative/python/megengine/functional/elemwise.py
浏览文件 @
3c3fc6f3
...
@@ -16,6 +16,7 @@ from ..core.tensor.array_method import _elwise
...
@@ -16,6 +16,7 @@ from ..core.tensor.array_method import _elwise
from
..core.tensor.utils
import
convert_inputs
from
..core.tensor.utils
import
convert_inputs
from
..tensor
import
Tensor
from
..tensor
import
Tensor
from
..utils.deprecation
import
deprecated_func
from
..utils.deprecation
import
deprecated_func
from
.tensor_cache
import
get_scalar_one
__all__
=
[
__all__
=
[
"abs"
,
"abs"
,
...
@@ -359,7 +360,11 @@ def asin(x):
...
@@ -359,7 +360,11 @@ def asin(x):
def
atan
(
x
):
def
atan
(
x
):
r
"""Element-wise `inverse tangent`."""
r
"""Element-wise `inverse tangent`."""
return
_elwise
(
x
,
1
,
mode
=
Elemwise
.
Mode
.
ATAN2
)
return
_elwise
(
x
,
get_scalar_one
(
"float32"
,
x
.
device
if
isinstance
(
x
,
Tensor
)
else
None
),
mode
=
Elemwise
.
Mode
.
ATAN2
,
)
def
atan2
(
y
,
x
):
def
atan2
(
y
,
x
):
...
...
imperative/python/megengine/functional/nn.py
浏览文件 @
3c3fc6f3
...
@@ -253,15 +253,6 @@ def conv2d(
...
@@ -253,15 +253,6 @@ def conv2d(
conv_mode
.
lower
()
==
"cross_correlation"
conv_mode
.
lower
()
==
"cross_correlation"
or
conv_mode
.
name
==
"CROSS_CORRELATION"
or
conv_mode
.
name
==
"CROSS_CORRELATION"
)
)
if
amp
.
_enabled
:
compute_mode
=
"float32"
inp
,
weight
,
bias
=
cast_tensors
(
inp
,
weight
,
bias
)
else
:
dtype
=
dtype_promotion
(
inp
,
weight
)
if
inp
.
dtype
!=
dtype
:
inp
=
inp
.
astype
(
dtype
)
if
weight
.
dtype
!=
dtype
:
weight
=
weight
.
astype
(
dtype
)
stride_h
,
stride_w
=
expand_hw
(
stride
)
stride_h
,
stride_w
=
expand_hw
(
stride
)
pad_h
,
pad_w
=
expand_hw
(
padding
)
pad_h
,
pad_w
=
expand_hw
(
padding
)
...
@@ -1328,29 +1319,32 @@ def batch_norm(
...
@@ -1328,29 +1319,32 @@ def batch_norm(
inplace: whether to update ``running_mean`` and ``running_var``
inplace: whether to update ``running_mean`` and ``running_var``
inplace or return new tensors. Default: True
inplace or return new tensors. Default: True
"""
"""
if
inp
.
ndim
!=
4
:
raise
NotImplementedError
(
"batch_norm for ndim != 4"
)
if
param_dim
==
"dim_1c11"
:
C
=
inp
.
shape
[
1
]
pshape
=
(
1
,
C
,
1
,
1
)
elif
param_dim
==
"dim_111c"
:
C
=
inp
.
shape
[
3
]
pshape
=
(
1
,
1
,
1
,
C
)
else
:
raise
ValueError
(
"Invalid param_dim {}"
.
format
(
param_dim
))
def
make_full_if_none
(
x
,
value
):
def
make_full_if_none
(
x
,
value
):
x_ndim
=
None
if
x
is
None
else
x
.
ndim
# in general case, x will be returned here directly
if
x_ndim
is
not
None
and
x_ndim
!=
1
:
return
x
if
param_dim
==
"dim_1c11"
:
C
=
inp
.
shape
[
1
]
pshape
=
(
1
,
C
,
1
,
1
)
elif
param_dim
==
"dim_111c"
:
C
=
inp
.
shape
[
3
]
pshape
=
(
1
,
1
,
1
,
C
)
else
:
raise
ValueError
(
"Invalid param_dim {}"
.
format
(
param_dim
))
if
x
is
None
:
if
x
is
None
:
(
x
,)
=
Const
(
value
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)()
(
x
,)
=
Const
(
value
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)()
shape
=
astensor1d
(
pshape
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
shape
=
astensor1d
(
pshape
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
(
result
,)
=
apply
(
builtin
.
Broadcast
(),
x
,
shape
)
(
result
,)
=
apply
(
builtin
.
Broadcast
(),
x
,
shape
)
return
result
return
result
elif
x
.
ndim
==
1
:
else
:
assert
x_ndim
==
1
shape
=
astensor1d
(
pshape
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
shape
=
astensor1d
(
pshape
,
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
(
result
,)
=
apply
(
builtin
.
Reshape
(),
x
,
shape
)
(
result
,)
=
apply
(
builtin
.
Reshape
(),
x
,
shape
)
return
result
return
result
return
x
has_mean
=
running_mean
is
not
None
has_mean
=
running_mean
is
not
None
has_var
=
running_var
is
not
None
has_var
=
running_var
is
not
None
...
@@ -1359,16 +1353,6 @@ def batch_norm(
...
@@ -1359,16 +1353,6 @@ def batch_norm(
assert
has_mean
,
"running_mean must be provided in inference mode"
assert
has_mean
,
"running_mean must be provided in inference mode"
assert
has_var
,
"running_var must be provided in inference mode"
assert
has_var
,
"running_var must be provided in inference mode"
if
has_mean
and
running_mean
.
ndim
!=
4
:
raise
ValueError
if
has_var
and
running_var
.
ndim
!=
4
:
raise
ValueError
if
amp
.
_enabled
:
inp
=
inp
.
astype
(
"float16"
)
weight
,
bias
,
running_mean
,
running_var
=
cast_tensors
(
weight
,
bias
,
running_mean
,
running_var
,
promote
=
True
)
weight
=
make_full_if_none
(
weight
,
1
)
weight
=
make_full_if_none
(
weight
,
1
)
bias
=
make_full_if_none
(
bias
,
0
)
bias
=
make_full_if_none
(
bias
,
0
)
...
...
imperative/python/megengine/functional/tensor_cache.py
0 → 100644
浏览文件 @
3c3fc6f3
from
..core.ops.special
import
Const
from
..jit.tracing
import
is_tracing
small_tensor_cache
=
{}
def
_get_scalar_tensor_with_value
(
value
,
dtype
=
None
,
device
=
None
):
global
small_tensor_cache
if
is_tracing
():
(
ret
,)
=
Const
(
value
,
dtype
=
dtype
,
device
=
device
)()
else
:
cache_key
=
(
value
,
dtype
,
device
)
if
cache_key
not
in
small_tensor_cache
:
(
ret
,)
=
Const
(
value
,
dtype
=
dtype
,
device
=
device
)()
small_tensor_cache
[
cache_key
]
=
ret
else
:
ret
=
small_tensor_cache
[
cache_key
]
return
ret
def
get_scalar_zero
(
dtype
=
None
,
device
=
None
):
return
_get_scalar_tensor_with_value
(
0
,
dtype
,
device
)
def
get_scalar_zero_point_five
(
dtype
=
None
,
device
=
None
):
return
_get_scalar_tensor_with_value
(
0.5
,
dtype
,
device
)
def
get_scalar_one
(
dtype
=
None
,
device
=
None
):
return
_get_scalar_tensor_with_value
(
1
,
dtype
,
device
)
def
get_scalar_two
(
dtype
=
None
,
device
=
None
):
return
_get_scalar_tensor_with_value
(
2
,
dtype
,
device
)
imperative/python/src/tensor.cpp
浏览文件 @
3c3fc6f3
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/profiler.h"
#include "megbrain/imperative/profiler.h"
#include "megbrain/imperative/transformations/dtype_promote.h"
#include "megbrain/imperative/transformations/eval.h"
#include "megbrain/imperative/transformations/eval.h"
#include "megbrain/imperative/transformations/lazy.h"
#include "megbrain/imperative/transformations/lazy.h"
#include "megbrain/imperative/transformations/scalar.h"
#include "megbrain/imperative/transformations/scalar.h"
...
@@ -59,16 +60,19 @@ struct SymbolVarContext {
...
@@ -59,16 +60,19 @@ struct SymbolVarContext {
TransformationContext
context
;
TransformationContext
context
;
std
::
shared_ptr
<
SymbolTransformation
>
symbol_tsf
;
std
::
shared_ptr
<
SymbolTransformation
>
symbol_tsf
;
std
::
shared_ptr
<
ScalarTransformation
>
scalar_tsf
;
std
::
shared_ptr
<
ScalarTransformation
>
scalar_tsf
;
std
::
shared_ptr
<
DTypePromoteTransformation
>
dtype_promote_tsf
;
SymbolVarContext
(
cg
::
ComputingGraph
*
graph
)
{
SymbolVarContext
(
cg
::
ComputingGraph
*
graph
)
{
symbol_tsf
=
std
::
make_shared
<
SymbolTransformation
>
(
graph
);
symbol_tsf
=
std
::
make_shared
<
SymbolTransformation
>
(
graph
);
scalar_tsf
=
std
::
make_shared
<
ScalarTransformation
>
();
scalar_tsf
=
std
::
make_shared
<
ScalarTransformation
>
();
dtype_promote_tsf
=
std
::
make_shared
<
DTypePromoteTransformation
>
();
Transformation
::
swap_context
(
context
);
Transformation
::
swap_context
(
context
);
}
}
void
init
()
{
void
init
()
{
symbol_tsf
->
register_at
(
Transformation
::
top
());
symbol_tsf
->
register_at
(
Transformation
::
top
());
scalar_tsf
->
register_at
(
Transformation
::
top
());
scalar_tsf
->
register_at
(
Transformation
::
top
());
dtype_promote_tsf
->
register_at
(
Transformation
::
top
());
}
}
ValueRef
symvar2val
(
py
::
handle
py_symbol_var
)
{
ValueRef
symvar2val
(
py
::
handle
py_symbol_var
)
{
...
@@ -110,6 +114,9 @@ REGISTE_APPLY_FUNC(cpp_astensor1d)
...
@@ -110,6 +114,9 @@ REGISTE_APPLY_FUNC(cpp_astensor1d)
#undef REGISTE_APPLY_FUNC
#undef REGISTE_APPLY_FUNC
PyArray_Descr
*
_dtype_promotion
(
PyObject
*
const
*
args
,
size_t
nargs
);
CompNode
_get_device
(
PyObject
*
const
*
args
,
size_t
nargs
);
PyObject
*
py_apply
(
PyObject
*
py_apply
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
/* , PyObject* kwnames */
)
{
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargs
/* , PyObject* kwnames */
)
{
try
{
try
{
...
@@ -133,19 +140,59 @@ PyObject* py_apply(
...
@@ -133,19 +140,59 @@ PyObject* py_apply(
auto
op
=
py
::
handle
(
py_op
).
cast
<
std
::
shared_ptr
<
OpDef
>>
();
auto
op
=
py
::
handle
(
py_op
).
cast
<
std
::
shared_ptr
<
OpDef
>>
();
SmallVector
<
ValueRef
,
8
>
tensors
(
nargs
);
SmallVector
<
ValueRef
,
8
>
tensors
(
nargs
);
bool
is_symbol_var
=
(
!
TensorWrapper
::
try_cast
(
args
[
0
]))
&&
SmallVector
<
bool
,
8
>
is_symbol_var
(
nargs
,
false
);
py
::
isinstance
<
PySymbolVar
>
(
py
::
handle
(
args
[
0
]));
ComputingGraph
*
cg
=
nullptr
;
if
(
is_symbol_var
)
{
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
if
((
!
TensorWrapper
::
try_cast
(
args
[
i
]))
&&
py
::
isinstance
<
PySymbolVar
>
(
py
::
handle
(
args
[
i
])))
{
is_symbol_var
[
i
]
=
true
;
ComputingGraph
*
cur_cg
=
py
::
handle
(
args
[
i
]).
cast
<
PySymbolVar
*>
()
->
m_node
->
owner_graph
();
if
(
cg
==
nullptr
)
{
cg
=
cur_cg
;
}
else
{
mgb_assert
(
cg
==
cur_cg
);
}
}
}
mgb
::
CompNode
target_cn
;
mgb
::
DType
target_dtype
;
auto
convert_pyinput_to_tensor
=
[
&
](
size_t
i
)
->
ValueRef
{
if
(
!
target_dtype
.
valid
())
{
target_dtype
=
npy
::
dtype_np2mgb_descr
(
_dtype_promotion
(
args
,
nargs
));
target_cn
=
_get_device
(
args
,
nargs
);
}
HostTensorND
ht
(
target_cn
);
ht
=
npy
::
np2tensor
(
args
[
i
],
npy
::
Meth
::
copy_into
(
&
ht
),
target_dtype
);
if
(
PyArray_Check
(
args
[
i
]))
{
// non scaler
return
imperative
::
apply
(
CreateTensor
(
CreateTensor
::
Const
,
target_cn
,
ht
.
layout
()),
HostStorage
::
make
(
ht
.
storage
()))[
0
];
}
else
{
// scaler
return
imperative
::
apply
(
CreateTensor
(
CreateTensor
::
Const
,
target_cn
,
target_dtype
,
{}),
HostStorage
::
make
(
ht
.
storage
()))[
0
];
}
};
if
(
cg
!=
nullptr
)
{
// swap to a special context to reuse scalar handle
// swap to a special context to reuse scalar handle
SymbolVarContext
context
(
size_t
symbol_var_idx
=
8
;
py
::
handle
(
args
[
0
]).
cast
<
PySymbolVar
*>
()
->
m_node
->
owner_graph
()
);
SymbolVarContext
context
(
cg
);
context
.
init
();
context
.
init
();
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
tensors
[
i
]
=
context
.
symvar2val
(
args
[
i
]);
if
(
is_symbol_var
[
i
])
{
symbol_var_idx
=
i
;
tensors
[
i
]
=
context
.
symvar2val
(
args
[
i
]);
}
else
{
tensors
[
i
]
=
convert_pyinput_to_tensor
(
i
);
}
}
}
auto
outputs
=
imperative
::
apply
(
*
op
,
tensors
);
auto
outputs
=
imperative
::
apply
(
*
op
,
tensors
);
auto
ret
=
pybind11
::
tuple
(
outputs
.
size
());
auto
ret
=
pybind11
::
tuple
(
outputs
.
size
());
auto
typeobj
=
py
::
handle
(
args
[
0
]).
get_type
();
auto
typeobj
=
py
::
handle
(
args
[
symbol_var_idx
]).
get_type
();
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
ret
[
i
]
=
context
.
val2symvar
(
typeobj
,
outputs
[
i
]);
ret
[
i
]
=
context
.
val2symvar
(
typeobj
,
outputs
[
i
]);
}
}
...
@@ -156,13 +203,7 @@ PyObject* py_apply(
...
@@ -156,13 +203,7 @@ PyObject* py_apply(
if
(
TensorWrapper
*
tw
=
TensorWrapper
::
try_cast
(
args
[
i
]))
{
if
(
TensorWrapper
*
tw
=
TensorWrapper
::
try_cast
(
args
[
i
]))
{
tensors
[
i
]
=
tw
->
m_tensor
->
data
();
tensors
[
i
]
=
tw
->
m_tensor
->
data
();
}
else
{
}
else
{
PyErr_SetString
(
tensors
[
i
]
=
convert_pyinput_to_tensor
(
i
);
PyExc_TypeError
,
ssprintf
(
"op %s expect type Tensor as inputs, got %s actually"
,
op
->
make_name
().
c_str
(),
Py_TYPE
(
args
[
i
])
->
tp_name
)
.
c_str
());
return
nullptr
;
}
}
}
}
...
@@ -616,6 +657,8 @@ void init_tensor(py::module m) {
...
@@ -616,6 +657,8 @@ void init_tensor(py::module m) {
std
::
shared_ptr
<
Channel
>
(
channel
,
[](
Channel
*
)
{})));
std
::
shared_ptr
<
Channel
>
(
channel
,
[](
Channel
*
)
{})));
transformations
.
register_at
<
Segment
::
Scalar
>
(
transformations
.
register_at
<
Segment
::
Scalar
>
(
std
::
make_shared
<
ScalarTransformation
>
());
std
::
make_shared
<
ScalarTransformation
>
());
transformations
.
register_at
<
Segment
::
DTypePromote
>
(
std
::
make_shared
<
DTypePromoteTransformation
>
());
static
py
::
exception
<
interpreter
::
AsyncError
>
py_async_error
(
static
py
::
exception
<
interpreter
::
AsyncError
>
py_async_error
(
m
,
"AsyncError"
,
PyExc_RuntimeError
);
m
,
"AsyncError"
,
PyExc_RuntimeError
);
...
@@ -1137,6 +1180,63 @@ void init_tensor(py::module m) {
...
@@ -1137,6 +1180,63 @@ void init_tensor(py::module m) {
m
.
def
(
"reset_stats"
,
[]
{
imperative
::
Stats
::
reset
();
});
m
.
def
(
"reset_stats"
,
[]
{
imperative
::
Stats
::
reset
();
});
m
.
def
(
"_get_convert_inputs"
,
[]()
->
bool
{
return
DTypePromoteCfg
::
convert_input_enabled
;
});
m
.
def
(
"_set_convert_inputs"
,
[](
bool
flag
)
->
bool
{
bool
ret
=
DTypePromoteCfg
::
convert_input_enabled
;
DTypePromoteCfg
::
convert_input_enabled
=
flag
;
return
ret
;
});
m
.
def
(
"_get_amp_dtype_autocast"
,
[]()
->
bool
{
return
DTypePromoteCfg
::
amp_dtype_autocast_enabled
;
});
m
.
def
(
"_set_amp_dtype_autocast"
,
[](
bool
flag
)
->
bool
{
bool
ret
=
DTypePromoteCfg
::
amp_dtype_autocast_enabled
;
DTypePromoteCfg
::
amp_dtype_autocast_enabled
=
flag
;
return
ret
;
});
static
auto
get_amp_prec_dtype
=
[](
bool
is_high
)
->
std
::
string
{
DType
&
target
=
is_high
?
DTypePromoteCfg
::
amp_high_prec_dtype
:
DTypePromoteCfg
::
amp_low_prec_dtype
;
mgb_assert
(
target
.
category
()
==
DTypeCategory
::
FLOAT
);
std
::
string
ret
=
target
.
name
();
transform
(
ret
.
begin
(),
ret
.
end
(),
ret
.
begin
(),
::
tolower
);
return
ret
;
};
static
auto
set_amp_prec_dtype
=
[](
bool
is_high
,
std
::
string
dtype_name
)
->
std
::
string
{
DType
&
target
=
is_high
?
DTypePromoteCfg
::
amp_high_prec_dtype
:
DTypePromoteCfg
::
amp_low_prec_dtype
;
std
::
string
ret
=
target
.
name
();
if
(
dtype_name
==
"float32"
)
{
target
=
dtype
::
Float32
();
}
else
if
(
dtype_name
==
"float16"
)
{
target
=
dtype
::
Float16
();
}
else
if
(
dtype_name
==
"bfloat16"
)
{
target
=
dtype
::
BFloat16
();
}
else
{
mgb_assert
(
false
,
"casted type of amp should be float, but you give %s
\n
"
,
dtype_name
.
c_str
());
}
transform
(
ret
.
begin
(),
ret
.
end
(),
ret
.
begin
(),
::
tolower
);
return
ret
;
};
m
.
def
(
"_get_amp_high_prec_dtype"
,
[]()
->
std
::
string
{
return
get_amp_prec_dtype
(
true
);
});
m
.
def
(
"_set_amp_high_prec_dtype"
,
[](
std
::
string
dtype_name
)
->
std
::
string
{
return
set_amp_prec_dtype
(
true
,
dtype_name
);
});
m
.
def
(
"_get_amp_low_prec_dtype"
,
[]()
->
std
::
string
{
return
get_amp_prec_dtype
(
false
);
});
m
.
def
(
"_set_amp_low_prec_dtype"
,
[](
std
::
string
dtype_name
)
->
std
::
string
{
return
set_amp_prec_dtype
(
false
,
dtype_name
);
});
py
::
register_exception
<
TraceError
>
(
m
,
"TraceError"
);
py
::
register_exception
<
TraceError
>
(
m
,
"TraceError"
);
}
}
...
...
imperative/python/src/transformation.h
浏览文件 @
3c3fc6f3
...
@@ -26,12 +26,13 @@ struct TransformationManager {
...
@@ -26,12 +26,13 @@ struct TransformationManager {
enum
Segment
{
enum
Segment
{
ModuleTrace
,
ModuleTrace
,
Grad
,
Grad
,
DTypePromote
,
Scalar
,
Scalar
,
Trace
,
Trace
,
Eval
,
Eval
,
};
};
std
::
array
<
std
::
vector
<
std
::
shared_ptr
<
Transformation
>>
,
5
>
segments
;
std
::
array
<
std
::
vector
<
std
::
shared_ptr
<
Transformation
>>
,
6
>
segments
;
template
<
Segment
segment
>
template
<
Segment
segment
>
void
register_at
(
std
::
shared_ptr
<
Transformation
>
transformation
)
{
void
register_at
(
std
::
shared_ptr
<
Transformation
>
transformation
)
{
...
...
imperative/python/test/unit/amp/test_autocast.py
浏览文件 @
3c3fc6f3
...
@@ -14,20 +14,20 @@ def test_grad_scaler():
...
@@ -14,20 +14,20 @@ def test_grad_scaler():
assert
amp
.
enabled
==
enabled
assert
amp
.
enabled
==
enabled
assert
origin_amp
.
_enabled
==
enabled
assert
origin_amp
.
_enabled
==
enabled
assert
amp
.
low_prec_dtype
==
low
assert
amp
.
low_prec_dtype
==
low
assert
origin_amp
.
_
low_prec_dtype
==
low
assert
origin_amp
.
_
get_amp_low_prec_dtype
()
==
low
assert
amp
.
high_prec_dtype
==
high
assert
amp
.
high_prec_dtype
==
high
assert
origin_amp
.
_
high_prec_dtype
==
high
assert
origin_amp
.
_
get_amp_high_prec_dtype
()
==
high
origin_enabled
=
amp
.
enabled
origin_enabled
=
amp
.
enabled
origin_high
=
amp
.
high_prec_dtype
origin_high
=
amp
.
high_prec_dtype
origin_low
=
amp
.
low_prec_dtype
origin_low
=
amp
.
low_prec_dtype
with
amp
.
autocast
(
low_prec_dtype
=
"
low"
,
high_prec_dtype
=
"high
"
):
with
amp
.
autocast
(
low_prec_dtype
=
"
float16"
,
high_prec_dtype
=
"float32
"
):
check
(
True
,
"
low"
,
"high
"
)
check
(
True
,
"
float16"
,
"float32
"
)
check
(
origin_enabled
,
origin_low
,
origin_high
)
check
(
origin_enabled
,
origin_low
,
origin_high
)
amp
.
enabled
=
True
amp
.
enabled
=
True
amp
.
high_prec_dtype
=
"
high
"
amp
.
high_prec_dtype
=
"
float32
"
amp
.
low_prec_dtype
=
"
low
"
amp
.
low_prec_dtype
=
"
float16
"
check
(
True
,
"
low"
,
"high
"
)
check
(
True
,
"
float16"
,
"float32
"
)
amp
.
enabled
=
origin_enabled
amp
.
enabled
=
origin_enabled
amp
.
high_prec_dtype
=
origin_high
amp
.
high_prec_dtype
=
origin_high
amp
.
low_prec_dtype
=
origin_low
amp
.
low_prec_dtype
=
origin_low
...
...
imperative/src/impl/transformations/dtype_promote.cpp
0 → 100644
浏览文件 @
3c3fc6f3
#include "megbrain/imperative/transformations/dtype_promote.h"
#include "megbrain/imperative/ops/autogen.h"
namespace
mgb
::
imperative
{
bool
DTypePromoteCfg
::
convert_input_enabled
=
true
;
bool
DTypePromoteCfg
::
amp_dtype_autocast_enabled
=
false
;
DType
DTypePromoteCfg
::
amp_high_prec_dtype
=
dtype
::
Float32
();
DType
DTypePromoteCfg
::
amp_low_prec_dtype
=
dtype
::
Float16
();
namespace
{
// TODO: ScalarRule and DTypePromoteRule should be unified
using
DTypePromoteRule
=
std
::
function
<
ValueRefList
(
const
OpDef
&
,
Span
<
ValueRef
>
)
>
;
static
std
::
unordered_map
<
Typeinfo
*
,
DTypePromoteRule
>
dtype_promotion_rules
;
template
<
typename
T
>
void
register_dtype_promote_rule
(
const
DTypePromoteRule
&
rule
)
{
dtype_promotion_rules
[
T
::
typeinfo
()]
=
[
rule
](
const
OpDef
&
def
,
Span
<
ValueRef
>
inputs
)
{
return
rule
(
def
.
cast_final_safe
<
T
>
(),
inputs
);
};
}
bool
is_quantized_dtype
(
const
DType
&
dtype
)
{
return
dtype
.
category
()
==
DTypeCategory
::
QUANTIZED
;
}
bool
is_all_integer
(
const
SmallVector
<
DType
>&
dtypes
)
{
for
(
size_t
i
=
0
;
i
<
dtypes
.
size
();
++
i
)
{
if
(
dtypes
[
i
].
category
()
!=
DTypeCategory
::
INT
)
{
return
false
;
}
}
return
true
;
}
SmallVector
<
DType
>
get_value_dtypes
(
const
Span
<
ValueRef
>
inputs
)
{
SmallVector
<
DType
>
dtypes
(
inputs
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
dtypes
[
i
]
=
*
(
inputs
[
i
].
dtype
());
}
return
dtypes
;
}
mgb
::
DType
get_promoted_dtype
(
const
SmallVector
<
DType
>&
dtypes
)
{
if
(
dtypes
.
size
()
==
0
)
{
mgb_assert
(
false
,
"there is no input for operator, dtype promote failed"
);
}
mgb
::
DType
ret
=
dtypes
[
0
];
for
(
size_t
i
=
1
;
i
<
dtypes
.
size
();
++
i
)
{
ret
=
mgb
::
dtype_promotion
(
ret
,
dtypes
[
i
]);
}
return
ret
;
}
ValueRefList
elemwise_rule
(
const
OpDef
&
op
,
Span
<
ValueRef
>
inputs
)
{
auto
&&
elem_op
=
op
.
cast_final_safe
<
Elemwise
>
();
SmallVector
<
DType
>
dtypes
(
inputs
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
dtypes
[
i
]
=
*
(
inputs
[
i
].
dtype
());
}
ValueRefList
converted
(
inputs
.
size
());
mgb
::
DType
target_dtype
=
get_promoted_dtype
(
dtypes
);
// TODO: we can save the dtypes of inputs here and perform TypeCvt at the end of
// this function, rather than perform TypeCvt eagerly. But for the compatibility, we
// implement this function with the similar process as the python version and
// perform TypeCvt here, so we maybe do TypeCvt several times in these function
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
!
is_quantized_dtype
(
dtypes
[
i
])
&&
dtypes
[
i
]
!=
target_dtype
&&
DTypePromoteCfg
::
convert_input_enabled
)
{
converted
[
i
]
=
imperative
::
apply
(
ApplyOp
(
*
TypeCvt
::
make
(
target_dtype
)),
inputs
[
i
])[
0
];
dtypes
[
i
]
=
target_dtype
;
}
else
{
converted
[
i
]
=
inputs
[
i
];
}
}
static
std
::
unordered_set
<
Elemwise
::
Mode
>
cast_case1
=
{
Elemwise
::
Mode
::
TRUE_DIV
,
Elemwise
::
Mode
::
EXP
,
Elemwise
::
Mode
::
POW
,
Elemwise
::
Mode
::
LOG
,
Elemwise
::
Mode
::
EXPM1
,
Elemwise
::
Mode
::
LOG1P
,
Elemwise
::
Mode
::
ACOS
,
Elemwise
::
Mode
::
ASIN
,
Elemwise
::
Mode
::
ATAN2
,
Elemwise
::
Mode
::
COS
,
Elemwise
::
Mode
::
SIN
,
Elemwise
::
Mode
::
LOG_SUM_EXP
,
};
static
std
::
unordered_set
<
Elemwise
::
Mode
>
cast_case2
=
{
Elemwise
::
Mode
::
TANH
,
};
auto
cast_to_high_prec
=
[
&
]()
{
for
(
size_t
i
=
0
;
i
<
dtypes
.
size
();
++
i
)
{
if
(
dtypes
[
i
]
!=
DTypePromoteCfg
::
amp_high_prec_dtype
)
{
converted
[
i
]
=
imperative
::
apply
(
ApplyOp
(
*
TypeCvt
::
make
(
DTypePromoteCfg
::
amp_high_prec_dtype
)),
converted
[
i
])[
0
];
dtypes
[
i
]
=
DTypePromoteCfg
::
amp_high_prec_dtype
;
}
}
};
if
(
cast_case1
.
find
(
elem_op
.
mode
)
!=
cast_case1
.
end
())
{
if
(
DTypePromoteCfg
::
amp_dtype_autocast_enabled
||
is_all_integer
(
dtypes
))
{
cast_to_high_prec
();
}
}
if
(
cast_case2
.
find
(
elem_op
.
mode
)
!=
cast_case2
.
end
())
{
if
(
is_all_integer
(
dtypes
))
{
cast_to_high_prec
();
}
}
static
std
::
unordered_set
<
Elemwise
::
Mode
>
cast_case3
=
{
Elemwise
::
Mode
::
CEIL
,
Elemwise
::
Mode
::
FLOOR
,
Elemwise
::
Mode
::
ROUND
};
if
(
cast_case3
.
find
(
elem_op
.
mode
)
!=
cast_case3
.
end
())
{
if
(
dtypes
[
0
].
category
()
==
DTypeCategory
::
INT
)
{
return
converted
;
}
}
return
imperative
::
apply
(
op
,
converted
);
}
ValueRefList
reduce_rule
(
const
OpDef
&
op
,
Span
<
ValueRef
>
inputs
)
{
auto
&&
reduce_op
=
op
.
cast_final_safe
<
Reduce
>
();
DType
org_dtype
=
*
(
inputs
[
0
].
dtype
());
DType
target_dtype
=
org_dtype
;
ValueRefList
converted
(
inputs
.
begin
(),
inputs
.
end
());
if
(
reduce_op
.
mode
==
Reduce
::
Mode
::
MEAN
)
{
target_dtype
=
dtype
::
Float32
();
}
else
if
(
org_dtype
.
category
()
==
DTypeCategory
::
BOOL
)
{
target_dtype
=
dtype
::
Int32
();
}
if
(
target_dtype
!=
org_dtype
)
{
converted
[
0
]
=
imperative
::
apply
(
ApplyOp
(
*
TypeCvt
::
make
(
target_dtype
)),
inputs
[
0
])[
0
];
}
ValueRefList
ret
=
imperative
::
apply
(
op
,
converted
);
if
(
org_dtype
.
category
()
==
DTypeCategory
::
BOOL
)
{
if
(
reduce_op
.
mode
==
Reduce
::
Mode
::
MIN
||
reduce_op
.
mode
==
Reduce
::
Mode
::
MAX
)
{
ret
[
0
]
=
imperative
::
apply
(
ApplyOp
(
*
TypeCvt
::
make
(
dtype
::
Bool
())),
ret
[
0
])[
0
];
}
}
return
ret
;
}
ValueRefList
convolution_rule
(
const
OpDef
&
op
,
Span
<
ValueRef
>
inputs
)
{
auto
&&
conv_op
=
const_cast
<
Convolution
&>
(
op
.
cast_final_safe
<
Convolution
>
());
SmallVector
<
DType
>
dtypes
=
get_value_dtypes
(
inputs
);
mgb
::
DType
target_dtype
;
if
(
DTypePromoteCfg
::
amp_dtype_autocast_enabled
)
{
conv_op
.
compute_mode
=
Convolution
::
ComputeMode
::
FLOAT32
;
target_dtype
=
DTypePromoteCfg
::
amp_low_prec_dtype
;
}
else
{
target_dtype
=
get_promoted_dtype
(
dtypes
);
}
ValueRefList
converted
(
inputs
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
dtypes
[
i
]
!=
target_dtype
)
{
converted
[
i
]
=
imperative
::
apply
(
ApplyOp
(
*
TypeCvt
::
make
(
target_dtype
)),
inputs
[
i
])[
0
];
}
else
{
converted
[
i
]
=
inputs
[
i
];
}
}
return
imperative
::
apply
(
op
,
converted
);
}
ValueRefList
batch_norm_rule
(
const
OpDef
&
op
,
Span
<
ValueRef
>
inputs
)
{
if
(
DTypePromoteCfg
::
amp_dtype_autocast_enabled
)
{
mgb_assert
(
inputs
.
size
()
>
0
);
ValueRefList
converted
(
inputs
.
size
());
converted
[
0
]
=
imperative
::
apply
(
ApplyOp
(
*
TypeCvt
::
make
(
dtype
::
Float16
())),
inputs
[
0
])[
0
];
for
(
size_t
i
=
1
;
i
<
inputs
.
size
();
++
i
)
{
DType
idtype
=
*
(
inputs
[
i
].
dtype
());
if
(
idtype
!=
DTypePromoteCfg
::
amp_high_prec_dtype
)
{
converted
[
i
]
=
imperative
::
apply
(
ApplyOp
(
*
TypeCvt
::
make
(
DTypePromoteCfg
::
amp_high_prec_dtype
)),
inputs
[
i
])[
0
];
}
else
{
converted
[
i
]
=
inputs
[
i
];
}
}
return
imperative
::
apply
(
op
,
converted
);
}
return
imperative
::
apply
(
op
,
inputs
);
}
struct
DTypePromoteRuleRegistry
{
DTypePromoteRuleRegistry
()
{
register_dtype_promote_rule
<
Elemwise
>
(
elemwise_rule
);
register_dtype_promote_rule
<
Reduce
>
(
reduce_rule
);
register_dtype_promote_rule
<
Convolution
>
(
convolution_rule
);
register_dtype_promote_rule
<
BatchNorm
>
(
batch_norm_rule
);
}
}
register_helper
;
}
// namespace
ValueRefList
DTypePromoteTransformation
::
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
if
(
auto
apply_op
=
op
.
as
<
ApplyOp
>
())
{
auto
iter
=
dtype_promotion_rules
.
find
(
apply_op
->
op
().
dyn_typeinfo
());
if
(
iter
!=
dtype_promotion_rules
.
end
())
{
return
iter
->
second
(
apply_op
->
op
(),
inputs
);
}
else
{
return
imperative
::
apply
(
op
,
inputs
);
}
}
return
imperative
::
apply
(
op
,
inputs
);
}
ValueRef
DTypePromoteTransformation
::
unwrap
(
ValueRef
value
)
{
return
value
;
}
std
::
string
DTypePromoteTransformation
::
name
()
const
{
return
"DTypePromoteTransformation"
;
}
void
DTypePromoteTransformation
::
on_register
()
{
// printf("DTypePromoteTransformation has been registered\n");
}
void
DTypePromoteTransformation
::
on_unregister
()
noexcept
{
// printf("DTypePromoteTransformation has been unregistered\n");
}
}
// namespace mgb::imperative
\ No newline at end of file
imperative/src/include/megbrain/imperative/transformations/dtype_promote.h
0 → 100644
浏览文件 @
3c3fc6f3
#pragma once
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/value.h"
namespace
mgb
::
imperative
{
class
DTypePromoteTransformation
final
:
public
Transformation
{
private:
public:
ValueRefList
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
override
;
ValueRef
unwrap
(
ValueRef
value
)
override
;
std
::
string
name
()
const
override
;
void
on_register
()
override
;
void
on_unregister
()
noexcept
override
;
};
struct
DTypePromoteCfg
{
static
bool
convert_input_enabled
;
static
bool
amp_dtype_autocast_enabled
;
static
DType
amp_high_prec_dtype
;
static
DType
amp_low_prec_dtype
;
};
}
// namespace mgb::imperative
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录