Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
fccb2510
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
fccb2510
编写于
10月 27, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(mge): add opdef for elemwise and batchnorm
GitOrigin-RevId: d51fad98678098651e83ebdd2a3c1ff6196cb7a2
上级
c008cf37
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
510 addition
and
136 deletion
+510
-136
imperative/python/megengine/core/autodiff/builtin_op_utils.py
...rative/python/megengine/core/autodiff/builtin_op_utils.py
+7
-14
imperative/python/megengine/core/autodiff/grad.py
imperative/python/megengine/core/autodiff/grad.py
+1
-1
imperative/python/megengine/core/tensor/tensor_wrapper.py
imperative/python/megengine/core/tensor/tensor_wrapper.py
+42
-40
imperative/python/megengine/functional/elemwise.py
imperative/python/megengine/functional/elemwise.py
+41
-40
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+26
-16
imperative/python/megengine/module/batchnorm.py
imperative/python/megengine/module/batchnorm.py
+9
-7
imperative/python/src/ops.cpp
imperative/python/src/ops.cpp
+89
-0
imperative/python/test/unit/core/test_autodiff.py
imperative/python/test/unit/core/test_autodiff.py
+6
-6
imperative/python/test/unit/core/test_imperative_rt.py
imperative/python/test/unit/core/test_imperative_rt.py
+7
-5
imperative/python/test/unit/test_tracing.py
imperative/python/test/unit/test_tracing.py
+8
-7
imperative/src/impl/ops/batch_norm.cpp
imperative/src/impl/ops/batch_norm.cpp
+84
-0
imperative/src/impl/ops/elemwise.cpp
imperative/src/impl/ops/elemwise.cpp
+78
-0
imperative/src/include/megbrain/imperative/ops/batch_norm.h
imperative/src/include/megbrain/imperative/ops/batch_norm.h
+70
-0
imperative/src/include/megbrain/imperative/ops/elemwise.h
imperative/src/include/megbrain/imperative/ops/elemwise.h
+42
-0
未找到文件。
imperative/python/megengine/core/autodiff/builtin_op_utils.py
浏览文件 @
fccb2510
...
...
@@ -30,7 +30,6 @@ from ..tensor.core import apply
from
..tensor.function
import
Function
from
..tensor.tensor_wrapper
import
TensorWrapper
_elemwise_add_param
=
Elemwise
(
mode
=
"add"
).
to_c
().
param
_reduce_sum_param
=
Reduce
(
mode
=
"SUM"
).
to_c
().
param
[
0
]
...
...
@@ -44,12 +43,12 @@ def _(op: OpDef, inputs, outputs, input_requires_grad):
if
isinstance
(
op
,
OprAttr
):
grad_fn
=
_oprAttr_grad_fn
.
get
(
op
.
type
,
None
)
if
grad_fn
is
None
:
if
op
.
type
==
Elemwise
.
name
and
op
.
param
==
_elemwise_add_param
:
grad_fn
=
elemwise_add_grad_fn
elif
op
.
type
==
Reduce
.
name
and
op
.
param
[
0
]
==
_reduce_sum_param
:
if
op
.
type
==
Reduce
.
name
and
op
.
param
[
0
]
==
_reduce_sum_param
:
grad_fn
=
reduce_sum_grad_fn
else
:
grad_fn
=
default_grad_fn
elif
isinstance
(
op
,
Elemwise
)
and
op
.
mode
==
Elemwise
.
Mode
.
ADD
:
grad_fn
=
elemwise_add_grad_fn
else
:
grad_fn
=
default_grad_fn
return
grad_fn
(
op
,
inputs
,
outputs
,
input_requires_grad
)
...
...
@@ -158,11 +157,8 @@ def subtensor_grad_fn(op, inputs, outputs, input_requires_grad):
params
=
inputs
[
1
:]
def
make_grad
(
grad_op
,
dy
):
grad
=
(
TensorWrapper
(
0
,
dtype
=
dy
.
dtype
,
device
=
dy
.
device
)
.
_broadcast
(
TensorWrapper
(
input_shape
))
.
__wrapped__
)
(
_z
,)
=
Const
(
0
,
dtype
=
dy
.
dtype
,
device
=
dy
.
device
)(
dy
)
(
grad
,)
=
apply
(
Broadcast
(),
_z
,
input_shape
)
(
dx
,)
=
apply
(
grad_op
,
grad
,
dy
,
*
params
)
return
dx
...
...
@@ -184,11 +180,8 @@ def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad):
params
=
inputs
[
1
:]
def
make_grad
(
grad_op
,
dy
):
grad
=
(
TensorWrapper
(
0
,
dtype
=
dy
.
dtype
,
device
=
dy
.
device
)
.
_broadcast
(
TensorWrapper
(
input_shape
))
.
__wrapped__
)
(
_z
,)
=
Const
(
0
,
dtype
=
dy
.
dtype
,
device
=
dy
.
device
)(
dy
)
(
grad
,)
=
apply
(
Broadcast
(),
_z
,
input_shape
)
(
dx
,)
=
apply
(
grad_op
,
grad
,
dy
,
*
params
)
return
dx
...
...
imperative/python/megengine/core/autodiff/grad.py
浏览文件 @
fccb2510
...
...
@@ -47,7 +47,7 @@ def get_grad_managers():
def
add
(
a
,
b
):
(
c
,)
=
apply
(
Elemwise
(
mode
=
"add"
),
a
,
b
)
(
c
,)
=
apply
(
Elemwise
(
Elemwise
.
Mode
.
ADD
),
a
,
b
)
return
c
...
...
imperative/python/megengine/core/tensor/tensor_wrapper.py
浏览文件 @
fccb2510
...
...
@@ -13,7 +13,7 @@ import numpy as np
from
.._trace_option
import
use_symbolic_shape
from
..ops
import
builtin
from
..ops.builtin
import
GetVarShape
from
..ops.builtin
import
Elemwise
,
GetVarShape
from
..ops.special
import
Const
from
.
import
utils
from
.core
import
OpBase
,
TensorBase
,
TensorWrapperBase
,
apply
...
...
@@ -23,10 +23,12 @@ from .raw_tensor import RawTensor, as_raw_tensor
from
.tensor
import
Tensor
from
.utils
import
make_shape_tuple
as
_make_shape_tuple
_ElwMod
=
Elemwise
.
Mode
def
_elwise
(
*
args
,
mode
):
op
=
builtin
.
Elemwise
(
mode
=
mode
)
if
mode
in
(
"TRUE_DIV"
,
"POW"
):
op
=
builtin
.
Elemwise
(
mode
)
if
mode
in
(
_ElwMod
.
TRUE_DIV
,
_ElwMod
.
POW
):
args
=
tuple
(
map
(
lambda
x
:
x
.
astype
(
"float32"
)
...
...
@@ -272,53 +274,53 @@ class ArrayMethodMixin(abc.ABC):
__hash__
=
None
# due to __eq__ diviates from python convention
__lt__
=
lambda
self
,
value
:
_elwise
(
self
,
value
,
mode
=
"LT"
).
astype
(
"bool"
)
__le__
=
lambda
self
,
value
:
_elwise
(
self
,
value
,
mode
=
"LEQ"
).
astype
(
"bool"
)
__gt__
=
lambda
self
,
value
:
_elwise
(
value
,
self
,
mode
=
"LT"
).
astype
(
"bool"
)
__ge__
=
lambda
self
,
value
:
_elwise
(
value
,
self
,
mode
=
"LEQ"
).
astype
(
"bool"
)
__eq__
=
lambda
self
,
value
:
_elwise
(
self
,
value
,
mode
=
"EQ"
).
astype
(
"bool"
)
__lt__
=
lambda
self
,
value
:
_elwise
(
self
,
value
,
mode
=
_ElwMod
.
LT
).
astype
(
"bool"
)
__le__
=
lambda
self
,
value
:
_elwise
(
self
,
value
,
mode
=
_ElwMod
.
LEQ
).
astype
(
"bool"
)
__gt__
=
lambda
self
,
value
:
_elwise
(
value
,
self
,
mode
=
_ElwMod
.
LT
).
astype
(
"bool"
)
__ge__
=
lambda
self
,
value
:
_elwise
(
value
,
self
,
mode
=
_ElwMod
.
LEQ
).
astype
(
"bool"
)
__eq__
=
lambda
self
,
value
:
_elwise
(
self
,
value
,
mode
=
_ElwMod
.
EQ
).
astype
(
"bool"
)
__ne__
=
lambda
self
,
value
:
_elwise
(
_elwise
(
self
,
value
,
mode
=
"EQ"
).
astype
(
"bool"
),
mode
=
"NOT"
_elwise
(
self
,
value
,
mode
=
_ElwMod
.
EQ
).
astype
(
"bool"
),
mode
=
_ElwMod
.
NOT
,
)
__neg__
=
_unary_elwise
(
"NEGATE"
)
__neg__
=
_unary_elwise
(
_ElwMod
.
NEGATE
)
__pos__
=
lambda
self
:
self
__abs__
=
_unary_elwise
(
"ABS"
)
__invert__
=
_logical_unary_elwise
(
"NOT"
)
__round__
=
_unary_elwise
(
"ROUND"
)
__abs__
=
_unary_elwise
(
_ElwMod
.
ABS
)
__invert__
=
_logical_unary_elwise
(
_ElwMod
.
NOT
)
__round__
=
_unary_elwise
(
_ElwMod
.
ROUND
)
__trunc__
=
_todo
__floor__
=
_unary_elwise
(
"FLOOR"
)
__ceil__
=
_unary_elwise
(
"CEIL"
)
__floor__
=
_unary_elwise
(
_ElwMod
.
FLOOR
)
__ceil__
=
_unary_elwise
(
_ElwMod
.
CEIL
)
__add__
=
_binary_elwise
(
"ADD"
)
__sub__
=
_binary_elwise
(
"SUB"
)
__mul__
=
_binary_elwise
(
"MUL"
)
__add__
=
_binary_elwise
(
_ElwMod
.
ADD
)
__sub__
=
_binary_elwise
(
_ElwMod
.
SUB
)
__mul__
=
_binary_elwise
(
_ElwMod
.
MUL
)
__matmul__
=
lambda
self
,
other
:
_matmul
(
self
,
other
)
__truediv__
=
_binary_elwise
(
"TRUE_DIV"
)
__floordiv__
=
_binary_elwise
(
"FLOOR_DIV"
)
__mod__
=
_binary_elwise
(
"MOD"
)
__truediv__
=
_binary_elwise
(
_ElwMod
.
TRUE_DIV
)
__floordiv__
=
_binary_elwise
(
_ElwMod
.
FLOOR_DIV
)
__mod__
=
_binary_elwise
(
_ElwMod
.
MOD
)
# __divmode__
__pow__
=
_binary_elwise
(
"POW"
)
__lshift__
=
_binary_elwise
(
"SHL"
)
__rshift__
=
_binary_elwise
(
"SHR"
)
__and__
=
_logical_binary_elwise
(
"AND"
)
__or__
=
_logical_binary_elwise
(
"OR"
)
__xor__
=
_logical_binary_elwise
(
"XOR"
)
__radd__
=
_binary_elwise
(
"ADD"
,
rev
=
1
)
__rsub__
=
_binary_elwise
(
"SUB"
,
rev
=
1
)
__rmul__
=
_binary_elwise
(
"MUL"
,
rev
=
1
)
__pow__
=
_binary_elwise
(
_ElwMod
.
POW
)
__lshift__
=
_binary_elwise
(
_ElwMod
.
SHL
)
__rshift__
=
_binary_elwise
(
_ElwMod
.
SHR
)
__and__
=
_logical_binary_elwise
(
_ElwMod
.
AND
)
__or__
=
_logical_binary_elwise
(
_ElwMod
.
OR
)
__xor__
=
_logical_binary_elwise
(
_ElwMod
.
XOR
)
__radd__
=
_binary_elwise
(
_ElwMod
.
ADD
,
rev
=
1
)
__rsub__
=
_binary_elwise
(
_ElwMod
.
SUB
,
rev
=
1
)
__rmul__
=
_binary_elwise
(
_ElwMod
.
MUL
,
rev
=
1
)
__rmatmul__
=
lambda
self
,
other
:
_matmul
(
other
,
self
)
__rtruediv__
=
_binary_elwise
(
"TRUE_DIV"
,
rev
=
1
)
__rfloordiv__
=
_binary_elwise
(
"FLOOR_DIV"
,
rev
=
1
)
__rmod__
=
_binary_elwise
(
"MOD"
,
rev
=
1
)
__rtruediv__
=
_binary_elwise
(
_ElwMod
.
TRUE_DIV
,
rev
=
1
)
__rfloordiv__
=
_binary_elwise
(
_ElwMod
.
FLOOR_DIV
,
rev
=
1
)
__rmod__
=
_binary_elwise
(
_ElwMod
.
MOD
,
rev
=
1
)
# __rdivmode__
__rpow__
=
_binary_elwise
(
"POW"
,
rev
=
1
)
__rlshift__
=
_binary_elwise
(
"SHL"
,
rev
=
1
)
__rrshift__
=
_binary_elwise
(
"SHR"
,
rev
=
1
)
__rand__
=
_logical_binary_elwise
(
"AND"
,
rev
=
1
)
__ror__
=
_logical_binary_elwise
(
"OR"
,
rev
=
1
)
__rxor__
=
_logical_binary_elwise
(
"XOR"
,
rev
=
1
)
__rpow__
=
_binary_elwise
(
_ElwMod
.
POW
,
rev
=
1
)
__rlshift__
=
_binary_elwise
(
_ElwMod
.
SHL
,
rev
=
1
)
__rrshift__
=
_binary_elwise
(
_ElwMod
.
SHR
,
rev
=
1
)
__rand__
=
_logical_binary_elwise
(
_ElwMod
.
AND
,
rev
=
1
)
__ror__
=
_logical_binary_elwise
(
_ElwMod
.
OR
,
rev
=
1
)
__rxor__
=
_logical_binary_elwise
(
_ElwMod
.
XOR
,
rev
=
1
)
__iadd__
=
_inplace
(
__add__
)
__isub__
=
_inplace
(
__sub__
)
...
...
imperative/python/megengine/functional/elemwise.py
浏览文件 @
fccb2510
...
...
@@ -10,6 +10,7 @@
import
functools
from
..core.ops
import
builtin
from
..core.ops.builtin
import
Elemwise
from
..core.tensor
import
megbrain_graph
,
utils
from
..core.tensor.core
import
apply
from
..device
import
get_default_device
...
...
@@ -72,7 +73,7 @@ __all__ = [
def
_elwise
(
*
args
,
mode
):
op
=
builtin
.
Elemwise
(
mode
=
mode
)
op
=
builtin
.
Elemwise
(
mode
)
tensor_args
=
list
(
filter
(
lambda
x
:
isinstance
(
x
,
(
Tensor
,
megbrain_graph
.
VarNode
)),
args
)
)
...
...
@@ -128,67 +129,67 @@ def add(x, y):
[ 6. 8. 10.]]
"""
return
_elwise
(
x
,
y
,
mode
=
"add"
)
return
_elwise
(
x
,
y
,
mode
=
Elemwise
.
Mode
.
ADD
)
def
sub
(
x
,
y
):
"""Element-wise `subtraction`."""
return
_elwise
(
x
,
y
,
mode
=
"sub"
)
return
_elwise
(
x
,
y
,
mode
=
Elemwise
.
Mode
.
SUB
)
def
mul
(
x
,
y
):
"""Element-wise `multiplication`."""
return
_elwise
(
x
,
y
,
mode
=
"mul"
)
return
_elwise
(
x
,
y
,
mode
=
Elemwise
.
Mode
.
MUL
)
def
div
(
x
,
y
):
"""Element-wise `(x / y)`."""
return
_elwise
(
x
,
y
,
mode
=
"true_div"
)
return
_elwise
(
x
,
y
,
mode
=
Elemwise
.
Mode
.
TRUE_DIV
)
def
floor_div
(
x
,
y
):
"""Element-wise `floor(x / y)`."""
return
_elwise
(
x
,
y
,
mode
=
"floor_divide"
)
return
_elwise
(
x
,
y
,
mode
=
Elemwise
.
Mode
.
FLOOR_DIVIDE
)
def
neg
(
x
):
"""Element-wise `negation`."""
return
_elwise
(
x
,
mode
=
"negate"
)
return
_elwise
(
x
,
mode
=
Elemwise
.
Mode
.
NEGATE
)
def
pow
(
x
,
y
):
"""Element-wise `power`."""
return
_elwise
(
x
,
y
,
mode
=
"pow"
)
return
_elwise
(
x
,
y
,
mode
=
Elemwise
.
Mode
.
POW
)
def
mod
(
x
,
y
):
"""Element-wise `remainder of division`."""
return
_elwise
(
x
,
y
,
mode
=
"mod"
)
return
_elwise
(
x
,
y
,
mode
=
Elemwise
.
Mode
.
MOD
)
def
abs
(
x
):
"""Element-wise `absolute value`."""
return
_elwise
(
x
,
mode
=
"abs"
)
return
_elwise
(
x
,
mode
=
Elemwise
.
Mode
.
ABS
)
def
exp
(
x
):
"""Element-wise `exponential`."""
return
_elwise
(
x
,
mode
=
"exp"
)
return
_elwise
(
x
,
mode
=
Elemwise
.
Mode
.
EXP
)
def
expm1
(
x
):
"""Element-wise `exp(x)-1`."""
return
_elwise
(
x
,
mode
=
"expm1"
)
return
_elwise
(
x
,
mode
=
Elemwise
.
Mode
.
EXPM1
)
def
log
(
x
):
"""Element-wise `logarithm (base e)`."""
return
_elwise
(
x
,
mode
=
"log"
)
return
_elwise
(
x
,
mode
=
Elemwise
.
Mode
.
LOG
)
def
log1p
(
x
):
"""Element-wise `log(x+1) (base e)`."""
return
_elwise
(
x
,
mode
=
"log1p"
)
return
_elwise
(
x
,
mode
=
Elemwise
.
Mode
.
LOG1P
)
def
sqrt
(
x
:
Tensor
)
->
Tensor
:
...
...
@@ -253,27 +254,27 @@ def square(x: Tensor) -> Tensor:
def
round
(
x
):
"""Element-wise `rounding to int`."""
return
_elwise
(
x
,
mode
=
"round"
)
return
_elwise
(
x
,
mode
=
Elemwise
.
Mode
.
ROUND
)
def
ceil
(
x
):
"""Element-wise `ceiling`."""
return
_elwise
(
x
,
mode
=
"ceil"
)
return
_elwise
(
x
,
mode
=
Elemwise
.
Mode
.
CEIL
)
def
floor
(
x
):
"""Element-wise `floor`."""
return
_elwise
(
x
,
mode
=
"floor"
)
return
_elwise
(
x
,
mode
=
Elemwise
.
Mode
.
FLOOR
)
def
maximum
(
x
,
y
):
"""Element-wise `maximum of array elements`."""
return
_elwise
(
x
,
y
,
mode
=
"max"
)
return
_elwise
(
x
,
y
,
mode
=
Elemwise
.
Mode
.
MAX
)
def
minimum
(
x
,
y
):
"""Element-wise `minimum of array elements`."""
return
_elwise
(
x
,
y
,
mode
=
"min"
)
return
_elwise
(
x
,
y
,
mode
=
Elemwise
.
Mode
.
MIN
)
# trigonometric functions
...
...
@@ -305,12 +306,12 @@ def cos(x):
[-0.99 -0.6536 0.2837]]
"""
return
_elwise
(
x
,
mode
=
"cos"
)
return
_elwise
(
x
,
mode
=
Elemwise
.
Mode
.
COS
)
def
sin
(
x
):
"""Element-wise `sine`."""
return
_elwise
(
x
,
mode
=
"sin"
)
return
_elwise
(
x
,
mode
=
Elemwise
.
Mode
.
SIN
)
def
tan
(
x
):
...
...
@@ -320,22 +321,22 @@ def tan(x):
def
acos
(
x
):
"""Element-wise `inverse cosine`."""
return
_elwise
(
x
,
mode
=
"acos"
)
return
_elwise
(
x
,
mode
=
Elemwise
.
Mode
.
ACOS
)
def
asin
(
x
):
"""Element-wise `inverse sine`."""
return
_elwise
(
x
,
mode
=
"asin"
)
return
_elwise
(
x
,
mode
=
Elemwise
.
Mode
.
ASIN
)
def
atan
(
x
):
"""Element-wise `inverse tangent`."""
return
_elwise
(
x
,
1
,
mode
=
"atan2"
)
return
_elwise
(
x
,
1
,
mode
=
Elemwise
.
Mode
.
ATAN2
)
def
atan2
(
y
,
x
):
"""Element-wise `2-argument arctangent`."""
return
_elwise
(
y
,
x
,
mode
=
"atan2"
)
return
_elwise
(
y
,
x
,
mode
=
Elemwise
.
Mode
.
ATAN2
)
def
cosh
(
x
):
...
...
@@ -351,7 +352,7 @@ def sinh(x):
def
tanh
(
x
):
r
"""Element-wise `hyperbolic tangent`."""
return
_elwise
(
x
,
mode
=
"tanh"
)
return
_elwise
(
x
,
mode
=
Elemwise
.
Mode
.
TANH
)
def
asinh
(
x
):
...
...
@@ -399,12 +400,12 @@ def left_shift(x, y):
[12 16 20]]
"""
return
_elwise
(
x
,
y
,
mode
=
"shl"
)
return
_elwise
(
x
,
y
,
mode
=
Elemwise
.
Mode
.
SHL
)
def
right_shift
(
x
,
y
):
"""Element-wise `bitwise binary: x >> y`."""
return
_elwise
(
x
,
y
,
mode
=
"shr"
)
return
_elwise
(
x
,
y
,
mode
=
Elemwise
.
Mode
.
SHR
)
# logical functions
...
...
@@ -412,22 +413,22 @@ def right_shift(x, y):
def
logical_and
(
x
,
y
):
"""Element-wise `logical and: x && y`."""
return
_elwise
(
x
,
y
,
mode
=
"AND"
)
return
_elwise
(
x
,
y
,
mode
=
Elemwise
.
Mode
.
AND
)
def
logical_not
(
x
):
"""Element-wise `logical not: ~x`."""
return
_elwise
(
x
,
mode
=
"NOT"
)
return
_elwise
(
x
,
mode
=
Elemwise
.
Mode
.
NOT
)
def
logical_or
(
x
,
y
):
"""Element-wise `logical or: x || y`."""
return
_elwise
(
x
,
y
,
mode
=
"OR"
)
return
_elwise
(
x
,
y
,
mode
=
Elemwise
.
Mode
.
OR
)
def
logical_xor
(
x
,
y
):
"""Element-wise `logical xor: x ^ y`."""
return
_elwise
(
x
,
y
,
mode
=
"XOR"
)
return
_elwise
(
x
,
y
,
mode
=
Elemwise
.
Mode
.
XOR
)
# comparison functions
...
...
@@ -461,7 +462,7 @@ def equal(x, y):
[1. 1. 1.]]
"""
return
_elwise
(
x
,
y
,
mode
=
"eq"
)
return
_elwise
(
x
,
y
,
mode
=
Elemwise
.
Mode
.
EQ
)
def
not_equal
(
x
,
y
):
...
...
@@ -471,22 +472,22 @@ def not_equal(x, y):
def
less
(
x
,
y
):
"""Element-wise `(x < y)`."""
return
_elwise
(
x
,
y
,
mode
=
"lt"
)
return
_elwise
(
x
,
y
,
mode
=
Elemwise
.
Mode
.
LT
)
def
less_equal
(
x
,
y
):
"""Element-wise `(x <= y)`."""
return
_elwise
(
x
,
y
,
mode
=
"leq"
)
return
_elwise
(
x
,
y
,
mode
=
Elemwise
.
Mode
.
LEQ
)
def
greater
(
x
,
y
):
"""Element-wise `(x > y)`."""
return
_elwise
(
y
,
x
,
mode
=
"lt"
)
return
_elwise
(
y
,
x
,
mode
=
Elemwise
.
Mode
.
LT
)
def
greater_equal
(
x
,
y
):
"""Element-wise `(x >= y)`."""
return
_elwise
(
y
,
x
,
mode
=
"leq"
)
return
_elwise
(
y
,
x
,
mode
=
Elemwise
.
Mode
.
LEQ
)
# other functions
...
...
@@ -515,7 +516,7 @@ def hswish(x):
[0. 0.6667 1.6667 3. 4. ]
"""
return
_elwise
(
x
,
mode
=
"h_swish"
)
return
_elwise
(
x
,
mode
=
Elemwise
.
Mode
.
H_SWISH
)
def
hsigmoid
(
x
):
...
...
@@ -525,7 +526,7 @@ def hsigmoid(x):
def
relu
(
x
):
"""Element-wise `max(x, 0)`."""
return
_elwise
(
x
,
mode
=
"relu"
)
return
_elwise
(
x
,
mode
=
Elemwise
.
Mode
.
RELU
)
def
relu6
(
x
):
...
...
@@ -535,7 +536,7 @@ def relu6(x):
def
sigmoid
(
x
):
"""Element-wise `1 / ( 1 + exp( -x ) )`."""
return
_elwise
(
x
,
mode
=
"sigmoid"
)
return
_elwise
(
x
,
mode
=
Elemwise
.
Mode
.
SIGMOID
)
def
clip
(
x
:
Tensor
,
lower
=
None
,
upper
=
None
)
->
Tensor
:
...
...
imperative/python/megengine/functional/nn.py
浏览文件 @
fccb2510
...
...
@@ -12,6 +12,7 @@ from typing import Optional, Sequence, Tuple, Union
from
..core._imperative_rt
import
CompNode
from
..core.ops
import
builtin
from
..core.ops._internal
import
param_defs
as
P
from
..core.ops.builtin
import
BatchNorm
from
..core.ops.special
import
Const
from
..core.tensor
import
megbrain_graph
,
utils
from
..core.tensor.core
import
TensorBase
,
TensorWrapperBase
,
apply
...
...
@@ -643,19 +644,22 @@ def batch_norm(
if
inp
.
ndim
!=
4
:
raise
NotImplementedError
(
"batch_norm for ndim != 4"
)
def
full_value
(
value
):
C
=
inp
.
shape
[
1
]
(
x
,)
=
Const
(
value
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)(
inp
)
return
broadcast_to
(
x
,
[
1
,
C
,
1
,
1
])
def
expand_or_full
(
x
,
value
):
if
x
is
None
:
return
full_value
(
value
)
return
expand_dims
(
x
,
[
0
,
2
,
3
])
C
=
inp
.
shape
[
1
]
def
make_full_if_none
(
x
,
value
):
if
x
is
None
:
return
full
(
shape
=
(
1
,
inp
.
shape
[
1
],
1
,
1
),
value
=
value
)
(
x
,)
=
Const
(
value
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)(
inp
)
shape
=
utils
.
astensor1d
(
(
1
,
C
,
1
,
1
),
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
(
result
,)
=
apply
(
builtin
.
Broadcast
(),
x
,
shape
)
return
result
elif
x
.
ndim
==
1
:
shape
=
utils
.
astensor1d
(
(
1
,
C
,
1
,
1
),
inp
,
dtype
=
"int32"
,
device
=
inp
.
device
)
(
result
,)
=
apply
(
builtin
.
Reshape
(),
x
,
shape
)
return
result
return
x
has_mean
=
running_mean
is
not
None
...
...
@@ -674,19 +678,25 @@ def batch_norm(
inp
,
weight
,
bias
,
running_mean
,
running_var
)
weight
=
expand_or_full
(
weight
,
1
)
bias
=
expand_or_full
(
bias
,
0
)
weight
=
make_full_if_none
(
weight
,
1
)
bias
=
make_full_if_none
(
bias
,
0
)
if
not
training
:
op
=
builtin
.
BatchNorm
(
fwd_mode
=
"INFERENCE"
,
epsilon
=
eps
,
param_dim
=
"DIM_1C11"
)
op
=
builtin
.
BatchNorm
(
BatchNorm
.
ParamDim
.
DIM_1C11
,
BatchNorm
.
FwdMode
.
INFERENCE
,
eps
,
1.0
,
1.0
,
0.0
)
ret
=
apply
(
op
,
inp
,
weight
,
bias
,
running_mean
,
running_var
)[
-
1
]
return
ret
else
:
op
=
builtin
.
BatchNorm
(
avg_factor
=
1
-
momentum
,
epsilon
=
eps
,
param_dim
=
"DIM_1C11"
BatchNorm
.
ParamDim
.
DIM_1C11
,
BatchNorm
.
FwdMode
.
TRAINING
,
eps
,
1.0
-
momentum
,
1.0
,
0.0
,
)
if
has_mean
or
has_var
:
running_mean
=
make_full_if_none
(
running_mean
,
0
)
running_var
=
make_full_if_none
(
running_var
,
1
)
...
...
@@ -708,7 +718,7 @@ def batch_norm(
else
:
return
inp
,
new_mean
,
new_var
else
:
_
,
_
,
inp
,
=
apply
(
op
,
inp
,
weight
,
bias
)
(
_
,
_
,
inp
,)
=
apply
(
op
,
inp
,
weight
,
bias
)
return
inp
...
...
imperative/python/megengine/module/batchnorm.py
浏览文件 @
fccb2510
...
...
@@ -72,14 +72,15 @@ class _BatchNorm(Module):
self
.
track_running_stats
==
False
),
"track_running_stats can not be initilized to False and changed to True later"
_ndims
=
len
(
inp
.
shape
)
inp_shape
=
inp
.
shape
_ndims
=
len
(
inp_shape
)
if
_ndims
!=
4
:
origin_shape
=
inp
.
shape
origin_shape
=
inp
_
shape
if
_ndims
==
2
:
n
,
c
=
inp
.
shape
[
0
],
inp
.
shape
[
1
]
n
,
c
=
inp
_shape
[
0
],
inp_
shape
[
1
]
new_shape
=
(
n
,
c
,
1
,
1
)
elif
_ndims
==
3
:
n
,
c
,
h
=
inp
.
shape
[
0
],
inp
.
shape
[
1
],
inp
.
shape
[
2
]
n
,
c
,
h
=
inp
_shape
[
0
],
inp_shape
[
1
],
inp_
shape
[
2
]
new_shape
=
(
n
,
c
,
h
,
1
)
inp
=
inp
.
reshape
(
new_shape
)
...
...
@@ -150,17 +151,18 @@ class SyncBatchNorm(_BatchNorm):
def
forward
(
self
,
inp
):
self
.
_check_input_ndim
(
inp
)
_ndims
=
len
(
inp
.
shape
)
inp_shape
=
inp
.
shape
_ndims
=
len
(
inp_shape
)
if
_ndims
!=
4
:
new_shape
=
Tensor
([
1
,
1
,
1
,
1
],
device
=
inp
.
device
)
origin_shape
=
inp
.
shape
origin_shape
=
inp
_
shape
if
_ndims
==
2
:
new_shape
[:
2
]
=
origin_shape
[:
2
]
elif
_ndims
==
3
:
new_shape
[:
3
]
=
origin_shape
[:
3
]
else
:
raise
ValueError
(
"expected 2D, 3D or 4D input (got {}D input)"
.
format
(
len
(
inp
.
shape
))
"expected 2D, 3D or 4D input (got {}D input)"
.
format
(
len
(
inp
_
shape
))
)
inp
=
inp
.
reshape
(
new_shape
)
...
...
imperative/python/src/ops.cpp
浏览文件 @
fccb2510
...
...
@@ -19,6 +19,8 @@
#include "megbrain/imperative/ops/io_remote.h"
#include "megbrain/imperative/ops/cond_take.h"
#include "megbrain/imperative/ops/nms.h"
#include "megbrain/imperative/ops/elemwise.h"
#include "megbrain/imperative/ops/batch_norm.h"
namespace
py
=
pybind11
;
...
...
@@ -117,4 +119,91 @@ void init_ops(py::module m) {
.
def
(
py
::
init
<
float
,
uint32_t
>
())
.
def_readwrite
(
"iou_thresh"
,
&
NMSKeep
::
iou_thresh
)
.
def_readwrite
(
"max_output"
,
&
NMSKeep
::
max_output
);
py
::
class_
<
Elemwise
,
std
::
shared_ptr
<
Elemwise
>
,
OpDef
>
elemwise
(
m
,
"Elemwise"
);
elemwise
.
def
(
py
::
init
<
Elemwise
::
Mode
>
())
.
def_readwrite
(
"mode"
,
&
Elemwise
::
mode
);
#define V(m) .value(#m, Elemwise::Mode::m)
py
::
enum_
<
Elemwise
::
Mode
>
(
elemwise
,
"Mode"
)
V
(
RELU
)
V
(
ABS
)
V
(
ACOS
)
V
(
ASIN
)
V
(
CEIL
)
V
(
COS
)
V
(
EXP
)
V
(
EXPM1
)
V
(
FLOOR
)
V
(
LOG
)
V
(
LOG1P
)
V
(
NEGATE
)
V
(
SIGMOID
)
V
(
SIN
)
V
(
TANH
)
V
(
ABS_GRAD
)
V
(
ADD
)
V
(
FLOOR_DIV
)
V
(
MAX
)
V
(
MIN
)
V
(
MOD
)
V
(
MUL
)
V
(
POW
)
V
(
SIGMOID_GRAD
)
V
(
SUB
)
V
(
SWITCH_GT0
)
V
(
TANH_GRAD
)
V
(
TRUE_DIV
)
V
(
LOG_SUM_EXP
)
V
(
LT
)
V
(
LEQ
)
V
(
EQ
)
V
(
SHL
)
V
(
SHR
)
V
(
COND_LEQ_MOV
)
V
(
FUSE_MUL_ADD3
)
V
(
FUSE_MUL_ADD4
)
V
(
FUSE_ADD_RELU
)
V
(
FUSE_ADD_SIGMOID
)
V
(
FUSE_ADD_TANH
)
V
(
FAST_TANH
)
V
(
FAST_TANH_GRAD
)
V
(
ROUND
)
V
(
RMULH
)
V
(
ATAN2
)
V
(
ERF
)
V
(
ERFINV
)
V
(
ERFC
)
V
(
ERFCINV
)
V
(
H_SWISH
)
V
(
H_SWISH_GRAD
)
V
(
FUSE_ADD_H_SWISH
)
V
(
NOT
)
V
(
AND
)
V
(
OR
)
V
(
XOR
);
#undef V
py
::
class_
<
BatchNorm
,
std
::
shared_ptr
<
BatchNorm
>
,
OpDef
>
batchnorm
(
m
,
"BatchNorm"
);
batchnorm
.
def
(
py
::
init
<
const
BatchNorm
::
Param
::
ParamDim
&
,
const
BatchNorm
::
Param
::
FwdMode
&
,
double
,
double
,
float
,
float
>
())
.
def_readwrite
(
"param_dim"
,
&
BatchNorm
::
param_dim
)
.
def_readwrite
(
"fwd_mode"
,
&
BatchNorm
::
fwd_mode
)
.
def_readwrite
(
"epsilon"
,
&
BatchNorm
::
epsilon
)
.
def_readwrite
(
"avg_factor"
,
&
BatchNorm
::
avg_factor
)
.
def_readwrite
(
"scale"
,
&
BatchNorm
::
scale
)
.
def_readwrite
(
"bias"
,
&
BatchNorm
::
bias
);
#define V(m) .value(#m, BatchNorm::Param::ParamDim::m)
py
::
enum_
<
BatchNorm
::
Param
::
ParamDim
>
(
batchnorm
,
"ParamDim"
)
V
(
DIM_11HW
)
V
(
DIM_1CHW
)
V
(
DIM_1C11
);
#undef V
#define V(m) .value(#m, BatchNorm::Param::FwdMode::m)
py
::
enum_
<
BatchNorm
::
Param
::
FwdMode
>
(
batchnorm
,
"FwdMode"
)
V
(
TRAINING
)
V
(
INFERENCE
);
#undef V
}
imperative/python/test/unit/core/test_autodiff.py
浏览文件 @
fccb2510
...
...
@@ -27,7 +27,7 @@ from megengine.functional.distributed import remote_recv, remote_send
def
_elwise
(
mode
):
op
=
Elemwise
(
mode
=
mode
)
op
=
Elemwise
(
mode
)
def
f
(
*
args
):
(
result
,)
=
apply
(
op
,
*
args
)
...
...
@@ -36,10 +36,10 @@ def _elwise(mode):
return
f
add
=
_elwise
(
"add"
)
mul
=
_elwise
(
"mul"
)
cos
=
_elwise
(
"cos"
)
relu
=
_elwise
(
"relu"
)
add
=
_elwise
(
Elemwise
.
Mode
.
ADD
)
mul
=
_elwise
(
Elemwise
.
Mode
.
MUL
)
cos
=
_elwise
(
Elemwise
.
Mode
.
COS
)
relu
=
_elwise
(
Elemwise
.
Mode
.
RELU
)
def
as_tensor
(
x
):
...
...
@@ -255,7 +255,7 @@ def test_elemwise_relu():
def
test_elemwise_relu_backward_fn
():
op
=
Elemwise
(
mode
=
"relu"
).
to_c
(
)
op
=
Elemwise
(
Elemwise
.
Mode
.
RELU
)
attr
=
TensorAttr
()
attr
.
dtype
=
"float32"
attr
.
comp_node
=
"xpux"
...
...
imperative/python/test/unit/core/test_imperative_rt.py
浏览文件 @
fccb2510
...
...
@@ -17,7 +17,7 @@ def elemwise(*args, mode):
from
megengine.core.ops.builtin
import
Elemwise
from
megengine.core._imperative_rt.imperative
import
apply_op
return
apply_op
(
Elemwise
(
mode
=
mode
).
to_c
(
),
args
)
return
apply_op
(
Elemwise
(
mode
),
args
)
def
test_basic_interface
():
...
...
@@ -37,13 +37,15 @@ def test_basic_interface():
def
test_opr_attr
():
from
megengine.core.ops.builtin
import
Elemwise
assert
Elemwise
(
mode
=
"add"
)
==
Elemwise
(
mode
=
"add"
)
assert
Elemwise
(
Elemwise
.
Mode
.
ADD
)
==
Elemwise
(
Elemwise
.
Mode
.
ADD
)
def
test_simple_arith
():
from
megengine.core.ops.builtin
import
Elemwise
x
=
np
.
random
.
rand
(
10
).
astype
(
"float32"
)
xx
=
megengine
.
core
.
_imperative_rt
.
put
(
x
)
(
yy
,)
=
elemwise
(
xx
,
xx
,
mode
=
"mul"
)
(
yy
,)
=
elemwise
(
xx
,
xx
,
mode
=
Elemwise
.
Mode
.
MUL
)
np
.
testing
.
assert_allclose
(
x
*
x
,
megengine
.
core
.
_imperative_rt
.
get_value
(
yy
))
megengine
.
core
.
_imperative_rt
.
delete
(
xx
)
megengine
.
core
.
_imperative_rt
.
delete
(
yy
)
...
...
@@ -64,7 +66,7 @@ def test_raw_tensor():
x
=
np
.
random
.
rand
(
10
).
astype
(
"float32"
)
xx
=
as_raw_tensor
(
x
)
(
yy
,)
=
apply
(
Elemwise
(
mode
=
"mul"
),
xx
,
xx
)
(
yy
,)
=
apply
(
Elemwise
(
Elemwise
.
Mode
.
MUL
),
xx
,
xx
)
np
.
testing
.
assert_allclose
(
x
*
x
,
yy
.
numpy
())
(
yy
,)
=
apply
(
Elemwise
(
mode
=
"mul"
),
xx
,
xx
)
(
yy
,)
=
apply
(
Elemwise
(
Elemwise
.
Mode
.
MUL
),
xx
,
xx
)
np
.
testing
.
assert_allclose
(
x
*
x
,
yy
.
numpy
())
imperative/python/test/unit/test_tracing.py
浏览文件 @
fccb2510
...
...
@@ -17,6 +17,7 @@ import megengine.functional as F
from
megengine
import
cgtools
,
tensor
from
megengine.core._trace_option
import
set_symbolic_shape
from
megengine.core.ops
import
builtin
as
ops
from
megengine.core.ops.builtin
import
Elemwise
from
megengine.core.tensor.core
import
apply
from
megengine.core.tensor.raw_tensor
import
as_raw_tensor
from
megengine.functional
import
exp
,
log
...
...
@@ -28,7 +29,7 @@ def test_trace():
@
trace
(
symbolic
=
symbolic
)
def
f
(
x
):
op
=
ops
.
Elemwise
(
mode
=
"negate"
)
op
=
ops
.
Elemwise
(
Elemwise
.
Mode
.
NEGATE
)
(
y
,)
=
apply
(
op
,
x
)
return
y
...
...
@@ -44,7 +45,7 @@ def test_exclude_from_trace():
@
trace
(
symbolic
=
symbolic
)
def
f
(
x
):
neg
=
ops
.
Elemwise
(
mode
=
"negate"
)
neg
=
ops
.
Elemwise
(
Elemwise
.
Mode
.
NEGATE
)
(
x
,)
=
apply
(
neg
,
x
)
with
exclude_from_trace
():
if
i
%
2
:
...
...
@@ -65,7 +66,7 @@ def test_print_in_trace():
@
trace
(
symbolic
=
symbolic
)
def
f
(
x
):
nonlocal
buf
neg
=
ops
.
Elemwise
(
mode
=
"negate"
)
neg
=
ops
.
Elemwise
(
Elemwise
.
Mode
.
NEGATE
)
(
x
,)
=
apply
(
neg
,
x
)
buf
=
x
.
numpy
()
(
x
,)
=
apply
(
neg
,
x
)
...
...
@@ -85,7 +86,7 @@ def test_print_in_trace():
def
test_dump
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
f
(
a
,
b
):
op
=
ops
.
Elemwise
(
mode
=
"add"
)
op
=
ops
.
Elemwise
(
Elemwise
.
Mode
.
ADD
)
(
y
,)
=
apply
(
op
,
a
,
b
)
return
y
...
...
@@ -111,7 +112,7 @@ def test_capture_dump():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
f
(
x
):
op
=
ops
.
Elemwise
(
mode
=
"mul"
)
op
=
ops
.
Elemwise
(
Elemwise
.
Mode
.
MUL
)
(
y
,)
=
apply
(
op
,
x
,
a
)
return
y
...
...
@@ -133,7 +134,7 @@ def test_dump_volatile():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
f
(
x
):
op
=
ops
.
Elemwise
(
mode
=
"mul"
)
op
=
ops
.
Elemwise
(
Elemwise
.
Mode
.
MUL
)
(
y
,)
=
apply
(
op
,
x
,
p
)
return
y
...
...
@@ -159,7 +160,7 @@ def test_trace_profiler():
@
trace
(
symbolic
=
symbolic
,
profiling
=
True
)
def
f
(
x
):
op
=
ops
.
Elemwise
(
mode
=
"negate"
)
op
=
ops
.
Elemwise
(
Elemwise
.
Mode
.
NEGATE
)
(
y
,)
=
apply
(
op
,
x
)
return
y
...
...
imperative/src/impl/ops/batch_norm.cpp
0 → 100644
浏览文件 @
fccb2510
/**
* \file imperative/src/impl/ops/batch_norm.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/ops/batch_norm.h"
#include "../op_trait.h"
namespace
mgb
{
namespace
imperative
{
namespace
{
std
::
shared_ptr
<
OpDef
>
make_from_op_node
(
cg
::
OperatorNodeBase
*
node_
)
{
auto
*
node
=
&
node_
->
cast_final_safe
<
opr
::
BatchNorm
>
();
auto
&&
param
=
node
->
param
();
return
BatchNorm
::
make
(
param
.
param_dim
,
param
.
fwd_mode
,
param
.
epsilon
,
param
.
avg_factor
,
param
.
scale
,
param
.
bias
);
}
cg
::
OperatorNodeBase
*
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
bn_opr
=
def
.
cast_final_safe
<
BatchNorm
>
();
size_t
nr_inp
=
inputs
.
size
();
mgb_assert
(
nr_inp
==
3
||
nr_inp
==
5
,
"BatchNorm expects 3 or 5 inputs; got %lu actually"
,
nr_inp
);
if
(
nr_inp
==
3
)
{
return
opr
::
BatchNorm
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
{
bn_opr
.
param_dim
,
bn_opr
.
fwd_mode
,
bn_opr
.
epsilon
,
bn_opr
.
avg_factor
,
bn_opr
.
scale
,
bn_opr
.
bias
})[
0
]
.
node
()
->
owner_opr
();
}
else
{
return
opr
::
BatchNorm
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
inputs
[
3
],
inputs
[
4
],
{
bn_opr
.
param_dim
,
bn_opr
.
fwd_mode
,
bn_opr
.
epsilon
,
bn_opr
.
avg_factor
,
bn_opr
.
scale
,
bn_opr
.
bias
})[
0
]
.
node
()
->
owner_opr
();
}
}
SmallVector
<
LogicalTensorDesc
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
&&
op_def
=
def
.
cast_final_safe
<
BatchNorm
>
();
size_t
nr_inp
=
inputs
.
size
();
mgb_assert
(
nr_inp
==
3
||
nr_inp
==
5
,
"BatchNorm expects 3 or 5 inputs; got %lu actually"
,
nr_inp
);
// need running mean/variance
bool
need_stat
=
(
nr_inp
==
5
)
&&
op_def
.
fwd_mode
==
BatchNorm
::
Param
::
FwdMode
::
TRAINING
;
size_t
nr_out
=
need_stat
?
5
:
3
;
SmallVector
<
LogicalTensorDesc
>
out_shapes
(
nr_out
);
auto
&&
i0
=
inputs
[
0
];
auto
&&
i1
=
inputs
[
1
];
size_t
i
=
0
;
if
(
!
need_stat
)
{
out_shapes
[
0
]
=
out_shapes
[
1
]
=
{
TensorLayout
({
0
},
i0
.
layout
.
dtype
,
i0
.
layout
.
format
),
i0
.
comp_node
};
i
=
2
;
}
for
(;
i
<
nr_out
-
1
;
++
i
)
{
out_shapes
[
i
]
=
{
i1
.
layout
,
i1
.
comp_node
};
}
out_shapes
[
nr_out
-
1
]
=
{
i0
.
layout
,
i0
.
comp_node
};
return
out_shapes
;
}
OP_TRAIT_REG
(
BatchNorm
,
BatchNorm
,
opr
::
BatchNorm
)
.
make_from_op_node
(
make_from_op_node
)
.
apply_on_var_node
(
apply_on_var_node
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
fallback
();
}
// anonymous namespace
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
BatchNorm
);
}
// namespace imperative
}
// namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
imperative/src/impl/ops/elemwise.cpp
0 → 100644
浏览文件 @
fccb2510
/**
* \file imperative/src/impl/ops/elemwise.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/ops/elemwise.h"
#include "../op_trait.h"
namespace
mgb
{
namespace
imperative
{
namespace
{
std
::
shared_ptr
<
OpDef
>
make_from_op_node
(
cg
::
OperatorNodeBase
*
node_
)
{
auto
*
node
=
&
node_
->
cast_final_safe
<
opr
::
Elemwise
>
();
return
Elemwise
::
make
(
node
->
param
().
mode
);
}
cg
::
OperatorNodeBase
*
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
elemwise_opr
=
def
.
cast_final_safe
<
Elemwise
>
();
return
opr
::
Elemwise
::
make
(
inputs
,
elemwise_opr
.
mode
).
node
()
->
owner_opr
();
}
SmallVector
<
LogicalTensorDesc
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
&&
op_def
=
def
.
cast_final_safe
<
Elemwise
>
();
auto
trait
=
Elemwise
::
ModeTrait
::
from_mode
(
op_def
.
mode
);
mgb_assert
(
inputs
.
size
()
==
trait
.
arity
,
"%s expects %u inputs; got %zu actually"
,
trait
.
name
,
trait
.
arity
,
inputs
.
size
());
TensorShapeArray
inp_shapes
;
DType
out_dt
;
CompNode
out_cn
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
auto
&&
t
=
inputs
[
i
];
if
(
!
i
)
{
out_cn
=
t
.
comp_node
;
out_dt
=
t
.
layout
.
dtype
;
}
else
{
mgb_assert
(
t
.
comp_node
==
out_cn
);
mgb_assert
(
t
.
layout
.
dtype
==
out_dt
);
}
if
(
t
.
layout
.
ndim
>
0
)
{
inp_shapes
.
push_back
(
t
.
layout
);
}
else
{
TensorLayout
out_layout
;
out_layout
.
ndim
=
0
;
out_layout
.
dtype
=
out_dt
;
return
{{
out_layout
,
out_cn
}};
}
}
auto
&&
out_shape
=
opr
::
Elemwise
::
get_output_var_shape
(
op_def
.
mode
,
inp_shapes
);
return
{{
TensorLayout
(
out_shape
,
out_dt
,
inputs
[
0
].
layout
.
format
),
out_cn
}};
}
OP_TRAIT_REG
(
Elemwise
,
Elemwise
,
opr
::
Elemwise
)
.
make_from_op_node
(
make_from_op_node
)
.
apply_on_var_node
(
apply_on_var_node
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
fallback
();
}
// anonymous namespace
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
Elemwise
);
}
// namespace imperative
}
// namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
imperative/src/include/megbrain/imperative/ops/batch_norm.h
0 → 100644
浏览文件 @
fccb2510
/**
* \file imperative/src/include/megbrain/imperative/ops/batch_norm.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/imperative/op_def.h"
#include "megbrain/utils/hash.h"
namespace
mgb
::
imperative
{
class
BatchNorm
:
public
OpDefImplBase
<
BatchNorm
>
{
MGB_DYN_TYPE_OBJ_FINAL_DECL
;
public:
using
Param
=
opr
::
BatchNorm
::
Param
;
Param
::
ParamDim
param_dim
;
Param
::
FwdMode
fwd_mode
;
double
epsilon
;
double
avg_factor
;
float
scale
;
float
bias
;
BatchNorm
()
=
default
;
BatchNorm
(
const
Param
::
ParamDim
&
param_dim_
,
const
Param
::
FwdMode
&
fwd_mode_
,
double
epsilon_
,
double
avg_factor_
,
float
scale_
,
float
bias_
)
:
param_dim
(
param_dim_
),
fwd_mode
(
fwd_mode_
),
epsilon
(
epsilon_
),
avg_factor
(
avg_factor_
),
scale
(
scale_
),
bias
(
bias_
)
{}
size_t
hash
()
const
override
{
XXHash
xxhash
{};
auto
append
=
[
&
xxhash
](
auto
field
){
auto
hash_val
=
HashTrait
<
decltype
(
field
)
>::
eval
(
field
);
xxhash
.
update
(
reinterpret_cast
<
void
*>
(
&
hash_val
),
sizeof
(
hash_val
));
};
append
(
param_dim
);
append
(
fwd_mode
);
append
(
epsilon
);
append
(
avg_factor
);
append
(
scale
);
append
(
bias
);
return
xxhash
.
digest
();
}
bool
is_same_st
(
const
Hashable
&
rhs_
)
const
override
{
auto
&&
rhs
=
static_cast
<
const
BatchNorm
&>
(
rhs_
);
return
rhs
.
param_dim
==
param_dim
&&
rhs
.
fwd_mode
==
fwd_mode
&&
rhs
.
epsilon
==
epsilon
&&
rhs
.
avg_factor
==
avg_factor
&&
rhs
.
scale
==
scale
&&
rhs
.
bias
==
bias
;
}
};
}
// namespace mgb::imperative
imperative/src/include/megbrain/imperative/ops/elemwise.h
0 → 100644
浏览文件 @
fccb2510
/**
* \file imperative/src/include/megbrain/imperative/ops/elemwise.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/opr/basic_arith.h"
#include "megbrain/imperative/op_def.h"
namespace
mgb
::
imperative
{
class
Elemwise
:
public
OpDefImplBase
<
Elemwise
>
{
MGB_DYN_TYPE_OBJ_FINAL_DECL
;
public:
using
Mode
=
opr
::
Elemwise
::
Mode
;
using
ModeTrait
=
megdnn
::
Elemwise
::
ModeTrait
;
Mode
mode
;
Elemwise
()
=
default
;
Elemwise
(
const
Mode
&
mode_
)
:
mode
(
mode_
)
{}
size_t
hash
()
const
override
{
return
hash_pair_combine
(
mgb
::
hash
(
mode
),
reinterpret_cast
<
std
::
uintptr_t
>
(
dyn_typeinfo
()));
}
bool
is_same_st
(
const
Hashable
&
rhs_
)
const
override
{
auto
&&
rhs
=
static_cast
<
const
Elemwise
&>
(
rhs_
);
return
rhs
.
mode
==
mode
;
}
};
}
// namespace mgb::imperative
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录