未验证 提交 6f5e7826 编写于 作者: H HongyuJia 提交者: GitHub

[Kernel Selection] Remove hard code of PADDLE_WITH_MKLDNN (Part2 add dnn_fallback flag) (#47200)

* use dnn_fallback flag to delete mkldnn hardcode

* polish code style

* fix protected error

* fix const error

* fix reduce_op fallback

* fix pool_op fallback

* add Set function of dnn_fallback_
上级 ea8e87fa
......@@ -1388,12 +1388,13 @@ bool OperatorWithKernel::SupportsKernelType(
#endif
// NOTE(jiahongyu): If MKLDNN can be used, the function SupportsKernelType needs
// to check whether current op supports MKLDNN kernel. There are two statements
// in if condition:
// 1. Whether this op has specific implementation;
// 2. Whether mkldnn kernel can be used.
// to check whether current op supports MKLDNN kernel. There are three
// statements in if condition:
// 1. Whether mkldnn kernel fallbacks to plain kernel;
// 2. Whether this op has specific implementation;
// 3. Whether mkldnn kernel can be used.
#ifdef PADDLE_WITH_MKLDNN
if (!paddle::platform::in_mkldnn_white_list(type_) &&
if (!this->DnnFallback() && !paddle::platform::in_mkldnn_white_list(type_) &&
this->CanMKLDNNBeUsed(exe_ctx, kernel_type.data_type_)) {
auto tmp_kernel_type = kernel_type;
tmp_kernel_type.library_type_ = framework::LibraryType::kMKLDNN;
......@@ -1569,11 +1570,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// NOTE(jiahongyu): The registered MKLDNN kernel have library_type =
// LibraryType::kMKLDNN and data_layout_ = DataLayout::kMKLDNN. But the default
// values are kPlain, so we need to modify the library_type and data_layout_
// here. There are two statements in if condition:
// 1. Whether this op has specific implementation;
// 2. Whether mkldnn kernel can be used.
// here. There are three statements in if condition:
// 1. Whether mkldnn kernel fallbacks to plain kernel;
// 2. Whether this op has specific implementation;
// 3. Whether mkldnn kernel can be used.
#ifdef PADDLE_WITH_MKLDNN
if (!paddle::platform::in_mkldnn_white_list(type_) &&
if (!this->DnnFallback() &&
!paddle::platform::in_mkldnn_white_list(type_) &&
this->CanMKLDNNBeUsed(exe_ctx, kernel_type_->data_type_)) {
kernel_type_->library_type_ = framework::LibraryType::kMKLDNN;
kernel_type_->data_layout_ = framework::DataLayout::kMKLDNN;
......@@ -1810,12 +1813,13 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
// NOTE(jiahongyu): PADDLE_WITH_MKLDNN codes are moved outside function
// GetExpectedKernelType, so that if MKLDNN can be used, the library_type_ and
// data_layout_ of expected_kernel_key need to be adjusted. There are two
// data_layout_ of expected_kernel_key need to be adjusted. There are three
// statements in if condition:
// 1. Whether this op has specific implementation;
// 2. Whether mkldnn kernel can be used.
// 1. Whether mkldnn kernel fallbacks to plain kernel;
// 2. Whether this op has specific implementation;
// 3. Whether mkldnn kernel can be used.
#ifdef PADDLE_WITH_MKLDNN
if (!paddle::platform::in_mkldnn_white_list(type_) &&
if (!this->DnnFallback() && !paddle::platform::in_mkldnn_white_list(type_) &&
this->CanMKLDNNBeUsed(ctx, expected_kernel_key.data_type_)) {
expected_kernel_key.library_type_ = framework::LibraryType::kMKLDNN;
expected_kernel_key.data_layout_ = framework::DataLayout::kMKLDNN;
......
......@@ -704,6 +704,10 @@ class OperatorWithKernel : public OperatorBase {
kernel_type_.reset(kernel_type);
}
bool DnnFallback() const { return dnn_fallback_; }
void SetDnnFallback(bool dnn_fallback) const { dnn_fallback_ = dnn_fallback; }
private:
void RunImpl(const Scope& scope, const platform::Place& place) const final;
void RunImpl(const Scope& scope,
......@@ -756,6 +760,10 @@ class OperatorWithKernel : public OperatorBase {
mutable bool all_kernels_must_compute_runtime_shape_ = false;
mutable std::mutex cache_update_mutex_;
mutable bool enable_cache_transfer_scope_ = false;
// NOTE(jiahongyu): Whether fallback to plain kernel after calling
// GetExpectedKernelType, use this bool flag to solve mkldnn and cudnn hard
// code
mutable bool dnn_fallback_ = false;
// NOTE(chenweihang): Similar op members are used to adapt to
// new phi kernel, if there is a better design in the future,
// we may polish the implementation here
......
......@@ -192,11 +192,12 @@ PreparedOp PrepareImpl(
// NOTE(jiahongyu): The registered MKLDNN kernel have library_type =
// LibraryType::kMKLDNN and data_layout_ = DataLayout::kMKLDNN. But the default
// values are kPlain, so we need to modify the library_type and data_layout_
// here. There are two statements in if condition:
// 1. Whether this op has specific implementation;
// 2. Whether mkldnn kernel can be used.
// here. There are three statements in if condition:
// 1. Whether mkldnn kernel fallbacks to plain kernel;
// 2. Whether this op has specific implementation;
// 3. Whether mkldnn kernel can be used.
#ifdef PADDLE_WITH_MKLDNN
if (!paddle::platform::in_mkldnn_white_list(op.Type()) &&
if (!op.DnnFallback() && !paddle::platform::in_mkldnn_white_list(op.Type()) &&
op.CanMKLDNNBeUsed(dygraph_exe_ctx, expected_kernel_key.data_type_)) {
expected_kernel_key.library_type_ = framework::LibraryType::kMKLDNN;
expected_kernel_key.data_layout_ = framework::DataLayout::kMKLDNN;
......
......@@ -91,31 +91,19 @@ class CastOp : public framework::OperatorWithKernel {
ctx.device_context());
}
#ifdef PADDLE_WITH_MKLDNN
// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN
int in_dtype = ctx.Attr<int>("in_dtype");
int out_dtype = ctx.Attr<int>("out_dtype");
auto MKLDNNSupportsCast = [&]() -> bool {
int dtype_fp32 = static_cast<int>(framework::proto::VarType::FP32);
int dtype_bf16 = static_cast<int>(framework::proto::VarType::BF16);
int dtype_fp32 = static_cast<int>(framework::proto::VarType::FP32);
int dtype_bf16 = static_cast<int>(framework::proto::VarType::BF16);
if ((in_dtype != dtype_fp32 && in_dtype != dtype_bf16) ||
(out_dtype != dtype_fp32 && out_dtype != dtype_bf16))
return false;
return true;
};
if (this->CanMKLDNNBeUsed(
ctx, framework::TransToProtoVarType(tensor->dtype())) &&
MKLDNNSupportsCast()) {
return framework::OpKernelType(
framework::TransToProtoVarType(tensor->dtype()),
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
if ((in_dtype != dtype_fp32 && in_dtype != dtype_bf16) ||
(out_dtype != dtype_fp32 && out_dtype != dtype_bf16)) {
this->SetDnnFallback(true);
}
#endif
// NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MLU
auto src_type = static_cast<VT::Type>(ctx.Attr<int>("in_dtype"));
auto dst_type = static_cast<VT::Type>(ctx.Attr<int>("out_dtype"));
......
......@@ -111,17 +111,13 @@ class LayerNormOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN
int begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
if (this->CanMKLDNNBeUsed(ctx, input_data_type) &&
begin_norm_axis ==
ctx.Input<phi::DenseTensor>("X")->dims().size() - 1) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
if (begin_norm_axis !=
ctx.Input<phi::DenseTensor>("X")->dims().size() - 1) {
this->SetDnnFallback(true);
}
#endif
// NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......
......@@ -35,25 +35,20 @@ class SGDOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param");
#ifdef PADDLE_WITH_MKLDNN
using dnnl::memory;
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
const auto *param_var = ctx.InputVar("Param");
const auto *grad_var = ctx.InputVar("Grad");
// supported cases
bool dense_param_sparse_grad = param_var->IsType<phi::DenseTensor>() &&
grad_var->IsType<phi::SelectedRows>();
bool dense_param_and_grad = param_var->IsType<phi::DenseTensor>() &&
grad_var->IsType<phi::DenseTensor>();
if (dense_param_sparse_grad || dense_param_and_grad)
return framework::OpKernelType(data_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN
const auto *param_var = ctx.InputVar("Param");
const auto *grad_var = ctx.InputVar("Grad");
// supported cases
bool dense_param_sparse_grad = param_var->IsType<phi::DenseTensor>() &&
grad_var->IsType<phi::SelectedRows>();
bool dense_param_and_grad = param_var->IsType<phi::DenseTensor>() &&
grad_var->IsType<phi::DenseTensor>();
if (!(dense_param_sparse_grad || dense_param_and_grad)) {
this->SetDnnFallback(true);
}
#endif
// NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN
return framework::OpKernelType(data_type, ctx.device_context());
}
......
......@@ -33,6 +33,9 @@ bool CanMKLDNNSupportPool(const framework::ExecutionContext& ctx) {
if (ctx.Attr<bool>("adaptive") == false) return true;
// (jczaja): oneDNN is supporting only unchangable in size pool window
auto src_tz = phi::vectorize(ctx.Input<phi::DenseTensor>("X")->dims());
if (!ctx.HasAttr("ksize")) {
return false;
}
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
// Fast but not exhustive check
return ((src_tz[src_tz.size() - 1] % ksize[1] == 0) &&
......@@ -50,13 +53,10 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
library_ = framework::LibraryType::kCUDNN;
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, data_type) && CanMKLDNNSupportPool(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = phi::DataLayout::kMKLDNN;
}
#endif
// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN
this->SetDnnFallback(!CanMKLDNNSupportPool(ctx));
// NOTE(jiahongyu) END: Above codes originally enclosed by PADDLE_WITH_MKLDNN
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, library_);
}
......@@ -95,14 +95,10 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
library_ = framework::LibraryType::kCUDNN;
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, input_data_type) &&
CanMKLDNNSupportPool(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = phi::DataLayout::kMKLDNN;
}
#endif
// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN
this->SetDnnFallback(!CanMKLDNNSupportPool(ctx));
// NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN
return framework::OpKernelType(
input_data_type, ctx.GetPlace(), layout_, library_);
......
......@@ -559,6 +559,10 @@ class ReduceOp : public framework::OperatorWithKernel {
experimental::DataType::BFLOAT16)
return true;
if (!ctx.HasAttr("dim") || !ctx.HasAttr("reduce_all")) {
return false;
}
auto reduce_dims = ctx.Attr<std::vector<int>>("dim");
const bool reduce_all = ctx.Attr<bool>("reduce_all");
int ndims = ctx.Input<phi::DenseTensor>("X")->dims().size();
......@@ -586,18 +590,12 @@ class ReduceOp : public framework::OperatorWithKernel {
// choose cudnn kernel if the runtime supported.
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
if (ctx.Input<phi::DenseTensor>("X")->dims().size() > 5)
return framework::OpKernelType(input_data_type, ctx.GetPlace());
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type) &&
HasOptimizedOneDNNKernel(ctx)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN
if (ctx.Input<phi::DenseTensor>("X")->dims().size() > 5 ||
!HasOptimizedOneDNNKernel(ctx)) {
this->SetDnnFallback(true);
}
#endif
// NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(
......@@ -674,22 +672,13 @@ class ReduceGradOp : public framework::OperatorWithKernel {
? static_cast<framework::proto::VarType::Type>(out_dtype)
: OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
auto CanMKLDNNReduceGradBeUsed = [&]() {
auto dx_dims = ctx.Input<phi::DenseTensor>("X")->dims();
if (dx_dims.size() > 5) return false; // max 5D tensor is supported
return true;
};
if (this->CanMKLDNNBeUsed(ctx, input_data_type) &&
CanMKLDNNReduceGradBeUsed()) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN
// max 5D tensor is supported
if (ctx.Input<phi::DenseTensor>("X")->dims().size() > 5) {
dnn_fallback_ = true;
}
#endif
// NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......
......@@ -76,22 +76,21 @@ class SumOp : public framework::OperatorWithKernel {
"Sum operator should have at least one tensor"));
auto data_type = static_cast<framework::proto::VarType::Type>(dtype);
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type) &&
(data_type == framework::proto::VarType::FP32 ||
data_type == framework::proto::VarType::BF16) &&
ctx.OutputVar("Out")->IsType<phi::DenseTensor>()) {
if (std::all_of(
x_vars.begin(), x_vars.end(), [](const framework::Variable* v) {
return v->IsType<phi::DenseTensor>();
})) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN
if (!((data_type == framework::proto::VarType::FP32 ||
data_type == framework::proto::VarType::BF16) &&
ctx.OutputVar("Out")->IsType<phi::DenseTensor>())) {
this->SetDnnFallback(true);
} else if (!std::all_of(x_vars.begin(),
x_vars.end(),
[](const framework::Variable* v) {
return v->IsType<phi::DenseTensor>();
})) {
this->SetDnnFallback(true);
}
#endif
// NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN
return framework::OpKernelType(data_type, ctx.GetPlace());
} else if (x_vars[0]->IsType<phi::SelectedRows>()) {
for (auto& var : x_vars) {
......
......@@ -27,18 +27,14 @@ namespace platform {
// TODO(jiahongyu): Delete mkldnn_white_list and fully support
// PADDLE_WITH_MKLDNN of GetExpectedKernelType.
static const std::unordered_set<std::string> mkldnn_white_list = {
"cast",
"transfer_dtype",
"layer_norm",
// NOTE(jiahongyu): Below ops use mem_desc function, which is encoded by
// PADDLE_WITH_MKLDNN in DenseTensor. The hardcodes within
// GetExpectedKernelType of these ops cannot be deleted now.
"pad2d",
"pad3d",
"pool2d",
"pool2d_grad",
"slice",
"slice_grad",
"split",
"sum",
"sgd",
// NOTE(jiahongyu): squeeze MKLDNN kernel are disabled
// (https://github.com/PaddlePaddle/Paddle/pull/35781). If these MKLDNN
// kernels and codes are deleted in the future, attributes `use_mkldnn`
......@@ -59,14 +55,6 @@ static const std::unordered_set<std::string> mkldnn_white_list = {
"flatten_grad",
"flatten2",
"flatten2_grad",
// NOTE(jiahongyu): After fixing GetExpectedKernelType in ReduceOp, reduce
// series hard code can be deleted together.
"reduce_max",
"reduce_mean",
"reduce_mean_grad",
"reduce_min",
"reduce_sum",
"reduce_sum_grad",
// NOTE(jiahongyu): Below ops register kernel with customized_type_value, we
// need to analysis and solve them one-by-one.
"prior_box"};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册