未验证 提交 17a03629 编写于 作者: C Charles-hit 提交者: GitHub

fix slice_assign_p (#47324)

* fix slice_assign_p jvp and transpose

* modify code style

* modify test_jvp_and_transpose for slice_assign_p

* modify code style

* add unit test
上级 cb09cf99
...@@ -883,7 +883,7 @@ class TestSliceSelectPJVPAndTranspose(TestAddPJVPAndTranspose): ...@@ -883,7 +883,7 @@ class TestSliceSelectPJVPAndTranspose(TestAddPJVPAndTranspose):
] ]
class TestSliceAssignPJVPAndTranspose(TestAddPJVPAndTranspose): class TestSliceAssignPJVPAndTranspose1(TestAddPJVPAndTranspose):
def init_data(self): def init_data(self):
# Set prim op # Set prim op
self.op_type = 'slice_assign_p' self.op_type = 'slice_assign_p'
...@@ -909,23 +909,110 @@ class TestSliceAssignPJVPAndTranspose(TestAddPJVPAndTranspose): ...@@ -909,23 +909,110 @@ class TestSliceAssignPJVPAndTranspose(TestAddPJVPAndTranspose):
self.jvp_out_shape_map = {0: self.prim_output['Z']} self.jvp_out_shape_map = {0: self.prim_output['Z']}
# Set transpose # Set transpose
check_dot = lambda v: v is X or v is Y check_dot = lambda v: v is X
Z_BAR = paddle.static.data(name='Z_BAR', shape=[3, 20], dtype='float64') Z_BAR = paddle.static.data(name='Z_BAR', shape=[3, 20], dtype='float64')
self.transpose_args = (check_dot, Z_BAR) self.transpose_args = (check_dot, Z_BAR)
self.transpose_out_shape_map = {0: X, 1: Y} self.transpose_out_shape_map = {0: X}
self.all_ops = [ self.all_ops = [
# prim op: # prim op:
'slice_assign_p', 'slice_assign_p',
# jvp op: # jvp op:
'slice_assign_p', 'slice_assign_p',
"slice_assign_p",
"add_p",
"fill_constant_p",
"fill_constant_p",
# transpose op: # transpose op:
'slice_assign_p', 'slice_assign_p',
'fill_constant_p',
]
class TestSliceAssignPJVPAndTranspose2(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'slice_assign_p'
X = paddle.static.data(name='X', shape=[3, 20], dtype='float64')
Y = paddle.static.data(name='Y', shape=[3, 5], dtype='float64')
self.prim_input = {'X': X, 'Y': Y}
self.prim_output = {
'Z': self.layer_help.create_variable_for_type_inference(
dtype=X.dtype
)
}
self.prim_attrs = {
'axis': [1],
'starts': [0],
'ends': [10],
'strides': [2],
}
# Set JVP
Y_DOT = paddle.static.data(name='Y_DOT', shape=[3, 5], dtype='float64')
self.jvp_args = (None, Y_DOT)
self.jvp_out_shape_map = {0: self.prim_output['Z']}
# Set transpose
check_dot = lambda v: v is Y
Z_BAR = paddle.static.data(name='Z_BAR', shape=[3, 20], dtype='float64')
self.transpose_args = (check_dot, Z_BAR)
self.transpose_out_shape_map = {1: Y}
self.all_ops = [
# prim op:
'slice_assign_p',
# jvp op:
'slice_assign_p',
"fill_constant_p",
# transpose op:
'slice_select_p', 'slice_select_p',
'fill_constant_p', 'fill_constant_p',
] ]
class TestSliceAssignPJVPAndTranspose3(TestAddPJVPAndTranspose):
def init_data(self):
# Set prim op
self.op_type = 'slice_assign_p'
X = paddle.static.data(name='X', shape=[3, 20], dtype='float64')
Y = paddle.static.data(name='Y', shape=[3, 5], dtype='float64')
self.prim_input = {'X': X, 'Y': Y}
self.prim_output = {
'Z': self.layer_help.create_variable_for_type_inference(
dtype=X.dtype
)
}
self.prim_attrs = {
'axis': [1],
'starts': [0],
'ends': [10],
'strides': [2],
}
# Set JVP
X_DOT = paddle.static.data(name='X_DOT', shape=[3, 20], dtype='float64')
self.jvp_args = (X_DOT, None)
self.jvp_out_shape_map = {0: self.prim_output['Z']}
# Set transpose
check_dot = lambda v: v is X
Z_BAR = paddle.static.data(name='Z_BAR', shape=[3, 20], dtype='float64')
self.transpose_args = (check_dot, Z_BAR)
self.transpose_out_shape_map = {0: X}
self.all_ops = [
# prim op:
'slice_assign_p',
# jvp op:
'slice_assign_p',
"fill_constant_p",
# transpose op:
'slice_assign_p',
'fill_constant_p',
]
class TestGatherPJVPAndTranspose(TestAddPJVPAndTranspose): class TestGatherPJVPAndTranspose(TestAddPJVPAndTranspose):
def init_data(self): def init_data(self):
# Set prim op # Set prim op
......
...@@ -100,7 +100,7 @@ class TestFowardApi(unittest.TestCase): ...@@ -100,7 +100,7 @@ class TestFowardApi(unittest.TestCase):
actual = actual() actual = actual()
self.assertEqual(type(actual), type(expected)) self.assertEqual(type(actual), type(expected))
for i, j in zip(actual, expected): for i, j in zip(actual, expected):
np.testing.assert_allclose(i, j, atol=1e-3, rtol=1e-3) np.testing.assert_allclose(i, j, rtol=1e-6)
@utils.place(config.DEVICES) @utils.place(config.DEVICES)
......
...@@ -1026,17 +1026,53 @@ def slice_select_jvp(op, x_dot): ...@@ -1026,17 +1026,53 @@ def slice_select_jvp(op, x_dot):
@REGISTER_JVP('slice_assign_p') @REGISTER_JVP('slice_assign_p')
def slice_assign_jvp(op, x_dot, y_dot): def slice_assign_jvp(op, x_dot, y_dot):
if x_dot is None: x, y = op_position_inputs(op)
assert y_dot is None, 'y_dot must be None.' assert (
return None x_dot is not None or y_dot is not None
else: ), "x_dot and y_dot can't be None at the same time. "
assert y_dot is not None, 'y_dot should not be None.'
axis = op.attr('axis') axis = op.attr('axis')
starts = op.attr('starts') starts = op.attr('starts')
ends = op.attr('ends') ends = op.attr('ends')
strides = op.attr('strides') strides = op.attr('strides')
if x_dot is None:
return linear_jvp( return linear_jvp(
op, x_dot, y_dot, axis=axis, starts=starts, ends=ends, strides=strides op,
fill_const(value=0.0, shape=x.shape, dtype=x.dtype),
y_dot,
axis=axis,
starts=starts,
ends=ends,
strides=strides,
)
elif y_dot is None:
return linear_jvp(
op,
x_dot,
fill_const(value=0.0, shape=y.shape, dtype=y.dtype),
axis=axis,
starts=starts,
ends=ends,
strides=strides,
)
return add(
linear_jvp(
op,
fill_const(value=0.0, shape=x.shape, dtype=x.dtype),
y_dot,
axis=axis,
starts=starts,
ends=ends,
strides=strides,
),
linear_jvp(
op,
x_dot,
fill_const(value=0.0, shape=y.shape, dtype=y.dtype),
axis=axis,
starts=starts,
ends=ends,
strides=strides,
),
) )
...@@ -1311,8 +1347,8 @@ def slice_select_transpose(op, check_dot, y_bar): ...@@ -1311,8 +1347,8 @@ def slice_select_transpose(op, check_dot, y_bar):
@REGISTER_TRANSPOSE('slice_assign_p') @REGISTER_TRANSPOSE('slice_assign_p')
def slice_assign_transpose(op, check_dot, z_bar): def slice_assign_transpose(op, check_dot, z_bar):
x, y = op_position_inputs(op) x, y = op_position_inputs(op)
assert check_dot(x) and check_dot(y), ( assert check_dot(x) ^ check_dot(y), (
f'(check_dot(x) and check_dot(y)) must be True, ' f'(check_dot(x) ^ check_dot(y)) must be True, '
f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.' f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.'
) )
zeros = fill_const(value=0.0, shape=y.shape, dtype=y.dtype) zeros = fill_const(value=0.0, shape=y.shape, dtype=y.dtype)
...@@ -1320,13 +1356,21 @@ def slice_assign_transpose(op, check_dot, z_bar): ...@@ -1320,13 +1356,21 @@ def slice_assign_transpose(op, check_dot, z_bar):
starts = op.attr('starts') starts = op.attr('starts')
ends = op.attr('ends') ends = op.attr('ends')
strides = op.attr('strides') strides = op.attr('strides')
x_bar = slice_assign( if check_dot(x):
z_bar, zeros, axis=axis, starts=starts, ends=ends, strides=strides return (
slice_assign(
z_bar,
zeros,
axis=axis,
starts=starts,
ends=ends,
strides=strides,
),
None,
) )
y_bar = slice_select( return None, slice_select(
z_bar, axis=axis, starts=starts, ends=ends, strides=strides z_bar, axis=axis, starts=starts, ends=ends, strides=strides
) )
return x_bar, y_bar
@REGISTER_TRANSPOSE('gather_p') @REGISTER_TRANSPOSE('gather_p')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册