未验证 提交 ef1c8759 编写于 作者: H HongyuJia 提交者: GitHub

[Kernel Selection] Remove hard code of PADDLE_WITH_MKLDNN (#46606)

* remove PADDLE_WITH_MKLDNN, test white_list=abs

* fix unique_ptr

* fix op.Type()

* remove TODO in kernel_dispatch.h

* remove IndicateVarDataType function, update white_list

* remove mkldnn hard code

* add comments

* fix ==

* update mkldnn_op_list

* delete hard code of OPs

* update mkldnn_op_list

* update mkldnn_op_list, remove interp

* add error check for ExecutionContext

* update mkldnn_op_list, remove transpose2_grad

* remove interpolate mkldnn

* remove fill_constant mkldnn

* opt HasAttr in DygraphExecutionContext

* deprecated commit, test mkldnn_white_list

* deprecated commit, test mkldnn_white_list

* deprecated commit, test mkldnn_black_list

* update mkldnn_op_list, add assert error op

* solve cudnn related op

* fix error

* add mkldnn fallback in phi_utils.cc

* remove mkldnn fallback in phi_utils.cc

* opt code implementation

* polish Copyright License
上级 f246ebba
...@@ -535,7 +535,8 @@ void BuildOpFuncList(const platform::Place& place, ...@@ -535,7 +535,8 @@ void BuildOpFuncList(const platform::Place& place,
if (op_with_kernel->PhiKernel()->IsValid()) { if (op_with_kernel->PhiKernel()->IsValid()) {
run_phi_kernel = true; run_phi_kernel = true;
} else { } else {
if (!op_with_kernel->SupportsKernelType(expected_kernel_key)) { if (!op_with_kernel->SupportsKernelType(expected_kernel_key,
exec_ctx)) {
auto phi_cpu_kernel_key = FallBackToCpu( auto phi_cpu_kernel_key = FallBackToCpu(
expected_kernel_key, phi_kernel_key, *op_with_kernel); expected_kernel_key, phi_kernel_key, *op_with_kernel);
op_with_kernel->ResetPhiKernel( op_with_kernel->ResetPhiKernel(
......
...@@ -50,6 +50,7 @@ class DenseTensor; ...@@ -50,6 +50,7 @@ class DenseTensor;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_op_list.h"
#endif #endif
#ifdef PADDLE_WITH_MLU #ifdef PADDLE_WITH_MLU
...@@ -1352,7 +1353,7 @@ bool OperatorWithKernel::SupportsMKLDNN( ...@@ -1352,7 +1353,7 @@ bool OperatorWithKernel::SupportsMKLDNN(
} }
bool OperatorWithKernel::SupportsKernelType( bool OperatorWithKernel::SupportsKernelType(
const OpKernelType& kernel_type) const { const OpKernelType& kernel_type, const ExecutionContext& exe_ctx) const {
auto& all_op_kernels = AllOpKernels(); auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_); auto kernels_iter = all_op_kernels.find(type_);
if (kernels_iter == all_op_kernels.end()) return false; if (kernels_iter == all_op_kernels.end()) return false;
...@@ -1386,16 +1387,38 @@ bool OperatorWithKernel::SupportsKernelType( ...@@ -1386,16 +1387,38 @@ bool OperatorWithKernel::SupportsKernelType(
} }
#endif #endif
// NOTE(jiahongyu): If MKLDNN can be used, the function SupportsKernelType needs
// to check whether current op supports MKLDNN kernel. There are three
// statements in if condition: The first statement checks whether library_type_
// are changed by other high priority backends; the second checks whether this
// op has specific implementation; the third checks whether mkldnn kernel can be
// used.
#ifdef PADDLE_WITH_MKLDNN
if (kernel_type.library_type_ == framework::LibraryType::kPlain &&
!paddle::platform::in_mkldnn_white_list(type_) &&
this->CanMKLDNNBeUsed(exe_ctx, kernel_type.data_type_)) {
auto tmp_kernel_type = kernel_type;
tmp_kernel_type.library_type_ = framework::LibraryType::kMKLDNN;
tmp_kernel_type.data_layout_ = framework::DataLayout::kMKLDNN;
return kernels.find(tmp_kernel_type) != kernels.end();
}
#endif
return kernel_iter != kernels.end(); return kernel_iter != kernels.end();
} }
bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const { proto::VarType::Type data_type) const {
// NOTE(jiahongyu): Only mkldnn kernels need to check "use_mkldnn" attribute,
// hence we first call function SupportsMKLDNN. If we check "use_mkldnn"
// attribute first, it will cause error because some codes add "use_mkldnn"
// attribute to non-mkldnn ops.
if (!this->SupportsMKLDNN(data_type)) {
return false;
}
const std::string use_mkldnn_attr = "use_mkldnn"; const std::string use_mkldnn_attr = "use_mkldnn";
bool use_mkldnn_ctx = ctx.HasAttr(use_mkldnn_attr) && return ctx.HasAttr(use_mkldnn_attr) && ctx.Attr<bool>(use_mkldnn_attr) &&
ctx.Attr<bool>(use_mkldnn_attr) && platform::is_cpu_place(ctx.GetPlace());
platform::is_cpu_place(ctx.GetPlace());
return use_mkldnn_ctx && this->SupportsMKLDNN(data_type);
} }
void OperatorWithKernel::InferShape(InferShapeContext* ctx) const { void OperatorWithKernel::InferShape(InferShapeContext* ctx) const {
...@@ -1544,6 +1567,23 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1544,6 +1567,23 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
} }
} else { } else {
phi_kernel_name = kernel_signature_->name; phi_kernel_name = kernel_signature_->name;
// NOTE(jiahongyu): The registered MKLDNN kernel have library_type =
// LibraryType::kMKLDNN and data_layout_ = DataLayout::kMKLDNN. But the default
// values are kPlain, so we need to modify the library_type and data_layout_
// here. There are three statements in if condition: The first statement checks
// whether library_type_ are changed by other high priority backends; the second
// checks whether this op has specific implementation; the third checks whether
// mkldnn kernel can be used.
#ifdef PADDLE_WITH_MKLDNN
if (kernel_type_->library_type_ == framework::LibraryType::kPlain &&
!paddle::platform::in_mkldnn_white_list(type_) &&
this->CanMKLDNNBeUsed(exe_ctx, kernel_type_->data_type_)) {
kernel_type_->library_type_ = framework::LibraryType::kMKLDNN;
kernel_type_->data_layout_ = framework::DataLayout::kMKLDNN;
}
#endif
// NOTE(Liu-xiandong):In my ctest, this branch do not be executed, // NOTE(Liu-xiandong):In my ctest, this branch do not be executed,
// I can't understand it, it's really confusing. // I can't understand it, it's really confusing.
// But we still need to keep this to avoid errors. // But we still need to keep this to avoid errors.
...@@ -1771,6 +1811,23 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1771,6 +1811,23 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
auto expected_kernel_key = this->GetExpectedKernelType(ctx); auto expected_kernel_key = this->GetExpectedKernelType(ctx);
// NOTE(jiahongyu): PADDLE_WITH_MKLDNN codes are moved outside function
// GetExpectedKernelType, so that if MKLDNN can be used, the library_type_ and
// data_layout_ of expected_kernel_key need to be adjusted. There are three
// statements in if condition: The first statement checks whether library_type_
// are changed by other high priority backends; the second checks whether this
// op has specific implementation; the third checks whether mkldnn kernel can be
// used.
#ifdef PADDLE_WITH_MKLDNN
if (expected_kernel_key.library_type_ == framework::LibraryType::kPlain &&
!paddle::platform::in_mkldnn_white_list(type_) &&
this->CanMKLDNNBeUsed(ctx, expected_kernel_key.data_type_)) {
expected_kernel_key.library_type_ = framework::LibraryType::kMKLDNN;
expected_kernel_key.data_layout_ = framework::DataLayout::kMKLDNN;
}
#endif
if (HasAttr("op_device")) { if (HasAttr("op_device")) {
if (Attr<std::string>("op_device") == "cpu") { if (Attr<std::string>("op_device") == "cpu") {
expected_kernel_key.place_ = platform::CPUPlace(); expected_kernel_key.place_ = platform::CPUPlace();
......
...@@ -323,10 +323,16 @@ class ExecutionContext { ...@@ -323,10 +323,16 @@ class ExecutionContext {
virtual const Attribute& GetAttr(const std::string& name) const { virtual const Attribute& GetAttr(const std::string& name) const {
auto iter = op_.Attrs().find(name); auto iter = op_.Attrs().find(name);
if (iter == op_.Attrs().end()) { if (iter == op_.Attrs().end()) {
return op_.RuntimeAttrs().at(name); iter = op_.RuntimeAttrs().find(name);
} else { PADDLE_ENFORCE_NE(
return iter->second; iter,
op_.RuntimeAttrs().end(),
platform::errors::NotFound("(%s) is not found in AttributeMap and "
"RuntimeAttributeMap of (%s) operator.",
name,
op_.Type()));
} }
return iter->second;
} }
virtual bool HasInput(const std::string& name) const; virtual bool HasInput(const std::string& name) const;
...@@ -621,7 +627,8 @@ class OperatorWithKernel : public OperatorBase { ...@@ -621,7 +627,8 @@ class OperatorWithKernel : public OperatorBase {
bool SupportsMKLDNN(proto::VarType::Type data_type) const; bool SupportsMKLDNN(proto::VarType::Type data_type) const;
bool SupportsKernelType(const OpKernelType& kernel_type) const; bool SupportsKernelType(const OpKernelType& kernel_type,
const ExecutionContext& exe_ctx) const;
bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const; proto::VarType::Type data_type) const;
......
...@@ -102,7 +102,10 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -102,7 +102,10 @@ class DygraphExecutionContext : public framework::ExecutionContext {
} }
bool HasAttr(const std::string& name) const override { bool HasAttr(const std::string& name) const override {
return attrs_.count(name) != 0 || default_attrs_.count(name) != 0; if (attrs_.find(name) == attrs_.end()) {
return default_attrs_.find(name) != default_attrs_.end();
}
return true;
} }
const framework::AttributeMap& Attrs() const override { return attrs_; } const framework::AttributeMap& Attrs() const override { return attrs_; }
......
...@@ -25,6 +25,9 @@ ...@@ -25,6 +25,9 @@
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/device/xpu/xpu_op_list.h" #include "paddle/fluid/platform/device/xpu/xpu_op_list.h"
#endif #endif
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_op_list.h"
#endif
#include "paddle/fluid/framework/library_type.h" #include "paddle/fluid/framework/library_type.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/platform/profiler/event_tracing.h"
...@@ -185,13 +188,29 @@ PreparedOp PrepareImpl( ...@@ -185,13 +188,29 @@ PreparedOp PrepareImpl(
phi::KernelSignature kernel_signature; phi::KernelSignature kernel_signature;
phi::KernelKey phi_kernel_key; phi::KernelKey phi_kernel_key;
std::string phi_kernel_name; std::string phi_kernel_name;
// NOTE(jiahongyu): The registered MKLDNN kernel have library_type =
// LibraryType::kMKLDNN and data_layout_ = DataLayout::kMKLDNN. But the default
// values are kPlain, so we need to modify the library_type and data_layout_
// here. There are three statements in if condition: The first statement checks
// whether library_type_ are changed by other high priority backends; the second
// checks whether this op has specific implementation; the third checks whether
// mkldnn kernel can be used.
#ifdef PADDLE_WITH_MKLDNN
if (expected_kernel_key.library_type_ == framework::LibraryType::kPlain &&
!paddle::platform::in_mkldnn_white_list(op.Type()) &&
op.CanMKLDNNBeUsed(dygraph_exe_ctx, expected_kernel_key.data_type_)) {
expected_kernel_key.library_type_ = framework::LibraryType::kMKLDNN;
expected_kernel_key.data_layout_ = framework::DataLayout::kMKLDNN;
}
#endif
#if defined(PADDLE_WITH_XPU) #if defined(PADDLE_WITH_XPU)
bool is_xpu_unsupport = bool is_xpu_unsupport =
paddle::platform::is_xpu_place(expected_kernel_key.place_) && paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
!paddle::platform::is_xpu_support_op(op.Type(), !paddle::platform::is_xpu_support_op(op.Type(),
expected_kernel_key) || expected_kernel_key) ||
paddle::platform::is_in_xpu_black_list(op.Type()); paddle::platform::is_in_xpu_black_list(op.Type());
#endif #endif
bool has_phi_kernel = false; bool has_phi_kernel = false;
......
...@@ -21,9 +21,6 @@ ...@@ -21,9 +21,6 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -36,15 +33,6 @@ class AbsOp : public framework::OperatorWithKernel { ...@@ -36,15 +33,6 @@ class AbsOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -86,15 +74,6 @@ class AbsGradOp : public framework::OperatorWithKernel { ...@@ -86,15 +74,6 @@ class AbsGradOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -82,27 +82,18 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -82,27 +82,18 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper, const framework::OperatorWithKernel& oper,
const std::string& name) { const std::string& name) {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = oper.IndicateVarDataType(ctx, name); auto data_type = oper.IndicateVarDataType(ctx, name);
// FIXME(liuwei1031) temporarily disable the code to unblock users // FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind // TODO(liuwei1031) figure out the reason behind
// https://github.com/PaddlePaddle/Paddle/issues/16096 // https://github.com/PaddlePaddle/Paddle/issues/16096
// and re-enable this in the future // and re-enable this in the future
// #ifdef PADDLE_WITH_CUDA // #ifdef PADDLE_WITH_CUDA
// auto it1 = oper.Attrs().find("use_cudnn"); // auto it1 = oper.Attrs().find("use_cudnn");
// if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) { // if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
// library = framework::LibraryType::kCUDNN; // library = framework::LibraryType::kCUDNN;
// } // }
// #endif // #endif
#ifdef PADDLE_WITH_MKLDNN return framework::OpKernelType(data_type, ctx.GetPlace());
if (library == framework::LibraryType::kPlain &&
oper.CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
} }
class ActivationOp : public framework::OperatorWithKernel { class ActivationOp : public framework::OperatorWithKernel {
......
...@@ -197,16 +197,6 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType( ...@@ -197,16 +197,6 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType(
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Variance input should be of float type")); "Variance input should be of float type"));
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -396,18 +386,7 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( ...@@ -396,18 +386,7 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
platform::errors::InvalidArgument("gradient variable of Y is empty")); platform::errors::InvalidArgument("gradient variable of Y is empty"));
} }
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
......
...@@ -30,15 +30,6 @@ class ClipOp : public framework::OperatorWithKernel { ...@@ -30,15 +30,6 @@ class ClipOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -98,15 +89,6 @@ class ClipOpGrad : public framework::OperatorWithKernel { ...@@ -98,15 +89,6 @@ class ClipOpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType( auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -24,10 +24,6 @@ limitations under the License. */ ...@@ -24,10 +24,6 @@ limitations under the License. */
#include "paddle/phi/infermeta/multiary.h" #include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/kernels/funcs/concat_funcs.h" #include "paddle/phi/kernels/funcs/concat_funcs.h"
#ifdef PADDLE_WITH_MKLDNN
#include <paddle/fluid/platform/mkldnn_helper.h>
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = phi::DenseTensor; using Tensor = phi::DenseTensor;
...@@ -53,14 +49,6 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -53,14 +49,6 @@ class ConcatOp : public framework::OperatorWithKernel {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"All Inputs of Concat OP are Empty!")); "All Inputs of Concat OP are Empty!"));
} }
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -127,19 +115,6 @@ class ConcatOpGrad : public framework::OperatorWithKernel { ...@@ -127,19 +115,6 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType( auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
// extra checking if attr "use_mkldnn" exist is needed because
// test_reverse_op is calling concat_grad kernel without setting
// "use_mkldnn" to any value
if (ctx.HasAttr("use_mkldnn") &&
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
......
...@@ -49,15 +49,6 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( ...@@ -49,15 +49,6 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
} }
} }
#endif #endif
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
......
...@@ -18,9 +18,6 @@ limitations under the License. */ ...@@ -18,9 +18,6 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
...@@ -199,15 +196,6 @@ class DataNormOp : public framework::OperatorWithKernel { ...@@ -199,15 +196,6 @@ class DataNormOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"bias input should be of float type")); "bias input should be of float type"));
} }
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -508,18 +496,7 @@ class DataNormGradOp : public framework::OperatorWithKernel { ...@@ -508,18 +496,7 @@ class DataNormGradOp : public framework::OperatorWithKernel {
"Y@GRAD can not be found for computation")); "Y@GRAD can not be found for computation"));
} }
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
}; };
......
...@@ -45,15 +45,6 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { ...@@ -45,15 +45,6 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Out"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Out");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
......
...@@ -32,15 +32,6 @@ class ElementwiseMulOp : public ElementwiseOp { ...@@ -32,15 +32,6 @@ class ElementwiseMulOp : public ElementwiseOp {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = auto input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
......
...@@ -156,15 +156,6 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -156,15 +156,6 @@ class ElementwiseOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -317,15 +308,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -317,15 +308,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType( auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -371,15 +353,6 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { ...@@ -371,15 +353,6 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -432,15 +405,6 @@ class ElementwiseOpDoubleGradWithoutDXDY ...@@ -432,15 +405,6 @@ class ElementwiseOpDoubleGradWithoutDXDY
input_data_type = input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "DDX", "DDY"); OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "DDX", "DDY");
} }
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -493,15 +457,6 @@ class ElementwiseOpTripleGrad : public framework::OperatorWithKernel { ...@@ -493,15 +457,6 @@ class ElementwiseOpTripleGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::proto::VarType::Type input_data_type; framework::proto::VarType::Type input_data_type;
input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "D_DDOut"); input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "D_DDOut");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
......
...@@ -37,15 +37,6 @@ class ExpandV2Op : public framework::OperatorWithKernel { ...@@ -37,15 +37,6 @@ class ExpandV2Op : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -163,15 +154,6 @@ class ExpandV2GradOp : public framework::OperatorWithKernel { ...@@ -163,15 +154,6 @@ class ExpandV2GradOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
......
...@@ -104,15 +104,6 @@ class FillConstantOp : public framework::OperatorWithKernel { ...@@ -104,15 +104,6 @@ class FillConstantOp : public framework::OperatorWithKernel {
} }
} }
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return kt; return kt;
} }
}; };
......
...@@ -153,14 +153,6 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -153,14 +153,6 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionGRUOp::GetExpectedKernelType( framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
......
...@@ -176,14 +176,6 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -176,14 +176,6 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionLSTMOp::GetExpectedKernelType( framework::OpKernelType FusionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
......
...@@ -60,16 +60,6 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -60,16 +60,6 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = auto data_type =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")); static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.device_context(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
......
...@@ -36,14 +36,6 @@ class GeluOp : public framework::OperatorWithKernel { ...@@ -36,14 +36,6 @@ class GeluOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
}; };
...@@ -76,14 +68,6 @@ class GeluGradOp : public framework::OperatorWithKernel { ...@@ -76,14 +68,6 @@ class GeluGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
}; };
......
...@@ -340,20 +340,6 @@ class InterpolateOp : public framework::OperatorWithKernel { ...@@ -340,20 +340,6 @@ class InterpolateOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
// TODO(danqing): support other interp_method
// (https://github.com/PaddlePaddle/Paddle/pull/30016/files)
// NOTE(jiahy0825): currently only support interp_method = nearest or
// interp_method = bilinear
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
......
...@@ -444,20 +444,6 @@ class InterpolateV2Op : public framework::OperatorWithKernel { ...@@ -444,20 +444,6 @@ class InterpolateV2Op : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
// TODO(danqing): support other interp_method
// (https://github.com/PaddlePaddle/Paddle/pull/30016/files)
// NOTE(jiahy0825): currently only support interp_method = nearest or
// interp_method = bilinear
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
......
...@@ -33,15 +33,6 @@ class LogSoftmaxOp : public framework::OperatorWithKernel { ...@@ -33,15 +33,6 @@ class LogSoftmaxOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -225,16 +225,6 @@ class LRNOp : public framework::OperatorWithKernel { ...@@ -225,16 +225,6 @@ class LRNOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
...@@ -359,16 +349,6 @@ class LRNOpGrad : public framework::OperatorWithKernel { ...@@ -359,16 +349,6 @@ class LRNOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
......
...@@ -697,15 +697,6 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -697,15 +697,6 @@ class MatMulOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -889,15 +880,6 @@ class MatMulOpGrad : public framework::OperatorWithKernel { ...@@ -889,15 +880,6 @@ class MatMulOpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -135,15 +135,6 @@ class MatMulV2Op : public framework::OperatorWithKernel { ...@@ -135,15 +135,6 @@ class MatMulV2Op : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = auto input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -210,15 +201,6 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel { ...@@ -210,15 +201,6 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType( auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
......
...@@ -36,15 +36,6 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -36,15 +36,6 @@ class PReluOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -127,15 +118,6 @@ class PReluGradOp : public framework::OperatorWithKernel { ...@@ -127,15 +118,6 @@ class PReluGradOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
......
...@@ -31,15 +31,6 @@ class ScaleOp : public framework::OperatorWithKernel { ...@@ -31,15 +31,6 @@ class ScaleOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -30,15 +30,6 @@ class ShapeOp : public framework::OperatorWithKernel { ...@@ -30,15 +30,6 @@ class ShapeOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "Input"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
......
...@@ -39,15 +39,6 @@ class ShuffleChannelOp : public framework::OperatorWithKernel { ...@@ -39,15 +39,6 @@ class ShuffleChannelOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -53,7 +53,6 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -53,7 +53,6 @@ class SoftmaxOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"float16 can only be used on GPU/NPU/XPU/MLU and custom place")); "float16 can only be used on GPU/NPU/XPU/MLU and custom place"));
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, return framework::OpKernelType(input_data_type,
...@@ -62,15 +61,6 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -62,15 +61,6 @@ class SoftmaxOp : public framework::OperatorWithKernel {
framework::LibraryType::kCUDNN); framework::LibraryType::kCUDNN);
} }
#endif #endif
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_);
} }
}; };
...@@ -158,15 +148,6 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -158,15 +148,6 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
framework::LibraryType::kCUDNN); framework::LibraryType::kCUDNN);
} }
#endif #endif
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_);
} }
}; };
......
...@@ -35,15 +35,6 @@ class StackOp : public framework::OperatorWithKernel { ...@@ -35,15 +35,6 @@ class StackOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -98,14 +98,6 @@ class TransposeOp : public framework::OperatorWithKernel { ...@@ -98,14 +98,6 @@ class TransposeOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
auto &data_format = ctx.Attr<std::string>("data_format"); auto &data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_); return framework::OpKernelType(data_type, ctx.GetPlace(), layout_);
...@@ -202,14 +194,6 @@ class TransposeOpGrad : public framework::OperatorWithKernel { ...@@ -202,14 +194,6 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType( auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_); return framework::OpKernelType(data_type, ctx.GetPlace(), layout_);
...@@ -360,14 +344,6 @@ class Transpose2OpGrad : public framework::OperatorWithKernel { ...@@ -360,14 +344,6 @@ class Transpose2OpGrad : public framework::OperatorWithKernel {
framework::proto::VarType::Type data_type = framework::proto::VarType::Type data_type =
OperatorWithKernel::IndicateVarDataType(ctx, OperatorWithKernel::IndicateVarDataType(ctx,
framework::GradVarName("Out")); framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_); return framework::OpKernelType(data_type, ctx.GetPlace(), layout_);
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_MKLDNN
#include <unordered_set>
namespace paddle {
namespace platform {
// NOTE(jiahongyu): Below ops have specific PADDLE_WITH_MKLDNN hard codes within
// the function GetExpectedKernelType, so we need to handle them through
// mkldnn_white_list and solve them one-by-one in the future.
// TODO(jiahongyu): Delete mkldnn_white_list and fully support
// PADDLE_WITH_MKLDNN of GetExpectedKernelType.
static const std::unordered_set<std::string> mkldnn_white_list = {
"cast",
"transfer_dtype",
"layer_norm",
"pad2d",
"pad3d",
"pool2d",
"pool2d_grad",
"slice",
"slice_grad",
"split",
"sum",
"sgd",
// NOTE(jiahongyu): squeeze MKLDNN kernel are disabled
// (https://github.com/PaddlePaddle/Paddle/pull/35781). If these MKLDNN
// kernels and codes are deleted in the future, attributes `use_mkldnn`
// should be removed from function declaration
"squeeze",
"squeeze_grad",
"squeeze2",
"squeeze2_grad",
// NOTE(jiahongyu): reshape and flatten have attribute use_mkldnn and they
// are registered in paddle, but they didn't change the ExpectedKernelType
// of tensor. Actually, mkldnn kernel of squeeze, reshape, and flatten
// should never be called.
"reshape",
"reshape_grad",
"reshape2",
"reshape2_grad",
"flatten",
"flatten_grad",
"flatten2",
"flatten2_grad",
// NOTE(jiahongyu): After fixing GetExpectedKernelType in ReduceOp, reduce
// series hard code can be deleted together.
"reduce_max",
"reduce_mean",
"reduce_mean_grad",
"reduce_min",
"reduce_sum",
"reduce_sum_grad",
// NOTE(jiahongyu): Below ops register kernel with customized_type_value, we
// need to analysis and solve them one-by-one.
"conv2d",
"conv2d_grad",
"depthwise_conv2d",
"depthwise_conv2d_grad",
"conv3d",
"conv3d_grad",
"prior_box",
"fc",
"mul",
"mul_grad",
"transpose2"};
inline bool in_mkldnn_white_list(const std::string& op_name) {
return mkldnn_white_list.find(op_name) != mkldnn_white_list.end();
}
} // namespace platform
} // namespace paddle
#endif
...@@ -99,7 +99,6 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> { ...@@ -99,7 +99,6 @@ struct KernelKeyParser : ArgsIterator<KernelKeyParser> {
inline void AssignKernelKeySet(const phi::TensorBase& tensor) { inline void AssignKernelKeySet(const phi::TensorBase& tensor) {
key_set.backend_set = key_set.backend_set =
key_set.backend_set | detail::GetTensorBackendSet(tensor); key_set.backend_set | detail::GetTensorBackendSet(tensor);
// TODO(chenweihang): select multi layout and dtype
phi::DataLayout tensor_layout = tensor.layout(); phi::DataLayout tensor_layout = tensor.layout();
key_set.layout = key_set.layout =
tensor_layout > key_set.layout ? tensor_layout : key_set.layout; tensor_layout > key_set.layout ? tensor_layout : key_set.layout;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册