From 59e8382d2197ce036ce65885fcaa146ab5b21652 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Tue, 19 Apr 2022 11:37:16 +0800 Subject: [PATCH] [Phi]Fix expand_sig infershape BUG under static graph mode (#41936) * [Phi]Fix expand_sig infershape BUG under static graph mode * [Phi]Fix expand_sig infershape BUG under static graph mode * [Phi]Fix unittest * [Phi]Fix unittest --- paddle/phi/ops/compat/expand_sig.cc | 11 +++++++++++ .../fluid/tests/unittests/test_expand_v2_op.py | 12 ++++++++++++ 2 files changed, 23 insertions(+) diff --git a/paddle/phi/ops/compat/expand_sig.cc b/paddle/phi/ops/compat/expand_sig.cc index c3df1595a2..b0f4ff79b4 100644 --- a/paddle/phi/ops/compat/expand_sig.cc +++ b/paddle/phi/ops/compat/expand_sig.cc @@ -17,6 +17,11 @@ namespace phi { KernelSignature ExpandOpArgumentMapping(const ArgumentMappingContext& ctx) { + const auto& shape = paddle::any_cast>(ctx.Attr("shape")); + // Infer output shape by Attr("shape") in CompileTime if it is specified. + if (!ctx.IsRuntime() && !shape.empty()) { + return KernelSignature("expand", {"X"}, {"shape"}, {"Out"}); + } if (ctx.HasInput("Shape")) { return KernelSignature("expand", {"X"}, {"Shape"}, {"Out"}); } else if (ctx.InputSize("expand_shapes_tensor") > 0) { @@ -27,6 +32,12 @@ KernelSignature ExpandOpArgumentMapping(const ArgumentMappingContext& ctx) { } KernelSignature ExpandGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + const auto& shape = paddle::any_cast>(ctx.Attr("shape")); + // Infer output shape by Attr("shape") in CompileTime if it is specified. + if (!ctx.IsRuntime() && !shape.empty()) { + return KernelSignature( + "expand_grad", {"X", "Out@GRAD"}, {"shape"}, {"X@GRAD"}); + } if (ctx.HasInput("Shape")) { return KernelSignature( "expand_grad", {"X", "Out@GRAD"}, {"Shape"}, {"X@GRAD"}); diff --git a/python/paddle/fluid/tests/unittests/test_expand_v2_op.py b/python/paddle/fluid/tests/unittests/test_expand_v2_op.py index fd46b41c5f..592a635ddc 100644 --- a/python/paddle/fluid/tests/unittests/test_expand_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_expand_v2_op.py @@ -231,6 +231,18 @@ class TestExpandV2API(unittest.TestCase): assert np.array_equal(res_3, np.tile(input, (1, 1))) +class TestExpandInferShape(unittest.TestCase): + def test_shape_with_var(self): + with program_guard(Program(), Program()): + x = paddle.static.data(shape=[-1, 1, 3], name='x') + fake_var = paddle.randn([2, 3]) + target_shape = [ + -1, paddle.shape(fake_var)[0], paddle.shape(fake_var)[1] + ] + out = paddle.expand(x, shape=target_shape) + self.assertListEqual(list(out.shape), [-1, -1, -1]) + + if __name__ == "__main__": paddle.enable_static() unittest.main() -- GitLab