提交 7bd848ce 编写于 作者: M Megvii Engine Team

fix(subgraph): fix hand-written backward for serval jit-elemwise ops

GitOrigin-RevId: ea3a40d96efb6dd083fa4278c041837aff3833d0
上级 7be7656c
......@@ -909,9 +909,9 @@ def _get_prelu_op(dtype=None, device=None):
min_0 = f("min", inp, c(0))
oup = f("fma3", min_0, weight, max_0)
(oup_grad,) = yield (oup,)
inp_grad_0 = f("cond_leq_mov", inp, c(0), oup_grad)
inp_grad_0 = f("cond_leq_mov", c(0), inp, oup_grad)
inp_grad_1 = f("*", oup_grad, weight)
inp_grad_1 = f("cond_leq_mov", c(0), inp, inp_grad_1)
inp_grad_1 = f("cond_leq_mov", inp, c(0), inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
weight_grad = f("*", oup_grad, min_0)
yield (inp_grad, weight_grad)
......@@ -925,7 +925,7 @@ def prelu(inp: Tensor, weight: Tensor) -> Tensor:
Refer to :class:`~.PReLU` for more information.
"""
prelu = _get_prelu_op(dtype=inp.dtype, device=inp.device)
(oup,) = prelu(inp, weight)
(oup,) = prelu(inp, broadcast_to(weight, inp.shape))
return oup
......@@ -947,7 +947,7 @@ def _get_leagk_relu_op(negative_slope, *, dtype=None, device=None):
(oup_grad,) = yield (oup,)
inp_grad_0 = f("cond_leq_mov", c(0), inp, oup_grad)
inp_grad_1 = f("*", oup_grad, c(negative_slope))
inp_grad_1 = f("cond_leq_mov", inp, c(negative_slope), inp_grad_1)
inp_grad_1 = f("cond_leq_mov", inp, c(0), inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
yield (inp_grad,)
......@@ -994,13 +994,14 @@ def _get_softplus_op(dtype=None, device=None):
(inp,) = inputs[0:1]
neg_abs = f("-", f("abs", inp))
exp = f("exp", neg_abs)
oup = f("log1p", exp)
oup = f("+", oup, f("relu", inp))
oup0 = f("log1p", exp)
oup1 = f("relu", inp)
oup = f("+", oup0, oup1)
(oup_grad,) = yield (oup,)
inp_grad_0 = f("switch_gt0", inp, oup_grad)
inp_grad_0 = f("switch_gt0", oup1, oup_grad)
inp_grad_1 = oup_grad
inp_grad_1 = f("/", oup_grad, f("+", exp, c(1)))
inp_grad_1 = f("*", oup_grad, exp)
inp_grad_1 = f("*", inp_grad_1, exp)
inp_grad_1 = f("-", inp_grad_1)
inp_grad_1 = f("abs_grad", inp, inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
......@@ -1098,16 +1099,17 @@ def _get_logsigmoid_op(dtype=None, device=None):
(inp,) = inputs[0:1]
neg_abs = f("-", f("abs", inp))
exp = f("exp", neg_abs)
oup = f("log1p", exp)
oup = f("+", oup, f("relu", f("-", inp)))
oup0 = f("log1p", exp)
oup1 = f("relu", f("-", inp))
oup = f("+", oup0, oup1)
oup = f("-", oup)
(oup_grad,) = yield (oup,)
oup_grad = f("-", oup_grad)
inp_grad_0 = f("switch_gt0", inp, oup_grad)
inp_grad_0 = f("switch_gt0", oup1, oup_grad)
inp_grad_0 = f("-", inp_grad_0)
inp_grad_1 = oup_grad
inp_grad_1 = f("/", oup_grad, f("+", exp, c(1)))
inp_grad_1 = f("*", oup_grad, exp)
inp_grad_1 = f("/", inp_grad_1, f("+", exp, c(1)))
inp_grad_1 = f("*", inp_grad_1, exp)
inp_grad_1 = f("-", inp_grad_1)
inp_grad_1 = f("abs_grad", inp, inp_grad_1)
inp_grad = f("+", inp_grad_0, inp_grad_1)
......
......@@ -726,6 +726,28 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
TensorShape shape;
DType dtype = input_descs[0].layout.dtype;
CompNode comp_node = input_descs[0].comp_node;
for (auto&& desc : input_descs) {
if (desc.layout.ndim) {
shape = desc.layout;
break;
}
}
for (size_t i = 0; i < input_descs.size(); ++i) {
if (input_descs[i].layout.ndim) {
mgb_assert(
input_descs[i].layout.eq_shape(shape),
"inputs of JITFusionOp should have same shapes");
}
mgb_assert(
input_descs[i].layout.dtype == dtype,
"inputs of JITFusionOp should have same dtypes");
mgb_assert(
input_descs[i].comp_node == comp_node,
"inputs of JITFusionOp should have same devices");
}
return OpDef::infer_output_attrs_fallible(
*def.cast_final_safe<JITFusionOp>().op, input_descs);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册