未验证 提交 e92e3aab 编写于 作者: Y YuanRisheng 提交者: GitHub

[PHI]Unify Fluid and PHI kernel (#49328)

* unify_kernel

* fix compile bugs

* modify macro name

* perfect code according comment

* fix compile bugs

* fix compile bugs

* fix ci bugs

* fix ci bug

* fix ci bugs

* fix ci bugs

* modify code according comment

* rm conv_fusion_op
上级 766a4ca9
...@@ -28,11 +28,7 @@ endfunction() ...@@ -28,11 +28,7 @@ endfunction()
function(find_phi_register FILENAME ADD_PATH PATTERN) function(find_phi_register FILENAME ADD_PATH PATTERN)
# set op_name to OUTPUT # set op_name to OUTPUT
set(options "")
set(oneValueArgs "")
set(multiValueArgs "")
file(READ ${FILENAME} CONTENT) file(READ ${FILENAME} CONTENT)
string( string(
REGEX REGEX
MATCH MATCH
...@@ -402,6 +398,7 @@ function(op_library TARGET) ...@@ -402,6 +398,7 @@ function(op_library TARGET)
set(op_name "") set(op_name "")
# Add PHI Kernel Registry Message # Add PHI Kernel Registry Message
find_phi_register(${cc_src} ${pybind_file} "PD_REGISTER_KERNEL") find_phi_register(${cc_src} ${pybind_file} "PD_REGISTER_KERNEL")
find_phi_register(${cc_src} ${pybind_file} "PD_REGISTER_STRUCT_KERNEL")
find_phi_register(${cc_src} ${pybind_file} "PD_REGISTER_GENERAL_KERNEL") find_phi_register(${cc_src} ${pybind_file} "PD_REGISTER_GENERAL_KERNEL")
find_register(${cc_src} "REGISTER_OPERATOR" op_name) find_register(${cc_src} "REGISTER_OPERATOR" op_name)
if(NOT ${op_name} EQUAL "") if(NOT ${op_name} EQUAL "")
...@@ -453,6 +450,7 @@ function(op_library TARGET) ...@@ -453,6 +450,7 @@ function(op_library TARGET)
set(op_name "") set(op_name "")
# Add PHI Kernel Registry Message # Add PHI Kernel Registry Message
find_phi_register(${cu_src} ${pybind_file} "PD_REGISTER_KERNEL") find_phi_register(${cu_src} ${pybind_file} "PD_REGISTER_KERNEL")
find_phi_register(${cu_src} ${pybind_file} "PD_REGISTER_STRUCT_KERNEL")
find_phi_register(${cu_src} ${pybind_file} "PD_REGISTER_GENERAL_KERNEL") find_phi_register(${cu_src} ${pybind_file} "PD_REGISTER_GENERAL_KERNEL")
find_register(${cu_src} "REGISTER_OP_CUDA_KERNEL" op_name) find_register(${cu_src} "REGISTER_OP_CUDA_KERNEL" op_name)
if(NOT ${op_name} EQUAL "") if(NOT ${op_name} EQUAL "")
......
...@@ -827,7 +827,9 @@ bool BuildOpFuncList(const platform::Place& place, ...@@ -827,7 +827,9 @@ bool BuildOpFuncList(const platform::Place& place,
} }
// step 5. run kernel // step 5. run kernel
if (run_phi_kernel) { if (run_phi_kernel &&
op_func_node.phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::FUNCTION) {
phi::KernelContext phi_kernel_context; phi::KernelContext phi_kernel_context;
op_with_kernel->BuildPhiKernelContext( op_with_kernel->BuildPhiKernelContext(
runtime_context, dev_ctx, &phi_kernel_context); runtime_context, dev_ctx, &phi_kernel_context);
...@@ -838,6 +840,12 @@ bool BuildOpFuncList(const platform::Place& place, ...@@ -838,6 +840,12 @@ bool BuildOpFuncList(const platform::Place& place,
op_with_kernel->PhiKernelSignature(), op_with_kernel->PhiKernelSignature(),
&phi_kernel_context); &phi_kernel_context);
} }
} else if (run_phi_kernel &&
op_func_node.phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::STRUCTURE) {
ExecutionContext execution_context(
*op_with_kernel, *runtime_scope, *dev_ctx, runtime_context);
(*op_func_node.phi_kernel_)(&execution_context);
} else { } else {
// the place of exec_ctx maybe has changed. // the place of exec_ctx maybe has changed.
if (!skip_run) { if (!skip_run) {
......
...@@ -34,6 +34,7 @@ limitations under the License. */ ...@@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
#include "paddle/phi/core/kernel_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -484,5 +485,36 @@ struct OpKernelRegistrarFunctorEx<PlaceType, ...@@ -484,5 +485,36 @@ struct OpKernelRegistrarFunctorEx<PlaceType,
USE_OP_KERNEL(op_type) USE_OP_KERNEL(op_type)
// clang-format on // clang-format on
template <typename StructureKernel>
struct StructKernelImpl {
static void Compute(phi::KernelContext* ctx) {
auto exe_ctx = static_cast<paddle::framework::ExecutionContext*>(ctx);
StructureKernel().Compute(*exe_ctx);
}
};
#define PHI_STRUCTURE_KERNEL(...) \
::paddle::framework::StructKernelImpl<__VA_ARGS__>::Compute
#define PHI_STRUCTURE_VARIADIC_KERNEL(...) nullptr
#define STRUCTURE_ARG_PARSE_FUNCTOR(...) nullptr
#define STRUCTURE_KERNEL_INSTANTIATION( \
meta_kernel_structure, cpp_dtype, context) \
template class meta_kernel_structure<cpp_dtype, context>;
#define PD_REGISTER_STRUCT_KERNEL( \
kernel_name, backend, layout, meta_kernel_structure, ...) \
_PD_REGISTER_KERNEL(::phi::RegType::INNER, \
kernel_name, \
backend, \
::phi::backend##Context, \
layout, \
meta_kernel_structure, \
STRUCTURE_KERNEL_INSTANTIATION, \
STRUCTURE_ARG_PARSE_FUNCTOR, \
PHI_STRUCTURE_KERNEL, \
PHI_STRUCTURE_VARIADIC_KERNEL, \
__VA_ARGS__)
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -1689,15 +1689,18 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1689,15 +1689,18 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
std::string phi_kernel_name; std::string phi_kernel_name;
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(type_)) { if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(type_)) {
if (kernel_signature_ == nullptr || phi_kernel_ == nullptr) { if (kernel_signature_ == nullptr || phi_kernel_ == nullptr) {
if (phi::KernelFactory::Instance().HasStructuredKernel(type_)) {
kernel_signature_.reset(new phi::KernelSignature(type_.c_str()));
} else {
kernel_signature_.reset(new phi::KernelSignature( kernel_signature_.reset(new phi::KernelSignature(
std::move(GetExpectedPhiKernelArgs(exe_ctx)))); std::move(GetExpectedPhiKernelArgs(exe_ctx))));
VLOG(6) << *kernel_signature_.get(); }
VLOG(6) << *kernel_signature_.get();
phi_kernel_name = kernel_signature_->name;
kernel_type_.reset( kernel_type_.reset(
new OpKernelType(std::move(InnerGetExpectedKernelType(exe_ctx)))); new OpKernelType(std::move(InnerGetExpectedKernelType(exe_ctx))));
dev_ctx = pool.Get(kernel_type_->place_); dev_ctx = pool.Get(kernel_type_->place_);
phi_kernel_name = kernel_signature_->name;
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP], // NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the // But the default library_type is Plain, so we need to modify the
// library_type here, otherwise it can't work. // library_type here, otherwise it can't work.
...@@ -1753,7 +1756,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1753,7 +1756,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
} }
} else { } else {
phi_kernel_name = kernel_signature_->name; phi_kernel_name = kernel_signature_->name;
// NOTE(jiahongyu): The registered MKLDNN kernel have library_type = // NOTE(jiahongyu): The registered MKLDNN kernel have library_type =
// LibraryType::kMKLDNN and data_layout_ = DataLayout::ONEDNN. But the default // LibraryType::kMKLDNN and data_layout_ = DataLayout::ONEDNN. But the default
// values are kPlain, so we need to modify the library_type and data_layout_ // values are kPlain, so we need to modify the library_type and data_layout_
...@@ -1939,7 +1941,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1939,7 +1941,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::TracerEventType::OperatorInner, platform::TracerEventType::OperatorInner,
1, 1,
platform::EventRole::kInnerOp); platform::EventRole::kInnerOp);
if (run_phi_kernel_) { if (run_phi_kernel_ && phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::FUNCTION) {
phi::KernelContext phi_kernel_context; phi::KernelContext phi_kernel_context;
if (enable_cache_runtime_context_ && !need_prepare_phi_data_ && if (enable_cache_runtime_context_ && !need_prepare_phi_data_ &&
!need_prepare_data_) { !need_prepare_data_) {
...@@ -1977,6 +1980,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1977,6 +1980,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
BuildPhiKernelContext(*runtime_ctx, dev_ctx, &phi_kernel_context); BuildPhiKernelContext(*runtime_ctx, dev_ctx, &phi_kernel_context);
(*phi_kernel_)(&phi_kernel_context); (*phi_kernel_)(&phi_kernel_context);
} }
} else if (run_phi_kernel_ && phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::STRUCTURE) {
ExecutionContext execution_context(
*this, exec_scope, *dev_ctx, *runtime_ctx);
(*phi_kernel_)(&execution_context);
} else { } else {
(*kernel_func_)( (*kernel_func_)(
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx)); ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
...@@ -2147,14 +2155,18 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( ...@@ -2147,14 +2155,18 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
phi::KernelKey OperatorWithKernel::ChoosePhiKernel( phi::KernelKey OperatorWithKernel::ChoosePhiKernel(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
std::string phi_kernel_name;
if (phi::KernelFactory::Instance().HasStructuredKernel(type_)) {
kernel_signature_.reset(new phi::KernelSignature(type_.c_str()));
} else {
kernel_signature_.reset( kernel_signature_.reset(
new phi::KernelSignature(std::move(GetExpectedPhiKernelArgs(ctx)))); new phi::KernelSignature(std::move(GetExpectedPhiKernelArgs(ctx))));
}
VLOG(6) << *kernel_signature_.get(); VLOG(6) << *kernel_signature_.get();
phi_kernel_name = kernel_signature_->name;
kernel_type_.reset( kernel_type_.reset(
new OpKernelType(std::move(InnerGetExpectedKernelType(ctx)))); new OpKernelType(std::move(InnerGetExpectedKernelType(ctx))));
auto phi_kernel_name = kernel_signature_->name;
auto phi_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get()); auto phi_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get());
phi_kernel_.reset(new phi::Kernel(phi::KernelFactory::Instance().SelectKernel( phi_kernel_.reset(new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
phi_kernel_name, phi_kernel_key))); phi_kernel_name, phi_kernel_key)));
...@@ -2616,7 +2628,8 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -2616,7 +2628,8 @@ Scope* OperatorWithKernel::PrepareData(
} }
}; };
if (run_phi_kernel_) { if (run_phi_kernel_ && phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::FUNCTION) {
const auto& input_names = kernel_signature_->input_names; const auto& input_names = kernel_signature_->input_names;
const auto& input_defs = phi_kernel_->args_def().input_defs(); const auto& input_defs = phi_kernel_->args_def().input_defs();
PADDLE_ENFORCE_EQ(input_names.size(), PADDLE_ENFORCE_EQ(input_names.size(),
......
...@@ -41,6 +41,7 @@ limitations under the License. */ ...@@ -41,6 +41,7 @@ limitations under the License. */
#include "paddle/phi/core/compat/arg_map_context.h" #include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_factory.h"
#include "paddle/utils/flat_hash_map.h" #include "paddle/utils/flat_hash_map.h"
...@@ -290,7 +291,7 @@ class OperatorBase { ...@@ -290,7 +291,7 @@ class OperatorBase {
const platform::Place& place) const = 0; const platform::Place& place) const = 0;
}; };
class ExecutionContext { class ExecutionContext : public phi::KernelContext {
public: public:
ExecutionContext(const OperatorBase& op, ExecutionContext(const OperatorBase& op,
const Scope& scope, const Scope& scope,
......
...@@ -272,6 +272,10 @@ PreparedOp PrepareImpl( ...@@ -272,6 +272,10 @@ PreparedOp PrepareImpl(
has_phi_kernel = true; has_phi_kernel = true;
kernel_signature = (*arg_map_fn)( kernel_signature = (*arg_map_fn)(
framework::ExecutionArgumentMappingContext(dygraph_exe_ctx)); framework::ExecutionArgumentMappingContext(dygraph_exe_ctx));
} else {
if (phi::KernelFactory::Instance().HasStructuredKernel(op.Type())) {
has_phi_kernel = true;
kernel_signature = phi::KernelSignature(op.Type().c_str());
} else { } else {
default_kernel_signature = default_kernel_signature =
default_phi_kernel_sig_map.GetNullable(op.Type()); default_phi_kernel_sig_map.GetNullable(op.Type());
...@@ -280,10 +284,12 @@ PreparedOp PrepareImpl( ...@@ -280,10 +284,12 @@ PreparedOp PrepareImpl(
kernel_signature = *default_kernel_signature; kernel_signature = *default_kernel_signature;
} }
} }
}
if (has_phi_kernel) { if (has_phi_kernel) {
VLOG(6) << kernel_signature; VLOG(6) << kernel_signature;
phi_kernel_name = kernel_signature.name; phi_kernel_name = kernel_signature.name;
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP], // NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the // But the default library_type is Plain, so we need to modify the
// library_type here, otherwise it can't work. // library_type here, otherwise it can't work.
...@@ -648,6 +654,7 @@ static void PreparedOpRunPtImpl( ...@@ -648,6 +654,7 @@ static void PreparedOpRunPtImpl(
const phi::KernelSignature* default_kernel_signature, const phi::KernelSignature* default_kernel_signature,
const phi::KernelSignature& kernel_signature, const phi::KernelSignature& kernel_signature,
const phi::Kernel& phi_kernel, const phi::Kernel& phi_kernel,
const framework::RuntimeContext& ctx,
platform::DeviceContext* dev_ctx, platform::DeviceContext* dev_ctx,
const NameVarMap<VarType>& ins, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const NameVarMap<VarType>& outs,
...@@ -678,8 +685,9 @@ static void PreparedOpRunPtImpl( ...@@ -678,8 +685,9 @@ static void PreparedOpRunPtImpl(
1, 1,
platform::EventRole::kInnerOp); platform::EventRole::kInnerOp);
if (phi_kernel.GetKernelRegisteredType() ==
phi::KernelRegisteredType::FUNCTION) {
PreparePhiData<VarType>(phi_kernel, kernel_signature, ins); PreparePhiData<VarType>(phi_kernel, kernel_signature, ins);
phi::KernelContext phi_kernel_context; phi::KernelContext phi_kernel_context;
BuildDygraphPhiKernelContext<VarType>(kernel_signature, BuildDygraphPhiKernelContext<VarType>(kernel_signature,
phi_kernel, phi_kernel,
...@@ -691,6 +699,11 @@ static void PreparedOpRunPtImpl( ...@@ -691,6 +699,11 @@ static void PreparedOpRunPtImpl(
&phi_kernel_context); &phi_kernel_context);
phi_kernel(&phi_kernel_context); phi_kernel(&phi_kernel_context);
} else {
DygraphExecutionContext<VarType> exe_ctx(
op, empty_scope, *dev_ctx, ctx, ins, outs, attrs, default_attrs);
phi_kernel(&exe_ctx);
}
} }
if (FLAGS_check_nan_inf) { if (FLAGS_check_nan_inf) {
...@@ -722,6 +735,7 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins, ...@@ -722,6 +735,7 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins,
default_kernel_signature_, default_kernel_signature_,
kernel_signature_, kernel_signature_,
phi_kernel_, phi_kernel_,
ctx_,
dev_ctx_, dev_ctx_,
ins, ins,
outs, outs,
...@@ -753,6 +767,7 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins, ...@@ -753,6 +767,7 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
default_kernel_signature_, default_kernel_signature_,
kernel_signature_, kernel_signature_,
phi_kernel_, phi_kernel_,
ctx_,
dev_ctx_, dev_ctx_,
ins, ins,
outs, outs,
...@@ -784,6 +799,7 @@ void PreparedOp::Run(const NameVarMap<egr::EagerVariable>& ins, ...@@ -784,6 +799,7 @@ void PreparedOp::Run(const NameVarMap<egr::EagerVariable>& ins,
default_kernel_signature_, default_kernel_signature_,
kernel_signature_, kernel_signature_,
phi_kernel_, phi_kernel_,
ctx_,
dev_ctx_, dev_ctx_,
ins, ins,
outs, outs,
......
...@@ -530,8 +530,12 @@ phi::KernelSignature Tracer::GetExpectedKernelSignature( ...@@ -530,8 +530,12 @@ phi::KernelSignature Tracer::GetExpectedKernelSignature(
"This op type:`%s` is not a OperatorWithKernel, only " "This op type:`%s` is not a OperatorWithKernel, only "
"OperatorWithKernel can get KernelSignature", "OperatorWithKernel can get KernelSignature",
type)); type));
return phi::KernelSignature( if (phi::KernelFactory::Instance().HasStructuredKernel(type)) {
std::move(opbase_with_kernel->GetExpectedPhiKernelArgs(dygraph_exe_ctx))); return phi::KernelSignature(op->Type().c_str());
} else {
return phi::KernelSignature(std::move(
opbase_with_kernel->GetExpectedPhiKernelArgs(dygraph_exe_ctx)));
}
} }
} // namespace imperative } // namespace imperative
......
...@@ -34,7 +34,7 @@ limitations under the License. */ ...@@ -34,7 +34,7 @@ limitations under the License. */
phi::RegType::INNER, \ phi::RegType::INNER, \
#kernel_name, \ #kernel_name, \
dev_type, \ dev_type, \
DATALAYOUT(layout), \ DATA_LAYOUT(layout), \
::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \ ::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
[](const phi::KernelKey& kernel_key, phi::Kernel* kernel) {}, \ [](const phi::KernelKey& kernel_key, phi::Kernel* kernel) {}, \
PHI_KERNEL(kernel_fn), \ PHI_KERNEL(kernel_fn), \
......
...@@ -240,12 +240,15 @@ REGISTER_OPERATOR(rank_loss, ...@@ -240,12 +240,15 @@ REGISTER_OPERATOR(rank_loss,
ops::RankLossGradMaker<paddle::framework::OpDesc>, ops::RankLossGradMaker<paddle::framework::OpDesc>,
ops::RankLossGradMaker<paddle::imperative::OpBase>); ops::RankLossGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(rank_loss_grad, ops::RankLossGradOp); REGISTER_OPERATOR(rank_loss_grad, ops::RankLossGradOp);
REGISTER_OP_CPU_KERNEL(rank_loss, ops::RankLossKernel<phi::CPUContext, float>);
REGISTER_OP_CPU_KERNEL(rank_loss_grad, PD_REGISTER_STRUCT_KERNEL(
ops::RankLossGradKernel<phi::CPUContext, float>); rank_loss, CPU, ALL_LAYOUT, ops::RankLossKernel, float) {}
PD_REGISTER_STRUCT_KERNEL(
REGISTER_OP_CUDA_KERNEL( rank_loss_grad, CPU, ALL_LAYOUT, ops::RankLossGradKernel, float) {}
rank_loss, paddle::operators::RankLossKernel<phi::GPUContext, float>);
REGISTER_OP_CUDA_KERNEL( #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
rank_loss_grad, PD_REGISTER_STRUCT_KERNEL(
paddle::operators::RankLossGradKernel<phi::GPUContext, float>); rank_loss, GPU, ALL_LAYOUT, ops::RankLossKernel, float) {}
PD_REGISTER_STRUCT_KERNEL(
rank_loss_grad, GPU, ALL_LAYOUT, ops::RankLossGradKernel, float) {}
#endif
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class RankLossKernel : public framework::OpKernel<T> { class RankLossKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
...@@ -42,7 +42,7 @@ class RankLossKernel : public framework::OpKernel<T> { ...@@ -42,7 +42,7 @@ class RankLossKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class RankLossGradKernel : public framework::OpKernel<T> { class RankLossGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
......
...@@ -56,6 +56,8 @@ struct KernelSignature { ...@@ -56,6 +56,8 @@ struct KernelSignature {
attr_names(attrs), attr_names(attrs),
output_names(outputs) {} output_names(outputs) {}
explicit KernelSignature(const char* kernel_name) : name(kernel_name) {}
// TODO(chenweihang): add assign constructor to solve windows compile // TODO(chenweihang): add assign constructor to solve windows compile
// problem, remove it later // problem, remove it later
KernelSignature(const KernelSignature& other) KernelSignature(const KernelSignature& other)
......
...@@ -62,6 +62,21 @@ bool KernelFactory::HasCompatiblePhiKernel(const std::string& op_type) const { ...@@ -62,6 +62,21 @@ bool KernelFactory::HasCompatiblePhiKernel(const std::string& op_type) const {
return false; return false;
} }
bool KernelFactory::HasStructuredKernel(const std::string& op_type) const {
auto phi_kernel_name = phi::OpUtilsMap::Instance().GetBaseKernelName(op_type);
auto kernel_iter = kernels_.find(phi_kernel_name);
if (deprecated_op_names.find(op_type) == deprecated_op_names.end() &&
kernel_iter != kernels_.end()) {
return std::any_of(kernel_iter->second.begin(),
kernel_iter->second.end(),
[](phi::KernelKeyMap::const_reference kernel_pair) {
return kernel_pair.second.GetKernelRegisteredType() ==
KernelRegisteredType::STRUCTURE;
});
}
return false;
}
const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name, const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const { const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name); auto iter = kernels_.find(kernel_name);
......
...@@ -238,13 +238,21 @@ class KernelArgsDef { ...@@ -238,13 +238,21 @@ class KernelArgsDef {
{}}; {}};
}; };
enum class KernelRegisteredType { FUNCTION, STRUCTURE };
class Kernel { class Kernel {
public: public:
// for map element construct // for map element construct
Kernel() = default; Kernel() = default;
explicit Kernel(KernelFn fn, void* variadic_fn) explicit Kernel(KernelFn fn, void* variadic_fn)
: fn_(fn), variadic_fn_(variadic_fn) {} : fn_(fn), variadic_fn_(variadic_fn) {
if (variadic_fn == nullptr) {
kernel_registered_type_ = KernelRegisteredType::STRUCTURE;
} else {
kernel_registered_type_ = KernelRegisteredType::FUNCTION;
}
}
void operator()(KernelContext* ctx) const { fn_(ctx); } void operator()(KernelContext* ctx) const { fn_(ctx); }
...@@ -272,10 +280,15 @@ class Kernel { ...@@ -272,10 +280,15 @@ class Kernel {
bool IsValid() const { return fn_ != nullptr; } bool IsValid() const { return fn_ != nullptr; }
KernelRegisteredType GetKernelRegisteredType() const {
return kernel_registered_type_;
}
private: private:
KernelFn fn_{nullptr}; KernelFn fn_{nullptr};
void* variadic_fn_ = nullptr; void* variadic_fn_ = nullptr;
KernelArgsDef args_def_; KernelArgsDef args_def_;
KernelRegisteredType kernel_registered_type_ = KernelRegisteredType::FUNCTION;
}; };
using KernelKeyMap = paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>; using KernelKeyMap = paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>;
...@@ -304,6 +317,8 @@ class KernelFactory { ...@@ -304,6 +317,8 @@ class KernelFactory {
bool HasCompatiblePhiKernel(const std::string& op_type) const; bool HasCompatiblePhiKernel(const std::string& op_type) const;
bool HasStructuredKernel(const std::string& op_type) const;
KernelResult SelectKernelOrThrowError(const std::string& kernel_name, KernelResult SelectKernelOrThrowError(const std::string& kernel_name,
const KernelKey& kernel_key) const; const KernelKey& kernel_key) const;
......
此差异已折叠。
...@@ -35,4 +35,6 @@ KernelSignature SaveCombineOpArgumentMapping( ...@@ -35,4 +35,6 @@ KernelSignature SaveCombineOpArgumentMapping(
} // namespace phi } // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(save_combine, save_combine_tensor);
PD_REGISTER_ARG_MAPPING_FN(save_combine, phi::SaveCombineOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(save_combine, phi::SaveCombineOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册