diff --git a/paddle/pten/core/kernel_registry.h b/paddle/pten/core/kernel_registry.h index 927c36e9e8f432f2d7f068c6f76142a2509faced..df95f50c98106b3bd6a2fc65b06d3bc96dbb547b 100644 --- a/paddle/pten/core/kernel_registry.h +++ b/paddle/pten/core/kernel_registry.h @@ -184,7 +184,7 @@ struct KernelRegistrar { KernelKey kernel_key(backend, layout, dtype); Kernel kernel(kernel_fn, variadic_kernel_fn); args_parse_fn(kernel_key, kernel.mutable_args_def()); - args_def_fn(&kernel); + args_def_fn(kernel_key, &kernel); KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel; } }; @@ -231,7 +231,7 @@ struct KernelRegistrar { kernel_name, backend, layout, meta_kernel_fn, ...) \ PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, __VA_ARGS__); \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ - ::pten::Kernel*); \ + const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel); \ PT_KERNEL_REGISTRAR_INIT( \ kernel_name, \ backend, \ @@ -240,7 +240,7 @@ struct KernelRegistrar { meta_kernel_fn, \ __VA_ARGS__); \ void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ - ::pten::Kernel* kernel) + const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel) #else /** * `template decltype(fn) fn` can work on gcc and clang, @@ -257,7 +257,7 @@ struct KernelRegistrar { #define _PT_REGISTER_2TA_KERNEL( \ kernel_name, backend, layout, meta_kernel_fn, ...) \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ - ::pten::Kernel*); \ + const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel); \ PT_EXPAND(PT_KERNEL_REGISTRAR_INIT( \ kernel_name, \ backend, \ @@ -266,7 +266,7 @@ struct KernelRegistrar { meta_kernel_fn, \ __VA_ARGS__)); \ void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ - ::pten::Kernel* kernel) + const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel) #endif #define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, ...) \ @@ -786,7 +786,7 @@ struct KernelRegistrar { kernel_name, backend, layout, kernel_fn, dtype) \ template decltype(kernel_fn) kernel_fn; \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ - ::pten::Kernel*); \ + const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel); \ static const ::pten::KernelRegistrar \ __reg_pt_kernel_##kernel_name##_##backend##_##layout( \ #kernel_name, \ @@ -800,12 +800,12 @@ struct KernelRegistrar { return 0; \ } \ void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ - ::pten::Kernel* kernel) + const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel) #else #define _PT_REGISTER_GENERAL_KERNEL( \ kernel_name, backend, layout, kernel_fn, dtype) \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ - ::pten::Kernel*); \ + const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel); \ static const ::pten::KernelRegistrar \ __reg_pt_kernel_##kernel_name##_##backend##_##layout( \ #kernel_name, \ @@ -819,7 +819,7 @@ struct KernelRegistrar { return 0; \ } \ void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ - ::pten::Kernel* kernel) + const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel) #endif /** PT_DECLARE_KERNEL diff --git a/paddle/pten/core/type_defs.h b/paddle/pten/core/type_defs.h index 9b91720d86f1e92c12df0a7836bd51c67f7383b6..4ecc12fcdef0174ae6f3efa5de6d1a9996bc2f7c 100644 --- a/paddle/pten/core/type_defs.h +++ b/paddle/pten/core/type_defs.h @@ -27,7 +27,7 @@ class ArgumentMappingContext; class InferMetaContext; using KernelFn = std::function; -using KernelArgsDefFn = void (*)(Kernel* kernel); +using KernelArgsDefFn = void (*)(const KernelKey& kernel_key, Kernel* kernel); using KernelArgsParseFn = void (*)(const KernelKey& default_key, KernelArgsDef* args_def); diff --git a/paddle/pten/tests/core/test_kernel_factory.cc b/paddle/pten/tests/core/test_kernel_factory.cc index 5355921ddbe018b24cf874d6bb7b54da48c967f3..c9e8dffe56ff958054c812dc860ec517418a06de 100644 --- a/paddle/pten/tests/core/test_kernel_factory.cc +++ b/paddle/pten/tests/core/test_kernel_factory.cc @@ -15,6 +15,8 @@ limitations under the License. */ #include #include +#include "paddle/pten/common/float16.h" +#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/kernel_factory.h" #include "paddle/pten/core/kernel_registry.h" @@ -47,5 +49,42 @@ TEST(KernelFactory, SelectedKernelMap) { } } +template +void TestKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& param, + DenseTensor* out) {} + +TEST(KernelRegistry, SetFP32Input) { + pten::KernelKey kernel_key(pten::Backend::CPU, + pten::DataLayout::ALL_LAYOUT, + pten::DataType::FLOAT16); + auto test_kernel = + pten::KernelFactory::Instance().SelectKernel("test", kernel_key); + EXPECT_TRUE(test_kernel.IsValid()); + auto& arg_defs = test_kernel.args_def(); + auto& input_defs = arg_defs.input_defs(); + auto& attr_defs = arg_defs.attribute_defs(); + auto& output_defs = arg_defs.output_defs(); + EXPECT_EQ(input_defs.size(), 2UL); + EXPECT_EQ(attr_defs.size(), 0UL); + EXPECT_EQ(output_defs.size(), 1UL); + EXPECT_EQ(input_defs.at(0).dtype, pten::DataType::FLOAT16); + EXPECT_EQ(input_defs.at(1).dtype, pten::DataType::FLOAT32); + EXPECT_EQ(output_defs.at(0).dtype, pten::DataType::FLOAT16); +} + } // namespace tests } // namespace pten + +PT_REGISTER_KERNEL(test, + CPU, + ALL_LAYOUT, + pten::tests::TestKernel, + float, + double, + pten::dtype::float16) { + if (kernel_key.dtype() == pten::DataType::FLOAT16) { + kernel->InputAt(1).SetDataType(pten::DataType::FLOAT32); + } +}