提交 0c6cf98d 编写于 作者: B buxue

fix bug of brpop of FloorMod

上级 ef596f26
......@@ -255,13 +255,10 @@ def get_bprop_floordiv(self):
@bprop_getters.register(P.FloorMod)
def get_bprop_floormod(self):
"""Grad definition for `FloorMod` operation."""
div_op = P.FloorMod()
neg = P.Neg()
mul_op = P.Mul()
def bprop(x, y, out, dout):
bc_x = div_op(dout, y)
bc_y = neg(mul_op(bc_x, out))
bc_x = dout
bc_y = -dout * (x // y)
return binop_grad_common(x, y, bc_x, bc_y)
return bprop
......@@ -412,6 +409,7 @@ def get_bprop_reducesum(self):
def get_bprop_cumsum(self):
"""Grad definition for `CumSum` operation."""
cumsum = P.CumSum(exclusive=self.exclusive, reverse=not self.reverse)
def bprop(x, axis, out, dout):
return cumsum(dout, axis), zeros_like(axis)
return bprop
......@@ -787,6 +785,7 @@ def get_bprop_atan2(self):
"""Generate bprop for Atan2"""
square = P.Square()
def bprop(x, y, out, dout):
tmp = dout / (square(x) + square(y))
dx = tmp * y
......
......@@ -351,9 +351,8 @@ test_case_math_ops = [
'skip': ['backward']}),
('FloorMod', {
'block': P.FloorMod(),
'desc_inputs': [Tensor(np.random.rand(4).astype(np.float16)),
Tensor(np.random.rand(4).astype(np.float16))],
'skip': ['backward']}),
'desc_inputs': [[3, 4, 5], [2, 3, 4, 5]],
'desc_bprop': [[2, 3, 4, 5]]}),
('identity', {
'block': ops.functional.identity,
'desc_inputs': [[2, 2]],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册