From 420d186a38f9c967b478203e9c825500183c3945 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 7 Sep 2022 17:47:00 +0800 Subject: [PATCH] [Phi] Fix infermeta bug for vector input and output (#45810) * fix infermeta bug for vector input and output * add unittest --- paddle/fluid/framework/infershape_utils.cc | 3 +++ paddle/phi/core/infermeta_utils.cc | 13 ++++++++++++- paddle/phi/tests/core/test_meta_fn_utils.cc | 14 ++++++++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 9ea3f05d64f..f59bb2503a5 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -459,6 +459,9 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, infer_meta_context.EmplaceBackInputs(std::move(inputs)); } } else { + // Note: Because the input of InferMetaFn is const MetaTensor&, + // so when we prepare input MetaTensor by InferMetaContext->InputAt(), + // we need to return a const reference of empty MetaTensor infer_meta_context.EmplaceBackInput( std::move(CompatMetaTensor(ctx->IsRuntime()))); } diff --git a/paddle/phi/core/infermeta_utils.cc b/paddle/phi/core/infermeta_utils.cc index 9d2b85435c7..20fc773a2f1 100644 --- a/paddle/phi/core/infermeta_utils.cc +++ b/paddle/phi/core/infermeta_utils.cc @@ -68,8 +68,13 @@ const MetaTensor& InferMetaContext::InputAt(size_t idx) const { std::vector InferMetaContext::InputsBetween( size_t start, size_t end) const { std::vector result; - result.reserve(end - start); + // If vector only contains one input that is not initialized, + // we should return a empty vector + if (end - start == 1 && !inputs_.at(start).initialized()) { + return result; + } + result.reserve(end - start); for (size_t i = start; i < end; ++i) { auto& in = inputs_.at(i); result.emplace_back(in.initialized() ? &in : nullptr); @@ -104,6 +109,12 @@ MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) { std::vector InferMetaContext::MutableOutputBetween(size_t start, size_t end) { std::vector result; + // If vector only contains one output that is not initialized, + // we should return a empty vector + if (end - start == 1 && !outputs_.at(start).initialized()) { + return result; + } + result.reserve(end - start); for (size_t i = start; i < end; ++i) { auto& out = outputs_.at(i); diff --git a/paddle/phi/tests/core/test_meta_fn_utils.cc b/paddle/phi/tests/core/test_meta_fn_utils.cc index afdd3bc0d9a..f3148d95a6a 100644 --- a/paddle/phi/tests/core/test_meta_fn_utils.cc +++ b/paddle/phi/tests/core/test_meta_fn_utils.cc @@ -91,5 +91,19 @@ TEST(MetaFnFactory, SplitInferMetaFn) { ASSERT_EQ(dense_out2.dims()[1], 10); } +void TestEmptyVectorInputInferMeta(const std::vector& inputs, + std::vector outputs) { + ASSERT_EQ(inputs.size(), 0UL); + ASSERT_EQ(outputs.size(), 0UL); +} + +TEST(MetaFnFactory, EmptyVectorInputInferMetaFn) { + phi::InferMetaContext ctx; + ctx.EmplaceBackInput(MetaTensor()); + ctx.EmplaceBackOutput(MetaTensor()); + + PD_INFER_META(TestEmptyVectorInputInferMeta)(&ctx); +} + } // namespace tests } // namespace phi -- GitLab