Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
cfc41648
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
cfc41648
编写于
6月 20, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge): fix grad of maximum(x, x)
GitOrigin-RevId: e0e2efb71bbe507bd5b4dab539b5b9cfe79d1187
上级
bbafe699
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
58 addition
and
3 deletion
+58
-3
imperative/python/test/unit/functional/test_elemwise.py
imperative/python/test/unit/functional/test_elemwise.py
+23
-0
src/jit/impl/ast_c.cpp
src/jit/impl/ast_c.cpp
+2
-0
src/jit/impl/halide/ast_hl.cpp
src/jit/impl/halide/ast_hl.cpp
+2
-0
src/jit/impl/mlir/ir/each_mode.cpp
src/jit/impl/mlir/ir/each_mode.cpp
+9
-0
src/jit/impl/mlir/ir/each_mode.h
src/jit/impl/mlir/ir/each_mode.h
+1
-0
src/jit/test/codegen.cpp
src/jit/test/codegen.cpp
+1
-0
src/jit/test/fusion.cpp
src/jit/test/fusion.cpp
+1
-0
src/opr/impl/basic_arith.cpp
src/opr/impl/basic_arith.cpp
+14
-3
src/opr/test/basic_arith/elemwise.cpp
src/opr/test/basic_arith/elemwise.cpp
+2
-0
src/opr/test/basic_arith/elemwise_ternary_trait_def.inl
src/opr/test/basic_arith/elemwise_ternary_trait_def.inl
+1
-0
src/opr/test/nn_int.cpp
src/opr/test/nn_int.cpp
+2
-0
未找到文件。
imperative/python/test/unit/functional/test_elemwise.py
浏览文件 @
cfc41648
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
megengine.autodiff
as
ad
import
megengine.functional
as
F
import
megengine.functional
as
F
import
megengine.functional.elemwise
as
elemwise
import
megengine.functional.elemwise
as
elemwise
from
megengine
import
tensor
from
megengine
import
tensor
...
@@ -293,3 +294,25 @@ def test_empty_tensor(is_trace):
...
@@ -293,3 +294,25 @@ def test_empty_tensor(is_trace):
run_test
(
op
,
[
inps
[
1
],
inps
[
1
]],
(
inps
[
1
]
+
inps
[
1
]).
shape
,
False
)
run_test
(
op
,
[
inps
[
1
],
inps
[
1
]],
(
inps
[
1
]
+
inps
[
1
]).
shape
,
False
)
run_test
(
op
,
[
inps
[
0
],
inps
[
2
]],
(
inps
[
0
]
+
inps
[
2
]).
shape
,
False
)
run_test
(
op
,
[
inps
[
0
],
inps
[
2
]],
(
inps
[
0
]
+
inps
[
2
]).
shape
,
False
)
run_test
(
op
,
[
inps
[
1
],
inps
[
2
]],
(
inps
[
1
]
+
inps
[
2
]).
shape
,
False
)
run_test
(
op
,
[
inps
[
1
],
inps
[
2
]],
(
inps
[
1
]
+
inps
[
2
]).
shape
,
False
)
@
pytest
.
mark
.
parametrize
(
"is_trace"
,
[
True
,
False
])
def
test_maximum_grad_consistency
(
is_trace
):
def
f
(
x
):
with
ad
.
GradManager
()
as
gm
:
gm
.
attach
(
x
)
gm
.
backward
(
F
.
maximum
(
x
,
x
))
dx
=
x
.
grad
x
.
grad
=
None
return
dx
def
run
(
f
):
x
=
F
.
arange
(
10
)
for
i
in
range
(
3
):
np
.
testing
.
assert_equal
(
f
(
x
).
numpy
(),
np
.
ones
(
10
))
if
is_trace
:
for
symbolic
in
[
False
,
True
]:
run
(
trace
(
symbolic
=
symbolic
)(
f
))
else
:
run
(
f
)
src/jit/impl/ast_c.cpp
浏览文件 @
cfc41648
...
@@ -117,6 +117,8 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() {
...
@@ -117,6 +117,8 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() {
// misc
// misc
ENTRY
(
COND_LEQ_MOV
,
ENTRY
(
COND_LEQ_MOV
,
ASTPtr
::
make
<
BinaryAST
>
(
"<="
,
inps
[
0
],
inps
[
1
])
*
inps
[
2
]),
ASTPtr
::
make
<
BinaryAST
>
(
"<="
,
inps
[
0
],
inps
[
1
])
*
inps
[
2
]),
ENTRY
(
COND_LT_MOV
,
ASTPtr
::
make
<
BinaryAST
>
(
"<"
,
inps
[
0
],
inps
[
1
])
*
inps
[
2
]),
ENTRY
(
FUSE_MUL_ADD3
,
inps
[
0
]
*
inps
[
1
]
+
inps
[
2
]),
ENTRY
(
FUSE_MUL_ADD3
,
inps
[
0
]
*
inps
[
1
]
+
inps
[
2
]),
ENTRY
(
FUSE_MUL_ADD4
,
inps
[
0
]
*
inps
[
1
]
+
inps
[
2
]
*
inps
[
3
]),
ENTRY
(
FUSE_MUL_ADD4
,
inps
[
0
]
*
inps
[
1
]
+
inps
[
2
]
*
inps
[
3
]),
ENTRY
(
FUSE_ADD_RELU
,
make_call
(
"fmaxf"
,
{
inps
[
0
]
+
inps
[
1
],
0
})),
ENTRY
(
FUSE_ADD_RELU
,
make_call
(
"fmaxf"
,
{
inps
[
0
]
+
inps
[
1
],
0
})),
...
...
src/jit/impl/halide/ast_hl.cpp
浏览文件 @
cfc41648
...
@@ -147,6 +147,8 @@ Halide::Expr dispatch_elemwise_mode(
...
@@ -147,6 +147,8 @@ Halide::Expr dispatch_elemwise_mode(
// ternary
// ternary
case
Mode
::
COND_LEQ_MOV
:
case
Mode
::
COND_LEQ_MOV
:
return
Halide
::
select
(
inp
(
0
)
<=
inp
(
1
),
inp
(
2
),
cv
(
0
));
return
Halide
::
select
(
inp
(
0
)
<=
inp
(
1
),
inp
(
2
),
cv
(
0
));
case
Mode
::
COND_LT_MOV
:
return
Halide
::
select
(
inp
(
0
)
<
inp
(
1
),
inp
(
2
),
cv
(
0
));
case
Mode
::
FUSE_MUL_ADD3
:
case
Mode
::
FUSE_MUL_ADD3
:
return
inp
(
0
)
*
inp
(
1
)
+
inp
(
2
);
return
inp
(
0
)
*
inp
(
1
)
+
inp
(
2
);
case
Mode
::
FUSE_MUL_ADD4
:
case
Mode
::
FUSE_MUL_ADD4
:
...
...
src/jit/impl/mlir/ir/each_mode.cpp
浏览文件 @
cfc41648
...
@@ -388,6 +388,15 @@ mlir::Value lower_mode<Mode::COND_LEQ_MOV>(
...
@@ -388,6 +388,15 @@ mlir::Value lower_mode<Mode::COND_LEQ_MOV>(
helper
.
le
(
operands
[
0
],
operands
[
1
]),
operands
[
2
],
helper
.
const_f32
(
0.
f
));
helper
.
le
(
operands
[
0
],
operands
[
1
]),
operands
[
2
],
helper
.
const_f32
(
0.
f
));
}
}
//! COND_LT_MOV: x < y ? z : ctype(0)
template
<
>
mlir
::
Value
lower_mode
<
Mode
::
COND_LT_MOV
>
(
mlir
::
OpBuilder
&
builder
,
mlir
::
Location
loc
,
ValueRange
operands
)
{
ValueBuilderHelper
helper
(
builder
,
loc
);
return
helper
.
select
(
helper
.
lt
(
operands
[
0
],
operands
[
1
]),
operands
[
2
],
helper
.
const_f32
(
0.
f
));
}
//! FUSE_MUL_ADD3: x * y + z
//! FUSE_MUL_ADD3: x * y + z
template
<
>
template
<
>
mlir
::
Value
lower_mode
<
Mode
::
FUSE_MUL_ADD3
>
(
mlir
::
Value
lower_mode
<
Mode
::
FUSE_MUL_ADD3
>
(
...
...
src/jit/impl/mlir/ir/each_mode.h
浏览文件 @
cfc41648
...
@@ -60,6 +60,7 @@
...
@@ -60,6 +60,7 @@
#define MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) \
#define MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) \
cb(CondLeqMovOp, COND_LEQ_MOV) \
cb(CondLeqMovOp, COND_LEQ_MOV) \
cb(CondLtMovOp, COND_LT_MOV) \
cb(FuseMulAdd3Op, FUSE_MUL_ADD3)
cb(FuseMulAdd3Op, FUSE_MUL_ADD3)
// clang-format on
// clang-format on
...
...
src/jit/test/codegen.cpp
浏览文件 @
cfc41648
...
@@ -449,6 +449,7 @@ TYPED_TEST(TestJITMlirBinaryElemwise, runGpu) {
...
@@ -449,6 +449,7 @@ TYPED_TEST(TestJITMlirBinaryElemwise, runGpu) {
// clang-format off
// clang-format off
#define FOREACH_TERNARY_MODE(cb) \
#define FOREACH_TERNARY_MODE(cb) \
cb(COND_LEQ_MOV) \
cb(COND_LEQ_MOV) \
cb(COND_LT_MOV) \
cb(FUSE_MUL_ADD3) \
cb(FUSE_MUL_ADD3) \
// clang-format on
// clang-format on
template
<
typename
tag
>
template
<
typename
tag
>
...
...
src/jit/test/fusion.cpp
浏览文件 @
cfc41648
...
@@ -452,6 +452,7 @@ void run<all_oprs>(Backend backend, CompNode cn) {
...
@@ -452,6 +452,7 @@ void run<all_oprs>(Backend backend, CompNode cn) {
CHECK_ELEM2
(
ATAN2
,
true
,
gt0
);
CHECK_ELEM2
(
ATAN2
,
true
,
gt0
);
CHECK_ELEM3
(
COND_LEQ_MOV
,
false
,
none
);
CHECK_ELEM3
(
COND_LEQ_MOV
,
false
,
none
);
CHECK_ELEM3
(
COND_LT_MOV
,
false
,
none
);
CHECK_ELEM3
(
FUSE_MUL_ADD3
,
true
,
none
);
CHECK_ELEM3
(
FUSE_MUL_ADD3
,
true
,
none
);
CHECK_ELEM4
(
FUSE_MUL_ADD4
,
true
,
none
);
CHECK_ELEM4
(
FUSE_MUL_ADD4
,
true
,
none
);
...
...
src/opr/impl/basic_arith.cpp
浏览文件 @
cfc41648
...
@@ -601,9 +601,17 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
...
@@ -601,9 +601,17 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
case
Mode
::
FLOOR_DIV
:
case
Mode
::
FLOOR_DIV
:
return
nullptr
;
return
nullptr
;
case
Mode
::
MAX
:
case
Mode
::
MAX
:
RET
(
EL3
(
COND_LEQ_MOV
,
i
[
!
wrt_idx
],
i
[
wrt_idx
],
og
));
if
(
wrt_idx
)
{
RET
(
EL3
(
COND_LT_MOV
,
i
[
0
],
i
[
1
],
og
));
}
else
{
RET
(
EL3
(
COND_LEQ_MOV
,
i
[
1
],
i
[
0
],
og
));
}
case
Mode
::
MIN
:
case
Mode
::
MIN
:
RET
(
EL3
(
COND_LEQ_MOV
,
i
[
wrt_idx
],
i
[
!
wrt_idx
],
og
));
if
(
wrt_idx
)
{
RET
(
EL3
(
COND_LT_MOV
,
i
[
1
],
i
[
0
],
og
));
}
else
{
RET
(
EL3
(
COND_LEQ_MOV
,
i
[
0
],
i
[
1
],
og
));
}
case
Mode
::
MOD
:
case
Mode
::
MOD
:
if
(
wrt_idx
==
0
)
{
if
(
wrt_idx
==
0
)
{
RET
(
og
);
RET
(
og
);
...
@@ -661,7 +669,10 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
...
@@ -661,7 +669,10 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
if
(
wrt_idx
<=
1
)
if
(
wrt_idx
<=
1
)
return
nullptr
;
return
nullptr
;
RET
(
EL3
(
COND_LEQ_MOV
,
i0
,
i1
,
og
));
RET
(
EL3
(
COND_LEQ_MOV
,
i0
,
i1
,
og
));
case
Mode
::
COND_LT_MOV
:
if
(
wrt_idx
<=
1
)
return
nullptr
;
RET
(
EL3
(
COND_LT_MOV
,
i0
,
i1
,
og
));
// fuse oprs
// fuse oprs
case
Mode
::
FUSE_MUL_ADD3
:
case
Mode
::
FUSE_MUL_ADD3
:
if
(
wrt_idx
<
2
)
{
if
(
wrt_idx
<
2
)
{
...
...
src/opr/test/basic_arith/elemwise.cpp
浏览文件 @
cfc41648
...
@@ -571,6 +571,8 @@ struct CheckerConfig<GELU_GRAD> : public NoGradCheckerConfig {};
...
@@ -571,6 +571,8 @@ struct CheckerConfig<GELU_GRAD> : public NoGradCheckerConfig {};
/* ======================= ternary config ======================= */
/* ======================= ternary config ======================= */
template
<
>
template
<
>
struct
CheckerConfig
<
COND_LEQ_MOV
>
:
public
BinaryInputMinGap
<
false
>
{};
struct
CheckerConfig
<
COND_LEQ_MOV
>
:
public
BinaryInputMinGap
<
false
>
{};
template
<
>
struct
CheckerConfig
<
COND_LT_MOV
>
:
public
BinaryInputMinGap
<
false
>
{};
/* ======================= test runner ======================= */
/* ======================= test runner ======================= */
namespace
detail
{
namespace
detail
{
...
...
src/opr/test/basic_arith/elemwise_ternary_trait_def.inl
浏览文件 @
cfc41648
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
#define _ALLOW_FLOAT true
#define _ALLOW_FLOAT true
#define _ALLOW_INT true
#define _ALLOW_INT true
DEF_TRAIT
(
COND_LEQ_MOV
,
x
<=
y
?
z
:
0
)
DEF_TRAIT
(
COND_LEQ_MOV
,
x
<=
y
?
z
:
0
)
DEF_TRAIT
(
COND_LT_MOV
,
x
<
y
?
z
:
0
)
DEF_TRAIT
(
FUSE_MUL_ADD3
,
x
*
y
+
z
)
DEF_TRAIT
(
FUSE_MUL_ADD3
,
x
*
y
+
z
)
#undef _ALLOW_INT
#undef _ALLOW_INT
#undef _ALLOW_FLOAT
#undef _ALLOW_FLOAT
...
...
src/opr/test/nn_int.cpp
浏览文件 @
cfc41648
...
@@ -589,6 +589,7 @@ TEST(TestOprElemwiseMultiType, QuantizedModeTernary_IS8_OS8) {
...
@@ -589,6 +589,7 @@ TEST(TestOprElemwiseMultiType, QuantizedModeTernary_IS8_OS8) {
switch
(
mode
)
{
switch
(
mode
)
{
MAKE_TERNARY
(
FUSE_MUL_ADD3
);
MAKE_TERNARY
(
FUSE_MUL_ADD3
);
MAKE_TERNARY
(
COND_LEQ_MOV
);
MAKE_TERNARY
(
COND_LEQ_MOV
);
MAKE_TERNARY
(
COND_LT_MOV
);
default:
default:
mgb_throw
(
InternalError
,
"Unknown ElemwiseMultiType Mode
\n
"
);
mgb_throw
(
InternalError
,
"Unknown ElemwiseMultiType Mode
\n
"
);
break
;
break
;
...
@@ -646,6 +647,7 @@ TEST(TestOprElemwiseMultiType, QuantizedModeTernary_I8Asymm_O8Asymm) {
...
@@ -646,6 +647,7 @@ TEST(TestOprElemwiseMultiType, QuantizedModeTernary_I8Asymm_O8Asymm) {
switch
(
mode
)
{
switch
(
mode
)
{
MAKE_TERNARY
(
FUSE_MUL_ADD3
);
MAKE_TERNARY
(
FUSE_MUL_ADD3
);
MAKE_TERNARY
(
COND_LEQ_MOV
);
MAKE_TERNARY
(
COND_LEQ_MOV
);
MAKE_TERNARY
(
COND_LT_MOV
);
default:
default:
mgb_throw
(
InternalError
,
"Unknown ElemwiseMultiType Mode
\n
"
);
mgb_throw
(
InternalError
,
"Unknown ElemwiseMultiType Mode
\n
"
);
break
;
break
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录