未验证 提交 edc3ba13 编写于 作者: A Aganlengzi 提交者: GitHub

[custom kernel]Delete useless and upgrade (#39791)

* [custom kernel]Delete useless

* change RegType enum names

* mod notes

* merge

* update
上级 a167a143
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/api/ext/op_kernel_info.h"
#include "paddle/phi/core/kernel_factory.h"
namespace paddle {
namespace framework {
class OpKernelInfoHelper {
public:
static const std::string& GetOpName(const paddle::OpKernelInfo& info) {
return info.op_name_;
}
static const phi::Backend& GetBackend(const paddle::OpKernelInfo& info) {
return info.backend_;
}
static const phi::DataLayout& GetDataLayout(
const paddle::OpKernelInfo& info) {
return info.layout_;
}
static const phi::DataType& GetDataType(const paddle::OpKernelInfo& info) {
return info.dtype_;
}
static phi::KernelKey GetKernelKey(const paddle::OpKernelInfo& info) {
return phi::KernelKey(info.backend_, info.layout_, info.dtype_);
}
static const CustomKernelFunc& GetKernelFn(const paddle::OpKernelInfo& info) {
return info.kernel_fn_;
}
static void* GetVariadicKernelFn(const paddle::OpKernelInfo& info) {
return info.variadic_kernel_fn_;
}
static const paddle::SmallVector<TensorArgDef>& GetInputDefs(
const paddle::OpKernelInfo& info) {
return info.input_defs_;
}
static const paddle::SmallVector<TensorArgDef>& GetOutputDefs(
const paddle::OpKernelInfo& info) {
return info.output_defs_;
}
static const paddle::SmallVector<AttributeArgDef>& GetAttributeDefs(
const paddle::OpKernelInfo& info) {
return info.attribute_defs_;
}
};
} // namespace framework
} // namespace paddle
此差异已折叠。
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/api/ext/op_kernel_info.h"
#include "paddle/fluid/framework/custom_kernel.h"
namespace paddle {
////////////////////// Op Kernel Info //////////////////////
OpKernelInfo& OpKernelInfo::SetKernelFn(CustomKernelFunc&& func) {
kernel_fn_ = std::forward<CustomKernelFunc>(func);
return *this;
}
OpKernelInfo& OpKernelInfo::SetVariadicKernelFn(void* func) {
variadic_kernel_fn_ = func;
return *this;
}
//////////////// Op Kernel Info Map /////////////////
std::vector<OpKernelInfo>& OpKernelInfoMap::operator[](
const std::string& name) {
return map_[name];
}
const std::unordered_map<std::string, std::vector<OpKernelInfo>>&
OpKernelInfoMap::GetMap() const {
return map_;
}
//////////////// Op Kernel Info Builder /////////////////
OpKernelInfoBuilder::OpKernelInfoBuilder(std::string&& op_name,
phi::Backend backend,
phi::DataLayout data_layout,
phi::DataType data_type) {
// 1. member assign
op_name_ = std::forward<std::string>(op_name);
backend_ = backend;
layout_ = data_layout;
dtype_ = data_type;
// 2. info parse
auto& info_vector = OpKernelInfoMap::Instance()[op_name_];
auto op_kernel_info = OpKernelInfo(op_name_, backend_, layout_, dtype_);
info_vector.emplace_back(std::move(op_kernel_info));
// 3. get current info ptr
info_ptr_ = &(info_vector.back());
}
OpKernelInfoBuilder& OpKernelInfoBuilder::SetKernelFn(CustomKernelFunc func) {
info_ptr_->SetKernelFn(std::forward<CustomKernelFunc>(func));
return *this;
}
OpKernelInfoBuilder& OpKernelInfoBuilder::SetVariadicKernelFn(void* func) {
info_ptr_->SetVariadicKernelFn(func);
return *this;
}
OpKernelInfoBuilder& OpKernelInfoBuilder::ArgsParse(
CustomKernelArgsParseFn func) {
func(this->info_ptr_);
return *this;
}
OpKernelInfoBuilder& OpKernelInfoBuilder::ArgsDef(CustomKernelArgsDefFn func) {
func(this->info_ptr_);
return *this;
}
/////////////////////// Op register API /////////////////////////
// For inference: compile directly with framework
// Call after PD_REGISTER_BUILTIN_KERNEL(...)
void RegisterAllCustomKernel() {
auto& op_kernel_info_map = OpKernelInfoMap::Instance();
framework::RegisterKernelWithMetaInfoMap(op_kernel_info_map);
}
} // namespace paddle
#ifdef __cplusplus
extern "C" {
#endif
// C-API to get global OpKernelInfoMap.
paddle::OpKernelInfoMap& PD_GetOpKernelInfoMap() {
return paddle::OpKernelInfoMap::Instance();
}
#ifdef __cplusplus
} // end extern "C"
#endif
...@@ -129,10 +129,10 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -129,10 +129,10 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
} }
}; };
// NOTE: used for making a difference between kernels compiled with phi or not. // NOTE: used for making a difference between inner or outer registration.
enum class RegType : uint8_t { enum class RegType : uint8_t {
BUILTIN = 0, // compiled with phi INNER = 0,
PLUGIN, // separate compiled and registered OUTER,
}; };
// TODO(chenweihang): Polish the kernel selection logic, support the selection // TODO(chenweihang): Polish the kernel selection logic, support the selection
...@@ -205,7 +205,7 @@ struct KernelRegistrar { ...@@ -205,7 +205,7 @@ struct KernelRegistrar {
Kernel kernel(kernel_fn, variadic_kernel_fn); Kernel kernel(kernel_fn, variadic_kernel_fn);
args_parse_fn(kernel_key, kernel.mutable_args_def()); args_parse_fn(kernel_key, kernel.mutable_args_def());
args_def_fn(kernel_key, &kernel); args_def_fn(kernel_key, &kernel);
if (reg_type == RegType::BUILTIN) { if (reg_type == RegType::INNER) {
KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel; KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel;
} else { } else {
CustomKernelMap::Instance().Kernels()[kernel_name][kernel_key] = kernel; CustomKernelMap::Instance().Kernels()[kernel_name][kernel_key] = kernel;
...@@ -244,7 +244,7 @@ struct KernelRegistrar { ...@@ -244,7 +244,7 @@ struct KernelRegistrar {
* Note: `2TA` means `2 template argument` * Note: `2TA` means `2 template argument`
*/ */
#define PD_REGISTER_KERNEL(kernel_name, backend, layout, meta_kernel_fn, ...) \ #define PD_REGISTER_KERNEL(kernel_name, backend, layout, meta_kernel_fn, ...) \
_PD_REGISTER_KERNEL(::phi::RegType::BUILTIN, \ _PD_REGISTER_KERNEL(::phi::RegType::INNER, \
kernel_name, \ kernel_name, \
backend, \ backend, \
::phi::backend##Context, \ ::phi::backend##Context, \
...@@ -918,7 +918,7 @@ struct KernelRegistrar { ...@@ -918,7 +918,7 @@ struct KernelRegistrar {
#define PD_REGISTER_GENERAL_KERNEL( \ #define PD_REGISTER_GENERAL_KERNEL( \
kernel_name, backend, layout, kernel_fn, dtype) \ kernel_name, backend, layout, kernel_fn, dtype) \
_PD_REGISTER_GENERAL_KERNEL( \ _PD_REGISTER_GENERAL_KERNEL( \
::phi::RegType::BUILTIN, kernel_name, backend, layout, kernel_fn, dtype) ::phi::RegType::INNER, kernel_name, backend, layout, kernel_fn, dtype)
#define _PD_REGISTER_GENERAL_KERNEL( \ #define _PD_REGISTER_GENERAL_KERNEL( \
reg_type, kernel_name, backend, layout, kernel_fn, dtype) \ reg_type, kernel_name, backend, layout, kernel_fn, dtype) \
...@@ -992,7 +992,7 @@ struct KernelRegistrar { ...@@ -992,7 +992,7 @@ struct KernelRegistrar {
*/ */
#define PD_REGISTER_BUILTIN_KERNEL( \ #define PD_REGISTER_BUILTIN_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, ...) \ kernel_name, backend, layout, meta_kernel_fn, ...) \
_PD_REGISTER_KERNEL(::phi::RegType::PLUGIN, \ _PD_REGISTER_KERNEL(::phi::RegType::OUTER, \
kernel_name, \ kernel_name, \
backend, \ backend, \
::phi::backend##Context, \ ::phi::backend##Context, \
...@@ -1007,7 +1007,7 @@ struct KernelRegistrar { ...@@ -1007,7 +1007,7 @@ struct KernelRegistrar {
*/ */
#define PD_REGISTER_PLUGIN_KERNEL( \ #define PD_REGISTER_PLUGIN_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, ...) \ kernel_name, backend, layout, meta_kernel_fn, ...) \
_PD_REGISTER_KERNEL(::phi::RegType::PLUGIN, \ _PD_REGISTER_KERNEL(::phi::RegType::OUTER, \
kernel_name, \ kernel_name, \
backend, \ backend, \
::phi::CustomContext, \ ::phi::CustomContext, \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册