Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
4c7905f3
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
4c7905f3
编写于
7月 06, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative): add some xla op rules
GitOrigin-RevId: 0650c75dc1e4ec9af8ae7d9ed3eca60e4681e04a
上级
0d2b4db9
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
495 addition
and
95 deletion
+495
-95
imperative/python/megengine/jit/partial_tracing.py
imperative/python/megengine/jit/partial_tracing.py
+3
-1
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+1
-1
imperative/python/megengine/xla/ir_utils.py
imperative/python/megengine/xla/ir_utils.py
+3
-2
imperative/python/megengine/xla/lower.py
imperative/python/megengine/xla/lower.py
+1
-1
imperative/python/megengine/xla/rules/elemwise.py
imperative/python/megengine/xla/rules/elemwise.py
+248
-16
imperative/python/megengine/xla/rules/math.py
imperative/python/megengine/xla/rules/math.py
+8
-1
imperative/python/megengine/xla/rules/reduction.py
imperative/python/megengine/xla/rules/reduction.py
+7
-3
imperative/python/megengine/xla/rules/tensor.py
imperative/python/megengine/xla/rules/tensor.py
+5
-0
imperative/python/megengine/xla/rules/trivial.py
imperative/python/megengine/xla/rules/trivial.py
+6
-0
imperative/python/megengine/xla/rules/utils.py
imperative/python/megengine/xla/rules/utils.py
+12
-4
imperative/python/src/grad_override.cpp
imperative/python/src/grad_override.cpp
+1
-1
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+2
-1
imperative/python/test/unit/jit/test_tracing.py
imperative/python/test/unit/jit/test_tracing.py
+4
-1
imperative/python/test/unit/xla/functional/test_xla_elemwise.py
...tive/python/test/unit/xla/functional/test_xla_elemwise.py
+76
-63
imperative/python/test/unit/xla/functional/test_xla_nn.py
imperative/python/test/unit/xla/functional/test_xla_nn.py
+67
-0
imperative/python/test/unit/xla/module/test_elemwise.py
imperative/python/test/unit/xla/module/test_elemwise.py
+49
-0
src/plugin/impl/opr_footprint.cpp
src/plugin/impl/opr_footprint.cpp
+2
-0
未找到文件。
imperative/python/megengine/jit/partial_tracing.py
浏览文件 @
4c7905f3
...
...
@@ -75,7 +75,9 @@ def _process_fwd_bwd_trace_result(fwd, bwd, inp_grad_map, out_grad_map):
def
check_external
(
trace_obj
):
for
var
in
trace_obj
.
vars
:
if
var
.
kind
==
"external"
and
not
var
.
inp_mark
:
raise
RuntimeError
(
"have unknown input in trace result"
)
raise
RuntimeError
(
"have unknown input in trace result, maybe you can set `capture_as_const=True` when trace"
)
check_external
(
fwd
)
check_external
(
bwd
)
...
...
imperative/python/megengine/jit/tracing.py
浏览文件 @
4c7905f3
...
...
@@ -579,7 +579,7 @@ class trace:
if
not
self
.
_trace
.
compiled
():
outlist
,
self
.
outdef
=
tree_flatten
(
outputs
)
for
i
,
out
in
enumerate
(
outlist
):
assert
isinstance
(
out
,
RawTensor
)
assert
isinstance
(
out
,
RawTensor
)
,
f
"get out of type
{
type
(
out
)
}
"
outlist
[
i
]
=
get_marked_output_tensor
(
self
.
output_num
,
out
)
del
out
self
.
out_list
.
append
(
self
.
output_num
)
...
...
imperative/python/megengine/xla/ir_utils.py
浏览文件 @
4c7905f3
...
...
@@ -101,10 +101,11 @@ class DropoutMaskCanonicalizer(Pass):
if
not
isinstance
(
eqn
.
op
,
mops
.
Dropout
):
continue
outputs
=
list
(
eqn
.
outputs
)
inputs
,
outputs
=
list
(
eqn
.
inputs
),
list
(
eqn
.
outputs
)
mask_var
=
tr
.
vars
[
outputs
[
1
]]
inp_shape
=
tr
.
vars
[
inputs
[
0
]].
shape
new_mask_var
=
AbstractVar
(
mask_var
.
id
,
(
int
(
np
.
prod
(
mask_var
.
shape
))
*
8
,),
mask_var
.
dtype
mask_var
.
id
,
(
int
(
np
.
prod
(
inp_shape
))
,),
mask_var
.
dtype
)
tr
.
vars
[
mask_var
.
id
]
=
new_mask_var
...
...
imperative/python/megengine/xla/lower.py
浏览文件 @
4c7905f3
...
...
@@ -142,7 +142,7 @@ def lowering_ops(
vars_out
=
[
trace_result
.
vars
[
oup
]
for
oup
in
eqn
.
outputs
],
param
=
eqn
.
param
,
)
rule
=
get_rule
(
eqn
.
op
)
rule
=
get_rule
(
eqn
.
op
,
use_fake_rule_for_debug
=
False
)
in_nodes
=
read
(
eqn
.
inputs
)
hinps
=
[
...
...
imperative/python/megengine/xla/rules/elemwise.py
浏览文件 @
4c7905f3
...
...
@@ -18,9 +18,9 @@ def _infer_elemwise_oshape(inp_shapes):
if
len
(
rhs_shape
)
==
0
:
return
lhs_shape
if
np
.
prod
(
lhs_shape
)
==
1
and
len
(
rhs_shape
)
!=
0
:
if
np
.
prod
(
lhs_shape
)
==
1
and
len
(
lhs_shape
)
==
1
and
len
(
rhs_shape
)
!=
0
:
return
rhs_shape
if
np
.
prod
(
rhs_shape
)
==
1
and
len
(
rhs_shape
)
!=
0
:
if
np
.
prod
(
rhs_shape
)
==
1
and
len
(
rhs_shape
)
==
1
and
len
(
rhs_shape
)
!=
0
:
return
lhs_shape
oshape
=
[]
...
...
@@ -62,6 +62,24 @@ def _infer_elemwise_odtype(inp_dtypes):
return
oup_dtype
def
bitcast
(
inp
,
oshape
,
odtype
):
odtype
=
np
.
dtype
(
odtype
)
if
isinstance
(
odtype
,
str
)
else
odtype
return
HLOTensor
(
hlo
.
BitcastConvertOp
(
ir_utils
.
make_ir_type_according_meta
(
oshape
,
odtype
),
inp
.
tensor
).
result
)
def
typecvt
(
inp
,
odtype
):
odtype
=
np
.
dtype
(
odtype
)
if
isinstance
(
odtype
,
str
)
else
odtype
return
HLOTensor
(
hlo
.
ConvertOp
(
ir_utils
.
make_ir_type_according_meta
(
inp
.
shape
,
odtype
),
inp
.
tensor
).
result
)
def
_compare
(
lhs
,
rhs
,
mode
,
comparison_type
=
None
):
"""
mod: can be
...
...
@@ -126,19 +144,36 @@ def _elemwise_binary(hlo_op, a, b):
return
_elemwise
(
hlo_op
,
[
a
,
b
])
def
_elemwise_ternary
(
hlo_op
,
a
,
b
,
c
):
return
_elemwise
(
hlo_op
,
[
a
,
b
,
c
])
neg
=
partial
(
_elemwise_unary
,
hlo
.
NegOp
)
abs
=
partial
(
_elemwise_unary
,
hlo
.
AbsOp
)
sin
=
partial
(
_elemwise_unary
,
hlo
.
SineOp
)
cos
=
partial
(
_elemwise_unary
,
hlo
.
CosineOp
)
tanh
=
partial
(
_elemwise_unary
,
hlo
.
TanhOp
)
exp
=
partial
(
_elemwise_unary
,
hlo
.
ExpOp
)
sqrt
=
partial
(
_elemwise_unary
,
hlo
.
SqrtOp
)
log
=
partial
(
_elemwise_unary
,
hlo
.
LogOp
)
log1p
=
partial
(
_elemwise_unary
,
hlo
.
Log1pOp
)
expm1
=
partial
(
_elemwise_unary
,
hlo
.
Expm1Op
)
floor
=
partial
(
_elemwise_unary
,
hlo
.
FloorOp
)
ceil
=
partial
(
_elemwise_unary
,
hlo
.
CeilOp
)
round
=
partial
(
_elemwise_unary
,
hlo
.
RoundOp
)
add
=
partial
(
_elemwise_binary
,
hlo
.
AddOp
)
sub
=
partial
(
_elemwise_binary
,
hlo
.
SubtractOp
)
mul
=
partial
(
_elemwise_binary
,
hlo
.
MulOp
)
div
=
partial
(
_elemwise_binary
,
hlo
.
DivOp
)
pow
=
partial
(
_elemwise_binary
,
hlo
.
PowOp
)
maximum
=
partial
(
_elemwise_binary
,
hlo
.
MaxOp
)
minimum
=
partial
(
_elemwise_binary
,
hlo
.
MinOp
)
atan2
=
partial
(
_elemwise_binary
,
hlo
.
Atan2Op
)
left_shift
=
partial
(
_elemwise_binary
,
hlo
.
ShiftLeftOp
)
right_shift
=
partial
(
_elemwise_binary
,
hlo
.
ShiftRightArithmeticOp
)
clip
=
partial
(
_elemwise_ternary
,
hlo
.
ClampOp
)
equal
=
partial
(
_compare
,
mode
=
"EQ"
)
not_equal
=
partial
(
_compare
,
mode
=
"NE"
)
...
...
@@ -147,31 +182,99 @@ greater_equal = partial(_compare, mode="GE")
less
=
partial
(
_compare
,
mode
=
"LT"
)
less_equal
=
partial
(
_compare
,
mode
=
"LE"
)
logical_and
=
partial
(
_elemwise_binary
,
hlo
.
AndOp
)
logical_or
=
partial
(
_elemwise_binary
,
hlo
.
OrOp
)
logical_not
=
partial
(
_elemwise_unary
,
hlo
.
NotOp
)
logical_xor
=
partial
(
_elemwise_binary
,
hlo
.
XorOp
)
def
floor_div
(
x
,
y
):
return
floor
(
div
(
x
,
y
))
def
mod
(
x
,
y
):
assert
False
,
"xla not support"
def
cond_leq_move
(
x
,
y
,
z
):
mask
=
(
x
<=
y
).
astype
(
x
.
dtype
)
return
mask
*
z
def
cond_lt_move
(
x
,
y
,
z
):
mask
=
(
x
<
y
).
astype
(
x
.
dtype
)
return
mask
*
z
def
log_add_exp
(
x
,
y
):
min_val
=
minimum
(
x
,
y
)
max_val
=
maximum
(
x
,
y
)
return
max_val
+
log1p
(
exp
(
min_val
-
max_val
))
def
square
(
x
):
return
mul
(
x
,
x
)
def
abs_grad
(
x
,
dy
):
return
(
x
/
abs
(
x
))
*
dy
def
tan
(
x
):
return
sin
(
x
)
/
cos
(
x
)
def
tan_grad
(
x
,
dy
):
return
(
1.0
+
tan
(
x
)
**
2.0
)
*
dy
def
sinh
(
x
):
return
(
exp
(
x
)
-
exp
(
-
x
))
/
2.0
def
cosh
(
x
):
return
(
exp
(
x
)
+
exp
(
-
x
))
/
2.0
def
tanh_grad
(
x
,
dy
):
return
(
1.0
-
tanh
(
x
)
**
2.0
)
*
dy
def
bitcast
(
inp
,
oshape
,
odtype
):
odtype
=
np
.
dtype
(
odtype
)
if
isinstance
(
odtype
,
str
)
else
odtype
return
HLOTensor
(
hlo
.
BitcastConvertOp
(
ir_utils
.
make_ir_type_according_meta
(
oshape
,
odtype
),
inp
.
tensor
).
result
)
def
atan
(
x
):
return
atan2
(
x
,
1.0
)
def
typecvt
(
inp
,
odtype
):
odtype
=
np
.
dtype
(
odtype
)
if
isinstance
(
odtype
,
str
)
else
odtype
return
HLOTensor
(
hlo
.
ConvertOp
(
ir_utils
.
make_ir_type_according_meta
(
inp
.
shape
,
odtype
),
inp
.
tensor
).
result
)
def
asin
(
x
):
return
atan
(
x
/
sqrt
(
1.0
-
x
**
2.0
))
def
acos
(
x
):
assert
False
,
"xla not support"
# return atan(sqrt(1.0 - x ** 2.0) / x)
def
asinh
(
x
):
return
log
(
x
+
sqrt
(
x
**
2.0
+
1.0
))
def
acosh
(
x
):
return
log
(
x
+
sqrt
(
x
**
2.0
-
1.0
))
def
atanh
(
x
):
return
log
((
1.0
+
x
)
/
(
1.0
-
x
))
/
2.0
def
asinh_grad
(
x
,
dy
):
return
dy
/
sqrt
(
x
**
2.0
+
1.0
)
def
acosh_grad
(
x
,
dy
):
return
dy
/
sqrt
(
x
**
2.0
-
1.0
)
def
atanh_grad
(
x
,
dy
):
return
dy
/
(
1.0
-
x
**
2.0
)
def
gelu
(
inp
,
approximate
:
bool
=
True
):
...
...
@@ -257,6 +360,86 @@ def relu_grad(x, dy):
return
dy
*
mask
def
sigmoid
(
inp
):
return
1.0
/
(
1.0
+
exp
(
-
inp
))
def
sigmoid_grad
(
y
,
dy
):
return
y
*
(
1.0
-
y
)
*
dy
def
hsigmoid
(
x
):
from
.tensor
import
where
return
where
(
x
<=
-
3.0
,
0.0
,
where
(
x
>=
3.0
,
1.0
,
(
x
+
3.0
)
/
6.0
))
def
hsigmoid_grad
(
x
,
dy
):
from
.tensor
import
where
return
where
(
x
<=
-
3.0
,
0.0
,
where
(
x
>=
3.0
,
0.0
,
dy
/
6.0
))
def
relu6
(
x
):
return
clip
(
x
,
0.0
,
6.0
)
def
relu6_grad
(
x
,
dy
):
from
.tensor
import
where
return
where
(
x
<=
0.0
,
0.0
,
where
(
x
>=
6.0
,
0.0
,
dy
))
def
hswish
(
x
):
return
x
*
minimum
(
maximum
(
x
+
3.0
,
0.0
),
6.0
)
*
(
1.0
/
6.0
)
def
hswish_grad
(
x
,
dy
):
from
.tensor
import
where
return
where
(
x
<
-
3.0
,
0.0
,
where
(
x
>
3.0
,
dy
,
(
2.0
*
x
+
3.0
)
/
6.0
*
dy
))
def
logsigmoid
(
x
):
from
.tensor
import
where
return
-
log1p
(
exp
(
-
abs
(
x
)))
+
where
(
x
>=
0.0
,
0.0
,
x
)
def
softplus
(
x
):
return
log1p
(
exp
(
-
abs
(
x
)))
+
relu
(
x
)
def
softplus_grad
(
x
,
dy
):
from
.tensor
import
where
exp_abs
=
exp
(
-
abs
(
x
))
logg
=
-
dy
*
exp_abs
/
(
1.0
+
exp_abs
)
grad0
=
where
(
x
>
0.0
,
logg
,
-
logg
)
relux
=
relu
(
x
)
grad1
=
where
(
relux
>
0.0
,
dy
,
0.0
)
return
grad0
+
grad1
def
prelu
(
inp
,
alpha
):
mask
=
(
inp
>
0.0
).
astype
(
inp
.
dtype
)
return
inp
*
mask
+
alpha
*
(
1.0
-
mask
)
*
inp
def
prelu_grad
(
x
,
dy
,
alpha
):
mask
=
(
x
>
0.0
).
astype
(
x
.
dtype
)
return
dy
*
mask
+
alpha
*
(
1.0
-
mask
)
*
dy
def
silu
(
inp
):
return
inp
/
(
1.0
+
exp
(
-
inp
))
def
silu_grad
(
x
,
dy
):
xsig
=
sigmoid
(
x
)
return
dy
*
xsig
*
(
1.0
+
x
*
(
1.0
-
xsig
))
# Elemwise.Mode is unhashable, so we convert it to str
mge_elemwise_to_xla
=
{
str
(
mops
.
Elemwise
.
Mode
.
ADD
):
add
,
...
...
@@ -264,22 +447,71 @@ mge_elemwise_to_xla = {
str
(
mops
.
Elemwise
.
Mode
.
SUB
):
sub
,
str
(
mops
.
Elemwise
.
Mode
.
EXP
):
exp
,
str
(
mops
.
Elemwise
.
Mode
.
LOG
):
log
,
str
(
mops
.
Elemwise
.
Mode
.
LOG1P
):
log1p
,
str
(
mops
.
Elemwise
.
Mode
.
LOG_SUM_EXP
):
log_add_exp
,
str
(
mops
.
Elemwise
.
Mode
.
MAX
):
maximum
,
str
(
mops
.
Elemwise
.
Mode
.
MIN
):
minimum
,
str
(
mops
.
Elemwise
.
Mode
.
COND_LEQ_MOV
):
cond_leq_move
,
str
(
mops
.
Elemwise
.
Mode
.
COND_LT_MOV
):
cond_lt_move
,
str
(
mops
.
Elemwise
.
Mode
.
FLOOR
):
floor
,
str
(
mops
.
Elemwise
.
Mode
.
CEIL
):
ceil
,
str
(
mops
.
Elemwise
.
Mode
.
ROUND
):
round
,
str
(
mops
.
Elemwise
.
Mode
.
CLIP
):
clip
,
str
(
mops
.
Elemwise
.
Mode
.
GELU
):
gelu
,
str
(
mops
.
Elemwise
.
Mode
.
GELU_GRAD
):
gelu_grad
,
str
(
mops
.
Elemwise
.
Mode
.
TRUE_DIV
):
div
,
str
(
mops
.
Elemwise
.
Mode
.
NEGATE
):
neg
,
str
(
mops
.
Elemwise
.
Mode
.
FLOOR_DIV
):
floor_div
,
str
(
mops
.
Elemwise
.
Mode
.
MOD
):
mod
,
str
(
mops
.
Elemwise
.
Mode
.
ABS
):
abs
,
str
(
mops
.
Elemwise
.
Mode
.
ABS_GRAD
):
abs_grad
,
str
(
mops
.
Elemwise
.
Mode
.
SIN
):
sin
,
str
(
mops
.
Elemwise
.
Mode
.
COS
):
cos
,
str
(
mops
.
Elemwise
.
Mode
.
TAN
):
tan
,
str
(
mops
.
Elemwise
.
Mode
.
SINH
):
sinh
,
str
(
mops
.
Elemwise
.
Mode
.
COSH
):
cosh
,
str
(
mops
.
Elemwise
.
Mode
.
TANH
):
tanh
,
str
(
mops
.
Elemwise
.
Mode
.
ASIN
):
asin
,
str
(
mops
.
Elemwise
.
Mode
.
ACOS
):
acos
,
str
(
mops
.
Elemwise
.
Mode
.
ASINH
):
asinh
,
str
(
mops
.
Elemwise
.
Mode
.
ACOSH
):
acosh
,
str
(
mops
.
Elemwise
.
Mode
.
ATANH
):
atanh
,
str
(
mops
.
Elemwise
.
Mode
.
ATAN2
):
atan2
,
str
(
mops
.
Elemwise
.
Mode
.
TANH_GRAD
):
tanh_grad
,
str
(
mops
.
Elemwise
.
Mode
.
ASINH_GRAD
):
asinh_grad
,
str
(
mops
.
Elemwise
.
Mode
.
ACOSH_GRAD
):
acosh_grad
,
str
(
mops
.
Elemwise
.
Mode
.
ATANH_GRAD
):
atanh_grad
,
str
(
mops
.
Elemwise
.
Mode
.
SQRT
):
sqrt
,
str
(
mops
.
Elemwise
.
Mode
.
SQUARE
):
square
,
str
(
mops
.
Elemwise
.
Mode
.
POW
):
pow
,
str
(
mops
.
Elemwise
.
Mode
.
EXPM1
):
expm1
,
str
(
mops
.
Elemwise
.
Mode
.
RELU
):
relu
,
str
(
mops
.
Elemwise
.
Mode
.
EQ
):
equal
,
str
(
mops
.
Elemwise
.
Mode
.
NEQ
):
not_equal
,
str
(
mops
.
Elemwise
.
Mode
.
LT
):
less
,
str
(
mops
.
Elemwise
.
Mode
.
LEQ
):
less_equal
,
str
(
mops
.
Elemwise
.
Mode
.
AND
):
logical_and
,
str
(
mops
.
Elemwise
.
Mode
.
OR
):
logical_or
,
str
(
mops
.
Elemwise
.
Mode
.
NOT
):
logical_not
,
str
(
mops
.
Elemwise
.
Mode
.
XOR
):
logical_xor
,
str
(
mops
.
Elemwise
.
Mode
.
SHL
):
left_shift
,
str
(
mops
.
Elemwise
.
Mode
.
SHR
):
right_shift
,
str
(
mops
.
Elemwise
.
Mode
.
SWITCH_GT0
):
relu_grad
,
str
(
mops
.
Elemwise
.
Mode
.
SIGMOID
):
sigmoid
,
str
(
mops
.
Elemwise
.
Mode
.
SIGMOID_GRAD
):
sigmoid_grad
,
str
(
mops
.
Elemwise
.
Mode
.
PRELU
):
prelu
,
str
(
mops
.
Elemwise
.
Mode
.
PRELU_GRAD
):
prelu_grad
,
str
(
mops
.
Elemwise
.
Mode
.
SILU
):
silu
,
str
(
mops
.
Elemwise
.
Mode
.
SILU_GRAD
):
silu_grad
,
str
(
mops
.
Elemwise
.
Mode
.
HSIGMOID
):
hsigmoid
,
str
(
mops
.
Elemwise
.
Mode
.
HSIGMOID_GRAD
):
hsigmoid_grad
,
str
(
mops
.
Elemwise
.
Mode
.
H_SWISH
):
hswish
,
str
(
mops
.
Elemwise
.
Mode
.
H_SWISH_GRAD
):
hswish_grad
,
str
(
mops
.
Elemwise
.
Mode
.
RELU6
):
relu6
,
str
(
mops
.
Elemwise
.
Mode
.
RELU6_GRAD
):
relu6_grad
,
str
(
mops
.
Elemwise
.
Mode
.
LOGSIGMOID
):
logsigmoid
,
str
(
mops
.
Elemwise
.
Mode
.
SOFTPLUS
):
softplus
,
str
(
mops
.
Elemwise
.
Mode
.
SOFTPLUS_GRAD
):
softplus_grad
,
}
...
...
imperative/python/megengine/xla/rules/math.py
浏览文件 @
4c7905f3
from
typing
import
Sequence
,
Union
import
numpy
as
np
from
...core._imperative_rt
import
ops
as
mops
from
..
import
ir_utils
from
..lib.mlir.dialects
import
hlo
from
..ir_utils
import
i64_attr
from
..lib.mlir.dialects
import
chlo
,
hlo
from
.hlotensor
import
HLOTensor
from
.utils
import
_can_broadcast_to
,
_shape_equal
,
register_lower_rule
...
...
@@ -236,3 +239,7 @@ def batched_matmul_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
precision_config
=
ir_utils
.
precision_attr
(
lhs
.
dtype
,
rhs
.
dtype
),
).
result
).
transpose
(
permutation
)
def
topk
(
inp
,
k
,
descending
=
True
,
kth_only
=
False
,
no_sort
=
False
):
return
[
HLOTensor
(
rst
)
for
rst
in
chlo
.
TopKOp
(
inp
.
tensor
,
i64_attr
(
k
)).
results
]
imperative/python/megengine/xla/rules/reduction.py
浏览文件 @
4c7905f3
...
...
@@ -51,11 +51,15 @@ def _get_bitwise_or_identity(dtype) -> np.ndarray:
return
np
.
array
(
0
,
dtype
)
def
_
infer_reduce_shape
(
ishape
,
axes
,
keepdims
=
False
):
def
_
normalize_reduce_axes
(
ishape
,
axes
):
axes
=
list
(
range
(
len
(
ishape
)))
if
axes
is
None
else
axes
axes
=
[
axes
]
if
isinstance
(
axes
,
int
)
else
axes
axes
=
[
axis
if
axis
>=
0
else
axis
+
len
(
ishape
)
for
axis
in
axes
]
return
axes
def
_infer_reduce_shape
(
ishape
,
axes
,
keepdims
=
False
):
axes
=
_normalize_reduce_axes
(
ishape
,
axes
)
reduced_shape
=
[]
for
axis
,
length
in
enumerate
(
ishape
):
...
...
@@ -89,8 +93,7 @@ def _reduce(
return
HLOTensor
(
reduce_op
.
result
)
axes
=
[
axes
]
if
isinstance
(
axes
,
int
)
else
axes
axes
=
[
axis
if
axis
>=
0
else
axis
+
inp
.
ndim
for
axis
in
axes
]
axes
=
_normalize_reduce_axes
(
inp
.
shape
,
axes
)
maykeepdim_shape
=
_infer_reduce_shape
(
inp
.
shape
,
axes
,
keepdims
)
_check_shape
(
maykeepdim_shape
,
oshape
)
...
...
@@ -110,6 +113,7 @@ any = partial(_reduce, hlo.OrOp, _get_bitwise_or_identity)
def
mean
(
inp
,
axes
=
None
,
keepdims
=
False
):
axes
=
_normalize_reduce_axes
(
inp
.
shape
,
axes
)
inp_sum
=
sum
(
inp
,
axes
,
keepdims
)
inp_shape
=
inp
.
shape
...
...
imperative/python/megengine/xla/rules/tensor.py
浏览文件 @
4c7905f3
...
...
@@ -226,6 +226,11 @@ def pad(inp, pad_value, padding):
)
def
where
(
mask
,
x
,
y
):
mask
=
mask
.
astype
(
"float32"
)
return
mask
*
x
+
(
1.0
-
mask
)
*
y
@
register_lower_rule
(
mops
.
Reshape
)
def
reshape_lower
(
ctx
,
*
args
:
Union
[
HLOTensor
,
Sequence
[
HLOTensor
]]):
assert
len
(
args
)
==
2
...
...
imperative/python/megengine/xla/rules/trivial.py
浏览文件 @
4c7905f3
...
...
@@ -5,6 +5,7 @@ import numpy as np
from
...core._imperative_rt
import
ops
as
mops
from
..lib.mlir
import
ir
from
.hlotensor
import
HLOTensor
from
.tensor
import
fill
from
.utils
import
_check_shape
,
register_lower_rule
...
...
@@ -51,3 +52,8 @@ def io_mark_var_lower(ctx, *args: Union[ir.Value, Sequence[ir.Value]]):
def
rename_lower
(
ctx
,
*
args
:
Union
[
ir
.
Value
,
Sequence
[
ir
.
Value
]]):
assert
len
(
args
)
==
1
return
args
@
register_lower_rule
(
"fake_op_rule_for_debug"
)
def
fake_op_lower
(
ctx
,
*
args
:
Union
[
ir
.
Value
,
Sequence
[
ir
.
Value
]]):
return
[
fill
(
0.0
,
out
.
shape
,
out
.
dtype
)
for
out
in
ctx
.
vars_out
]
imperative/python/megengine/xla/rules/utils.py
浏览文件 @
4c7905f3
import
warnings
import
numpy
as
np
from
..lib.mlir
import
ir
...
...
@@ -19,10 +21,16 @@ def register_lower_rule(*ops):
return
decorator
def
get_rule
(
op
):
if
isinstance
(
op
,
str
):
return
lower_rule
[
op
]
return
lower_rule
[
type
(
op
)]
def
get_rule
(
op
,
use_fake_rule_for_debug
=
False
):
op_key
=
op
if
isinstance
(
op
,
str
)
else
type
(
op
)
if
use_fake_rule_for_debug
:
if
op_key
in
lower_rule
:
return
lower_rule
[
op_key
]
else
:
warnings
.
warn
(
f
"op:
{
op_key
}
not register, use fake op rule"
)
return
lower_rule
[
"fake_op_rule_for_debug"
]
else
:
return
lower_rule
[
op_key
]
def
_log_mge_opr_attrs
(
mopr
):
...
...
imperative/python/src/grad_override.cpp
浏览文件 @
4c7905f3
...
...
@@ -81,7 +81,7 @@ ValueRef make_empty_tensor(
storage
.
ensure_size
(
dtype
->
size
());
std
::
memset
(
storage
.
ptr
(),
0
,
dtype
->
size
());
auto
t
=
imperative
::
apply
(
CreateTensor
(
CreateTensor
::
Unique
,
*
device
,
*
dtype
,
ValueShape
()),
CreateTensor
(
CreateTensor
::
Const
,
*
device
,
*
dtype
,
ValueShape
()),
HostStorage
::
make
(
storage
))[
0
];
auto
res
=
broadcast_to
(
t
,
shape
);
return
res
;
...
...
imperative/python/src/tensor.cpp
浏览文件 @
4c7905f3
...
...
@@ -1321,7 +1321,8 @@ void init_tensor(py::module m) {
}
else
if
(
self
.
check_external
)
{
throw
std
::
runtime_error
(
"have some unknown input tensors in trace "
"result"
);
"result, maybe you can set "
"`capture_as_const=True` when trace"
);
}
}
}
...
...
imperative/python/test/unit/jit/test_tracing.py
浏览文件 @
4c7905f3
...
...
@@ -848,7 +848,10 @@ def test_trace_without_error():
c
=
tensor
([
3.0
])
fwd
(
a
,
b
,
c
)
except
Exception
as
e
:
assert
str
(
e
)
==
"have some unknown input tensors in trace result"
assert
(
str
(
e
)
==
"have some unknown input tensors in trace result, maybe you can set `capture_as_const=True` when trace"
)
else
:
assert
False
...
...
imperative/python/test/unit/xla/functional/test_xla_elemwise.py
浏览文件 @
4c7905f3
...
...
@@ -18,94 +18,107 @@ def test_elemwise():
np
.
random
.
seed
(
123
)
mge
.
random
.
seed
(
123
)
def
tester
(
felemwise
,
*
inp_shapes
,
backward
=
True
,
dtype
=
None
,
atol
=
1e-5
):
def
tester
(
felemwise
,
*
inp_shapes
,
backward
=
True
,
dtype
=
None
,
atol
=
1e-5
,
**
kwargs
):
dtype
=
dtype
or
np
.
float32
inps
=
[
tensor
(
0.1
*
np
.
random
.
randn
(
*
inp_shape
),
dtype
=
dtype
)
for
inp_shape
in
inp_shapes
]
doup
=
tensor
(
0.1
*
np
.
random
.
randn
(
*
felemwise
(
*
inps
).
shape
),
dtype
=
dtype
)
if
dtype
in
[
np
.
int16
,
np
.
int32
,
np
.
uint16
,
np
.
uint32
]:
inps
=
[
tensor
(
np
.
random
.
randint
(
0
,
10
,
size
=
inp_shape
),
dtype
=
dtype
)
for
inp_shape
in
inp_shapes
]
else
:
inps
=
[
tensor
(
0.1
*
np
.
random
.
randn
(
*
inp_shape
),
dtype
=
dtype
)
for
inp_shape
in
inp_shapes
]
doup
=
tensor
(
0.1
*
np
.
random
.
randn
(
*
felemwise
(
*
inps
,
**
kwargs
).
shape
),
dtype
=
dtype
)
gm
=
GradManager
()
@
jit
.
xla_trace
(
without_host
=
True
)
def
func
(
inps
,
doup
):
gm
.
attach
(
inps
)
with
gm
:
oup
=
felemwise
(
*
inps
)
if
backward
:
if
backward
:
gm
.
attach
(
inps
)
with
gm
:
oup
=
felemwise
(
*
inps
,
**
kwargs
)
gm
.
backward
(
oup
,
doup
)
return
[
oup
,
*
[
inp
.
grad
for
inp
in
inps
]]
else
:
return
[
oup
]
else
:
oup
=
felemwise
(
*
inps
,
**
kwargs
)
return
[
oup
]
mge_rsts
=
func
(
inps
,
doup
)
xla_rsts
=
func
(
inps
,
doup
)
for
mge_rst
,
xla_rst
in
zip
(
mge_rsts
,
xla_rsts
):
for
_
,
(
mge_rst
,
xla_rst
)
in
enumerate
(
zip
(
mge_rsts
,
xla_rsts
)
):
np
.
testing
.
assert_allclose
(
mge_rst
.
numpy
(),
xla_rst
.
numpy
(),
atol
=
atol
)
tester
(
F
.
neg
,
(
4
,
16
,
12
,
12
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
abs
,
(
2
,
32
,
16
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
tanh
,
(
4
,
16
,
3
,
1
),
backward
=
False
,
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
sin
,
(
1
,
16
,
3
,
1
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
cos
,
(
4
,
16
,
3
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
tan
,
(
4
,
16
,
1
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
sinh
,
(
4
,
16
,
1
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
cosh
,
(
3
,
16
,
1
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
tanh
,
(
4
,
6
,
3
,
1
),
dtype
=
np
.
float32
,
atol
=
5e-4
)
tester
(
F
.
asin
,
(
4
,
1
,
3
,
1
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
# tester(F.acos, (4, 16, 3, 1), dtype=np.float32, atol=1e-5) # xla compute error
tester
(
F
.
atan
,
(
4
,
16
,
3
,
1
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
asinh
,
(
4
,
1
,
3
,
1
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
acosh
,
(
4
,
1
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
atanh
,
(
1
,),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
exp
,
(
2
,
8
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
sqrt
,
(
32
,),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
square
,
(
32
,),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
log
,
(
8
,
8
,
16
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
log1p
,
(
8
,
1
,
16
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
expm1
,
(
6
,
8
,
2
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
floor
,
(
4
,
16
,
1
,
1
),
backward
=
False
,
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
ceil
,
(
4
,
1
,
1
),
backward
=
False
,
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
round
,
(
1
,
4
,
1
),
backward
=
False
,
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
clip
,
(
4
,
16
,
1
),
dtype
=
np
.
float32
,
atol
=
1e-5
,
lower
=-
1.0
,
upper
=
1.0
)
tester
(
F
.
relu
,
(
1
,),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
gelu
,
(
4
,
16
,
12
,
12
),
dtype
=
np
.
float32
,
atol
=
2e-5
)
tester
(
F
.
sigmoid
,
(
4
,
16
,
16
,
12
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
hsigmoid
,
(
4
,
16
,
16
,
12
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
hswish
,
(
4
,
16
,
16
,
12
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
relu6
,
(
12
,
16
,
1
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
leaky_relu
,
(
1
,
16
,
1
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
leaky_relu
,
(
12
,
16
,
1
),
dtype
=
np
.
float32
,
atol
=
1e-5
,
negative_slope
=
0.5
)
tester
(
F
.
silu
,
(
4
,
16
,
12
,
12
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
logsigmoid
,
(
4
,
16
,
12
,
12
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
softplus
,
(
4
,
16
,
12
,
12
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
add
,
(
4
,
16
,
12
,
12
),
(
4
,
16
,
12
,
12
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
sub
,
(
4
,
16
,
12
,
12
),
(
4
,
16
,
1
,
1
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
mul
,
(
4
,
16
,
12
,
12
),
(
1
,
1
,
12
,
12
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
div
,
(
4
,
16
,
1
,
1
),
(
4
,
16
,
12
,
12
),
backward
=
False
,
dtype
=
np
.
float32
,
atol
=
1e-5
,
)
tester
(
F
.
pow
,
(
4
,
1
,
12
,
12
),
(
1
,
16
,
12
,
12
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
div
,
(
4
,
16
,
1
,
1
),
(
4
,
16
,
12
,
12
),
atol
=
5e-4
)
tester
(
F
.
floor_div
,
(
4
,
16
,
12
,
12
),
(
4
,
16
,
1
,
1
),
backward
=
False
,
atol
=
5e-5
)
# tester(F.mod, (8, 1, 4), (8, 1, 1), backward=False, dtype=np.int32, atol=1e-5) # xla not support
tester
(
F
.
pow
,
(
4
,
1
,
12
,
12
),
(
1
,
16
,
12
,
12
),
dtype
=
np
.
float32
,
atol
=
5e-5
)
tester
(
F
.
prelu
,
(
4
,
16
,
12
,
12
),
(
1
,),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
prelu
,
(
16
,
5
,
12
),
(
1
,
5
,
1
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
logaddexp
,
(
16
,
5
,
12
),
(
1
,
5
,
12
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
maximum
,
(
1
,
5
,
1
),
(
1
,
5
,
12
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
minimum
,
(
1
,
5
,
12
),
(
16
,
5
,
12
),
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
equal
,
(
4
,
16
,
12
,
12
),
(
1
,
1
),
backward
=
False
,
dtype
=
np
.
float32
,
atol
=
1e-5
)
tester
(
F
.
not_equal
,
(
4
,
16
,
12
,
12
),
(
4
,
16
,
1
,
1
),
backward
=
False
,
dtype
=
np
.
float32
,
atol
=
1e-5
,
)
tester
(
F
.
greater
,
(
4
,
16
,
1
,
1
),
(
4
,
16
,
12
,
12
),
backward
=
False
,
dtype
=
np
.
float32
,
atol
=
1e-5
,
F
.
left_shift
,
(
4
,
16
,
12
,
12
),
(
1
,
1
,
12
,
12
),
backward
=
False
,
dtype
=
np
.
int32
)
tester
(
F
.
greater_equal
,
(
16
,
1
,
1
),
(
4
,
16
,
12
,
12
),
backward
=
False
,
dtype
=
np
.
float32
,
atol
=
1e-5
,
)
tester
(
F
.
less
,
(
4
,
16
,
12
,
1
),
(
4
,
16
,
12
,
12
),
backward
=
False
,
dtype
=
np
.
float32
,
atol
=
1e-5
,
)
tester
(
F
.
less_equal
,
(
1
,
1
,
12
,
12
),
(
4
,
16
,
12
,
12
),
backward
=
False
,
dtype
=
np
.
float32
,
atol
=
1e-5
,
F
.
right_shift
,
(
4
,
16
,
12
,
12
),
(
1
,
1
,
12
,
12
),
backward
=
False
,
dtype
=
np
.
int32
)
tester
(
F
.
equal
,
(
4
,
16
,
12
,
12
),
(
1
,
1
),
backward
=
False
)
tester
(
F
.
not_equal
,
(
4
,
16
,
12
,
12
),
(
4
,
16
,
1
,
1
),
backward
=
False
)
tester
(
F
.
greater
,
(
4
,
16
,
1
,
1
),
(
4
,
16
,
12
,
12
),
backward
=
False
)
tester
(
F
.
greater_equal
,
(
16
,
1
,
1
),
(
4
,
16
,
12
,
12
),
backward
=
False
)
tester
(
F
.
less
,
(
4
,
16
,
12
,
1
),
(
4
,
16
,
12
,
12
),
backward
=
False
)
tester
(
F
.
less_equal
,
(
1
,
1
,
12
,
12
),
(
4
,
16
,
12
,
12
),
backward
=
False
)
# bool is not support in dlpack now
# tester(F.logical_and, (4, 16, 12, 12), (1, 1), backward=False, dtype=np.bool8)
# tester(F.logical_or, (4, 16, 12, 12), (4, 16, 1, 1), backward=False, dtype=np.bool8)
# tester(
# F.logical_xor, (4, 16, 1, 1), (4, 16, 12, 12), backward=False, dtype=np.bool8
# )
# tester(F.logical_not, (16, 1, 1), backward=False, dtype=np.bool8)
imperative/python/test/unit/xla/functional/test_xla_nn.py
浏览文件 @
4c7905f3
...
...
@@ -258,3 +258,70 @@ def test_softmax():
tester
((
32
,
16
,
5
),
0
)
tester
((
1
,
16
,
5
),
-
1
)
tester
((
14
,
1
,
13
,
5
),
1
)
@
pytest
.
mark
.
skipif
(
int
(
platform
.
python_version_tuple
()[
1
])
<
8
,
reason
=
"need py38"
)
@
pytest
.
mark
.
skipif
(
platform
.
system
()
!=
"Linux"
,
reason
=
"only support linux now"
)
@
pytest
.
mark
.
skipif
(
not
is_cuda_available
(),
reason
=
"only support cuda now"
)
def
test_loss
():
def
tester
(
loss_fn
,
pred_shape
,
label_shape
,
label_type
=
"default"
,
atol
=
1e-5
,
dtype
=
None
,
**
kwargs
):
dtype
=
dtype
or
np
.
float32
pred
=
tensor
(
np
.
random
.
randn
(
*
pred_shape
),
dtype
=
dtype
)
if
label_type
==
"default"
:
label
=
tensor
(
np
.
random
.
randn
(
*
label_shape
),
dtype
=
dtype
)
elif
label_type
==
"classes"
:
label
=
tensor
(
np
.
random
.
randint
(
0
,
10
,
size
=
label_shape
),
dtype
=
dtype
)
dout
=
tensor
(
np
.
random
.
randn
(
1
,),
dtype
=
dtype
)
gm
=
autodiff
.
GradManager
()
@
jit
.
xla_trace
(
without_host
=
True
)
def
func
(
pred
,
label
,
dout
):
gm
.
attach
([
pred
])
with
gm
:
out
=
loss_fn
(
pred
,
label
,
**
kwargs
)
gm
.
backward
(
out
,
dout
)
return
out
,
pred
.
grad
mge_rsts
=
func
(
pred
,
label
,
dout
)
xla_rsts
=
func
(
pred
,
label
,
dout
)
for
idx
,
(
mge_rst
,
xla_rst
)
in
enumerate
(
zip
(
mge_rsts
,
xla_rsts
)):
np
.
testing
.
assert_allclose
(
mge_rst
.
numpy
(),
xla_rst
.
numpy
(),
atol
=
atol
)
from
megengine.functional
import
loss
tester
(
loss
.
l1_loss
,
(
32
,
16
,
8
,
8
),
(
32
,
16
,
8
,
8
))
tester
(
loss
.
l1_loss
,
(
1
,
16
),
(
1
,
16
))
tester
(
loss
.
square_loss
,
(
32
,
16
,
8
,
8
),
(
32
,
16
,
8
,
8
))
tester
(
loss
.
square_loss
,
(
16
,
1
),
(
16
,
1
))
tester
(
loss
.
cross_entropy
,
(
16
,
32
),
(
16
,),
label_type
=
"classes"
,
axis
=
1
,
with_logits
=
True
,
label_smooth
=
0.0
,
)
tester
(
loss
.
cross_entropy
,
(
16
,
32
),
(
32
,),
label_type
=
"classes"
,
axis
=
0
,
with_logits
=
False
,
label_smooth
=
0.5
,
)
tester
(
loss
.
binary_cross_entropy
,
(
16
,
32
,
4
,
8
),
(
16
,
32
,
4
,
8
),
with_logits
=
True
)
tester
(
loss
.
binary_cross_entropy
,
(
1
,
32
,
1
),
(
1
,
32
,
1
),
with_logits
=
False
)
tester
(
loss
.
hinge_loss
,
(
32
,
16
,
8
,
8
),
(
32
,
16
,
8
,
8
),
norm
=
"L1"
)
tester
(
loss
.
hinge_loss
,
(
1
,
16
,
1
,
1
),
(
1
,
16
,
1
,
1
),
norm
=
"L2"
)
imperative/python/test/unit/xla/module/test_elemwise.py
0 → 100644
浏览文件 @
4c7905f3
import
platform
import
numpy
as
np
import
pytest
import
megengine
as
mge
import
megengine.functional
as
F
import
megengine.module
as
M
import
megengine.tensor
as
tensor
from
megengine
import
is_cuda_available
,
jit
from
megengine.autodiff
import
GradManager
from
megengine.optimizer
import
Adam
@
pytest
.
mark
.
skipif
(
int
(
platform
.
python_version_tuple
()[
1
])
<
8
,
reason
=
"need py38"
)
@
pytest
.
mark
.
skipif
(
platform
.
system
()
!=
"Linux"
,
reason
=
"only support linux now"
)
@
pytest
.
mark
.
skipif
(
not
is_cuda_available
(),
reason
=
"only support cuda now"
)
def
test_elemwise_activation
():
def
tester
(
TestMod
,
ishape
,
dtype
=
None
,
atol
=
1e-5
,
**
kwargs
):
dtype
=
dtype
or
np
.
float32
inp
=
tensor
(
0.1
*
np
.
random
.
randn
(
*
ishape
),
dtype
=
dtype
)
doup
=
tensor
(
0.1
*
np
.
random
.
randn
(
*
ishape
),
dtype
=
dtype
)
gm
=
GradManager
()
mod
=
TestMod
(
**
kwargs
)
@
jit
.
xla_trace
(
without_host
=
True
)
def
func
(
mod
,
inp
,
doup
):
gm
.
attach
(
inp
)
with
gm
:
oup
=
mod
(
inp
)
gm
.
backward
(
oup
,
doup
)
return
oup
,
inp
.
grad
mge_rsts
=
func
(
mod
,
inp
,
doup
)
xla_rsts
=
func
(
mod
,
inp
,
doup
)
for
mge_rst
,
xla_rst
in
zip
(
mge_rsts
,
xla_rsts
):
np
.
testing
.
assert_allclose
(
mge_rst
.
numpy
(),
xla_rst
.
numpy
(),
atol
=
atol
)
tester
(
M
.
Sigmoid
,
(
2
,
3
,
4
,
5
))
tester
(
M
.
ReLU
,
(
2
,
3
,))
tester
(
M
.
LeakyReLU
,
(
4
,
5
))
tester
(
M
.
LeakyReLU
,
(
4
,
5
),
negative_slope
=
0.3
)
tester
(
M
.
PReLU
,
(
8
,
6
,
5
))
tester
(
M
.
PReLU
,
(
8
,
6
,
5
,
7
),
num_parameters
=
6
,
init
=
0.1
)
tester
(
M
.
PReLU
,
(
1
,))
tester
(
M
.
SiLU
,
(
4
,
8
,
3
,
2
))
tester
(
M
.
SiLU
,
(
1
,
1
,))
tester
(
M
.
GELU
,
(
1
,
1
,
2
))
src/plugin/impl/opr_footprint.cpp
浏览文件 @
4c7905f3
...
...
@@ -564,6 +564,7 @@ REGISTE_PARAM_JSON_FUNC(LayerNormBackward)
REGISTE_PARAM_JSON_FUNC
(
AdaptivePoolingBackward
)
REGISTE_PARAM_JSON_FUNC
(
DropoutBackward
)
REGISTE_PARAM_JSON_FUNC
(
SoftmaxBackward
)
REGISTE_PARAM_JSON_FUNC
(
ArgsortBackward
)
std
::
shared_ptr
<
json
::
Value
>
dimshuffle_param2json
(
const
opr
::
Dimshuffle
::
Param
&
param
)
{
...
...
@@ -862,6 +863,7 @@ void OprFootprint::init_all_footprints() {
add_single_param_json
<
opr
::
AdaptivePoolingBackward
>
();
add_single_param_json
<
opr
::
DropoutBackward
>
();
add_single_param_json
<
opr
::
SoftmaxBackward
>
();
add_single_param_json
<
opr
::
ArgsortBackward
>
();
#endif
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录