提交 bf1a0fb7 编写于 作者: M Megvii Engine Team

fix(imperative): fix logsigmode bwd implementation

GitOrigin-RevId: 86de18760c1a298a7f5265e0959693c30366dd3f
上级 3a35827d
......@@ -2,11 +2,13 @@
import numpy as np
import pytest
import megengine as mge
import megengine.autodiff as ad
import megengine.functional as F
import megengine.functional.elemwise as elemwise
from megengine import tensor
from megengine.core.tensor import dtype
from megengine.core.tensor.utils import subgraph_fn
from megengine.functional.elemwise import Elemwise
from megengine.jit import trace
......@@ -316,3 +318,113 @@ def test_maximum_grad_consistency(is_trace):
run(trace(symbolic=symbolic)(f))
else:
run(f)
def _get_logsigmoid_op(dtype=None, device=None):
@subgraph_fn(
"LogSigmoid",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=False,
custom_grad=True,
)
def logsigmoid(inputs, f, c):
(inp,) = inputs[0:1]
neg_abs = f("-", f("abs", inp))
exp = f("exp", neg_abs)
oup0 = f("log1p", exp)
oup1 = f("relu", f("-", inp))
oup = f("+", oup0, oup1)
oup = f("-", oup)
(oup_grad,) = yield (oup,)
oup_grad = f("-", oup_grad)
inp_grad_0 = f("switch_gt0", oup1, oup_grad)
inp_grad_0 = f("-", inp_grad_0)
inp_grad_1 = oup_grad
inp_grad_1 = f("/", inp_grad_1, f("+", exp, c(1)))
inp_grad_1 = f("*", inp_grad_1, exp)
inp_grad_1 = f("-", inp_grad_1)
inp_grad_1 = f("abs_grad", inp, inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
yield (inp_grad,)
return logsigmoid
def origin_logsigmoid(inp: mge.tensor) -> mge.tensor:
logsigmoid = _get_logsigmoid_op(inp.dtype, inp.device)
(oup,) = logsigmoid(inp)
return oup
def _get_softplus_op(dtype=None, device=None):
@subgraph_fn(
"Softplus",
dtype=dtype,
device=device,
nr_inputs=1,
jit_fusion=False,
custom_grad=True,
)
def softplus(inputs, f, c):
(inp,) = inputs[0:1]
neg_abs = f("-", f("abs", inp))
exp = f("exp", neg_abs)
oup0 = f("log1p", exp)
oup1 = f("relu", inp)
oup = f("+", oup0, oup1)
(oup_grad,) = yield (oup,)
inp_grad_0 = f("switch_gt0", oup1, oup_grad)
inp_grad_1 = oup_grad
inp_grad_1 = f("/", oup_grad, f("+", exp, c(1)))
inp_grad_1 = f("*", inp_grad_1, exp)
inp_grad_1 = f("-", inp_grad_1)
inp_grad_1 = f("abs_grad", inp, inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
yield (inp_grad,)
return softplus
def origin_softplus(inp: mge.tensor) -> mge.tensor:
softplus = _get_softplus_op(inp.dtype, inp.device)
(oup,) = softplus(inp)
return oup
def test_subgraph_elemwise_mode():
def _test_allclose(func, ori_func):
targets = np.array(2)
inp = np.random.randn(2, 256, 10, 16).astype("float32")
ori_inp = mge.tensor(inp)
mge_inp = mge.tensor(inp)
mge_gm = mge.autodiff.GradManager().attach(mge_inp)
ori_gm = mge.autodiff.GradManager().attach(ori_inp)
for _ in range(2):
with mge_gm:
mge_output = func(mge_inp)
loss = F.loss.square_loss(
mge_output.sum(), mge.tensor(targets, dtype=np.float32)
)
mge_gm.backward(loss)
with ori_gm:
ori_output = ori_func(ori_inp)
loss = F.loss.square_loss(
ori_output.sum(), mge.tensor(targets, dtype=np.float32)
)
ori_gm.backward(loss)
np.testing.assert_allclose(
mge_output.numpy(), ori_output.numpy(), rtol=1e-06
)
np.testing.assert_allclose(
ori_inp.grad.numpy(), mge_inp.grad.numpy(), rtol=1e-06
)
_test_allclose(F.logsigmoid, origin_logsigmoid)
_test_allclose(F.softplus, origin_softplus)
......@@ -559,12 +559,21 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
}
case Mode::RELU6:
RET(EL2(RELU6_GRAD, i0, og));
case Mode::SOFTPLUS:
RET(EL2(SOFTPLUS_GRAD, i0, og));
case Mode::SOFTPLUS: {
auto abse = EL1(EXP, EL1(NEGATE, EL1(ABS, i0)));
auto logg = og * abse / (1 + abse);
auto absg = EL2(ABS_GRAD, i0, EL1(NEGATE, logg));
RET(EL2(ADD, absg, EL2(SWITCH_GT0, EL1(RELU, i0), og)));
}
case Mode::HSIGMOID:
RET(EL2(HSIGMOID_GRAD, i0, og));
case Mode::LOGSIGMOID:
RET(EL2(SOFTPLUS_GRAD, EL1(NEGATE, i0), og));
case Mode::LOGSIGMOID: {
og = EL1(NEGATE, og);
auto abse = EL1(EXP, EL1(NEGATE, EL1(ABS, i0)));
auto logg = og * abse / (1 + abse);
auto absg = EL2(ABS_GRAD, i0, EL1(NEGATE, logg));
RET(EL2(SUB, absg, EL2(SWITCH_GT0, EL1(RELU, EL1(NEGATE, i0)), og)));
}
case Mode::SQRT:
RET(og / EL1(SQRT, i0) / 2);
case Mode::SQUARE:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册