From 93f0e594284c1c49199eee4884e73a0e1f8618d1 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 20 Apr 2022 11:27:34 +0800 Subject: [PATCH] [Cherry-Pick]Fix expand_sig infershape BUG under static graph mode and NeedTransformPlace behavior if set skip_transform in yaml (#41973) * [Phi]Fix expand_sig infershape BUG under static graph mode (#41936) * [Phi]Fix expand_sig infershape BUG under static graph mode * [Phi]Fix expand_sig infershape BUG under static graph mode * [Phi]Fix unittest * [Phi]Fix unittest * [Eager]Fix NeedTransformPlace behavior if set skip_transform in yaml (#41920) * [Eager]Fix NeedTransformPlace behavior if set skip_transform in yaml * add unittest for full_like * fix unittest --- paddle/phi/api/lib/data_transform.cc | 16 +++++++++++----- paddle/phi/ops/compat/expand_sig.cc | 11 +++++++++++ .../fluid/tests/unittests/test_expand_v2_op.py | 12 ++++++++++++ .../fluid/tests/unittests/test_full_like_op.py | 15 +++++++++++++++ python/paddle/utils/code_gen/api.yaml | 4 +++- 5 files changed, 52 insertions(+), 6 deletions(-) diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 82d2e741e9..4fd429fbd3 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -37,11 +37,17 @@ inline bool NeedTransformDataType(const DataType& input, inline bool NeedTransformPlace(const paddle::platform::Place& input, const Backend& target, const TransformFlag& transform_flag) { - bool ret = - input.GetType() == AllocationType::GPUPINNED || - (transform_flag.need_trans_backend() && target != Backend::ALL_BACKEND && - phi::TransToPhiBackend(input) != - (target != Backend::GPUDNN ? target : Backend::GPU)); + // NOTE(dev): The default value of TransformFlag is True, if it is set with + // False + // somewhere such as api.yaml or backward.yaml that means we should skip data + // transform. Because "stop_transform_" has highest priority. + if (!transform_flag.need_trans_backend()) { + return false; + } + bool ret = input.GetType() == AllocationType::GPUPINNED || + (target != Backend::ALL_BACKEND && + phi::TransToPhiBackend(input) != + (target != Backend::GPUDNN ? target : Backend::GPU)); return ret; } diff --git a/paddle/phi/ops/compat/expand_sig.cc b/paddle/phi/ops/compat/expand_sig.cc index 3b2e468267..9b0a1f5ab7 100644 --- a/paddle/phi/ops/compat/expand_sig.cc +++ b/paddle/phi/ops/compat/expand_sig.cc @@ -17,6 +17,11 @@ namespace phi { KernelSignature ExpandOpArgumentMapping(const ArgumentMappingContext& ctx) { + const auto& shape = paddle::any_cast>(ctx.Attr("shape")); + // Infer output shape by Attr("shape") in CompileTime if it is specified. + if (!ctx.IsRuntime() && !shape.empty()) { + return KernelSignature("expand", {"X"}, {"shape"}, {"Out"}); + } if (ctx.HasInput("Shape")) { return KernelSignature("expand", {"X"}, {"Shape"}, {"Out"}); } else if (ctx.InputSize("expand_shapes_tensor") > 0) { @@ -27,6 +32,12 @@ KernelSignature ExpandOpArgumentMapping(const ArgumentMappingContext& ctx) { } KernelSignature ExpandGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + const auto& shape = paddle::any_cast>(ctx.Attr("shape")); + // Infer output shape by Attr("shape") in CompileTime if it is specified. + if (!ctx.IsRuntime() && !shape.empty()) { + return KernelSignature( + "expand_grad", {"X", "Out@GRAD"}, {"shape"}, {"X@GRAD"}); + } if (ctx.HasInput("Shape")) { return KernelSignature("expand_grad", {"X", GradVarName("Out")}, diff --git a/python/paddle/fluid/tests/unittests/test_expand_v2_op.py b/python/paddle/fluid/tests/unittests/test_expand_v2_op.py index fd46b41c5f..592a635ddc 100644 --- a/python/paddle/fluid/tests/unittests/test_expand_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_expand_v2_op.py @@ -231,6 +231,18 @@ class TestExpandV2API(unittest.TestCase): assert np.array_equal(res_3, np.tile(input, (1, 1))) +class TestExpandInferShape(unittest.TestCase): + def test_shape_with_var(self): + with program_guard(Program(), Program()): + x = paddle.static.data(shape=[-1, 1, 3], name='x') + fake_var = paddle.randn([2, 3]) + target_shape = [ + -1, paddle.shape(fake_var)[0], paddle.shape(fake_var)[1] + ] + out = paddle.expand(x, shape=target_shape) + self.assertListEqual(list(out.shape), [-1, -1, -1]) + + if __name__ == "__main__": paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_full_like_op.py b/python/paddle/fluid/tests/unittests/test_full_like_op.py index 05a310a9c5..d3fea677a4 100644 --- a/python/paddle/fluid/tests/unittests/test_full_like_op.py +++ b/python/paddle/fluid/tests/unittests/test_full_like_op.py @@ -22,6 +22,7 @@ import unittest import numpy as np from op_test import OpTest from paddle.fluid.framework import convert_np_dtype_to_dtype_ +from paddle.fluid.framework import _test_eager_guard class TestFullOp(unittest.TestCase): @@ -133,5 +134,19 @@ class TestFullLikeOp3(TestFullLikeOp1): self.dtype = np.int64 +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFullLikeOp4(unittest.TestCase): + def test_skip_data_transform(self): + paddle.disable_static() + with _test_eager_guard(): + x = paddle.to_tensor( + [1., 2., 3., 4.], place=paddle.CUDAPinnedPlace()) + out = paddle.full_like(x, 1.) + self.assertTrue( + (out.numpy() == np.ones([4]).astype(np.float32)).all(), True) + paddle.enable_static() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 0808eb6c69..4779750b5b 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -448,7 +448,7 @@ - api : deformable_conv args : (Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) output : Tensor(out) - infer_meta : + infer_meta : func : DeformableConvInferMeta kernel : func : deformable_conv @@ -781,6 +781,8 @@ param : [x, value, dtype] data_type : dtype > x backend : place > x + data_transform : + skip_transform : x - api : gather args : (Tensor x, Tensor index, Scalar axis=0) -- GitLab