diff --git a/paddle/phi/core/infermeta_utils.cc b/paddle/phi/core/infermeta_utils.cc index 4644420d4227efcf9d4630fd15bc0f876c4d75e9..d21232ed82296cb48af5c72a32264e5c8fd76085 100644 --- a/paddle/phi/core/infermeta_utils.cc +++ b/paddle/phi/core/infermeta_utils.cc @@ -83,6 +83,16 @@ MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) { return outputs_.at(idx).get(); } +std::vector InferMetaContext::MutableOutputBetween(size_t start, + size_t end) { + std::vector result; + result.reserve(end - start); + for (size_t i = start; i < end; ++i) { + result.emplace_back(*outputs_.at(i)); + } + return result; +} + MetaFnFactory& MetaFnFactory::Instance() { static MetaFnFactory g_meta_fn_map; return g_meta_fn_map; diff --git a/paddle/phi/core/infermeta_utils.h b/paddle/phi/core/infermeta_utils.h index ad1f1b484885fc560221ab1099cdbc764c068d78..2b98ab22bcdbd43a1863c2d59d93e31c510368b8 100644 --- a/paddle/phi/core/infermeta_utils.h +++ b/paddle/phi/core/infermeta_utils.h @@ -52,6 +52,7 @@ class InferMetaContext { const MetaTensor& InputAt(size_t idx) const; std::vector InputsBetween(size_t start, size_t end) const; MetaTensor* MutableOutputAt(size_t idx); + std::vector MutableOutputBetween(size_t start, size_t end); template AttrType AttrAt(size_t idx) { @@ -186,7 +187,20 @@ struct InferMetaFnImpl { } }; - // TODO(chenweihang): support vector output later + template + struct InferMetaFnCallHelper*, Tail...> { + template + static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { + const std::pair range = ctx->OutputRangeAt(out_idx); + std::vector tmp = + ctx->MutableOutputBetween(range.first, range.second); + std::vector* arg = &tmp; + InferMetaFnCallHelper< + Tail...>::template Call(ctx, + pargs..., + arg); + } + }; template struct InferMetaFnCallHelper { diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index a1d2c147d288bf42eece32ec7a16d77edd0072eb..2f01174dff9b34c56f3c59d861ca25d0ffbbc4f5 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -79,6 +79,13 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out) { out->set_layout(x.layout()); } +void CopyToInferMeta(const MetaTensor& x, + Backend backend, + bool blocking, + MetaTensor* out) { + UnchangedInferMeta(x, out); +} + void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) { out->set_dims(x.dims()); out->set_dtype(dtype == DataType::UNDEFINED ? x.dtype() : dtype); @@ -497,3 +504,6 @@ void TraceInferMeta( } } // namespace phi + +PT_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta); +PT_REGISTER_INFER_META_FN(split, phi::SplitInferMeta); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 83c2075b0b988b36a5b25a31e8b80cef85580de3..560ce0d2d4c489fb4537b426c6ca45a1407a2853 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -41,6 +41,11 @@ void FlattenInferMeta(const MetaTensor& x, void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out); +void CopyToInferMeta(const MetaTensor& x, + Backend backend, + bool blocking, + MetaTensor* out); + void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out); void InferMetaFromVecValue(const MetaTensor& x, diff --git a/paddle/phi/tests/core/test_meta_fn_utils.cc b/paddle/phi/tests/core/test_meta_fn_utils.cc index 97d91f952f9addc127092d2f85e4c2e2f6e9cde4..f4288c2aa2f9418eeff489aa53fe685aa4a155ec 100644 --- a/paddle/phi/tests/core/test_meta_fn_utils.cc +++ b/paddle/phi/tests/core/test_meta_fn_utils.cc @@ -60,5 +60,62 @@ TEST(MetaFnFactory, InferMetaFnExists) { EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]); } +TEST(MetaFnFactory, CopyInferMetaFn) { + phi::DenseTensor dense_x; + dense_x.Resize({3, 4}); + + phi::MetaTensor meta_x(&dense_x); + phi::DenseTensor dense_out1; + phi::MetaTensor meta_out(&dense_out1); + phi::UnchangedInferMeta(meta_x, &meta_out); + + auto shared_meat_x = std::make_shared(&dense_x); + phi::DenseTensor dense_out2; + auto shared_meta_out = std::make_shared(&dense_out2); + + phi::InferMetaContext ctx; + ctx.EmplaceBackInput(shared_meat_x); + ctx.EmplaceBackAttr(Backend::CPU); + ctx.EmplaceBackAttr(false); + ctx.EmplaceBackOutput(shared_meta_out); + ctx.SetMetaConfig(/*is_runtime=*/true); + phi::MetaFnFactory::Instance().Get("copy_to")(&ctx); + + EXPECT_EQ(dense_out1.dims().size(), dense_out2.dims().size()); + EXPECT_EQ(dense_out1.dims()[0], dense_out2.dims()[0]); + EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]); +} + +TEST(MetaFnFactory, SplitInferMetaFn) { + phi::DenseTensor dense_x; + dense_x.Resize({4, 10}); + phi::MetaTensor meta_x(&dense_x); + auto shared_meat_x = std::make_shared(&dense_x); + + phi::DenseTensor dense_out1; + phi::DenseTensor dense_out2; + paddle::SmallVector> out; + out.push_back(std::make_shared(&dense_out1)); + out.push_back(std::make_shared(&dense_out2)); + + phi::InferMetaContext ctx; + ctx.EmplaceBackInput(shared_meat_x); + ScalarArray num_or_sections{2, 2}; + Scalar axis{0}; + ctx.EmplaceBackAttr(num_or_sections); + ctx.EmplaceBackAttr(axis); + ctx.EmplaceBackOutputs(out); + ctx.SetMetaConfig(/*is_runtime=*/true); + phi::MetaFnFactory::Instance().Get("split")(&ctx); + + ASSERT_EQ(dense_out1.dims().size(), 2); + ASSERT_EQ(dense_out1.dims()[0], 2); + ASSERT_EQ(dense_out1.dims()[1], 10); + + ASSERT_EQ(dense_out2.dims().size(), 2); + ASSERT_EQ(dense_out2.dims()[0], 2); + ASSERT_EQ(dense_out2.dims()[1], 10); +} + } // namespace tests } // namespace phi