未验证 提交 5fb9cf60 编写于 作者: C Chen Weihang 提交者: GitHub

support set fp32 input for fp16 kernel (#39625)

上级 d63ece1f
......@@ -184,7 +184,7 @@ struct KernelRegistrar {
KernelKey kernel_key(backend, layout, dtype);
Kernel kernel(kernel_fn, variadic_kernel_fn);
args_parse_fn(kernel_key, kernel.mutable_args_def());
args_def_fn(&kernel);
args_def_fn(kernel_key, &kernel);
KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel;
}
};
......@@ -231,7 +231,7 @@ struct KernelRegistrar {
kernel_name, backend, layout, meta_kernel_fn, ...) \
PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, __VA_ARGS__); \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel); \
PT_KERNEL_REGISTRAR_INIT( \
kernel_name, \
backend, \
......@@ -240,7 +240,7 @@ struct KernelRegistrar {
meta_kernel_fn, \
__VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel)
#else
/**
* `template decltype(fn) fn` can work on gcc and clang,
......@@ -257,7 +257,7 @@ struct KernelRegistrar {
#define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, ...) \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel); \
PT_EXPAND(PT_KERNEL_REGISTRAR_INIT( \
kernel_name, \
backend, \
......@@ -266,7 +266,7 @@ struct KernelRegistrar {
meta_kernel_fn, \
__VA_ARGS__)); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel)
#endif
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, ...) \
......@@ -786,7 +786,7 @@ struct KernelRegistrar {
kernel_name, backend, layout, kernel_fn, dtype) \
template decltype(kernel_fn) kernel_fn; \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel); \
static const ::pten::KernelRegistrar \
__reg_pt_kernel_##kernel_name##_##backend##_##layout( \
#kernel_name, \
......@@ -800,12 +800,12 @@ struct KernelRegistrar {
return 0; \
} \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel)
#else
#define _PT_REGISTER_GENERAL_KERNEL( \
kernel_name, backend, layout, kernel_fn, dtype) \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel); \
static const ::pten::KernelRegistrar \
__reg_pt_kernel_##kernel_name##_##backend##_##layout( \
#kernel_name, \
......@@ -819,7 +819,7 @@ struct KernelRegistrar {
return 0; \
} \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
const ::pten::KernelKey& kernel_key, ::pten::Kernel* kernel)
#endif
/** PT_DECLARE_KERNEL
......
......@@ -27,7 +27,7 @@ class ArgumentMappingContext;
class InferMetaContext;
using KernelFn = std::function<void(KernelContext* ctx)>;
using KernelArgsDefFn = void (*)(Kernel* kernel);
using KernelArgsDefFn = void (*)(const KernelKey& kernel_key, Kernel* kernel);
using KernelArgsParseFn = void (*)(const KernelKey& default_key,
KernelArgsDef* args_def);
......
......@@ -15,6 +15,8 @@ limitations under the License. */
#include <iostream>
#include <sstream>
#include "paddle/pten/common/float16.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_factory.h"
#include "paddle/pten/core/kernel_registry.h"
......@@ -47,5 +49,42 @@ TEST(KernelFactory, SelectedKernelMap) {
}
}
template <typename T, typename Context>
void TestKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& param,
DenseTensor* out) {}
TEST(KernelRegistry, SetFP32Input) {
pten::KernelKey kernel_key(pten::Backend::CPU,
pten::DataLayout::ALL_LAYOUT,
pten::DataType::FLOAT16);
auto test_kernel =
pten::KernelFactory::Instance().SelectKernel("test", kernel_key);
EXPECT_TRUE(test_kernel.IsValid());
auto& arg_defs = test_kernel.args_def();
auto& input_defs = arg_defs.input_defs();
auto& attr_defs = arg_defs.attribute_defs();
auto& output_defs = arg_defs.output_defs();
EXPECT_EQ(input_defs.size(), 2UL);
EXPECT_EQ(attr_defs.size(), 0UL);
EXPECT_EQ(output_defs.size(), 1UL);
EXPECT_EQ(input_defs.at(0).dtype, pten::DataType::FLOAT16);
EXPECT_EQ(input_defs.at(1).dtype, pten::DataType::FLOAT32);
EXPECT_EQ(output_defs.at(0).dtype, pten::DataType::FLOAT16);
}
} // namespace tests
} // namespace pten
PT_REGISTER_KERNEL(test,
CPU,
ALL_LAYOUT,
pten::tests::TestKernel,
float,
double,
pten::dtype::float16) {
if (kernel_key.dtype() == pten::DataType::FLOAT16) {
kernel->InputAt(1).SetDataType(pten::DataType::FLOAT32);
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册