From cfc41648a4b7759b8c2ccfacba265e4d247693cb Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 20 Jun 2022 16:21:09 +0800 Subject: [PATCH] fix(mge): fix grad of maximum(x, x) GitOrigin-RevId: e0e2efb71bbe507bd5b4dab539b5b9cfe79d1187 --- .../test/unit/functional/test_elemwise.py | 23 +++++++++++++++++++ src/jit/impl/ast_c.cpp | 2 ++ src/jit/impl/halide/ast_hl.cpp | 2 ++ src/jit/impl/mlir/ir/each_mode.cpp | 9 ++++++++ src/jit/impl/mlir/ir/each_mode.h | 1 + src/jit/test/codegen.cpp | 1 + src/jit/test/fusion.cpp | 1 + src/opr/impl/basic_arith.cpp | 17 +++++++++++--- src/opr/test/basic_arith/elemwise.cpp | 2 ++ .../elemwise_ternary_trait_def.inl | 1 + src/opr/test/nn_int.cpp | 2 ++ 11 files changed, 58 insertions(+), 3 deletions(-) diff --git a/imperative/python/test/unit/functional/test_elemwise.py b/imperative/python/test/unit/functional/test_elemwise.py index 30ed9d80f..b02a8aa11 100644 --- a/imperative/python/test/unit/functional/test_elemwise.py +++ b/imperative/python/test/unit/functional/test_elemwise.py @@ -2,6 +2,7 @@ import numpy as np import pytest +import megengine.autodiff as ad import megengine.functional as F import megengine.functional.elemwise as elemwise from megengine import tensor @@ -293,3 +294,25 @@ def test_empty_tensor(is_trace): run_test(op, [inps[1], inps[1]], (inps[1] + inps[1]).shape, False) run_test(op, [inps[0], inps[2]], (inps[0] + inps[2]).shape, False) run_test(op, [inps[1], inps[2]], (inps[1] + inps[2]).shape, False) + + +@pytest.mark.parametrize("is_trace", [True, False]) +def test_maximum_grad_consistency(is_trace): + def f(x): + with ad.GradManager() as gm: + gm.attach(x) + gm.backward(F.maximum(x, x)) + dx = x.grad + x.grad = None + return dx + + def run(f): + x = F.arange(10) + for i in range(3): + np.testing.assert_equal(f(x).numpy(), np.ones(10)) + + if is_trace: + for symbolic in [False, True]: + run(trace(symbolic=symbolic)(f)) + else: + run(f) diff --git a/src/jit/impl/ast_c.cpp b/src/jit/impl/ast_c.cpp index 408a3c92b..25df37758 100644 --- a/src/jit/impl/ast_c.cpp +++ b/src/jit/impl/ast_c.cpp @@ -117,6 +117,8 @@ const ElemGeneratorMap& ast_c::elem_opr_generator() { // misc ENTRY(COND_LEQ_MOV, ASTPtr::make("<=", inps[0], inps[1]) * inps[2]), + ENTRY(COND_LT_MOV, + ASTPtr::make("<", inps[0], inps[1]) * inps[2]), ENTRY(FUSE_MUL_ADD3, inps[0] * inps[1] + inps[2]), ENTRY(FUSE_MUL_ADD4, inps[0] * inps[1] + inps[2] * inps[3]), ENTRY(FUSE_ADD_RELU, make_call("fmaxf", {inps[0] + inps[1], 0})), diff --git a/src/jit/impl/halide/ast_hl.cpp b/src/jit/impl/halide/ast_hl.cpp index 1aca7dde5..275f938cc 100644 --- a/src/jit/impl/halide/ast_hl.cpp +++ b/src/jit/impl/halide/ast_hl.cpp @@ -147,6 +147,8 @@ Halide::Expr dispatch_elemwise_mode( // ternary case Mode::COND_LEQ_MOV: return Halide::select(inp(0) <= inp(1), inp(2), cv(0)); + case Mode::COND_LT_MOV: + return Halide::select(inp(0) < inp(1), inp(2), cv(0)); case Mode::FUSE_MUL_ADD3: return inp(0) * inp(1) + inp(2); case Mode::FUSE_MUL_ADD4: diff --git a/src/jit/impl/mlir/ir/each_mode.cpp b/src/jit/impl/mlir/ir/each_mode.cpp index 15a8d8121..b89f55f69 100644 --- a/src/jit/impl/mlir/ir/each_mode.cpp +++ b/src/jit/impl/mlir/ir/each_mode.cpp @@ -388,6 +388,15 @@ mlir::Value lower_mode( helper.le(operands[0], operands[1]), operands[2], helper.const_f32(0.f)); } +//! COND_LT_MOV: x < y ? z : ctype(0) +template <> +mlir::Value lower_mode( + mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.select( + helper.lt(operands[0], operands[1]), operands[2], helper.const_f32(0.f)); +} + //! FUSE_MUL_ADD3: x * y + z template <> mlir::Value lower_mode( diff --git a/src/jit/impl/mlir/ir/each_mode.h b/src/jit/impl/mlir/ir/each_mode.h index 7c2f994fc..830ef6fc3 100644 --- a/src/jit/impl/mlir/ir/each_mode.h +++ b/src/jit/impl/mlir/ir/each_mode.h @@ -60,6 +60,7 @@ #define MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) \ cb(CondLeqMovOp, COND_LEQ_MOV) \ + cb(CondLtMovOp, COND_LT_MOV) \ cb(FuseMulAdd3Op, FUSE_MUL_ADD3) // clang-format on diff --git a/src/jit/test/codegen.cpp b/src/jit/test/codegen.cpp index a363ae9ee..37bbb8aa1 100644 --- a/src/jit/test/codegen.cpp +++ b/src/jit/test/codegen.cpp @@ -449,6 +449,7 @@ TYPED_TEST(TestJITMlirBinaryElemwise, runGpu) { // clang-format off #define FOREACH_TERNARY_MODE(cb) \ cb(COND_LEQ_MOV) \ + cb(COND_LT_MOV) \ cb(FUSE_MUL_ADD3) \ // clang-format on template diff --git a/src/jit/test/fusion.cpp b/src/jit/test/fusion.cpp index 26d15a4b6..fca6c633b 100644 --- a/src/jit/test/fusion.cpp +++ b/src/jit/test/fusion.cpp @@ -452,6 +452,7 @@ void run(Backend backend, CompNode cn) { CHECK_ELEM2(ATAN2, true, gt0); CHECK_ELEM3(COND_LEQ_MOV, false, none); + CHECK_ELEM3(COND_LT_MOV, false, none); CHECK_ELEM3(FUSE_MUL_ADD3, true, none); CHECK_ELEM4(FUSE_MUL_ADD4, true, none); diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index 1639a5abb..15b87f815 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -601,9 +601,17 @@ MGB_IMPL_OPR_GRAD(Elemwise) { case Mode::FLOOR_DIV: return nullptr; case Mode::MAX: - RET(EL3(COND_LEQ_MOV, i[!wrt_idx], i[wrt_idx], og)); + if (wrt_idx) { + RET(EL3(COND_LT_MOV, i[0], i[1], og)); + } else { + RET(EL3(COND_LEQ_MOV, i[1], i[0], og)); + } case Mode::MIN: - RET(EL3(COND_LEQ_MOV, i[wrt_idx], i[!wrt_idx], og)); + if (wrt_idx) { + RET(EL3(COND_LT_MOV, i[1], i[0], og)); + } else { + RET(EL3(COND_LEQ_MOV, i[0], i[1], og)); + } case Mode::MOD: if (wrt_idx == 0) { RET(og); @@ -661,7 +669,10 @@ MGB_IMPL_OPR_GRAD(Elemwise) { if (wrt_idx <= 1) return nullptr; RET(EL3(COND_LEQ_MOV, i0, i1, og)); - + case Mode::COND_LT_MOV: + if (wrt_idx <= 1) + return nullptr; + RET(EL3(COND_LT_MOV, i0, i1, og)); // fuse oprs case Mode::FUSE_MUL_ADD3: if (wrt_idx < 2) { diff --git a/src/opr/test/basic_arith/elemwise.cpp b/src/opr/test/basic_arith/elemwise.cpp index e0ccf54ff..c220102cd 100644 --- a/src/opr/test/basic_arith/elemwise.cpp +++ b/src/opr/test/basic_arith/elemwise.cpp @@ -571,6 +571,8 @@ struct CheckerConfig : public NoGradCheckerConfig {}; /* ======================= ternary config ======================= */ template <> struct CheckerConfig : public BinaryInputMinGap {}; +template <> +struct CheckerConfig : public BinaryInputMinGap {}; /* ======================= test runner ======================= */ namespace detail { diff --git a/src/opr/test/basic_arith/elemwise_ternary_trait_def.inl b/src/opr/test/basic_arith/elemwise_ternary_trait_def.inl index fb4191863..e9e5cd8ec 100644 --- a/src/opr/test/basic_arith/elemwise_ternary_trait_def.inl +++ b/src/opr/test/basic_arith/elemwise_ternary_trait_def.inl @@ -13,6 +13,7 @@ #define _ALLOW_FLOAT true #define _ALLOW_INT true DEF_TRAIT(COND_LEQ_MOV, x <= y ? z : 0) +DEF_TRAIT(COND_LT_MOV, x < y ? z : 0) DEF_TRAIT(FUSE_MUL_ADD3, x* y + z) #undef _ALLOW_INT #undef _ALLOW_FLOAT diff --git a/src/opr/test/nn_int.cpp b/src/opr/test/nn_int.cpp index 0175dcf74..f8b2fe4d8 100644 --- a/src/opr/test/nn_int.cpp +++ b/src/opr/test/nn_int.cpp @@ -589,6 +589,7 @@ TEST(TestOprElemwiseMultiType, QuantizedModeTernary_IS8_OS8) { switch (mode) { MAKE_TERNARY(FUSE_MUL_ADD3); MAKE_TERNARY(COND_LEQ_MOV); + MAKE_TERNARY(COND_LT_MOV); default: mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n"); break; @@ -646,6 +647,7 @@ TEST(TestOprElemwiseMultiType, QuantizedModeTernary_I8Asymm_O8Asymm) { switch (mode) { MAKE_TERNARY(FUSE_MUL_ADD3); MAKE_TERNARY(COND_LEQ_MOV); + MAKE_TERNARY(COND_LT_MOV); default: mgb_throw(InternalError, "Unknown ElemwiseMultiType Mode\n"); break; -- GitLab