未验证 提交 5fb9cf60 编写于 作者: C Chen Weihang 提交者: GitHub

support set fp32 input for fp16 kernel (#39625)

上级 d63ece1f
...@@ -184,7 +184,7 @@ struct KernelRegistrar { ...@@ -184,7 +184,7 @@ struct KernelRegistrar {
KernelKey kernel_key(backend, layout, dtype); KernelKey kernel_key(backend, layout, dtype);
Kernel kernel(kernel_fn, variadic_kernel_fn); Kernel kernel(kernel_fn, variadic_kernel_fn);
args_parse_fn(kernel_key, kernel.mutable_args_def()); args_parse_fn(kernel_key, kernel.mutable_args_def());
args_def_fn(&kernel); args_def_fn(kernel_key, &kernel);
KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel; KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel;
} }
}; };
...@@ -231,7 +231,7 @@ struct KernelRegistrar { ...@@ -231,7 +231,7 @@ struct KernelRegistrar {
kernel_name, backend, layout, meta_kernel_fn, ...) \ kernel_name, backend, layout, meta_kernel_fn, ...) \
PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, __VA_ARGS__); \ PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, __VA_ARGS__); \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \ const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel); \
PT_KERNEL_REGISTRAR_INIT( \ PT_KERNEL_REGISTRAR_INIT( \
kernel_name, \ kernel_name, \
backend, \ backend, \
...@@ -240,7 +240,7 @@ struct KernelRegistrar { ...@@ -240,7 +240,7 @@ struct KernelRegistrar {
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__); \ __VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel) const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel)
#else #else
/** /**
* `template decltype(fn) fn` can work on gcc and clang, * `template decltype(fn) fn` can work on gcc and clang,
...@@ -257,7 +257,7 @@ struct KernelRegistrar { ...@@ -257,7 +257,7 @@ struct KernelRegistrar {
#define _PT_REGISTER_2TA_KERNEL( \ #define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, ...) \ kernel_name, backend, layout, meta_kernel_fn, ...) \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \ const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel); \
PT_EXPAND(PT_KERNEL_REGISTRAR_INIT( \ PT_EXPAND(PT_KERNEL_REGISTRAR_INIT( \
kernel_name, \ kernel_name, \
backend, \ backend, \
...@@ -266,7 +266,7 @@ struct KernelRegistrar { ...@@ -266,7 +266,7 @@ struct KernelRegistrar {
meta_kernel_fn, \ meta_kernel_fn, \
__VA_ARGS__)); \ __VA_ARGS__)); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel) const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel)
#endif #endif
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, ...) \ #define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, ...) \
...@@ -786,7 +786,7 @@ struct KernelRegistrar { ...@@ -786,7 +786,7 @@ struct KernelRegistrar {
kernel_name, backend, layout, kernel_fn, dtype) \ kernel_name, backend, layout, kernel_fn, dtype) \
template decltype(kernel_fn) kernel_fn; \ template decltype(kernel_fn) kernel_fn; \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \ const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel); \
static const ::pten::KernelRegistrar \ static const ::pten::KernelRegistrar \
__reg_pt_kernel_##kernel_name##_##backend##_##layout( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout( \
#kernel_name, \ #kernel_name, \
...@@ -800,12 +800,12 @@ struct KernelRegistrar { ...@@ -800,12 +800,12 @@ struct KernelRegistrar {
return 0; \ return 0; \
} \ } \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel) const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel)
#else #else
#define _PT_REGISTER_GENERAL_KERNEL( \ #define _PT_REGISTER_GENERAL_KERNEL( \
kernel_name, backend, layout, kernel_fn, dtype) \ kernel_name, backend, layout, kernel_fn, dtype) \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \ const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel); \
static const ::pten::KernelRegistrar \ static const ::pten::KernelRegistrar \
__reg_pt_kernel_##kernel_name##_##backend##_##layout( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout( \
#kernel_name, \ #kernel_name, \
...@@ -819,7 +819,7 @@ struct KernelRegistrar { ...@@ -819,7 +819,7 @@ struct KernelRegistrar {
return 0; \ return 0; \
} \ } \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel) const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel)
#endif #endif
/** PT_DECLARE_KERNEL /** PT_DECLARE_KERNEL
......
...@@ -27,7 +27,7 @@ class ArgumentMappingContext; ...@@ -27,7 +27,7 @@ class ArgumentMappingContext;
class InferMetaContext; class InferMetaContext;
using KernelFn = std::function<void(KernelContext* ctx)>; using KernelFn = std::function<void(KernelContext* ctx)>;
using KernelArgsDefFn = void (*)(Kernel* kernel); using KernelArgsDefFn = void (*)(const KernelKey& kernel_key, Kernel* kernel);
using KernelArgsParseFn = void (*)(const KernelKey& default_key, using KernelArgsParseFn = void (*)(const KernelKey& default_key,
KernelArgsDef* args_def); KernelArgsDef* args_def);
......
...@@ -15,6 +15,8 @@ limitations under the License. */ ...@@ -15,6 +15,8 @@ limitations under the License. */
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "paddle/pten/common/float16.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_factory.h" #include "paddle/pten/core/kernel_factory.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
...@@ -47,5 +49,42 @@ TEST(KernelFactory, SelectedKernelMap) { ...@@ -47,5 +49,42 @@ TEST(KernelFactory, SelectedKernelMap) {
} }
} }
template <typename T, typename Context>
void TestKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& param,
DenseTensor* out) {}
TEST(KernelRegistry, SetFP32Input) {
pten::KernelKey kernel_key(pten::Backend::CPU,
pten::DataLayout::ALL_LAYOUT,
pten::DataType::FLOAT16);
auto test_kernel =
pten::KernelFactory::Instance().SelectKernel("test", kernel_key);
EXPECT_TRUE(test_kernel.IsValid());
auto& arg_defs = test_kernel.args_def();
auto& input_defs = arg_defs.input_defs();
auto& attr_defs = arg_defs.attribute_defs();
auto& output_defs = arg_defs.output_defs();
EXPECT_EQ(input_defs.size(), 2UL);
EXPECT_EQ(attr_defs.size(), 0UL);
EXPECT_EQ(output_defs.size(), 1UL);
EXPECT_EQ(input_defs.at(0).dtype, pten::DataType::FLOAT16);
EXPECT_EQ(input_defs.at(1).dtype, pten::DataType::FLOAT32);
EXPECT_EQ(output_defs.at(0).dtype, pten::DataType::FLOAT16);
}
} // namespace tests } // namespace tests
} // namespace pten } // namespace pten
PT_REGISTER_KERNEL(test,
CPU,
ALL_LAYOUT,
pten::tests::TestKernel,
float,
double,
pten::dtype::float16) {
if (kernel_key.dtype() == pten::DataType::FLOAT16) {
kernel->InputAt(1).SetDataType(pten::DataType::FLOAT32);
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册