未验证 提交 e92e3aab 编写于 作者: Y YuanRisheng 提交者: GitHub

[PHI]Unify Fluid and PHI kernel (#49328)

* unify_kernel

* fix compile bugs

* modify macro name

* perfect code according comment

* fix compile bugs

* fix compile bugs

* fix ci bugs

* fix ci bug

* fix ci bugs

* fix ci bugs

* modify code according comment

* rm conv_fusion_op
上级 766a4ca9
......@@ -28,11 +28,7 @@ endfunction()
function(find_phi_register FILENAME ADD_PATH PATTERN)
# set op_name to OUTPUT
set(options "")
set(oneValueArgs "")
set(multiValueArgs "")
file(READ ${FILENAME} CONTENT)
string(
REGEX
MATCH
......@@ -402,6 +398,7 @@ function(op_library TARGET)
set(op_name "")
# Add PHI Kernel Registry Message
find_phi_register(${cc_src} ${pybind_file} "PD_REGISTER_KERNEL")
find_phi_register(${cc_src} ${pybind_file} "PD_REGISTER_STRUCT_KERNEL")
find_phi_register(${cc_src} ${pybind_file} "PD_REGISTER_GENERAL_KERNEL")
find_register(${cc_src} "REGISTER_OPERATOR" op_name)
if(NOT ${op_name} EQUAL "")
......@@ -453,6 +450,7 @@ function(op_library TARGET)
set(op_name "")
# Add PHI Kernel Registry Message
find_phi_register(${cu_src} ${pybind_file} "PD_REGISTER_KERNEL")
find_phi_register(${cu_src} ${pybind_file} "PD_REGISTER_STRUCT_KERNEL")
find_phi_register(${cu_src} ${pybind_file} "PD_REGISTER_GENERAL_KERNEL")
find_register(${cu_src} "REGISTER_OP_CUDA_KERNEL" op_name)
if(NOT ${op_name} EQUAL "")
......
......@@ -827,7 +827,9 @@ bool BuildOpFuncList(const platform::Place& place,
}
// step 5. run kernel
if (run_phi_kernel) {
if (run_phi_kernel &&
op_func_node.phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::FUNCTION) {
phi::KernelContext phi_kernel_context;
op_with_kernel->BuildPhiKernelContext(
runtime_context, dev_ctx, &phi_kernel_context);
......@@ -838,6 +840,12 @@ bool BuildOpFuncList(const platform::Place& place,
op_with_kernel->PhiKernelSignature(),
&phi_kernel_context);
}
} else if (run_phi_kernel &&
op_func_node.phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::STRUCTURE) {
ExecutionContext execution_context(
*op_with_kernel, *runtime_scope, *dev_ctx, runtime_context);
(*op_func_node.phi_kernel_)(&execution_context);
} else {
// the place of exec_ctx maybe has changed.
if (!skip_run) {
......
......@@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/phi/core/kernel_registry.h"
namespace paddle {
namespace framework {
......@@ -484,5 +485,36 @@ struct OpKernelRegistrarFunctorEx<PlaceType,
USE_OP_KERNEL(op_type)
// clang-format on
template <typename StructureKernel>
struct StructKernelImpl {
static void Compute(phi::KernelContext* ctx) {
auto exe_ctx = static_cast<paddle::framework::ExecutionContext*>(ctx);
StructureKernel().Compute(*exe_ctx);
}
};
#define PHI_STRUCTURE_KERNEL(...) \
::paddle::framework::StructKernelImpl<__VA_ARGS__>::Compute
#define PHI_STRUCTURE_VARIADIC_KERNEL(...) nullptr
#define STRUCTURE_ARG_PARSE_FUNCTOR(...) nullptr
#define STRUCTURE_KERNEL_INSTANTIATION( \
meta_kernel_structure, cpp_dtype, context) \
template class meta_kernel_structure<cpp_dtype, context>;
#define PD_REGISTER_STRUCT_KERNEL( \
kernel_name, backend, layout, meta_kernel_structure, ...) \
_PD_REGISTER_KERNEL(::phi::RegType::INNER, \
kernel_name, \
backend, \
::phi::backend##Context, \
layout, \
meta_kernel_structure, \
STRUCTURE_KERNEL_INSTANTIATION, \
STRUCTURE_ARG_PARSE_FUNCTOR, \
PHI_STRUCTURE_KERNEL, \
PHI_STRUCTURE_VARIADIC_KERNEL, \
__VA_ARGS__)
} // namespace framework
} // namespace paddle
......@@ -1689,15 +1689,18 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
std::string phi_kernel_name;
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(type_)) {
if (kernel_signature_ == nullptr || phi_kernel_ == nullptr) {
kernel_signature_.reset(new phi::KernelSignature(
std::move(GetExpectedPhiKernelArgs(exe_ctx))));
VLOG(6) << *kernel_signature_.get();
if (phi::KernelFactory::Instance().HasStructuredKernel(type_)) {
kernel_signature_.reset(new phi::KernelSignature(type_.c_str()));
} else {
kernel_signature_.reset(new phi::KernelSignature(
std::move(GetExpectedPhiKernelArgs(exe_ctx))));
}
VLOG(6) << *kernel_signature_.get();
phi_kernel_name = kernel_signature_->name;
kernel_type_.reset(
new OpKernelType(std::move(InnerGetExpectedKernelType(exe_ctx))));
dev_ctx = pool.Get(kernel_type_->place_);
phi_kernel_name = kernel_signature_->name;
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the
// library_type here, otherwise it can't work.
......@@ -1753,7 +1756,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
}
} else {
phi_kernel_name = kernel_signature_->name;
// NOTE(jiahongyu): The registered MKLDNN kernel have library_type =
// LibraryType::kMKLDNN and data_layout_ = DataLayout::ONEDNN. But the default
// values are kPlain, so we need to modify the library_type and data_layout_
......@@ -1939,7 +1941,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::TracerEventType::OperatorInner,
1,
platform::EventRole::kInnerOp);
if (run_phi_kernel_) {
if (run_phi_kernel_ && phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::FUNCTION) {
phi::KernelContext phi_kernel_context;
if (enable_cache_runtime_context_ && !need_prepare_phi_data_ &&
!need_prepare_data_) {
......@@ -1977,6 +1980,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
BuildPhiKernelContext(*runtime_ctx, dev_ctx, &phi_kernel_context);
(*phi_kernel_)(&phi_kernel_context);
}
} else if (run_phi_kernel_ && phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::STRUCTURE) {
ExecutionContext execution_context(
*this, exec_scope, *dev_ctx, *runtime_ctx);
(*phi_kernel_)(&execution_context);
} else {
(*kernel_func_)(
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
......@@ -2147,14 +2155,18 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
phi::KernelKey OperatorWithKernel::ChoosePhiKernel(
const ExecutionContext& ctx) const {
kernel_signature_.reset(
new phi::KernelSignature(std::move(GetExpectedPhiKernelArgs(ctx))));
std::string phi_kernel_name;
if (phi::KernelFactory::Instance().HasStructuredKernel(type_)) {
kernel_signature_.reset(new phi::KernelSignature(type_.c_str()));
} else {
kernel_signature_.reset(
new phi::KernelSignature(std::move(GetExpectedPhiKernelArgs(ctx))));
}
VLOG(6) << *kernel_signature_.get();
phi_kernel_name = kernel_signature_->name;
kernel_type_.reset(
new OpKernelType(std::move(InnerGetExpectedKernelType(ctx))));
auto phi_kernel_name = kernel_signature_->name;
auto phi_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get());
phi_kernel_.reset(new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
phi_kernel_name, phi_kernel_key)));
......@@ -2616,7 +2628,8 @@ Scope* OperatorWithKernel::PrepareData(
}
};
if (run_phi_kernel_) {
if (run_phi_kernel_ && phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::FUNCTION) {
const auto& input_names = kernel_signature_->input_names;
const auto& input_defs = phi_kernel_->args_def().input_defs();
PADDLE_ENFORCE_EQ(input_names.size(),
......
......@@ -41,6 +41,7 @@ limitations under the License. */
#include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/utils/flat_hash_map.h"
......@@ -290,7 +291,7 @@ class OperatorBase {
const platform::Place& place) const = 0;
};
class ExecutionContext {
class ExecutionContext : public phi::KernelContext {
public:
ExecutionContext(const OperatorBase& op,
const Scope& scope,
......
......@@ -273,17 +273,23 @@ PreparedOp PrepareImpl(
kernel_signature = (*arg_map_fn)(
framework::ExecutionArgumentMappingContext(dygraph_exe_ctx));
} else {
default_kernel_signature =
default_phi_kernel_sig_map.GetNullable(op.Type());
if (default_kernel_signature) {
if (phi::KernelFactory::Instance().HasStructuredKernel(op.Type())) {
has_phi_kernel = true;
kernel_signature = *default_kernel_signature;
kernel_signature = phi::KernelSignature(op.Type().c_str());
} else {
default_kernel_signature =
default_phi_kernel_sig_map.GetNullable(op.Type());
if (default_kernel_signature) {
has_phi_kernel = true;
kernel_signature = *default_kernel_signature;
}
}
}
if (has_phi_kernel) {
VLOG(6) << kernel_signature;
phi_kernel_name = kernel_signature.name;
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the
// library_type here, otherwise it can't work.
......@@ -648,6 +654,7 @@ static void PreparedOpRunPtImpl(
const phi::KernelSignature* default_kernel_signature,
const phi::KernelSignature& kernel_signature,
const phi::Kernel& phi_kernel,
const framework::RuntimeContext& ctx,
platform::DeviceContext* dev_ctx,
const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs,
......@@ -678,19 +685,25 @@ static void PreparedOpRunPtImpl(
1,
platform::EventRole::kInnerOp);
PreparePhiData<VarType>(phi_kernel, kernel_signature, ins);
phi::KernelContext phi_kernel_context;
BuildDygraphPhiKernelContext<VarType>(kernel_signature,
phi_kernel,
ins,
outs,
attrs,
default_attrs,
dev_ctx,
&phi_kernel_context);
if (phi_kernel.GetKernelRegisteredType() ==
phi::KernelRegisteredType::FUNCTION) {
PreparePhiData<VarType>(phi_kernel, kernel_signature, ins);
phi::KernelContext phi_kernel_context;
BuildDygraphPhiKernelContext<VarType>(kernel_signature,
phi_kernel,
ins,
outs,
attrs,
default_attrs,
dev_ctx,
&phi_kernel_context);
phi_kernel(&phi_kernel_context);
phi_kernel(&phi_kernel_context);
} else {
DygraphExecutionContext<VarType> exe_ctx(
op, empty_scope, *dev_ctx, ctx, ins, outs, attrs, default_attrs);
phi_kernel(&exe_ctx);
}
}
if (FLAGS_check_nan_inf) {
......@@ -722,6 +735,7 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins,
default_kernel_signature_,
kernel_signature_,
phi_kernel_,
ctx_,
dev_ctx_,
ins,
outs,
......@@ -753,6 +767,7 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
default_kernel_signature_,
kernel_signature_,
phi_kernel_,
ctx_,
dev_ctx_,
ins,
outs,
......@@ -784,6 +799,7 @@ void PreparedOp::Run(const NameVarMap<egr::EagerVariable>& ins,
default_kernel_signature_,
kernel_signature_,
phi_kernel_,
ctx_,
dev_ctx_,
ins,
outs,
......
......@@ -530,8 +530,12 @@ phi::KernelSignature Tracer::GetExpectedKernelSignature(
"This op type:`%s` is not a OperatorWithKernel, only "
"OperatorWithKernel can get KernelSignature",
type));
return phi::KernelSignature(
std::move(opbase_with_kernel->GetExpectedPhiKernelArgs(dygraph_exe_ctx)));
if (phi::KernelFactory::Instance().HasStructuredKernel(type)) {
return phi::KernelSignature(op->Type().c_str());
} else {
return phi::KernelSignature(std::move(
opbase_with_kernel->GetExpectedPhiKernelArgs(dygraph_exe_ctx)));
}
}
} // namespace imperative
......
......@@ -34,7 +34,7 @@ limitations under the License. */
phi::RegType::INNER, \
#kernel_name, \
dev_type, \
DATALAYOUT(layout), \
DATA_LAYOUT(layout), \
::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
[](const phi::KernelKey& kernel_key, phi::Kernel* kernel) {}, \
PHI_KERNEL(kernel_fn), \
......
......@@ -240,12 +240,15 @@ REGISTER_OPERATOR(rank_loss,
ops::RankLossGradMaker<paddle::framework::OpDesc>,
ops::RankLossGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(rank_loss_grad, ops::RankLossGradOp);
REGISTER_OP_CPU_KERNEL(rank_loss, ops::RankLossKernel<phi::CPUContext, float>);
REGISTER_OP_CPU_KERNEL(rank_loss_grad,
ops::RankLossGradKernel<phi::CPUContext, float>);
REGISTER_OP_CUDA_KERNEL(
rank_loss, paddle::operators::RankLossKernel<phi::GPUContext, float>);
REGISTER_OP_CUDA_KERNEL(
rank_loss_grad,
paddle::operators::RankLossGradKernel<phi::GPUContext, float>);
PD_REGISTER_STRUCT_KERNEL(
rank_loss, CPU, ALL_LAYOUT, ops::RankLossKernel, float) {}
PD_REGISTER_STRUCT_KERNEL(
rank_loss_grad, CPU, ALL_LAYOUT, ops::RankLossGradKernel, float) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_STRUCT_KERNEL(
rank_loss, GPU, ALL_LAYOUT, ops::RankLossKernel, float) {}
PD_REGISTER_STRUCT_KERNEL(
rank_loss_grad, GPU, ALL_LAYOUT, ops::RankLossGradKernel, float) {}
#endif
......@@ -21,7 +21,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class RankLossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
......@@ -42,7 +42,7 @@ class RankLossKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
template <typename T, typename DeviceContext>
class RankLossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
......
......@@ -56,6 +56,8 @@ struct KernelSignature {
attr_names(attrs),
output_names(outputs) {}
explicit KernelSignature(const char* kernel_name) : name(kernel_name) {}
// TODO(chenweihang): add assign constructor to solve windows compile
// problem, remove it later
KernelSignature(const KernelSignature& other)
......
......@@ -62,6 +62,21 @@ bool KernelFactory::HasCompatiblePhiKernel(const std::string& op_type) const {
return false;
}
bool KernelFactory::HasStructuredKernel(const std::string& op_type) const {
auto phi_kernel_name = phi::OpUtilsMap::Instance().GetBaseKernelName(op_type);
auto kernel_iter = kernels_.find(phi_kernel_name);
if (deprecated_op_names.find(op_type) == deprecated_op_names.end() &&
kernel_iter != kernels_.end()) {
return std::any_of(kernel_iter->second.begin(),
kernel_iter->second.end(),
[](phi::KernelKeyMap::const_reference kernel_pair) {
return kernel_pair.second.GetKernelRegisteredType() ==
KernelRegisteredType::STRUCTURE;
});
}
return false;
}
const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name);
......
......@@ -238,13 +238,21 @@ class KernelArgsDef {
{}};
};
enum class KernelRegisteredType { FUNCTION, STRUCTURE };
class Kernel {
public:
// for map element construct
Kernel() = default;
explicit Kernel(KernelFn fn, void* variadic_fn)
: fn_(fn), variadic_fn_(variadic_fn) {}
: fn_(fn), variadic_fn_(variadic_fn) {
if (variadic_fn == nullptr) {
kernel_registered_type_ = KernelRegisteredType::STRUCTURE;
} else {
kernel_registered_type_ = KernelRegisteredType::FUNCTION;
}
}
void operator()(KernelContext* ctx) const { fn_(ctx); }
......@@ -272,10 +280,15 @@ class Kernel {
bool IsValid() const { return fn_ != nullptr; }
KernelRegisteredType GetKernelRegisteredType() const {
return kernel_registered_type_;
}
private:
KernelFn fn_{nullptr};
void* variadic_fn_ = nullptr;
KernelArgsDef args_def_;
KernelRegisteredType kernel_registered_type_ = KernelRegisteredType::FUNCTION;
};
using KernelKeyMap = paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>;
......@@ -304,6 +317,8 @@ class KernelFactory {
bool HasCompatiblePhiKernel(const std::string& op_type) const;
bool HasStructuredKernel(const std::string& op_type) const;
KernelResult SelectKernelOrThrowError(const std::string& kernel_name,
const KernelKey& kernel_key) const;
......
此差异已折叠。
......@@ -35,4 +35,6 @@ KernelSignature SaveCombineOpArgumentMapping(
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(save_combine, save_combine_tensor);
PD_REGISTER_ARG_MAPPING_FN(save_combine, phi::SaveCombineOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册