diff --git a/paddle/pten/core/infermeta_utils.cc b/paddle/pten/core/infermeta_utils.cc index 9f0037d18edf6a3256bae791459cb7b4bd3ab53e..3edd1eb87267457cf9fb2132119839b76ae62674 100644 --- a/paddle/pten/core/infermeta_utils.cc +++ b/paddle/pten/core/infermeta_utils.cc @@ -70,4 +70,9 @@ MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) { return outputs_.at(idx).get(); } +MetaFunctionMap& MetaFunctionMap::Instance() { + static MetaFunctionMap g_meta_fn_map; + return g_meta_fn_map; +} + } // namespace pten diff --git a/paddle/pten/core/infermeta_utils.h b/paddle/pten/core/infermeta_utils.h index bfc9d29e63709f7ad6eff498953027003c677edf..47f55f85ac2a3d2bf55d92ad57a71d98a837b433 100644 --- a/paddle/pten/core/infermeta_utils.h +++ b/paddle/pten/core/infermeta_utils.h @@ -17,7 +17,11 @@ limitations under the License. */ #include #include +#include "paddle/pten/core/enforce.h" +#include "paddle/pten/core/kernel_def.h" +#include "paddle/pten/core/macros.h" #include "paddle/pten/core/meta_tensor.h" +#include "paddle/utils/flat_hash_map.h" #include "paddle/utils/small_vector.h" namespace pten { @@ -59,7 +63,7 @@ class InferMetaContext { try { return paddle::any_cast(attrs_.at(idx)); } catch (paddle::bad_any_cast&) { - PADDLE_THROW(paddle::platform::errors::InvalidArgument( + PADDLE_THROW(pten::errors::InvalidArgument( "Attribute cast error in InferMeta Context.")); } } @@ -167,4 +171,73 @@ struct InferMetaFnImpl { }; }; +class MetaFunctionMap { + public: + static MetaFunctionMap& Instance(); + + bool Contains(const std::string& kernel_name_prefix) const { + return meta_fn_map_.count(kernel_name_prefix) > 0; + } + + void Insert(std::string kernel_name_prefix, InferMetaFn infer_meta_fn) { + PADDLE_ENFORCE_NE( + Contains(kernel_name_prefix), + true, + pten::errors::AlreadyExists( + "`%s`'s Series Kernel's InferMetaFn has been registered.", + kernel_name_prefix)); + meta_fn_map_.insert( + {std::move(kernel_name_prefix), std::move(infer_meta_fn)}); + } + + const InferMetaFn& Get(const std::string& kernel_name_prefix) const { + auto it = meta_fn_map_.find(kernel_name_prefix); + PADDLE_ENFORCE_NE( + it, + meta_fn_map_.end(), + pten::errors::NotFound( + "`%s`'s Series Kernel's InferMetaFn is not registered.", + kernel_name_prefix)); + return it->second; + } + + private: + MetaFunctionMap() = default; + + /** + * [ Why use kernel name prefix? ] + * + * one op -> a matrix of kernels + * + * such as, scale op, it may correspond to the following kernels: + * + * - scale, scale_sr, scale_dnnl + * - scale_raw, scale_raw_sr, scale_raw_dnnl + * + * All the kernels in each row correspond to the same infershape function, + * the number of kernel arguments in the same row is the same, and only + * the tensor types in the arguments are different. + */ + paddle::flat_hash_map meta_fn_map_; + + DISABLE_COPY_AND_ASSIGN(MetaFunctionMap); +}; + +struct InferMetaFnRegistrar { + InferMetaFnRegistrar(const char* kernel_name_prefix, + InferMetaFn infer_meta_fn) { + MetaFunctionMap::Instance().Insert(kernel_name_prefix, + std::move(infer_meta_fn)); + } +}; + +#define PT_REGISTER_INFER_META_FN(kernel_name_prefix, variadic_infer_meta_fn) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + pt_register_infer_meta_fn_ns_check_##kernel_name_prefix, \ + "PT_REGISTER_INFER_META_FN must be called in global namespace."); \ + static const ::pten::InferMetaFnRegistrar \ + __registrar_arg_map_fn_for_##kernel_name_prefix( \ + #kernel_name_prefix, PT_INFER_META(variadic_infer_meta_fn)); \ + int TouchInferMetaFnSymbol_##op_type() { return 0; } + } // namespace pten diff --git a/paddle/pten/infermeta/unary.cc b/paddle/pten/infermeta/unary.cc index fec50d528dfc42f357c006ef895549465a02f3e7..3f6b559f5604fa72758f609c14a2a2fda3e8ac88 100644 --- a/paddle/pten/infermeta/unary.cc +++ b/paddle/pten/infermeta/unary.cc @@ -305,3 +305,5 @@ DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta, } } // namespace pten + +PT_REGISTER_INFER_META_FN(sign, pten::UnchangedInferMetaNew); diff --git a/paddle/pten/tests/core/CMakeLists.txt b/paddle/pten/tests/core/CMakeLists.txt index 60a0ca285412fe01e0b740c89fdd69f4e16ad3df..1404b9921f3dab2f3981cfdf7fcf5324dae1de69 100644 --- a/paddle/pten/tests/core/CMakeLists.txt +++ b/paddle/pten/tests/core/CMakeLists.txt @@ -7,6 +7,7 @@ cc_test(test_sparse_coo_tensor SRCS test_sparse_coo_tensor.cc DEPS dense_tensor cc_test(test_sparse_csr_tensor SRCS test_sparse_csr_tensor.cc DEPS dense_tensor sparse_csr_tensor) cc_test(test_op_utils SRCS test_op_utils.cc DEPS op_compat_infos) cc_test(test_pten_device_context SRCS test_device_context.cc DEPS pten_context cpu_context) +cc_test(test_meta_fn_utils SRCS test_meta_fn_utils.cc DEPS dense_tensor infermeta infermeta_utils) cc_test(test_ddim SRCS test_ddim.cc DEPS ddim) if(WITH_GPU) diff --git a/paddle/pten/tests/core/test_meta_fn_utils.cc b/paddle/pten/tests/core/test_meta_fn_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..e25fdd3a204dce31f730182928552e760958e181 --- /dev/null +++ b/paddle/pten/tests/core/test_meta_fn_utils.cc @@ -0,0 +1,49 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "gtest/gtest.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/infermeta_utils.h" +#include "paddle/pten/infermeta/unary.h" + +namespace pten { +namespace tests { + +TEST(MetaFunctionMap, InferMetaFnExists) { + pten::DenseTensor dense_x; + dense_x.Resize(pten::framework::make_ddim({3, 4})); + + pten::MetaTensor meta_x(&dense_x); + pten::DenseTensor dense_out1; + pten::MetaTensor meta_out(&dense_out1); + pten::UnchangedInferMetaNew(/*is_runtime=*/true, meta_x, &meta_out); + + auto shared_meat_x = std::make_shared(&dense_x); + pten::DenseTensor dense_out2; + auto shared_meta_out = std::make_shared(&dense_out2); + pten::InferMetaContext ctx; + ctx.EmplaceBackInput(shared_meat_x); + ctx.EmplaceBackOutput(shared_meta_out); + ctx.SetMetaConfig(/*is_runtime=*/true); + pten::MetaFunctionMap::Instance().Get("sign")(&ctx); + + EXPECT_EQ(dense_out1.dims().size(), dense_out2.dims().size()); + EXPECT_EQ(dense_out1.dims()[0], dense_out2.dims()[0]); + EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]); +} + +} // namespace tests +} // namespace pten