diff --git a/python/paddle/fluid/tests/unittests/test_split_op.py b/python/paddle/fluid/tests/unittests/test_split_op.py index c826a0e1030f42d96f1680c0d2192515a1d8708b..bf3be4080a9fc8b3f4b70e9e1b2aa3161cca7791 100644 --- a/python/paddle/fluid/tests/unittests/test_split_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_op.py @@ -19,6 +19,7 @@ import numpy as np from op_test import OpTest, convert_float_to_uint16 import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard, core +from paddle.fluid.framework import _test_eager_guard class TestSplitOp(OpTest): @@ -402,12 +403,30 @@ class API_TestDygraphSplit(unittest.TestCase): with fluid.dygraph.guard(): input_1 = np.random.random([4, 6, 6]).astype("int32") # input is a variable which shape is [4, 6, 6] - input = fluid.dygraph.to_variable(input_1) + input = paddle.to_tensor(input_1) x0, x1, x2 = paddle.split(input, num_or_sections=3, axis=1) x0_out = x0.numpy() x1_out = x1.numpy() x2_out = x2.numpy() ex_x0, ex_x1, ex_x2 = np.split(input_1, 3, axis=1) + + with _test_eager_guard(): + # input is a variable which shape is [4, 6, 6] + input = paddle.to_tensor(input_1) + input.stop_gradient = False + x0, x1, x2 = paddle.split(input, num_or_sections=3, axis=1) + eager_x0_out = x0.numpy() + eager_x1_out = x1.numpy() + eager_x2_out = x2.numpy() + loss = x0.sum() + loss.backward() + manul_grad = np.zeros_like(input_1) + manul_grad[:, :2, :] = 1 + self.assertTrue(np.allclose(input.gradient(), manul_grad)) + self.assertTrue(np.allclose(ex_x0, eager_x0_out)) + self.assertTrue(np.allclose(ex_x1, eager_x1_out)) + self.assertTrue(np.allclose(ex_x2, eager_x2_out)) + self.assertTrue(np.allclose(ex_x0, x0_out)) self.assertTrue(np.allclose(ex_x1, x1_out)) self.assertTrue(np.allclose(ex_x2, x2_out)) @@ -416,7 +435,7 @@ class API_TestDygraphSplit(unittest.TestCase): with fluid.dygraph.guard(): input_1 = np.random.random([4, 6, 6]).astype("bool") # input is a variable which shape is [4, 6, 6] - input = fluid.dygraph.to_variable(input_1) + input = paddle.to_tensor(input_1) x0, x1, x2 = paddle.split(input, num_or_sections=3, axis=1) x0_out = x0.numpy() x1_out = x1.numpy() @@ -430,7 +449,7 @@ class API_TestDygraphSplit(unittest.TestCase): with fluid.dygraph.guard(): input_1 = np.random.random([4, 6, 6]).astype("int32") # input is a variable which shape is [4, 6, 6] - input = fluid.dygraph.to_variable(input_1) + input = paddle.to_tensor(input_1) num1 = paddle.full(shape=[1], fill_value=2, dtype='int32') x0, x1, x2 = paddle.split( input, num_or_sections=[num1, 2, 2], axis=1) @@ -446,7 +465,7 @@ class API_TestDygraphSplit(unittest.TestCase): with fluid.dygraph.guard(): input_1 = np.random.random([4, 6, 6]).astype("int32") # input is a variable which shape is [4, 6, 6] - input = fluid.dygraph.to_variable(input_1) + input = paddle.to_tensor(input_1) num1 = paddle.full(shape=[1], fill_value=1, dtype='int32') x0, x1, x2 = paddle.split( input, num_or_sections=[2, 2, 2], axis=num1) diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index b4abe5b303b8e6425ebbbe17931e8a6d1db7da16..f5245d59babd29f1983f71a8501f335bb017181b 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -1917,6 +1917,7 @@ args : (Tensor x, IntArray num_or_sections, Scalar(int) axis) output : Tensor[] invoke : split_impl(x, num_or_sections, axis) + backward : split_grad - api : sqrt args : (Tensor x) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index d0f337cb054f46b6f6cecc2633e075e75ee72092..97c9c7ddf158442a56fc92195bb346c80502f73e 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -1523,7 +1523,7 @@ - backward_api : split_grad forward : split (Tensor x, IntArray num_or_sections, Scalar axis) -> Tensor[](out) - args : (Tensor[] out_grad, Scalar axis) + args : (Tensor[] out_grad, Scalar axis = -1) output : Tensor(x_grad) invoke : concat( out_grad, axis) # TODO(zhangyunfei) The config of double grad and triple grad will be supported in the future.