diff --git a/paddle/phi/ops/compat/expand_sig.cc b/paddle/phi/ops/compat/expand_sig.cc index c3df1595a210824c7f39a9282fbd5a39922e0b0b..b0f4ff79b4c5c268b4d2aac840bcab9529a2e8e1 100644 --- a/paddle/phi/ops/compat/expand_sig.cc +++ b/paddle/phi/ops/compat/expand_sig.cc @@ -17,6 +17,11 @@ namespace phi { KernelSignature ExpandOpArgumentMapping(const ArgumentMappingContext& ctx) { + const auto& shape = paddle::any_cast>(ctx.Attr("shape")); + // Infer output shape by Attr("shape") in CompileTime if it is specified. + if (!ctx.IsRuntime() && !shape.empty()) { + return KernelSignature("expand", {"X"}, {"shape"}, {"Out"}); + } if (ctx.HasInput("Shape")) { return KernelSignature("expand", {"X"}, {"Shape"}, {"Out"}); } else if (ctx.InputSize("expand_shapes_tensor") > 0) { @@ -27,6 +32,12 @@ KernelSignature ExpandOpArgumentMapping(const ArgumentMappingContext& ctx) { } KernelSignature ExpandGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + const auto& shape = paddle::any_cast>(ctx.Attr("shape")); + // Infer output shape by Attr("shape") in CompileTime if it is specified. + if (!ctx.IsRuntime() && !shape.empty()) { + return KernelSignature( + "expand_grad", {"X", "Out@GRAD"}, {"shape"}, {"X@GRAD"}); + } if (ctx.HasInput("Shape")) { return KernelSignature( "expand_grad", {"X", "Out@GRAD"}, {"Shape"}, {"X@GRAD"}); diff --git a/python/paddle/fluid/tests/unittests/test_expand_v2_op.py b/python/paddle/fluid/tests/unittests/test_expand_v2_op.py index fd46b41c5f07e2b1481ba657451bd8545fc8478b..592a635ddcccc587cba766e00525fc9c8f3c6639 100644 --- a/python/paddle/fluid/tests/unittests/test_expand_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_expand_v2_op.py @@ -231,6 +231,18 @@ class TestExpandV2API(unittest.TestCase): assert np.array_equal(res_3, np.tile(input, (1, 1))) +class TestExpandInferShape(unittest.TestCase): + def test_shape_with_var(self): + with program_guard(Program(), Program()): + x = paddle.static.data(shape=[-1, 1, 3], name='x') + fake_var = paddle.randn([2, 3]) + target_shape = [ + -1, paddle.shape(fake_var)[0], paddle.shape(fake_var)[1] + ] + out = paddle.expand(x, shape=target_shape) + self.assertListEqual(list(out.shape), [-1, -1, -1]) + + if __name__ == "__main__": paddle.enable_static() unittest.main()