未验证 提交 ef1c8759 编写于 作者: H HongyuJia 提交者: GitHub

[Kernel Selection] Remove hard code of PADDLE_WITH_MKLDNN (#46606)

* remove PADDLE_WITH_MKLDNN, test white_list=abs

* fix unique_ptr

* fix op.Type()

* remove TODO in kernel_dispatch.h

* remove IndicateVarDataType function, update white_list

* remove mkldnn hard code

* add comments

* fix ==

* update mkldnn_op_list

* delete hard code of OPs

* update mkldnn_op_list

* update mkldnn_op_list, remove interp

* add error check for ExecutionContext

* update mkldnn_op_list, remove transpose2_grad

* remove interpolate mkldnn

* remove fill_constant mkldnn

* opt HasAttr in DygraphExecutionContext

* deprecated commit, test mkldnn_white_list

* deprecated commit, test mkldnn_white_list

* deprecated commit, test mkldnn_black_list

* update mkldnn_op_list, add assert error op

* solve cudnn related op

* fix error

* add mkldnn fallback in phi_utils.cc

* remove mkldnn fallback in phi_utils.cc

* opt code implementation

* polish Copyright License
上级 f246ebba
......@@ -535,7 +535,8 @@ void BuildOpFuncList(const platform::Place& place,
if (op_with_kernel->PhiKernel()->IsValid()) {
run_phi_kernel = true;
} else {
if (!op_with_kernel->SupportsKernelType(expected_kernel_key)) {
if (!op_with_kernel->SupportsKernelType(expected_kernel_key,
exec_ctx)) {
auto phi_cpu_kernel_key = FallBackToCpu(
expected_kernel_key, phi_kernel_key, *op_with_kernel);
op_with_kernel->ResetPhiKernel(
......
......@@ -50,6 +50,7 @@ class DenseTensor;
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_op_list.h"
#endif
#ifdef PADDLE_WITH_MLU
......@@ -1352,7 +1353,7 @@ bool OperatorWithKernel::SupportsMKLDNN(
}
bool OperatorWithKernel::SupportsKernelType(
const OpKernelType& kernel_type) const {
const OpKernelType& kernel_type, const ExecutionContext& exe_ctx) const {
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
if (kernels_iter == all_op_kernels.end()) return false;
......@@ -1386,16 +1387,38 @@ bool OperatorWithKernel::SupportsKernelType(
}
#endif
// NOTE(jiahongyu): If MKLDNN can be used, the function SupportsKernelType needs
// to check whether current op supports MKLDNN kernel. There are three
// statements in if condition: The first statement checks whether library_type_
// are changed by other high priority backends; the second checks whether this
// op has specific implementation; the third checks whether mkldnn kernel can be
// used.
#ifdef PADDLE_WITH_MKLDNN
if (kernel_type.library_type_ == framework::LibraryType::kPlain &&
!paddle::platform::in_mkldnn_white_list(type_) &&
this->CanMKLDNNBeUsed(exe_ctx, kernel_type.data_type_)) {
auto tmp_kernel_type = kernel_type;
tmp_kernel_type.library_type_ = framework::LibraryType::kMKLDNN;
tmp_kernel_type.data_layout_ = framework::DataLayout::kMKLDNN;
return kernels.find(tmp_kernel_type) != kernels.end();
}
#endif
return kernel_iter != kernels.end();
}
bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const {
// NOTE(jiahongyu): Only mkldnn kernels need to check "use_mkldnn" attribute,
// hence we first call function SupportsMKLDNN. If we check "use_mkldnn"
// attribute first, it will cause error because some codes add "use_mkldnn"
// attribute to non-mkldnn ops.
if (!this->SupportsMKLDNN(data_type)) {
return false;
}
const std::string use_mkldnn_attr = "use_mkldnn";
bool use_mkldnn_ctx = ctx.HasAttr(use_mkldnn_attr) &&
ctx.Attr<bool>(use_mkldnn_attr) &&
platform::is_cpu_place(ctx.GetPlace());
return use_mkldnn_ctx && this->SupportsMKLDNN(data_type);
return ctx.HasAttr(use_mkldnn_attr) && ctx.Attr<bool>(use_mkldnn_attr) &&
platform::is_cpu_place(ctx.GetPlace());
}
void OperatorWithKernel::InferShape(InferShapeContext* ctx) const {
......@@ -1544,6 +1567,23 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
}
} else {
phi_kernel_name = kernel_signature_->name;
// NOTE(jiahongyu): The registered MKLDNN kernel have library_type =
// LibraryType::kMKLDNN and data_layout_ = DataLayout::kMKLDNN. But the default
// values are kPlain, so we need to modify the library_type and data_layout_
// here. There are three statements in if condition: The first statement checks
// whether library_type_ are changed by other high priority backends; the second
// checks whether this op has specific implementation; the third checks whether
// mkldnn kernel can be used.
#ifdef PADDLE_WITH_MKLDNN
if (kernel_type_->library_type_ == framework::LibraryType::kPlain &&
!paddle::platform::in_mkldnn_white_list(type_) &&
this->CanMKLDNNBeUsed(exe_ctx, kernel_type_->data_type_)) {
kernel_type_->library_type_ = framework::LibraryType::kMKLDNN;
kernel_type_->data_layout_ = framework::DataLayout::kMKLDNN;
}
#endif
// NOTE(Liu-xiandong):In my ctest, this branch do not be executed,
// I can't understand it, it's really confusing.
// But we still need to keep this to avoid errors.
......@@ -1771,6 +1811,23 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
const ExecutionContext& ctx) const {
auto expected_kernel_key = this->GetExpectedKernelType(ctx);
// NOTE(jiahongyu): PADDLE_WITH_MKLDNN codes are moved outside function
// GetExpectedKernelType, so that if MKLDNN can be used, the library_type_ and
// data_layout_ of expected_kernel_key need to be adjusted. There are three
// statements in if condition: The first statement checks whether library_type_
// are changed by other high priority backends; the second checks whether this
// op has specific implementation; the third checks whether mkldnn kernel can be
// used.
#ifdef PADDLE_WITH_MKLDNN
if (expected_kernel_key.library_type_ == framework::LibraryType::kPlain &&
!paddle::platform::in_mkldnn_white_list(type_) &&
this->CanMKLDNNBeUsed(ctx, expected_kernel_key.data_type_)) {
expected_kernel_key.library_type_ = framework::LibraryType::kMKLDNN;
expected_kernel_key.data_layout_ = framework::DataLayout::kMKLDNN;
}
#endif
if (HasAttr("op_device")) {
if (Attr<std::string>("op_device") == "cpu") {
expected_kernel_key.place_ = platform::CPUPlace();
......
......@@ -323,10 +323,16 @@ class ExecutionContext {
virtual const Attribute& GetAttr(const std::string& name) const {
auto iter = op_.Attrs().find(name);
if (iter == op_.Attrs().end()) {
return op_.RuntimeAttrs().at(name);
} else {
return iter->second;
iter = op_.RuntimeAttrs().find(name);
PADDLE_ENFORCE_NE(
iter,
op_.RuntimeAttrs().end(),
platform::errors::NotFound("(%s) is not found in AttributeMap and "
"RuntimeAttributeMap of (%s) operator.",
name,
op_.Type()));
}
return iter->second;
}
virtual bool HasInput(const std::string& name) const;
......@@ -621,7 +627,8 @@ class OperatorWithKernel : public OperatorBase {
bool SupportsMKLDNN(proto::VarType::Type data_type) const;
bool SupportsKernelType(const OpKernelType& kernel_type) const;
bool SupportsKernelType(const OpKernelType& kernel_type,
const ExecutionContext& exe_ctx) const;
bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const;
......
......@@ -102,7 +102,10 @@ class DygraphExecutionContext : public framework::ExecutionContext {
}
bool HasAttr(const std::string& name) const override {
return attrs_.count(name) != 0 || default_attrs_.count(name) != 0;
if (attrs_.find(name) == attrs_.end()) {
return default_attrs_.find(name) != default_attrs_.end();
}
return true;
}
const framework::AttributeMap& Attrs() const override { return attrs_; }
......
......@@ -25,6 +25,9 @@
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/device/xpu/xpu_op_list.h"
#endif
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_op_list.h"
#endif
#include "paddle/fluid/framework/library_type.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
......@@ -185,13 +188,29 @@ PreparedOp PrepareImpl(
phi::KernelSignature kernel_signature;
phi::KernelKey phi_kernel_key;
std::string phi_kernel_name;
// NOTE(jiahongyu): The registered MKLDNN kernel have library_type =
// LibraryType::kMKLDNN and data_layout_ = DataLayout::kMKLDNN. But the default
// values are kPlain, so we need to modify the library_type and data_layout_
// here. There are three statements in if condition: The first statement checks
// whether library_type_ are changed by other high priority backends; the second
// checks whether this op has specific implementation; the third checks whether
// mkldnn kernel can be used.
#ifdef PADDLE_WITH_MKLDNN
if (expected_kernel_key.library_type_ == framework::LibraryType::kPlain &&
!paddle::platform::in_mkldnn_white_list(op.Type()) &&
op.CanMKLDNNBeUsed(dygraph_exe_ctx, expected_kernel_key.data_type_)) {
expected_kernel_key.library_type_ = framework::LibraryType::kMKLDNN;
expected_kernel_key.data_layout_ = framework::DataLayout::kMKLDNN;
}
#endif
#if defined(PADDLE_WITH_XPU)
bool is_xpu_unsupport =
paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
!paddle::platform::is_xpu_support_op(op.Type(),
expected_kernel_key) ||
paddle::platform::is_in_xpu_black_list(op.Type());
#endif
bool has_phi_kernel = false;
......
......@@ -21,9 +21,6 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
......@@ -36,15 +33,6 @@ class AbsOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -86,15 +74,6 @@ class AbsGradOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -82,27 +82,18 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper,
const std::string& name) {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = oper.IndicateVarDataType(ctx, name);
// FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind
// https://github.com/PaddlePaddle/Paddle/issues/16096
// and re-enable this in the future
// #ifdef PADDLE_WITH_CUDA
// auto it1 = oper.Attrs().find("use_cudnn");
// if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
// library = framework::LibraryType::kCUDNN;
// }
// #endif
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
oper.CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
// FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind
// https://github.com/PaddlePaddle/Paddle/issues/16096
// and re-enable this in the future
// #ifdef PADDLE_WITH_CUDA
// auto it1 = oper.Attrs().find("use_cudnn");
// if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
// library = framework::LibraryType::kCUDNN;
// }
// #endif
return framework::OpKernelType(data_type, ctx.GetPlace());
}
class ActivationOp : public framework::OperatorWithKernel {
......
......@@ -197,16 +197,6 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType(
platform::errors::InvalidArgument(
"Variance input should be of float type"));
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......@@ -396,18 +386,7 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
platform::errors::InvalidArgument("gradient variable of Y is empty"));
}
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace());
}
......
......@@ -30,15 +30,6 @@ class ClipOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -98,15 +89,6 @@ class ClipOpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -24,10 +24,6 @@ limitations under the License. */
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/kernels/funcs/concat_funcs.h"
#ifdef PADDLE_WITH_MKLDNN
#include <paddle/fluid/platform/mkldnn_helper.h>
#endif
namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
......@@ -53,14 +49,6 @@ class ConcatOp : public framework::OperatorWithKernel {
PADDLE_THROW(platform::errors::InvalidArgument(
"All Inputs of Concat OP are Empty!"));
}
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......@@ -127,19 +115,6 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
// extra checking if attr "use_mkldnn" exist is needed because
// test_reverse_op is calling concat_grad kernel without setting
// "use_mkldnn" to any value
if (ctx.HasAttr("use_mkldnn") &&
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......
......@@ -49,15 +49,6 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
}
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace());
}
......
......@@ -18,9 +18,6 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/data_layout.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
......@@ -199,15 +196,6 @@ class DataNormOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument(
"bias input should be of float type"));
}
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......@@ -508,18 +496,7 @@ class DataNormGradOp : public framework::OperatorWithKernel {
"Y@GRAD can not be found for computation"));
}
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace());
}
};
......
......@@ -45,15 +45,6 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Out");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......
......@@ -32,15 +32,6 @@ class ElementwiseMulOp : public ElementwiseOp {
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......
......@@ -156,15 +156,6 @@ class ElementwiseOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......@@ -317,15 +308,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......@@ -371,15 +353,6 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......@@ -432,15 +405,6 @@ class ElementwiseOpDoubleGradWithoutDXDY
input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "DDX", "DDY");
}
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......@@ -493,15 +457,6 @@ class ElementwiseOpTripleGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
framework::proto::VarType::Type input_data_type;
input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "D_DDOut");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......
......@@ -37,15 +37,6 @@ class ExpandV2Op : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......@@ -163,15 +154,6 @@ class ExpandV2GradOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......
......@@ -104,15 +104,6 @@ class FillConstantOp : public framework::OperatorWithKernel {
}
}
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return kt;
}
};
......
......@@ -153,14 +153,6 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace());
}
......
......@@ -176,14 +176,6 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace());
}
......
......@@ -60,16 +60,6 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
auto data_type =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.device_context(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.device_context());
}
......
......@@ -36,14 +36,6 @@ class GeluOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace());
}
};
......@@ -76,14 +68,6 @@ class GeluGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace());
}
};
......
......@@ -340,20 +340,6 @@ class InterpolateOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
// TODO(danqing): support other interp_method
// (https://github.com/PaddlePaddle/Paddle/pull/30016/files)
// NOTE(jiahy0825): currently only support interp_method = nearest or
// interp_method = bilinear
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace());
}
......
......@@ -444,20 +444,6 @@ class InterpolateV2Op : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
// TODO(danqing): support other interp_method
// (https://github.com/PaddlePaddle/Paddle/pull/30016/files)
// NOTE(jiahy0825): currently only support interp_method = nearest or
// interp_method = bilinear
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace());
}
......
......@@ -33,15 +33,6 @@ class LogSoftmaxOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -225,16 +225,6 @@ class LRNOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace());
}
......@@ -359,16 +349,6 @@ class LRNOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace());
}
......
......@@ -697,15 +697,6 @@ class MatMulOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......@@ -889,15 +880,6 @@ class MatMulOpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -135,15 +135,6 @@ class MatMulV2Op : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......@@ -210,15 +201,6 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......
......@@ -36,15 +36,6 @@ class PReluOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......@@ -127,15 +118,6 @@ class PReluGradOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......
......@@ -31,15 +31,6 @@ class ScaleOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -30,15 +30,6 @@ class ShapeOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......
......@@ -39,15 +39,6 @@ class ShuffleChannelOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -53,7 +53,6 @@ class SoftmaxOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument(
"float16 can only be used on GPU/NPU/XPU/MLU and custom place"));
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type,
......@@ -62,15 +61,6 @@ class SoftmaxOp : public framework::OperatorWithKernel {
framework::LibraryType::kCUDNN);
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_);
}
};
......@@ -158,15 +148,6 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
framework::LibraryType::kCUDNN);
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_);
}
};
......
......@@ -35,15 +35,6 @@ class StackOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -98,14 +98,6 @@ class TransposeOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
auto &data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_);
......@@ -202,14 +194,6 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_);
......@@ -360,14 +344,6 @@ class Transpose2OpGrad : public framework::OperatorWithKernel {
framework::proto::VarType::Type data_type =
OperatorWithKernel::IndicateVarDataType(ctx,
framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_);
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_MKLDNN
#include <unordered_set>
namespace paddle {
namespace platform {
// NOTE(jiahongyu): Below ops have specific PADDLE_WITH_MKLDNN hard codes within
// the function GetExpectedKernelType, so we need to handle them through
// mkldnn_white_list and solve them one-by-one in the future.
// TODO(jiahongyu): Delete mkldnn_white_list and fully support
// PADDLE_WITH_MKLDNN of GetExpectedKernelType.
static const std::unordered_set<std::string> mkldnn_white_list = {
"cast",
"transfer_dtype",
"layer_norm",
"pad2d",
"pad3d",
"pool2d",
"pool2d_grad",
"slice",
"slice_grad",
"split",
"sum",
"sgd",
// NOTE(jiahongyu): squeeze MKLDNN kernel are disabled
// (https://github.com/PaddlePaddle/Paddle/pull/35781). If these MKLDNN
// kernels and codes are deleted in the future, attributes `use_mkldnn`
// should be removed from function declaration
"squeeze",
"squeeze_grad",
"squeeze2",
"squeeze2_grad",
// NOTE(jiahongyu): reshape and flatten have attribute use_mkldnn and they
// are registered in paddle, but they didn't change the ExpectedKernelType
// of tensor. Actually, mkldnn kernel of squeeze, reshape, and flatten
// should never be called.
"reshape",
"reshape_grad",
"reshape2",
"reshape2_grad",
"flatten",
"flatten_grad",
"flatten2",
"flatten2_grad",
// NOTE(jiahongyu): After fixing GetExpectedKernelType in ReduceOp, reduce
// series hard code can be deleted together.
"reduce_max",
"reduce_mean",
"reduce_mean_grad",
"reduce_min",
"reduce_sum",
"reduce_sum_grad",
// NOTE(jiahongyu): Below ops register kernel with customized_type_value, we
// need to analysis and solve them one-by-one.
"conv2d",
"conv2d_grad",
"depthwise_conv2d",
"depthwise_conv2d_grad",
"conv3d",
"conv3d_grad",
"prior_box",
"fc",
"mul",
"mul_grad",
"transpose2"};
inline bool in_mkldnn_white_list(const std::string& op_name) {
return mkldnn_white_list.find(op_name) != mkldnn_white_list.end();
}
} // namespace platform
} // namespace paddle
#endif
......@@ -99,7 +99,6 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> {
inline void AssignKernelKeySet(const phi::TensorBase& tensor) {
key_set.backend_set =
key_set.backend_set | detail::GetTensorBackendSet(tensor);
// TODO(chenweihang): select multi layout and dtype
phi::DataLayout tensor_layout = tensor.layout();
key_set.layout =
tensor_layout > key_set.layout ? tensor_layout : key_set.layout;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册