diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 9ea3f05d64f52a09ed8292e7046952265f18855f..f59bb2503a57093d9de7a00b4646ad6fa59743e1 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -459,6 +459,9 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, infer_meta_context.EmplaceBackInputs(std::move(inputs)); } } else { + // Note: Because the input of InferMetaFn is const MetaTensor&, + // so when we prepare input MetaTensor by InferMetaContext->InputAt(), + // we need to return a const reference of empty MetaTensor infer_meta_context.EmplaceBackInput( std::move(CompatMetaTensor(ctx->IsRuntime()))); } diff --git a/paddle/phi/core/infermeta_utils.cc b/paddle/phi/core/infermeta_utils.cc index 9d2b85435c7f3915155599f67734c67a80a32737..20fc773a2f133716a06c8f99af67c06fd99d1a75 100644 --- a/paddle/phi/core/infermeta_utils.cc +++ b/paddle/phi/core/infermeta_utils.cc @@ -68,8 +68,13 @@ const MetaTensor& InferMetaContext::InputAt(size_t idx) const { std::vector InferMetaContext::InputsBetween( size_t start, size_t end) const { std::vector result; - result.reserve(end - start); + // If vector only contains one input that is not initialized, + // we should return a empty vector + if (end - start == 1 && !inputs_.at(start).initialized()) { + return result; + } + result.reserve(end - start); for (size_t i = start; i < end; ++i) { auto& in = inputs_.at(i); result.emplace_back(in.initialized() ? &in : nullptr); @@ -104,6 +109,12 @@ MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) { std::vector InferMetaContext::MutableOutputBetween(size_t start, size_t end) { std::vector result; + // If vector only contains one output that is not initialized, + // we should return a empty vector + if (end - start == 1 && !outputs_.at(start).initialized()) { + return result; + } + result.reserve(end - start); for (size_t i = start; i < end; ++i) { auto& out = outputs_.at(i); diff --git a/paddle/phi/tests/core/test_meta_fn_utils.cc b/paddle/phi/tests/core/test_meta_fn_utils.cc index afdd3bc0d9ad043c6f04b6ec966da4cfa9b8ea8e..f3148d95a6afbe3d636a8c63fda22aedb4df96d1 100644 --- a/paddle/phi/tests/core/test_meta_fn_utils.cc +++ b/paddle/phi/tests/core/test_meta_fn_utils.cc @@ -91,5 +91,19 @@ TEST(MetaFnFactory, SplitInferMetaFn) { ASSERT_EQ(dense_out2.dims()[1], 10); } +void TestEmptyVectorInputInferMeta(const std::vector& inputs, + std::vector outputs) { + ASSERT_EQ(inputs.size(), 0UL); + ASSERT_EQ(outputs.size(), 0UL); +} + +TEST(MetaFnFactory, EmptyVectorInputInferMetaFn) { + phi::InferMetaContext ctx; + ctx.EmplaceBackInput(MetaTensor()); + ctx.EmplaceBackOutput(MetaTensor()); + + PD_INFER_META(TestEmptyVectorInputInferMeta)(&ctx); +} + } // namespace tests } // namespace phi