未验证 提交 e16ab42b 编写于 作者: Z zyfncg 提交者: GitHub

[Pten] Add copy_to wrapped infermeta (#39703)

* add copy_to wrapped infermeta

* test=allcases

* test=allcases

* test=allcases
上级 dcfe1986
...@@ -83,6 +83,16 @@ MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) { ...@@ -83,6 +83,16 @@ MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) {
return outputs_.at(idx).get(); return outputs_.at(idx).get();
} }
std::vector<MetaTensor> InferMetaContext::MutableOutputBetween(size_t start,
size_t end) {
std::vector<MetaTensor> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
result.emplace_back(*outputs_.at(i));
}
return result;
}
MetaFnFactory& MetaFnFactory::Instance() { MetaFnFactory& MetaFnFactory::Instance() {
static MetaFnFactory g_meta_fn_map; static MetaFnFactory g_meta_fn_map;
return g_meta_fn_map; return g_meta_fn_map;
......
...@@ -52,6 +52,7 @@ class InferMetaContext { ...@@ -52,6 +52,7 @@ class InferMetaContext {
const MetaTensor& InputAt(size_t idx) const; const MetaTensor& InputAt(size_t idx) const;
std::vector<MetaTensor> InputsBetween(size_t start, size_t end) const; std::vector<MetaTensor> InputsBetween(size_t start, size_t end) const;
MetaTensor* MutableOutputAt(size_t idx); MetaTensor* MutableOutputAt(size_t idx);
std::vector<MetaTensor> MutableOutputBetween(size_t start, size_t end);
template <typename AttrType> template <typename AttrType>
AttrType AttrAt(size_t idx) { AttrType AttrAt(size_t idx) {
...@@ -186,7 +187,20 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> { ...@@ -186,7 +187,20 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
} }
}; };
// TODO(chenweihang): support vector<MetaTensor> output later template <typename... Tail>
struct InferMetaFnCallHelper<std::vector<MetaTensor>*, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) {
const std::pair<int, int> range = ctx->OutputRangeAt(out_idx);
std::vector<MetaTensor> tmp =
ctx->MutableOutputBetween(range.first, range.second);
std::vector<MetaTensor>* arg = &tmp;
InferMetaFnCallHelper<
Tail...>::template Call<in_idx, attr_idx, out_idx + 1>(ctx,
pargs...,
arg);
}
};
template <typename... Tail> template <typename... Tail>
struct InferMetaFnCallHelper<MetaConfig, Tail...> { struct InferMetaFnCallHelper<MetaConfig, Tail...> {
......
...@@ -79,6 +79,13 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out) { ...@@ -79,6 +79,13 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out) {
out->set_layout(x.layout()); out->set_layout(x.layout());
} }
void CopyToInferMeta(const MetaTensor& x,
Backend backend,
bool blocking,
MetaTensor* out) {
UnchangedInferMeta(x, out);
}
void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) { void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) {
out->set_dims(x.dims()); out->set_dims(x.dims());
out->set_dtype(dtype == DataType::UNDEFINED ? x.dtype() : dtype); out->set_dtype(dtype == DataType::UNDEFINED ? x.dtype() : dtype);
...@@ -497,3 +504,6 @@ void TraceInferMeta( ...@@ -497,3 +504,6 @@ void TraceInferMeta(
} }
} // namespace phi } // namespace phi
PT_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta);
PT_REGISTER_INFER_META_FN(split, phi::SplitInferMeta);
...@@ -41,6 +41,11 @@ void FlattenInferMeta(const MetaTensor& x, ...@@ -41,6 +41,11 @@ void FlattenInferMeta(const MetaTensor& x,
void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out); void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out);
void CopyToInferMeta(const MetaTensor& x,
Backend backend,
bool blocking,
MetaTensor* out);
void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out); void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out);
void InferMetaFromVecValue(const MetaTensor& x, void InferMetaFromVecValue(const MetaTensor& x,
......
...@@ -60,5 +60,62 @@ TEST(MetaFnFactory, InferMetaFnExists) { ...@@ -60,5 +60,62 @@ TEST(MetaFnFactory, InferMetaFnExists) {
EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]); EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]);
} }
TEST(MetaFnFactory, CopyInferMetaFn) {
phi::DenseTensor dense_x;
dense_x.Resize({3, 4});
phi::MetaTensor meta_x(&dense_x);
phi::DenseTensor dense_out1;
phi::MetaTensor meta_out(&dense_out1);
phi::UnchangedInferMeta(meta_x, &meta_out);
auto shared_meat_x = std::make_shared<phi::MetaTensor>(&dense_x);
phi::DenseTensor dense_out2;
auto shared_meta_out = std::make_shared<phi::MetaTensor>(&dense_out2);
phi::InferMetaContext ctx;
ctx.EmplaceBackInput(shared_meat_x);
ctx.EmplaceBackAttr(Backend::CPU);
ctx.EmplaceBackAttr(false);
ctx.EmplaceBackOutput(shared_meta_out);
ctx.SetMetaConfig(/*is_runtime=*/true);
phi::MetaFnFactory::Instance().Get("copy_to")(&ctx);
EXPECT_EQ(dense_out1.dims().size(), dense_out2.dims().size());
EXPECT_EQ(dense_out1.dims()[0], dense_out2.dims()[0]);
EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]);
}
TEST(MetaFnFactory, SplitInferMetaFn) {
phi::DenseTensor dense_x;
dense_x.Resize({4, 10});
phi::MetaTensor meta_x(&dense_x);
auto shared_meat_x = std::make_shared<phi::MetaTensor>(&dense_x);
phi::DenseTensor dense_out1;
phi::DenseTensor dense_out2;
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> out;
out.push_back(std::make_shared<phi::MetaTensor>(&dense_out1));
out.push_back(std::make_shared<phi::MetaTensor>(&dense_out2));
phi::InferMetaContext ctx;
ctx.EmplaceBackInput(shared_meat_x);
ScalarArray num_or_sections{2, 2};
Scalar axis{0};
ctx.EmplaceBackAttr(num_or_sections);
ctx.EmplaceBackAttr(axis);
ctx.EmplaceBackOutputs(out);
ctx.SetMetaConfig(/*is_runtime=*/true);
phi::MetaFnFactory::Instance().Get("split")(&ctx);
ASSERT_EQ(dense_out1.dims().size(), 2);
ASSERT_EQ(dense_out1.dims()[0], 2);
ASSERT_EQ(dense_out1.dims()[1], 10);
ASSERT_EQ(dense_out2.dims().size(), 2);
ASSERT_EQ(dense_out2.dims()[0], 2);
ASSERT_EQ(dense_out2.dims()[1], 10);
}
} // namespace tests } // namespace tests
} // namespace phi } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册