未验证 提交 420d186a 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Fix infermeta bug for vector input and output (#45810)

* fix infermeta bug for vector input and output

* add unittest
上级 ba653e7b
......@@ -459,6 +459,9 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
infer_meta_context.EmplaceBackInputs(std::move(inputs));
}
} else {
// Note: Because the input of InferMetaFn is const MetaTensor&,
// so when we prepare input MetaTensor by InferMetaContext->InputAt(),
// we need to return a const reference of empty MetaTensor
infer_meta_context.EmplaceBackInput(
std::move(CompatMetaTensor(ctx->IsRuntime())));
}
......
......@@ -68,8 +68,13 @@ const MetaTensor& InferMetaContext::InputAt(size_t idx) const {
std::vector<const MetaTensor*> InferMetaContext::InputsBetween(
size_t start, size_t end) const {
std::vector<const MetaTensor*> result;
result.reserve(end - start);
// If vector only contains one input that is not initialized,
// we should return a empty vector
if (end - start == 1 && !inputs_.at(start).initialized()) {
return result;
}
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
auto& in = inputs_.at(i);
result.emplace_back(in.initialized() ? &in : nullptr);
......@@ -104,6 +109,12 @@ MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) {
std::vector<MetaTensor*> InferMetaContext::MutableOutputBetween(size_t start,
size_t end) {
std::vector<MetaTensor*> result;
// If vector only contains one output that is not initialized,
// we should return a empty vector
if (end - start == 1 && !outputs_.at(start).initialized()) {
return result;
}
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
auto& out = outputs_.at(i);
......
......@@ -91,5 +91,19 @@ TEST(MetaFnFactory, SplitInferMetaFn) {
ASSERT_EQ(dense_out2.dims()[1], 10);
}
void TestEmptyVectorInputInferMeta(const std::vector<const MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs) {
ASSERT_EQ(inputs.size(), 0UL);
ASSERT_EQ(outputs.size(), 0UL);
}
TEST(MetaFnFactory, EmptyVectorInputInferMetaFn) {
phi::InferMetaContext ctx;
ctx.EmplaceBackInput(MetaTensor());
ctx.EmplaceBackOutput(MetaTensor());
PD_INFER_META(TestEmptyVectorInputInferMeta)(&ctx);
}
} // namespace tests
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册