From 70115083798664b86b22b7486bd367300406d739 Mon Sep 17 00:00:00 2001 From: fangzehua Date: Mon, 3 Aug 2020 14:29:36 +0800 Subject: [PATCH] add squreasumall grad --- mindspore/ops/_grad/grad_math_ops.py | 16 ++++++++++++++++ tests/ut/python/ops/test_ops.py | 3 ++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index 446c634a2..e3ca27f65 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -397,6 +397,22 @@ def get_bprop_xlogy(self): return bprop +@bprop_getters.register(P.SquareSumAll) +def get_bprop_square_sum_all(self): + """Grad definition for `Square` operation.""" + mul_func = P.Mul() + fill_func = P.Fill() + dtype = P.DType() + + def bprop(x, y, out, dout): + temp_x = mul_func(dout[0], x) + temp_y = mul_func(dout[1], y) + dx = mul_func(fill_func(dtype(temp_x), shape_op(x), 2.0), temp_x) + dy = mul_func(fill_func(dtype(temp_y), shape_op(y), 2.0), temp_y) + return (dx, dy) + + return bprop + @bprop_getters.register(P.Sqrt) def get_bprop_sqrt(self): diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index f22366e13..3ef6df6dc 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -1188,7 +1188,8 @@ test_case_math_ops = [ 'block': P.SquareSumAll(), 'desc_inputs': [Tensor(np.array([0, 1, 4, 5]).astype(np.float32)), Tensor(np.array([1, 1, 3, 7]).astype(np.float32))], - 'skip': ['backward']}), + 'desc_bprop': [Tensor(np.array(0.1).astype(np.float32)), + Tensor(np.array(0.1).astype(np.float32))]}), ('Cos', { 'block': P.Cos(), 'desc_inputs': [[2, 3]], -- GitLab