未验证 提交 b76ef045 编写于 作者: Z zyfncg 提交者: GitHub

【PTen】Add variadic args kernel for PTen API to replace KernelContext (#37942)

* add variadic_args kernel in pten

* merge develop code

* add variadic_args kernel and benchmark

* change dynamic_cast to static_cast for DeviceContext

* merge the code

* modify code format

* refactor variadic kernel function
上级 512e4339
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/platform/device_context.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h"
// This header is used to cast kernel function from void* to original form of
// function Currnetly.
// It may be generated automatically in the future.
namespace pten {
using DeviceContext = paddle::platform::DeviceContext;
using add_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const DenseTensor&,
int,
DenseTensor*);
using cast_kernel = void (*)(
const DeviceContext&, const DenseTensor&, DataType, DataType, DenseTensor*);
using divide_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const DenseTensor&,
int,
DenseTensor*);
using dot_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const DenseTensor&,
DenseTensor*);
using flatten_kernel =
void (*)(const DeviceContext&, const DenseTensor&, int, int, DenseTensor*);
using full_kernel = void (*)(const DeviceContext&,
const ScalarArray&,
const Scalar&,
DenseTensor*);
using full_like_kernel = void (*)(const DeviceContext&,
const Scalar&,
DenseTensor*);
using matmul_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const DenseTensor&,
bool,
bool,
DenseTensor*);
using mean_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const std::vector<int64_t>&,
bool,
bool,
DataType,
DataType,
DenseTensor*);
using multiply_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const DenseTensor&,
int,
DenseTensor*);
using reshape_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const std::vector<int64_t>&,
DenseTensor*);
using scale_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const Scalar&,
float,
bool,
DenseTensor*);
using sum_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const std::vector<int64_t>&,
bool,
bool,
DataType,
DataType,
DenseTensor*);
using subtract_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const DenseTensor&,
int,
DenseTensor*);
} // namespace pten
...@@ -228,10 +228,17 @@ class Kernel { ...@@ -228,10 +228,17 @@ class Kernel {
// for map element contruct // for map element contruct
Kernel() = default; Kernel() = default;
explicit Kernel(KernelFn fn) : fn_(fn) {} explicit Kernel(KernelFn fn, void* variadic_fn)
: fn_(fn), variadic_fn_(variadic_fn) {}
void operator()(KernelContext* ctx) const { fn_(ctx); } void operator()(KernelContext* ctx) const { fn_(ctx); }
template <typename Fn>
Fn GetVariadicKernelFn() const {
auto* func = reinterpret_cast<Fn>(variadic_fn_);
return func;
}
KernelArgsDef* mutable_args_def() { return &args_def_; } KernelArgsDef* mutable_args_def() { return &args_def_; }
const KernelArgsDef& args_def() const { return args_def_; } const KernelArgsDef& args_def() const { return args_def_; }
...@@ -244,6 +251,7 @@ class Kernel { ...@@ -244,6 +251,7 @@ class Kernel {
private: private:
KernelFn fn_{nullptr}; KernelFn fn_{nullptr};
void* variadic_fn_ = nullptr;
KernelArgsDef args_def_; KernelArgsDef args_def_;
}; };
......
...@@ -101,14 +101,16 @@ struct KernelRegistrar { ...@@ -101,14 +101,16 @@ struct KernelRegistrar {
DataType dtype, DataType dtype,
KernelArgsParseFn args_parse_fn, KernelArgsParseFn args_parse_fn,
KernelArgsDefFn args_def_fn, KernelArgsDefFn args_def_fn,
KernelFn kernel_fn) { KernelFn kernel_fn,
void* variadic_kernel_fn) {
ConstructKernel(kernel_name_cstr, ConstructKernel(kernel_name_cstr,
backend, backend,
layout, layout,
dtype, dtype,
args_parse_fn, args_parse_fn,
args_def_fn, args_def_fn,
kernel_fn); kernel_fn,
variadic_kernel_fn);
} }
KernelRegistrar(const char* kernel_name_cstr, KernelRegistrar(const char* kernel_name_cstr,
...@@ -116,7 +118,8 @@ struct KernelRegistrar { ...@@ -116,7 +118,8 @@ struct KernelRegistrar {
DataLayout layout, DataLayout layout,
KernelArgsParseFn args_parse_fn, KernelArgsParseFn args_parse_fn,
KernelArgsDefFn args_def_fn, KernelArgsDefFn args_def_fn,
KernelFn kernel_fn) { KernelFn kernel_fn,
void* variadic_kernel_fn) {
for (size_t dtype = static_cast<size_t>(DataType::BOOL); for (size_t dtype = static_cast<size_t>(DataType::BOOL);
dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES); dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES);
dtype++) { dtype++) {
...@@ -126,7 +129,8 @@ struct KernelRegistrar { ...@@ -126,7 +129,8 @@ struct KernelRegistrar {
static_cast<DataType>(dtype), static_cast<DataType>(dtype),
args_parse_fn, args_parse_fn,
args_def_fn, args_def_fn,
kernel_fn); kernel_fn,
variadic_kernel_fn);
} }
} }
...@@ -137,10 +141,11 @@ struct KernelRegistrar { ...@@ -137,10 +141,11 @@ struct KernelRegistrar {
DataType dtype, DataType dtype,
KernelArgsParseFn args_parse_fn, KernelArgsParseFn args_parse_fn,
KernelArgsDefFn args_def_fn, KernelArgsDefFn args_def_fn,
KernelFn kernel_fn) { KernelFn kernel_fn,
void* variadic_kernel_fn) {
KernelName kernel_name(kernel_name_cstr); KernelName kernel_name(kernel_name_cstr);
KernelKey kernel_key(backend, layout, dtype); KernelKey kernel_key(backend, layout, dtype);
Kernel kernel(kernel_fn); Kernel kernel(kernel_fn, variadic_kernel_fn);
args_parse_fn(kernel_key, kernel.mutable_args_def()); args_parse_fn(kernel_key, kernel.mutable_args_def());
args_def_fn(&kernel); args_def_fn(&kernel);
KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel; KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel;
...@@ -356,7 +361,8 @@ struct KernelRegistrar { ...@@ -356,7 +361,8 @@ struct KernelRegistrar {
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; } int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; }
#define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \
registrar_id, \ registrar_id, \
...@@ -375,7 +381,8 @@ struct KernelRegistrar { ...@@ -375,7 +381,8 @@ struct KernelRegistrar {
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \
PT_ID, \ PT_ID, \
backend, \ backend, \
...@@ -400,7 +407,8 @@ struct KernelRegistrar { ...@@ -400,7 +407,8 @@ struct KernelRegistrar {
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \
PT_ID, \ PT_ID, \
backend, \ backend, \
...@@ -425,7 +433,8 @@ struct KernelRegistrar { ...@@ -425,7 +433,8 @@ struct KernelRegistrar {
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \
PT_ID, \ PT_ID, \
backend, \ backend, \
...@@ -450,7 +459,8 @@ struct KernelRegistrar { ...@@ -450,7 +459,8 @@ struct KernelRegistrar {
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \
PT_ID, \ PT_ID, \
backend, \ backend, \
...@@ -475,7 +485,8 @@ struct KernelRegistrar { ...@@ -475,7 +485,8 @@ struct KernelRegistrar {
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \
PT_ID, \ PT_ID, \
backend, \ backend, \
...@@ -500,7 +511,8 @@ struct KernelRegistrar { ...@@ -500,7 +511,8 @@ struct KernelRegistrar {
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \
PT_ID, \ PT_ID, \
backend, \ backend, \
...@@ -525,7 +537,8 @@ struct KernelRegistrar { ...@@ -525,7 +537,8 @@ struct KernelRegistrar {
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \
PT_ID, \ PT_ID, \
backend, \ backend, \
...@@ -550,7 +563,8 @@ struct KernelRegistrar { ...@@ -550,7 +563,8 @@ struct KernelRegistrar {
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \
PT_ID, \ PT_ID, \
backend, \ backend, \
...@@ -575,7 +589,8 @@ struct KernelRegistrar { ...@@ -575,7 +589,8 @@ struct KernelRegistrar {
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \
PT_ID, \ PT_ID, \
backend, \ backend, \
...@@ -600,7 +615,8 @@ struct KernelRegistrar { ...@@ -600,7 +615,8 @@ struct KernelRegistrar {
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \
PT_ID, \ PT_ID, \
backend, \ backend, \
...@@ -625,7 +641,8 @@ struct KernelRegistrar { ...@@ -625,7 +641,8 @@ struct KernelRegistrar {
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \
PT_ID, \ PT_ID, \
backend, \ backend, \
...@@ -650,7 +667,8 @@ struct KernelRegistrar { ...@@ -650,7 +667,8 @@ struct KernelRegistrar {
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \
PT_ID, \ PT_ID, \
backend, \ backend, \
...@@ -675,7 +693,8 @@ struct KernelRegistrar { ...@@ -675,7 +693,8 @@ struct KernelRegistrar {
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \
PT_ID, \ PT_ID, \
backend, \ backend, \
...@@ -700,7 +719,8 @@ struct KernelRegistrar { ...@@ -700,7 +719,8 @@ struct KernelRegistrar {
::pten::KernelArgsParseFunctor<decltype( \ ::pten::KernelArgsParseFunctor<decltype( \
&meta_kernel_fn<cpp_dtype>)>::Parse, \ &meta_kernel_fn<cpp_dtype>)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype>)); \ PT_KERNEL(meta_kernel_fn<cpp_dtype>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype>)); \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \
PT_ID, \ PT_ID, \
backend, \ backend, \
...@@ -728,7 +748,8 @@ struct KernelRegistrar { ...@@ -728,7 +748,8 @@ struct KernelRegistrar {
DATATYPE(dtype), \ DATATYPE(dtype), \
::pten::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \ ::pten::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
args_def_fn, \ args_def_fn, \
PT_KERNEL(kernel_fn)); \ PT_KERNEL(kernel_fn), \
PT_VARIADIC_KERNEL(kernel_fn)); \
int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; } \ int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; } \
void __PT_SINGLE_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*) void __PT_SINGLE_KERNEL_args_def_FN_##kernel_name(::pten::Kernel*)
...@@ -750,7 +771,8 @@ struct KernelRegistrar { ...@@ -750,7 +771,8 @@ struct KernelRegistrar {
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::pten::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \ ::pten::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
&__PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name, \ &__PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name, \
PT_KERNEL(kernel_fn)); \ PT_KERNEL(kernel_fn), \
PT_VARIADIC_KERNEL(kernel_fn)); \
int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; } \ int TouchKernelSymbolFor_##kernel_name##_##backend() { return 0; } \
void __PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name(::pten::Kernel* kernel) void __PT_KERNEL_ALL_DTYPE_args_def_FN_##kernel_name(::pten::Kernel* kernel)
......
...@@ -44,6 +44,10 @@ using XPUContext = paddle::platform::XPUDeviceContext; ...@@ -44,6 +44,10 @@ using XPUContext = paddle::platform::XPUDeviceContext;
#define PT_KERNEL(...) \ #define PT_KERNEL(...) \
::pten::KernelImpl<decltype(&__VA_ARGS__), &__VA_ARGS__>::Compute ::pten::KernelImpl<decltype(&__VA_ARGS__), &__VA_ARGS__>::Compute
#define PT_VARIADIC_KERNEL(...) \
reinterpret_cast<void*>(&::pten::KernelImpl<decltype(&__VA_ARGS__), \
&__VA_ARGS__>::VariadicCompute)
#define PT_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(dev_ctx) \ #define PT_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(dev_ctx) \
template <typename... Tail> \ template <typename... Tail> \
struct KernelCallHelper<const dev_ctx&, Tail...> { \ struct KernelCallHelper<const dev_ctx&, Tail...> { \
...@@ -169,10 +173,19 @@ struct TypeTag {}; ...@@ -169,10 +173,19 @@ struct TypeTag {};
template <typename Fn, Fn fn> template <typename Fn, Fn fn>
struct KernelImpl; struct KernelImpl;
template <typename Return, typename... Args, Return (*kernel_fn)(Args...)> template <typename Return,
struct KernelImpl<Return (*)(Args...), kernel_fn> { typename DevCtx,
typename... Args,
Return (*kernel_fn)(DevCtx, Args...)>
struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
static void Compute(KernelContext* ctx) { static void Compute(KernelContext* ctx) {
KernelCallHelper<Args..., TypeTag<int>>::template Compute<0, 0, 0, 0>(ctx); KernelCallHelper<DevCtx,
Args...,
TypeTag<int>>::template Compute<0, 0, 0, 0>(ctx);
}
static void VariadicCompute(const DeviceContext& dev_ctx, Args... args) {
return kernel_fn(static_cast<DevCtx>(dev_ctx), std::forward<Args>(args)...);
} }
private: private:
...@@ -224,12 +237,12 @@ struct KernelImpl<Return (*)(Args...), kernel_fn> { ...@@ -224,12 +237,12 @@ struct KernelImpl<Return (*)(Args...), kernel_fn> {
template <typename T> template <typename T>
struct KernelCallHelper<TypeTag<T>> { struct KernelCallHelper<TypeTag<T>> {
template <int dev_ctx_idx, int in_idx, int attr_idx, int out_idx> template <int dev_ctx_idx, int in_idx, int attr_idx, int out_idx>
static void Compute(KernelContext* ctx, Args&... args) { static void Compute(KernelContext* ctx, DevCtx dev_ctx, Args&... args) {
static_assert(dev_ctx_idx > 0, static_assert(dev_ctx_idx > 0,
"Kernel should pass DeviceContext as argument."); "Kernel should pass DeviceContext as argument.");
static_assert(out_idx > 0, "Kernel should have output argument."); static_assert(out_idx > 0, "Kernel should have output argument.");
// TODO(chenweihang): check dev_ctx, in, attr, out number // TODO(chenweihang): check dev_ctx, in, attr, out number
return kernel_fn(args...); return kernel_fn(dev_ctx, args...);
} }
}; };
}; };
......
...@@ -21,3 +21,4 @@ cc_test(test_to_api SRCS test_to_api.cc DEPS pten_tensor pten_api pten_api_utils ...@@ -21,3 +21,4 @@ cc_test(test_to_api SRCS test_to_api.cc DEPS pten_tensor pten_api pten_api_utils
cc_test(test_slice_api SRCS test_slice_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_slice_api SRCS test_slice_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_sum_api SRCS test_sum_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_sum_api SRCS test_sum_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_scale_api SRCS test_scale_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_scale_api SRCS test_scale_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_scale_benchmark SRCS test_scale_benchmark.cc DEPS pten_tensor pten_api pten_api_utils)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "glog/logging.h"
#include "paddle/pten/api/include/tensor.h"
#include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/infermeta.h"
#include "paddle/pten/kernels/cpu/math.h"
#include "paddle/pten/kernels/cuda/math.h"
namespace paddle {
namespace experimental {
PADDLE_API Tensor scale_kernel_context(const Tensor& x,
const Scalar& scale,
float bias,
bool bias_after_scale) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"scale", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "scale API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "scale API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto kernel_context = pten::KernelContext(dev_ctx);
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
kernel_context.EmplaceBackInput(dense_x);
kernel_context.EmplaceBackAttr(pten::Scalar(scale));
kernel_context.EmplaceBackAttr(bias);
kernel_context.EmplaceBackAttr(bias_after_scale);
auto out_meta = pten::UnchangedInferMeta(dense_x->meta());
const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>(
pten::TransToFluidPlace(kernel_backend));
auto dense_out = std::make_shared<pten::DenseTensor>(allocator, out_meta);
kernel_context.EmplaceBackOutput(dense_out);
Tensor out;
out.set_impl(dense_out);
kernel(&kernel_context);
return out;
}
static void ScaleCPU(DataType kernel_dtype,
const pten::CPUContext& dev_ctx,
const pten::DenseTensor& x,
const Scalar& scale,
float bias,
bool bias_after_scale,
pten::DenseTensor* dense_out) {
switch (kernel_dtype) {
case pten::DataType::FLOAT64: {
pten::Scale<double>(
dev_ctx, x, pten::Scalar(scale), bias, bias_after_scale, dense_out);
break;
}
case pten::DataType::FLOAT32: {
pten::Scale<float>(
dev_ctx, x, pten::Scalar(scale), bias, bias_after_scale, dense_out);
break;
}
case pten::DataType::BFLOAT16: {
pten::Scale<paddle::platform::bfloat16>(
dev_ctx, x, pten::Scalar(scale), bias, bias_after_scale, dense_out);
break;
}
case pten::DataType::INT64: {
pten::Scale<int64_t>(
dev_ctx, x, pten::Scalar(scale), bias, bias_after_scale, dense_out);
break;
}
case pten::DataType::INT32: {
pten::Scale<int32_t>(
dev_ctx, x, pten::Scalar(scale), bias, bias_after_scale, dense_out);
break;
}
case pten::DataType::INT16: {
pten::Scale<int16_t>(
dev_ctx, x, pten::Scalar(scale), bias, bias_after_scale, dense_out);
break;
}
case pten::DataType::INT8: {
pten::Scale<int8_t>(
dev_ctx, x, pten::Scalar(scale), bias, bias_after_scale, dense_out);
break;
}
case pten::DataType::UINT8: {
pten::Scale<uint8_t>(
dev_ctx, x, pten::Scalar(scale), bias, bias_after_scale, dense_out);
break;
}
default: {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Detected unsupported data type."
"Only Float64, Float32, BFloat16, Int64, Int32, Int16, Int8, UInt8 "
"are supported for now."));
break;
}
}
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
static void ScaleCUDA(DataType kernel_dtype,
const pten::CUDAContext& dev_ctx,
const pten::DenseTensor& x,
const Scalar& scale,
float bias,
bool bias_after_scale,
pten::DenseTensor* dense_out) {
switch (kernel_dtype) {
case pten::DataType::FLOAT64: {
pten::Scale<double>(
dev_ctx, x, pten::Scalar(scale), bias, bias_after_scale, dense_out);
break;
}
case pten::DataType::FLOAT32: {
pten::Scale<float>(
dev_ctx, x, pten::Scalar(scale), bias, bias_after_scale, dense_out);
break;
}
case pten::DataType::FLOAT16: {
pten::Scale<paddle::platform::float16>(
dev_ctx, x, pten::Scalar(scale), bias, bias_after_scale, dense_out);
break;
}
case pten::DataType::INT64: {
pten::Scale<int64_t>(
dev_ctx, x, pten::Scalar(scale), bias, bias_after_scale, dense_out);
break;
}
case pten::DataType::INT32: {
pten::Scale<int32_t>(
dev_ctx, x, pten::Scalar(scale), bias, bias_after_scale, dense_out);
break;
}
case pten::DataType::INT16: {
pten::Scale<int16_t>(
dev_ctx, x, pten::Scalar(scale), bias, bias_after_scale, dense_out);
break;
}
case pten::DataType::INT8: {
pten::Scale<int8_t>(
dev_ctx, x, pten::Scalar(scale), bias, bias_after_scale, dense_out);
break;
}
case pten::DataType::UINT8: {
pten::Scale<uint8_t>(
dev_ctx, x, pten::Scalar(scale), bias, bias_after_scale, dense_out);
break;
}
default: {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Detected unsupported data type."
"Only Float64, Float32, Float16, Int64, Int32, Int16, Int8, UInt8 "
"are "
"supported for now."));
break;
}
}
}
#endif
Tensor scale_switch_case(const Tensor& x,
const Scalar& scale,
float bias,
bool bias_after_scale) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"scale", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "scale API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "scale API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
auto out_meta = pten::UnchangedInferMeta(dense_x->meta());
const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>(
pten::TransToFluidPlace(kernel_backend));
auto dense_out = std::make_shared<pten::DenseTensor>(allocator, out_meta);
Tensor out;
out.set_impl(dense_out);
switch (kernel_backend) {
case Backend::CPU:
ScaleCPU(kernel_data_type,
static_cast<const pten::CPUContext&>(*dev_ctx),
*dense_x,
scale,
bias,
bias_after_scale,
dense_out.get());
break;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
case Backend::CUDA:
ScaleCUDA(kernel_data_type,
static_cast<const pten::CUDAContext&>(*dev_ctx),
*dense_x,
scale,
bias,
bias_after_scale,
dense_out.get());
break;
#endif
default:
PADDLE_THROW(paddle::platform::errors::Fatal(
"Detected unsupported backend."
"Only CPU and CUDA Backend are supported for now."
"Please double check if your backend falls into the above two "
"categories."));
}
return out;
}
} // namespace experimental
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/api/include/api.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/tests/api/scale_api.h"
#include "paddle/pten/tests/core/timer.h"
namespace paddle {
namespace tests {
TEST(API, scale) {
auto x = experimental::full(
{3, 4}, 1.0, experimental::DataType::FLOAT32, experimental::Backend::CPU);
const size_t cycles = 300;
pten::tests::Timer timer;
double t1{}, t2{}, t3{};
for (size_t i = 0; i < cycles; ++i) {
timer.tic();
for (size_t i = 0; i < cycles; ++i) {
auto out = experimental::scale_kernel_context(x, 2.0, 1.0, true);
}
t1 += timer.toc();
timer.tic();
for (size_t i = 0; i < cycles; ++i) {
auto out = experimental::scale(x, 2.0, 1.0, true);
}
t2 += timer.toc();
timer.tic();
for (size_t i = 0; i < cycles; ++i) {
auto out = experimental::scale_switch_case(x, 2.0, 1.0, true);
}
t3 += timer.toc();
}
LOG(INFO) << "The cost of kernel_context is " << t1 << "ms.";
LOG(INFO) << "The cost of variadic_args_kernel_fn is " << t2 << "ms.";
LOG(INFO) << "The cost of switch_case is " << t3 << "ms.";
}
} // namespace tests
} // namespace paddle
...@@ -65,7 +65,7 @@ ...@@ -65,7 +65,7 @@
param : [x, dtype, layout] param : [x, dtype, layout]
kernel : kernel :
func : full_like func : full_like
param : [x, value] param : [value]
data_type : dtype > x data_type : dtype > x
backend : place > x backend : place > x
layout : layout > x layout : layout > x
......
...@@ -263,60 +263,57 @@ PADDLE_API {self.output} {self.api}({self.args['args_declare']}); ...@@ -263,60 +263,57 @@ PADDLE_API {self.output} {self.api}({self.args['args_declare']});
auto out_meta = pten::{infer_meta['func']}({param_code}); auto out_meta = pten::{infer_meta['func']}({param_code});
""" """
def gene_kernel_context(self, input_names, attrs, infer_meta, kernel_param): def get_kernel_args(self, input_names, attrs, kernel_param):
input_tensor_code = ""
for input_name in input_names:
# set input code
input_tensor_code = input_tensor_code + f"""
auto {self.prefix_tensor_name}{input_name} = std::dynamic_pointer_cast<pten::DenseTensor>({input_name}.impl());"""
attr_names = attrs['names'] attr_names = attrs['names']
if kernel_param is None: if kernel_param is None:
kernel_param = input_names + attr_names kernel_param = input_names + attr_names
input_code_str = "" kernel_args = "*dev_ctx, "
attr_code_str = ""
for param in kernel_param: for param in kernel_param:
if param in input_names: if param in input_names:
# set input for kernel_context kernel_args = kernel_args + "*" + self.prefix_tensor_name + param + ", "
input_code_str = input_code_str + f"""
auto {self.prefix_tensor_name}{param} = std::dynamic_pointer_cast<pten::DenseTensor>({param}.impl());
kernel_context.EmplaceBackInput({self.prefix_tensor_name}{param});"""
elif param in attr_names: elif param in attr_names:
# set attr for kernel_context # set attr for kernel_context
if 'ScalarArray' in attrs['attr_info'][param][0]: if 'ScalarArray' in attrs['attr_info'][param][0]:
param = 'pten::ScalarArray(' + param + ')' param = 'pten::ScalarArray(' + param + ')'
elif 'Scalar' in attrs['attr_info'][param][0]: elif 'Scalar' in attrs['attr_info'][param][0]:
param = 'pten::Scalar(' + param + ')' param = 'pten::Scalar(' + param + ')'
attr_code_str = attr_code_str + f""" kernel_args = kernel_args + param + ", "
kernel_context.EmplaceBackAttr({param});"""
elif isinstance(param, bool): elif isinstance(param, bool):
attr_code_str = attr_code_str + f""" kernel_args = kernel_args + str(param).lower() + ", "
kernel_context.EmplaceBackAttr({str(param).lower()});"""
else: else:
attr_code_str = attr_code_str + f""" kernel_args = kernel_args + str(param) + ", "
kernel_context.EmplaceBackAttr({param});""" return input_tensor_code, kernel_args[:-2]
def gene_api_code(self):
if self.is_base_api:
input_tensors, kernel_args = self.get_kernel_args(
self.args['inputs']['names'], self.args['attrs'],
self.kernel['param'])
return f"""
PADDLE_API {self.output} {self.api}({self.args["args_define"]}) {{
{self.gene_kernel_select(self.args['inputs']['names'], self.args['attrs'], self.kernel)}
return f"""
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto kernel_context = pten::KernelContext(dev_ctx); {input_tensors}
{input_code_str} {self.gene_infer_meta(self.args['inputs']['names'], self.args['attrs']['names'], self.infer_meta)}
{attr_code_str}
{self.gene_infer_meta(input_names, attr_names, infer_meta)}
const auto allocator = const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>( std::make_shared<paddle::experimental::DefaultAllocator>(
pten::TransToFluidPlace(kernel_backend)); pten::TransToFluidPlace(kernel_backend));
auto dense_out = std::make_shared<pten::DenseTensor>(allocator, out_meta); auto dense_out = std::make_shared<pten::DenseTensor>(allocator, out_meta);
kernel_context.EmplaceBackOutput(dense_out);
Tensor out; Tensor out;
out.set_impl(dense_out);""" out.set_impl(dense_out);
def gene_api_code(self): auto* kernel_fn = kernel.GetVariadicKernelFn<pten::{self.api}_kernel>();
if self.is_base_api: (*kernel_fn)({kernel_args}, dense_out.get());
return f"""
PADDLE_API {self.output} {self.api}({self.args["args_define"]}) {{
{self.gene_kernel_select(self.args['inputs']['names'], self.args['attrs'], self.kernel)}
{self.gene_kernel_context(self.args['inputs']['names'], self.args['attrs'], self.infer_meta, self.kernel['param'])}
kernel(&kernel_context);
return out; return out;
}} }}
""" """
...@@ -344,6 +341,7 @@ def source_include(header_file_path): ...@@ -344,6 +341,7 @@ def source_include(header_file_path):
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/pten/api/include/kernel_signature.h"
#include "paddle/pten/api/lib/api_registry.h" #include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/kernel_declare.h" #include "paddle/pten/api/lib/kernel_declare.h"
#include "paddle/pten/api/lib/kernel_dispatch.h" #include "paddle/pten/api/lib/kernel_dispatch.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册