未验证 提交 af498677 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Add register_ctx_kernel marco and move scale kernel (#38121)

* add register_ctx_kernel and move scale kernel

* polish details by reviewer comment

* fix xpu compile failed

* fix cmake error
上级 58b4bc72
......@@ -33,30 +33,30 @@ static void ScaleDeviceDispatch(const pten::DenseTensor& dense_tensor,
pten::DenseTensor* dense_out) {
switch (dense_tensor.dtype()) {
case pten::DataType::FLOAT64: {
pten::Scale<double>(dev_ctx, dense_tensor /* tensor */, scale /* scale */,
bias /* bias */,
bias_after_scale /* bias_after_scale */,
pten::Scale<double, DeviceContext>(
dev_ctx, dense_tensor /* tensor */, scale /* scale */,
bias /* bias */, bias_after_scale /* bias_after_scale */,
dense_out /* out tensor */);
break;
}
case pten::DataType::FLOAT32: {
pten::Scale<float>(dev_ctx, dense_tensor /* tensor */, scale /* scale */,
bias /* bias */,
pten::Scale<float, DeviceContext>(dev_ctx, dense_tensor /* tensor */,
scale /* scale */, bias /* bias */,
bias_after_scale /* bias_after_scale */,
dense_out /* out tensor */);
break;
}
case pten::DataType::INT64: {
pten::Scale<int64_t>(dev_ctx, dense_tensor /* tensor */,
scale /* scale */, bias /* bias */,
bias_after_scale /* bias_after_scale */,
pten::Scale<int64_t, DeviceContext>(
dev_ctx, dense_tensor /* tensor */, scale /* scale */,
bias /* bias */, bias_after_scale /* bias_after_scale */,
dense_out /* out tensor */);
break;
}
case pten::DataType::INT32: {
pten::Scale<int32_t>(dev_ctx, dense_tensor /* tensor */,
scale /* scale */, bias /* bias */,
bias_after_scale /* bias_after_scale */,
pten::Scale<int32_t, DeviceContext>(
dev_ctx, dense_tensor /* tensor */, scale /* scale */,
bias /* bias */, bias_after_scale /* bias_after_scale */,
dense_out /* out tensor */);
break;
}
......
......@@ -20,7 +20,7 @@ limitations under the License. */
// only can include the headers in paddle/top/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/math.h"
#include "paddle/pten/kernels/scale_kernel.h"
namespace paddle {
namespace operators {
......
......@@ -24,6 +24,7 @@ add_subdirectory(tests)
# make an unity target for compile deps
set(PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory kernel_context)
set(PTEN_DEPS scale_kernel_eigen)
set(PTEN_DEPS ${PTEN_DEPS} math_cpu linalg_cpu creation_cpu manipulation_cpu)
set(PTEN_DEPS ${PTEN_DEPS} nary unary binary)
if(WITH_GPU OR WITH_ROCM)
......
......@@ -24,6 +24,7 @@ PT_DECLARE_KERNEL(full_like, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(dot, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(flatten, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(sign, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_KERNEL(full_like, CUDA, ALL_LAYOUT);
......
......@@ -772,6 +772,558 @@ struct KernelRegistrar {
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
/** PT_REGISTER_CTX_KERNEL
*
* Used for kernel registration with device context and data type as
* template parameter.
*/
#define PT_REGISTER_CTX_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_tp_ctx_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_CTX_KERNEL must be called in global namespace."); \
_PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, __VA_ARGS__)
#ifndef _WIN32
#define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
PT_KERNEL_INSTANTIATION2(meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__); \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT2( \
kernel_name, \
backend, \
layout, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
#else
#define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel*); \
PT_KERNEL_REGISTRAR_INIT2( \
kernel_name, \
backend, \
layout, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__); \
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
::pten::Kernel* kernel)
#endif
#define PT_KERNEL_INSTANTIATION2(meta_kernel_fn, backend, cpp_dtype, ...) \
_PT_KERNEL_INSTANTIATION2(PT_NARGS(cpp_dtype, __VA_ARGS__), \
meta_kernel_fn, \
backend, \
cpp_dtype, \
__VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION2(N, meta_kernel_fn, backend, cpp_dtype, ...) \
PT_CONCATENATE(_PT_KERNEL_INSTANTIATION2_, N) \
(meta_kernel_fn, backend, cpp_dtype, __VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION2_1(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>
#define _PT_KERNEL_INSTANTIATION2_2(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_1(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_3(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_2(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_4(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_3(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_5(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_4(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_6(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_5(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_7(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_6(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_8(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_7(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_9(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_8(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_10(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_9(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_11(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_10(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_12(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_11(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_13(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_12(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_14(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_13(meta_kernel_fn, backend, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION2_15(meta_kernel_fn, backend, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>) \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION2_14(meta_kernel_fn, backend, __VA_ARGS__))
#define PT_KERNEL_REGISTRAR_INIT2( \
kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \
_PT_KERNEL_REGISTRAR_INIT2(PT_NARGS(cpp_dtype, __VA_ARGS__), \
kernel_name, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__)
// clang-format off
/* The =pre-commit always treats this macro into the wrong format,
and multi-line macros cannot be skipped with NOLINT.*/
#define _PT_KERNEL_REGISTRAR_INIT2(N, \
kernel_name, \
backend, \
layout, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT2_, N) ( \
kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
__VA_ARGS__)
// clang-format on
#define _PT_KERNEL_REGISTRAR_INIT2_1(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; }
#define _PT_KERNEL_REGISTRAR_INIT2_2(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_1(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_3(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_2(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_4(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_3(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_5(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_4(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_6(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_5(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_7(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_6(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_8(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_7(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_9(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_8(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_10(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_9(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_11(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_10(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_12(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_11(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_13(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_12(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_14(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_13(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT2_15(kernel_name, \
backend, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
cpp_dtype, \
...) \
static const ::pten::KernelRegistrar PT_CONCATENATE( \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
#kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::pten::backend##Context>), \
PT_VARIADIC_KERNEL( \
meta_kernel_fn<cpp_dtype, ::pten::backend##Context>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT2_14(kernel_name, \
backend, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__))
/** PT_DECLARE_KERNEL
*
* Used to export the symbols of the file where the kernel is located,
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/pten/include/infermeta.h"
#include "paddle/pten/kernels/cpu/math.h"
#include "paddle/pten/kernels/cuda/math.h"
#include "paddle/pten/kernels/scale_kernel.h"
namespace pten {
......@@ -86,7 +87,7 @@ DenseTensor Scale(const ContextT& dev_ctx,
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(out_meta));
Scale<T>(dev_ctx, x, scale, bias, bias_after_scale, &dense_out);
Scale<T, ContextT>(dev_ctx, x, scale, bias, bias_after_scale, &dense_out);
return dense_out;
}
......
......@@ -17,7 +17,6 @@
#include "paddle/pten/api/ext/dispatch.h"
#include "paddle/pten/kernels/hybird/cpu/elementwise.h"
#include "paddle/pten/kernels/hybird/eigen/reduce.h"
#include "paddle/pten/kernels/hybird/eigen/scale.h"
#include "paddle/pten/kernels/hybird/eigen/sign.h"
#include "paddle/pten/kernels/hybird/general/elementwise_functor.h"
#include "paddle/pten/kernels/hybird/general/reduce_impl.h"
......@@ -47,17 +46,6 @@ void Mean(const CPUContext& dev_ctx,
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
template <typename T>
void Scale(const CPUContext& dev_ctx,
const DenseTensor& x,
const Scalar& scale,
float bias,
bool bias_after_scale,
DenseTensor* out) {
eigen::Scale<CPUContext, T>(
dev_ctx, x, scale.to<float>(), bias, bias_after_scale, out);
}
template <typename T>
void Divide(const CPUContext& dev_ctx,
const DenseTensor& x,
......@@ -113,18 +101,6 @@ using complex128 = ::paddle::platform::complex<double>;
// using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_KERNEL(sign, CPU, ALL_LAYOUT, pten::Sign, float, double) {}
PT_REGISTER_KERNEL(mean, CPU, ALL_LAYOUT, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(scale,
CPU,
ALL_LAYOUT,
pten::Scale,
float,
double,
paddle::platform::bfloat16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(add,
CPU,
ALL_LAYOUT,
......
......@@ -38,14 +38,6 @@ void Mean(const CPUContext& dev_ctx,
DataType out_dtype,
DenseTensor* out);
template <typename T>
void Scale(const CPUContext& dev_ctx,
const DenseTensor& x,
const Scalar& scale,
float bias,
bool bias_after_scale,
DenseTensor* out);
template <typename T>
void Add(const CPUContext& dev_ctx,
const DenseTensor& x,
......
......@@ -17,7 +17,6 @@ limitations under the License. */
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h"
#include "paddle/pten/kernels/hybird/cuda/reduce/reduce.h"
#include "paddle/pten/kernels/hybird/eigen/scale.h"
#include "paddle/pten/kernels/hybird/eigen/sign.h"
#include "paddle/pten/kernels/hybird/general/elementwise_functor.h"
#include "paddle/pten/kernels/hybird/general/reduce_impl.h"
......@@ -76,17 +75,6 @@ void Mean(const CUDAContext& dev_ctx,
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
template <typename T>
void Scale(const CUDAContext& dev_ctx,
const DenseTensor& x,
const Scalar& scale,
float bias,
bool bias_after_scale,
DenseTensor* out) {
eigen::Scale<CUDAContext, T>(
dev_ctx, x, scale.to<float>(), bias, bias_after_scale, out);
}
// Create the definition of Add
DEFINE_CUDA_ELEMENTWISE_OP(Add)
// Create the definition of Subtract
......@@ -118,18 +106,6 @@ using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL(sign, CUDA, ALL_LAYOUT, pten::Sign, float, double, float16) {
}
PT_REGISTER_KERNEL(mean, CUDA, ALL_LAYOUT, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(scale,
CUDA,
ALL_LAYOUT,
pten::Scale,
float,
double,
float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(add,
CUDA,
ALL_LAYOUT,
......
if(WITH_GPU)
nv_library(scale_kernel_eigen SRCS scale_kernel.cc DEPS dense_tensor eigen_function)
elseif(WITH_ROCM)
hip_library(scale_kernel_eigen SRCS scale_kernel.cc DEPS dense_tensor eigen_function)
else()
cc_library(scale_kernel_eigen SRCS scale_kernel.cc DEPS dense_tensor eigen_function)
endif()
......@@ -12,21 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/kernels/scale_kernel.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/hybird/eigen/common.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
namespace pten {
namespace eigen {
template <typename DevCtx, typename T>
void Scale(const DevCtx& dev_ctx,
template <typename T, typename ContextT>
void Scale(const ContextT& dev_ctx,
const DenseTensor& x,
float scale,
const Scalar& scale,
float bias,
bool bias_after_scale,
DenseTensor* out) {
......@@ -42,10 +44,41 @@ void Scale(const DevCtx& dev_ctx,
dev,
eigen_out,
eigen_x,
static_cast<T>(scale),
scale.to<T>(),
static_cast<T>(bias),
bias_after_scale);
}
} // namespace eigen
} // namespace pten
// TODO(chenweihang): Use EigenContext to specialize the ContextT parameter,
// and only register the backend as Eigen's kernel during registration,
// instead of using macros to register the CPU and CUDA kernels separately
PT_REGISTER_CTX_KERNEL(scale,
CPU,
ALL_LAYOUT,
pten::Scale,
float,
double,
paddle::platform::bfloat16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_CTX_KERNEL(scale,
CUDA,
ALL_LAYOUT,
pten::Scale,
float,
double,
paddle::platform::float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
#endif
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/core/dense_tensor.h"
namespace pten {
template <typename T, typename ContextT>
void Scale(const ContextT& dev_ctx,
const DenseTensor& x,
const Scalar& scale,
float bias,
bool bias_after_scale,
DenseTensor* out);
} // namespace pten
......@@ -21,4 +21,4 @@ cc_test(test_to_api SRCS test_to_api.cc DEPS pten_tensor pten_api pten_api_utils
cc_test(test_slice_api SRCS test_slice_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_sum_api SRCS test_sum_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_scale_api SRCS test_scale_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_scale_benchmark SRCS test_scale_benchmark.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_scale_benchmark SRCS test_scale_benchmark.cc DEPS pten_tensor pten_api pten_api_utils scale_kernel_eigen)
......@@ -25,8 +25,7 @@
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/infermeta.h"
#include "paddle/pten/kernels/cpu/math.h"
#include "paddle/pten/kernels/cuda/math.h"
#include "paddle/pten/kernels/scale_kernel.h"
namespace paddle {
namespace experimental {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册