未验证 提交 f3f16126 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Add infermeta registry (#39204)

* add infermeta registry

* add infermeta registry

* add unittest

* polish details
上级 56410b4a
......@@ -70,4 +70,9 @@ MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) {
return outputs_.at(idx).get();
}
MetaFunctionMap& MetaFunctionMap::Instance() {
static MetaFunctionMap g_meta_fn_map;
return g_meta_fn_map;
}
} // namespace pten
......@@ -17,7 +17,11 @@ limitations under the License. */
#include <string>
#include <utility>
#include "paddle/pten/core/enforce.h"
#include "paddle/pten/core/kernel_def.h"
#include "paddle/pten/core/macros.h"
#include "paddle/pten/core/meta_tensor.h"
#include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/small_vector.h"
namespace pten {
......@@ -59,7 +63,7 @@ class InferMetaContext {
try {
return paddle::any_cast<AttrType>(attrs_.at(idx));
} catch (paddle::bad_any_cast&) {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
PADDLE_THROW(pten::errors::InvalidArgument(
"Attribute cast error in InferMeta Context."));
}
}
......@@ -167,4 +171,73 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
};
};
class MetaFunctionMap {
public:
static MetaFunctionMap& Instance();
bool Contains(const std::string& kernel_name_prefix) const {
return meta_fn_map_.count(kernel_name_prefix) > 0;
}
void Insert(std::string kernel_name_prefix, InferMetaFn infer_meta_fn) {
PADDLE_ENFORCE_NE(
Contains(kernel_name_prefix),
true,
pten::errors::AlreadyExists(
"`%s`'s Series Kernel's InferMetaFn has been registered.",
kernel_name_prefix));
meta_fn_map_.insert(
{std::move(kernel_name_prefix), std::move(infer_meta_fn)});
}
const InferMetaFn& Get(const std::string& kernel_name_prefix) const {
auto it = meta_fn_map_.find(kernel_name_prefix);
PADDLE_ENFORCE_NE(
it,
meta_fn_map_.end(),
pten::errors::NotFound(
"`%s`'s Series Kernel's InferMetaFn is not registered.",
kernel_name_prefix));
return it->second;
}
private:
MetaFunctionMap() = default;
/**
* [ Why use kernel name prefix? ]
*
* one op -> a matrix of kernels
*
* such as, scale op, it may correspond to the following kernels:
*
* - scale, scale_sr, scale_dnnl
* - scale_raw, scale_raw_sr, scale_raw_dnnl
*
* All the kernels in each row correspond to the same infershape function,
* the number of kernel arguments in the same row is the same, and only
* the tensor types in the arguments are different.
*/
paddle::flat_hash_map<std::string, InferMetaFn> meta_fn_map_;
DISABLE_COPY_AND_ASSIGN(MetaFunctionMap);
};
struct InferMetaFnRegistrar {
InferMetaFnRegistrar(const char* kernel_name_prefix,
InferMetaFn infer_meta_fn) {
MetaFunctionMap::Instance().Insert(kernel_name_prefix,
std::move(infer_meta_fn));
}
};
#define PT_REGISTER_INFER_META_FN(kernel_name_prefix, variadic_infer_meta_fn) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_infer_meta_fn_ns_check_##kernel_name_prefix, \
"PT_REGISTER_INFER_META_FN must be called in global namespace."); \
static const ::pten::InferMetaFnRegistrar \
__registrar_arg_map_fn_for_##kernel_name_prefix( \
#kernel_name_prefix, PT_INFER_META(variadic_infer_meta_fn)); \
int TouchInferMetaFnSymbol_##op_type() { return 0; }
} // namespace pten
......@@ -305,3 +305,5 @@ DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta,
}
} // namespace pten
PT_REGISTER_INFER_META_FN(sign, pten::UnchangedInferMetaNew);
......@@ -7,6 +7,7 @@ cc_test(test_sparse_coo_tensor SRCS test_sparse_coo_tensor.cc DEPS dense_tensor
cc_test(test_sparse_csr_tensor SRCS test_sparse_csr_tensor.cc DEPS dense_tensor sparse_csr_tensor)
cc_test(test_op_utils SRCS test_op_utils.cc DEPS op_compat_infos)
cc_test(test_pten_device_context SRCS test_device_context.cc DEPS pten_context cpu_context)
cc_test(test_meta_fn_utils SRCS test_meta_fn_utils.cc DEPS dense_tensor infermeta infermeta_utils)
cc_test(test_ddim SRCS test_ddim.cc DEPS ddim)
if(WITH_GPU)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "gtest/gtest.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/infermeta/unary.h"
namespace pten {
namespace tests {
TEST(MetaFunctionMap, InferMetaFnExists) {
pten::DenseTensor dense_x;
dense_x.Resize(pten::framework::make_ddim({3, 4}));
pten::MetaTensor meta_x(&dense_x);
pten::DenseTensor dense_out1;
pten::MetaTensor meta_out(&dense_out1);
pten::UnchangedInferMetaNew(/*is_runtime=*/true, meta_x, &meta_out);
auto shared_meat_x = std::make_shared<pten::MetaTensor>(&dense_x);
pten::DenseTensor dense_out2;
auto shared_meta_out = std::make_shared<pten::MetaTensor>(&dense_out2);
pten::InferMetaContext ctx;
ctx.EmplaceBackInput(shared_meat_x);
ctx.EmplaceBackOutput(shared_meta_out);
ctx.SetMetaConfig(/*is_runtime=*/true);
pten::MetaFunctionMap::Instance().Get("sign")(&ctx);
EXPECT_EQ(dense_out1.dims().size(), dense_out2.dims().size());
EXPECT_EQ(dense_out1.dims()[0], dense_out2.dims()[0]);
EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]);
}
} // namespace tests
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册