提交 70115083 编写于 作者: F fangzehua

add squreasumall grad

上级 3e7ba14e
......@@ -397,6 +397,22 @@ def get_bprop_xlogy(self):
return bprop
@bprop_getters.register(P.SquareSumAll)
def get_bprop_square_sum_all(self):
"""Grad definition for `Square` operation."""
mul_func = P.Mul()
fill_func = P.Fill()
dtype = P.DType()
def bprop(x, y, out, dout):
temp_x = mul_func(dout[0], x)
temp_y = mul_func(dout[1], y)
dx = mul_func(fill_func(dtype(temp_x), shape_op(x), 2.0), temp_x)
dy = mul_func(fill_func(dtype(temp_y), shape_op(y), 2.0), temp_y)
return (dx, dy)
return bprop
@bprop_getters.register(P.Sqrt)
def get_bprop_sqrt(self):
......
......@@ -1188,7 +1188,8 @@ test_case_math_ops = [
'block': P.SquareSumAll(),
'desc_inputs': [Tensor(np.array([0, 1, 4, 5]).astype(np.float32)),
Tensor(np.array([1, 1, 3, 7]).astype(np.float32))],
'skip': ['backward']}),
'desc_bprop': [Tensor(np.array(0.1).astype(np.float32)),
Tensor(np.array(0.1).astype(np.float32))]}),
('Cos', {
'block': P.Cos(),
'desc_inputs': [[2, 3]],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册