From 17a03629482d35feb291c7282a3b1f8c2a058ddb Mon Sep 17 00:00:00 2001 From: Charles-hit <56987902+Charles-hit@users.noreply.github.com> Date: Wed, 26 Oct 2022 10:32:17 +0800 Subject: [PATCH] fix slice_assign_p (#47324) * fix slice_assign_p jvp and transpose * modify code style * modify test_jvp_and_transpose for slice_assign_p * modify code style * add unit test --- .../autograd/test_jvp_and_transpose.py | 93 ++++++++++++++++++- .../tests/unittests/autograd/test_primapi.py | 2 +- python/paddle/incubate/autograd/primrules.py | 72 +++++++++++--- 3 files changed, 149 insertions(+), 18 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py b/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py index e90d6871c30..b3beebfa72f 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_jvp_and_transpose.py @@ -883,7 +883,7 @@ class TestSliceSelectPJVPAndTranspose(TestAddPJVPAndTranspose): ] -class TestSliceAssignPJVPAndTranspose(TestAddPJVPAndTranspose): +class TestSliceAssignPJVPAndTranspose1(TestAddPJVPAndTranspose): def init_data(self): # Set prim op self.op_type = 'slice_assign_p' @@ -909,23 +909,110 @@ class TestSliceAssignPJVPAndTranspose(TestAddPJVPAndTranspose): self.jvp_out_shape_map = {0: self.prim_output['Z']} # Set transpose - check_dot = lambda v: v is X or v is Y + check_dot = lambda v: v is X Z_BAR = paddle.static.data(name='Z_BAR', shape=[3, 20], dtype='float64') self.transpose_args = (check_dot, Z_BAR) - self.transpose_out_shape_map = {0: X, 1: Y} + self.transpose_out_shape_map = {0: X} self.all_ops = [ # prim op: 'slice_assign_p', # jvp op: 'slice_assign_p', + "slice_assign_p", + "add_p", + "fill_constant_p", + "fill_constant_p", # transpose op: 'slice_assign_p', + 'fill_constant_p', + ] + + +class TestSliceAssignPJVPAndTranspose2(TestAddPJVPAndTranspose): + def init_data(self): + # Set prim op + self.op_type = 'slice_assign_p' + X = paddle.static.data(name='X', shape=[3, 20], dtype='float64') + Y = paddle.static.data(name='Y', shape=[3, 5], dtype='float64') + self.prim_input = {'X': X, 'Y': Y} + self.prim_output = { + 'Z': self.layer_help.create_variable_for_type_inference( + dtype=X.dtype + ) + } + self.prim_attrs = { + 'axis': [1], + 'starts': [0], + 'ends': [10], + 'strides': [2], + } + + # Set JVP + Y_DOT = paddle.static.data(name='Y_DOT', shape=[3, 5], dtype='float64') + self.jvp_args = (None, Y_DOT) + self.jvp_out_shape_map = {0: self.prim_output['Z']} + + # Set transpose + check_dot = lambda v: v is Y + Z_BAR = paddle.static.data(name='Z_BAR', shape=[3, 20], dtype='float64') + self.transpose_args = (check_dot, Z_BAR) + self.transpose_out_shape_map = {1: Y} + + self.all_ops = [ + # prim op: + 'slice_assign_p', + # jvp op: + 'slice_assign_p', + "fill_constant_p", + # transpose op: 'slice_select_p', 'fill_constant_p', ] +class TestSliceAssignPJVPAndTranspose3(TestAddPJVPAndTranspose): + def init_data(self): + # Set prim op + self.op_type = 'slice_assign_p' + X = paddle.static.data(name='X', shape=[3, 20], dtype='float64') + Y = paddle.static.data(name='Y', shape=[3, 5], dtype='float64') + self.prim_input = {'X': X, 'Y': Y} + self.prim_output = { + 'Z': self.layer_help.create_variable_for_type_inference( + dtype=X.dtype + ) + } + self.prim_attrs = { + 'axis': [1], + 'starts': [0], + 'ends': [10], + 'strides': [2], + } + + # Set JVP + X_DOT = paddle.static.data(name='X_DOT', shape=[3, 20], dtype='float64') + self.jvp_args = (X_DOT, None) + self.jvp_out_shape_map = {0: self.prim_output['Z']} + + # Set transpose + check_dot = lambda v: v is X + Z_BAR = paddle.static.data(name='Z_BAR', shape=[3, 20], dtype='float64') + self.transpose_args = (check_dot, Z_BAR) + self.transpose_out_shape_map = {0: X} + + self.all_ops = [ + # prim op: + 'slice_assign_p', + # jvp op: + 'slice_assign_p', + "fill_constant_p", + # transpose op: + 'slice_assign_p', + 'fill_constant_p', + ] + + class TestGatherPJVPAndTranspose(TestAddPJVPAndTranspose): def init_data(self): # Set prim op diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py index 7f1a06c9240..7e576f95ca5 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py @@ -100,7 +100,7 @@ class TestFowardApi(unittest.TestCase): actual = actual() self.assertEqual(type(actual), type(expected)) for i, j in zip(actual, expected): - np.testing.assert_allclose(i, j, atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(i, j, rtol=1e-6) @utils.place(config.DEVICES) diff --git a/python/paddle/incubate/autograd/primrules.py b/python/paddle/incubate/autograd/primrules.py index 1d98d62cef3..badd8476463 100644 --- a/python/paddle/incubate/autograd/primrules.py +++ b/python/paddle/incubate/autograd/primrules.py @@ -1026,17 +1026,53 @@ def slice_select_jvp(op, x_dot): @REGISTER_JVP('slice_assign_p') def slice_assign_jvp(op, x_dot, y_dot): - if x_dot is None: - assert y_dot is None, 'y_dot must be None.' - return None - else: - assert y_dot is not None, 'y_dot should not be None.' + x, y = op_position_inputs(op) + assert ( + x_dot is not None or y_dot is not None + ), "x_dot and y_dot can't be None at the same time. " axis = op.attr('axis') starts = op.attr('starts') ends = op.attr('ends') strides = op.attr('strides') - return linear_jvp( - op, x_dot, y_dot, axis=axis, starts=starts, ends=ends, strides=strides + if x_dot is None: + return linear_jvp( + op, + fill_const(value=0.0, shape=x.shape, dtype=x.dtype), + y_dot, + axis=axis, + starts=starts, + ends=ends, + strides=strides, + ) + elif y_dot is None: + return linear_jvp( + op, + x_dot, + fill_const(value=0.0, shape=y.shape, dtype=y.dtype), + axis=axis, + starts=starts, + ends=ends, + strides=strides, + ) + return add( + linear_jvp( + op, + fill_const(value=0.0, shape=x.shape, dtype=x.dtype), + y_dot, + axis=axis, + starts=starts, + ends=ends, + strides=strides, + ), + linear_jvp( + op, + x_dot, + fill_const(value=0.0, shape=y.shape, dtype=y.dtype), + axis=axis, + starts=starts, + ends=ends, + strides=strides, + ), ) @@ -1311,8 +1347,8 @@ def slice_select_transpose(op, check_dot, y_bar): @REGISTER_TRANSPOSE('slice_assign_p') def slice_assign_transpose(op, check_dot, z_bar): x, y = op_position_inputs(op) - assert check_dot(x) and check_dot(y), ( - f'(check_dot(x) and check_dot(y)) must be True, ' + assert check_dot(x) ^ check_dot(y), ( + f'(check_dot(x) ^ check_dot(y)) must be True, ' f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.' ) zeros = fill_const(value=0.0, shape=y.shape, dtype=y.dtype) @@ -1320,13 +1356,21 @@ def slice_assign_transpose(op, check_dot, z_bar): starts = op.attr('starts') ends = op.attr('ends') strides = op.attr('strides') - x_bar = slice_assign( - z_bar, zeros, axis=axis, starts=starts, ends=ends, strides=strides - ) - y_bar = slice_select( + if check_dot(x): + return ( + slice_assign( + z_bar, + zeros, + axis=axis, + starts=starts, + ends=ends, + strides=strides, + ), + None, + ) + return None, slice_select( z_bar, axis=axis, starts=starts, ends=ends, strides=strides ) - return x_bar, y_bar @REGISTER_TRANSPOSE('gather_p') -- GitLab