From e16ab42b9276679e86c3b3fa48e03e354b11e388 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 21 Feb 2022 10:07:38 +0800 Subject: [PATCH] [Pten] Add copy_to wrapped infermeta (#39703) * add copy_to wrapped infermeta * test=allcases * test=allcases * test=allcases --- paddle/phi/core/infermeta_utils.cc | 10 ++++ paddle/phi/core/infermeta_utils.h | 16 +++++- paddle/phi/infermeta/unary.cc | 10 ++++ paddle/phi/infermeta/unary.h | 5 ++ paddle/phi/tests/core/test_meta_fn_utils.cc | 57 +++++++++++++++++++++ 5 files changed, 97 insertions(+), 1 deletion(-) diff --git a/paddle/phi/core/infermeta_utils.cc b/paddle/phi/core/infermeta_utils.cc index 4644420d42..d21232ed82 100644 --- a/paddle/phi/core/infermeta_utils.cc +++ b/paddle/phi/core/infermeta_utils.cc @@ -83,6 +83,16 @@ MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) { return outputs_.at(idx).get(); } +std::vector InferMetaContext::MutableOutputBetween(size_t start, + size_t end) { + std::vector result; + result.reserve(end - start); + for (size_t i = start; i < end; ++i) { + result.emplace_back(*outputs_.at(i)); + } + return result; +} + MetaFnFactory& MetaFnFactory::Instance() { static MetaFnFactory g_meta_fn_map; return g_meta_fn_map; diff --git a/paddle/phi/core/infermeta_utils.h b/paddle/phi/core/infermeta_utils.h index ad1f1b4848..2b98ab22bc 100644 --- a/paddle/phi/core/infermeta_utils.h +++ b/paddle/phi/core/infermeta_utils.h @@ -52,6 +52,7 @@ class InferMetaContext { const MetaTensor& InputAt(size_t idx) const; std::vector InputsBetween(size_t start, size_t end) const; MetaTensor* MutableOutputAt(size_t idx); + std::vector MutableOutputBetween(size_t start, size_t end); template AttrType AttrAt(size_t idx) { @@ -186,7 +187,20 @@ struct InferMetaFnImpl { } }; - // TODO(chenweihang): support vector output later + template + struct InferMetaFnCallHelper*, Tail...> { + template + static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { + const std::pair range = ctx->OutputRangeAt(out_idx); + std::vector tmp = + ctx->MutableOutputBetween(range.first, range.second); + std::vector* arg = &tmp; + InferMetaFnCallHelper< + Tail...>::template Call(ctx, + pargs..., + arg); + } + }; template struct InferMetaFnCallHelper { diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index a1d2c147d2..2f01174dff 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -79,6 +79,13 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out) { out->set_layout(x.layout()); } +void CopyToInferMeta(const MetaTensor& x, + Backend backend, + bool blocking, + MetaTensor* out) { + UnchangedInferMeta(x, out); +} + void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) { out->set_dims(x.dims()); out->set_dtype(dtype == DataType::UNDEFINED ? x.dtype() : dtype); @@ -497,3 +504,6 @@ void TraceInferMeta( } } // namespace phi + +PT_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta); +PT_REGISTER_INFER_META_FN(split, phi::SplitInferMeta); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 83c2075b0b..560ce0d2d4 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -41,6 +41,11 @@ void FlattenInferMeta(const MetaTensor& x, void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out); +void CopyToInferMeta(const MetaTensor& x, + Backend backend, + bool blocking, + MetaTensor* out); + void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out); void InferMetaFromVecValue(const MetaTensor& x, diff --git a/paddle/phi/tests/core/test_meta_fn_utils.cc b/paddle/phi/tests/core/test_meta_fn_utils.cc index 97d91f952f..f4288c2aa2 100644 --- a/paddle/phi/tests/core/test_meta_fn_utils.cc +++ b/paddle/phi/tests/core/test_meta_fn_utils.cc @@ -60,5 +60,62 @@ TEST(MetaFnFactory, InferMetaFnExists) { EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]); } +TEST(MetaFnFactory, CopyInferMetaFn) { + phi::DenseTensor dense_x; + dense_x.Resize({3, 4}); + + phi::MetaTensor meta_x(&dense_x); + phi::DenseTensor dense_out1; + phi::MetaTensor meta_out(&dense_out1); + phi::UnchangedInferMeta(meta_x, &meta_out); + + auto shared_meat_x = std::make_shared(&dense_x); + phi::DenseTensor dense_out2; + auto shared_meta_out = std::make_shared(&dense_out2); + + phi::InferMetaContext ctx; + ctx.EmplaceBackInput(shared_meat_x); + ctx.EmplaceBackAttr(Backend::CPU); + ctx.EmplaceBackAttr(false); + ctx.EmplaceBackOutput(shared_meta_out); + ctx.SetMetaConfig(/*is_runtime=*/true); + phi::MetaFnFactory::Instance().Get("copy_to")(&ctx); + + EXPECT_EQ(dense_out1.dims().size(), dense_out2.dims().size()); + EXPECT_EQ(dense_out1.dims()[0], dense_out2.dims()[0]); + EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]); +} + +TEST(MetaFnFactory, SplitInferMetaFn) { + phi::DenseTensor dense_x; + dense_x.Resize({4, 10}); + phi::MetaTensor meta_x(&dense_x); + auto shared_meat_x = std::make_shared(&dense_x); + + phi::DenseTensor dense_out1; + phi::DenseTensor dense_out2; + paddle::SmallVector> out; + out.push_back(std::make_shared(&dense_out1)); + out.push_back(std::make_shared(&dense_out2)); + + phi::InferMetaContext ctx; + ctx.EmplaceBackInput(shared_meat_x); + ScalarArray num_or_sections{2, 2}; + Scalar axis{0}; + ctx.EmplaceBackAttr(num_or_sections); + ctx.EmplaceBackAttr(axis); + ctx.EmplaceBackOutputs(out); + ctx.SetMetaConfig(/*is_runtime=*/true); + phi::MetaFnFactory::Instance().Get("split")(&ctx); + + ASSERT_EQ(dense_out1.dims().size(), 2); + ASSERT_EQ(dense_out1.dims()[0], 2); + ASSERT_EQ(dense_out1.dims()[1], 10); + + ASSERT_EQ(dense_out2.dims().size(), 2); + ASSERT_EQ(dense_out2.dims()[0], 2); + ASSERT_EQ(dense_out2.dims()[1], 10); +} + } // namespace tests } // namespace phi -- GitLab