From 4e09e402732c02af6ecccb18bd59ddc72e41706c Mon Sep 17 00:00:00 2001 From: Charles-hit <56987902+Charles-hit@users.noreply.github.com> Date: Fri, 16 Sep 2022 11:04:30 +0800 Subject: [PATCH] =?UTF-8?q?=EF=BC=88cherry-pick=EF=BC=89Fix=20split=20infe?= =?UTF-8?q?rshape=20in=20static=20mode=20and=20add=20convert=20rules=20for?= =?UTF-8?q?=20fill=5Fany=5Flike=20op=20(#46079)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix split bug in static mode (#45906) * fix split bug in static mode * modify code style * modify code style * add unit test for split * add convert rules for fill_any_like op in paddle science (#45985) * add convert rules for fill_any_like op in paddle science * add unit test for fill_any_like op in paddle science * modify fill_any_like convert rule * modify fill_any_like convert rule dtype --- paddle/phi/infermeta/unary.cc | 22 ++++++---- .../unittests/autograd/test_orig2prim.py | 41 +++++++++++++++++++ .../fluid/tests/unittests/test_split_op.py | 15 +++++++ python/paddle/incubate/autograd/primrules.py | 13 ++++++ 4 files changed, 84 insertions(+), 7 deletions(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 006f77132f0..39db2579ecb 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -3205,10 +3205,14 @@ void SplitInferMeta(const MetaTensor& x, // fill out dims with -1 if ((sections.FromTensor() && !config.is_runtime) || axis_value == -1 || (axis_value >= 0 && x.dims().at(axis_value) <= 0)) { - std::vector out_dims( - sections_data.size(), - phi::make_ddim(std::vector(x.dims().size(), -1))); - + std::vector out_dims; + if ((sections.FromTensor() && !config.is_runtime) || axis_value == -1) { + out_dims = std::vector( + sections_data.size(), + phi::make_ddim(std::vector(x.dims().size(), -1))); + } else { + out_dims = std::vector(sections_data.size(), x.dims()); + } for (size_t i = 0; i < sections_data.size(); ++i) { if (axis_value != 0) { // Only pass LoD when not spliting along the first dim. @@ -3293,9 +3297,13 @@ void SplitWithNumInferMeta(const MetaTensor& x, int axis_value = GetSplitAxisValue(x, axis, config); // fill out dims with -1 if (axis_value == -1 || (axis_value >= 0 && x.dims().at(axis_value) <= 0)) { - std::vector out_dims( - num, phi::make_ddim(std::vector(x.dims().size(), -1))); - + std::vector out_dims; + if (axis_value == -1) { + out_dims = std::vector( + num, phi::make_ddim(std::vector(x.dims().size(), -1))); + } else { + out_dims = std::vector(num, x.dims()); + } for (int i = 0; i < num; ++i) { if (axis_value != 0) { // Only pass LoD when not spliting along the first dim. diff --git a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py index c9f1aa6c41a..5693520ef0a 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_orig2prim.py @@ -18,6 +18,7 @@ import paddle from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layers.utils import flatten from paddle.incubate.autograd.primrules import _orig2prim, _prim2orig, _jvp, _transpose +import paddle.fluid.core as core paddle.enable_static() @@ -343,6 +344,46 @@ class TestFillZerosLikeOrig2Prim(TestElementWiseAddOrig2Prim): self.out_map = {0: self.output['Out']} +class TestFillAnyLikeOrig2Prim(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'fill_any_like' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + + self.input = { + 'X': X, + } + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {} + + self.orig2prim_args = (X, ) + self.all_ops = ['fill_any_like', 'fill_constant_p'] + self.out_map = {0: self.output['Out']} + + +class TestFillAnyLikeOrig2Prim2(TestElementWiseAddOrig2Prim): + + def init_data(self): + self.op_type = 'fill_any_like' + X = paddle.static.data(name='X', shape=[5, 6], dtype='int64') + + self.input = { + 'X': X, + } + self.output = { + 'Out': + self.layer_help.create_variable_for_type_inference(dtype=X.dtype) + } + self.attrs = {'dtype': paddle.float32, 'value': 5} + + self.orig2prim_args = (X, ) + self.all_ops = ['fill_any_like', 'fill_constant_p'] + self.out_map = {0: self.output['Out']} + + class TestSumOrig2Prim(TestElementWiseAddOrig2Prim): def init_data(self): diff --git a/python/paddle/fluid/tests/unittests/test_split_op.py b/python/paddle/fluid/tests/unittests/test_split_op.py index c31169feedb..37ea0d429ca 100644 --- a/python/paddle/fluid/tests/unittests/test_split_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_op.py @@ -441,6 +441,21 @@ class API_TestSplit5(unittest.TestCase): np.testing.assert_allclose(ex_out, re, rtol=1e-05) +class API_TestSplit6(unittest.TestCase): + + def test_out(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.layers.data('data', shape=[-1, 10], dtype='float64') + x0, x1 = paddle.split(data, num_or_sections=[1, 1], axis=0) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + input1 = np.random.random([2, 10]).astype('float64') + r0, r1 = exe.run(feed={"data": input1}, fetch_list=[x0, x1]) + ex_x0, ex_x1 = np.split(input1, (1, ), axis=0) + np.testing.assert_allclose(ex_x0, r0, rtol=1e-05) + np.testing.assert_allclose(ex_x1, r1, rtol=1e-05) + + class API_TestDygraphFluidSplit(unittest.TestCase): def test_out1(self): diff --git a/python/paddle/incubate/autograd/primrules.py b/python/paddle/incubate/autograd/primrules.py index 9e14c863330..3fe40da787d 100644 --- a/python/paddle/incubate/autograd/primrules.py +++ b/python/paddle/incubate/autograd/primrules.py @@ -26,6 +26,8 @@ from .primreg import (REGISTER_JVP, REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG, lookup_orig2prim, lookup_prim2orig, lookup_transpose, op_position_inputs, op_position_output) from .utils import INT_DTYPE_2_STRING, get_input_var_list, get_output_var_list +from paddle.fluid.data_feeder import convert_dtype +from paddle.fluid.framework import convert_np_dtype_to_dtype_ def _orig2prim(op, *args): @@ -63,6 +65,7 @@ elementwise_sub elementwise_mul tanh fill_zeros_like +fill_any_like sum index_select scale @@ -187,6 +190,16 @@ def fill_zeros_like_orig2prim(op, x): return fill_const(value=0.0, shape=x.shape, dtype=x.dtype) +@REGISTER_ORIG2PRIM('fill_any_like') +def fill_any_like_orig2prim(op, x): + if op.attr('dtype') == -1: + return fill_const(value=op.attr('value'), shape=x.shape, dtype=x.dtype) + return fill_const(value=op.attr('value'), + shape=x.shape, + dtype=convert_np_dtype_to_dtype_( + convert_dtype(INT_DTYPE_2_STRING[op.attr('dtype')]))) + + @REGISTER_ORIG2PRIM('sum') def sum_orig2prim(op, xs): x0 = xs[0] -- GitLab