提交 cfc41648 编写于 作者: M Megvii Engine Team

fix(mge): fix grad of maximum(x, x)

GitOrigin-RevId: e0e2efb71bbe507bd5b4dab539b5b9cfe79d1187
上级 bbafe699
......@@ -2,6 +2,7 @@
import numpy as np
import pytest
import megengine.autodiff as ad
import megengine.functional as F
import megengine.functional.elemwise as elemwise
from megengine import tensor
......@@ -293,3 +294,25 @@ def test_empty_tensor(is_trace):
run_test(op, [inps[1], inps[1]], (inps[1] + inps[1]).shape, False)
run_test(op, [inps[0], inps[2]], (inps[0] + inps[2]).shape, False)
run_test(op, [inps[1], inps[2]], (inps[1] + inps[2]).shape, False)
@pytest.mark.parametrize("is_trace", [True, False])
def test_maximum_grad_consistency(is_trace):
def f(x):
with ad.GradManager() as gm:
gm.attach(x)
gm.backward(F.maximum(x, x))
dx = x.grad
x.grad = None
return dx
def run(f):
x = F.arange(10)
for i in range(3):
np.testing.assert_equal(f(x).numpy(), np.ones(10))
if is_trace:
for symbolic in [False, True]:
run(trace(symbolic=symbolic)(f))
else:
run(f)
......@@ -117,6 +117,8 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() {
// misc
ENTRY(COND_LEQ_MOV,
ASTPtr::make<BinaryAST>("<=", inps[0], inps[1]) * inps[2]),
ENTRY(COND_LT_MOV,
ASTPtr::make<BinaryAST>("<", inps[0], inps[1]) * inps[2]),
ENTRY(FUSE_MUL_ADD3, inps[0] * inps[1] + inps[2]),
ENTRY(FUSE_MUL_ADD4, inps[0] * inps[1] + inps[2] * inps[3]),
ENTRY(FUSE_ADD_RELU, make_call("fmaxf", {inps[0] + inps[1], 0})),
......
......@@ -147,6 +147,8 @@ Halide::Expr dispatch_elemwise_mode(
// ternary
case Mode::COND_LEQ_MOV:
return Halide::select(inp(0) <= inp(1), inp(2), cv(0));
case Mode::COND_LT_MOV:
return Halide::select(inp(0) < inp(1), inp(2), cv(0));
case Mode::FUSE_MUL_ADD3:
return inp(0) * inp(1) + inp(2);
case Mode::FUSE_MUL_ADD4:
......
......@@ -388,6 +388,15 @@ mlir::Value lower_mode<Mode::COND_LEQ_MOV>(
helper.le(operands[0], operands[1]), operands[2], helper.const_f32(0.f));
}
//! COND_LT_MOV: x < y ? z : ctype(0)
template <>
mlir::Value lower_mode<Mode::COND_LT_MOV>(
mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
ValueBuilderHelper helper(builder, loc);
return helper.select(
helper.lt(operands[0], operands[1]), operands[2], helper.const_f32(0.f));
}
//! FUSE_MUL_ADD3: x * y + z
template <>
mlir::Value lower_mode<Mode::FUSE_MUL_ADD3>(
......
......@@ -60,6 +60,7 @@
#define MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) \
cb(CondLeqMovOp, COND_LEQ_MOV) \
cb(CondLtMovOp, COND_LT_MOV) \
cb(FuseMulAdd3Op, FUSE_MUL_ADD3)
// clang-format on
......
......@@ -449,6 +449,7 @@ TYPED_TEST(TestJITMlirBinaryElemwise, runGpu) {
// clang-format off
#define FOREACH_TERNARY_MODE(cb) \
cb(COND_LEQ_MOV) \
cb(COND_LT_MOV) \
cb(FUSE_MUL_ADD3) \
// clang-format on
template <typename tag>
......
......@@ -452,6 +452,7 @@ void run<all_oprs>(Backend backend, CompNode cn) {
CHECK_ELEM2(ATAN2, true, gt0);
CHECK_ELEM3(COND_LEQ_MOV, false, none);
CHECK_ELEM3(COND_LT_MOV, false, none);
CHECK_ELEM3(FUSE_MUL_ADD3, true, none);
CHECK_ELEM4(FUSE_MUL_ADD4, true, none);
......
......@@ -601,9 +601,17 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
case Mode::FLOOR_DIV:
return nullptr;
case Mode::MAX:
RET(EL3(COND_LEQ_MOV, i[!wrt_idx], i[wrt_idx], og));
if (wrt_idx) {
RET(EL3(COND_LT_MOV, i[0], i[1], og));
} else {
RET(EL3(COND_LEQ_MOV, i[1], i[0], og));
}
case Mode::MIN:
RET(EL3(COND_LEQ_MOV, i[wrt_idx], i[!wrt_idx], og));
if (wrt_idx) {
RET(EL3(COND_LT_MOV, i[1], i[0], og));
} else {
RET(EL3(COND_LEQ_MOV, i[0], i[1], og));
}
case Mode::MOD:
if (wrt_idx == 0) {
RET(og);
......@@ -661,7 +669,10 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
if (wrt_idx <= 1)
return nullptr;
RET(EL3(COND_LEQ_MOV, i0, i1, og));
case Mode::COND_LT_MOV:
if (wrt_idx <= 1)
return nullptr;
RET(EL3(COND_LT_MOV, i0, i1, og));
// fuse oprs
case Mode::FUSE_MUL_ADD3:
if (wrt_idx < 2) {
......
......@@ -571,6 +571,8 @@ struct CheckerConfig<GELU_GRAD> : public NoGradCheckerConfig {};
/* ======================= ternary config ======================= */
template <>
struct CheckerConfig<COND_LEQ_MOV> : public BinaryInputMinGap<false> {};
template <>
struct CheckerConfig<COND_LT_MOV> : public BinaryInputMinGap<false> {};
/* ======================= test runner ======================= */
namespace detail {
......
......@@ -13,6 +13,7 @@
#define _ALLOW_FLOAT true
#define _ALLOW_INT true
DEF_TRAIT(COND_LEQ_MOV, x <= y ? z : 0)
DEF_TRAIT(COND_LT_MOV, x < y ? z : 0)
DEF_TRAIT(FUSE_MUL_ADD3, x* y + z)
#undef _ALLOW_INT
#undef _ALLOW_FLOAT
......
......@@ -589,6 +589,7 @@ TEST(TestOprElemwiseMultiType, QuantizedModeTernary_IS8_OS8) {
switch (mode) {
MAKE_TERNARY(FUSE_MUL_ADD3);
MAKE_TERNARY(COND_LEQ_MOV);
MAKE_TERNARY(COND_LT_MOV);
default:
mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n");
break;
......@@ -646,6 +647,7 @@ TEST(TestOprElemwiseMultiType, QuantizedModeTernary_I8Asymm_O8Asymm) {
switch (mode) {
MAKE_TERNARY(FUSE_MUL_ADD3);
MAKE_TERNARY(COND_LEQ_MOV);
MAKE_TERNARY(COND_LT_MOV);
default:
mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n");
break;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册