未验证 提交 bc902044 编写于 作者: A arlesniak 提交者: GitHub

Fixes mkldnn dygraph learning rate scheduler crashes (#28988)

上级 0fca8cdf
...@@ -1007,6 +1007,23 @@ static void CheckTensorNANOrInf(const std::string& op_type, ...@@ -1007,6 +1007,23 @@ static void CheckTensorNANOrInf(const std::string& op_type,
op_type, name)); op_type, name));
} }
bool OperatorWithKernel::SupportsMKLDNN() const {
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(op_kernels.begin(), op_kernels.end(),
[](OpKernelMap::const_reference kern_pair) {
return platform::is_cpu_place(kern_pair.first.place_) &&
kern_pair.first.library_type_ ==
LibraryType::kMKLDNN;
});
}
bool OperatorWithKernel::CanMKLDNNBeUsed(
const framework::ExecutionContext& ctx) const {
bool use_mkldnn_ctx =
ctx.Attr<bool>("use_mkldnn") && platform::is_cpu_place(ctx.GetPlace());
return use_mkldnn_ctx && this->SupportsMKLDNN();
}
void OperatorWithKernel::RuntimeInferShape(const Scope& scope, void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
const platform::Place& place, const platform::Place& place,
const RuntimeContext& ctx) const { const RuntimeContext& ctx) const {
......
...@@ -156,6 +156,8 @@ class OperatorBase { ...@@ -156,6 +156,8 @@ class OperatorBase {
virtual bool SupportGPU() const { return false; } virtual bool SupportGPU() const { return false; }
virtual bool SupportsMKLDNN() const { return false; }
const std::string& Type() const { return type_; } const std::string& Type() const { return type_; }
bool HasAttr(const std::string& name) const { return attrs_.count(name); } bool HasAttr(const std::string& name) const { return attrs_.count(name); }
...@@ -490,6 +492,9 @@ class OperatorWithKernel : public OperatorBase { ...@@ -490,6 +492,9 @@ class OperatorWithKernel : public OperatorBase {
return platform::is_gpu_place(kern_pair.first.place_); return platform::is_gpu_place(kern_pair.first.place_);
}); });
} }
bool SupportsMKLDNN() const override;
bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) const;
virtual void InferShape(InferShapeContext* ctx) const = 0; virtual void InferShape(InferShapeContext* ctx) const = 0;
......
...@@ -106,7 +106,7 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, ...@@ -106,7 +106,7 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
auto it = oper.Attrs().find("use_mkldnn"); auto it = oper.Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() && if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
platform::CanMKLDNNBeUsed(ctx)) { oper.CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
......
...@@ -119,7 +119,7 @@ class AddMMOp : public framework::OperatorWithKernel { ...@@ -119,7 +119,7 @@ class AddMMOp : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
......
...@@ -157,8 +157,7 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType( ...@@ -157,8 +157,7 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType(
framework::LibraryType library = framework::LibraryType::kPlain; framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) {
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
...@@ -527,8 +526,7 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( ...@@ -527,8 +526,7 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
framework::DataLayout layout = framework::DataLayout::kAnyLayout; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) {
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
......
...@@ -83,7 +83,7 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -83,7 +83,7 @@ class ConcatOp : public framework::OperatorWithKernel {
"All Inputs of Concat OP are Empty!")); "All Inputs of Concat OP are Empty!"));
} }
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) { if (this->CanMKLDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(), return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN); framework::LibraryType::kMKLDNN);
......
...@@ -155,8 +155,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -155,8 +155,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
} }
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) {
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
customized_type_value = customized_type_value =
...@@ -565,7 +564,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( ...@@ -565,7 +564,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { this->CanMKLDNNBeUsed(ctx)) {
const std::string data_format = ctx.Attr<std::string>("data_format"); const std::string data_format = ctx.Attr<std::string>("data_format");
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
......
...@@ -193,7 +193,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( ...@@ -193,7 +193,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
......
...@@ -184,7 +184,7 @@ class DataNormOp : public framework::OperatorWithKernel { ...@@ -184,7 +184,7 @@ class DataNormOp : public framework::OperatorWithKernel {
framework::DataLayout layout = framework::DataLayout::kAnyLayout; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
...@@ -486,7 +486,7 @@ class DataNormGradOp : public framework::OperatorWithKernel { ...@@ -486,7 +486,7 @@ class DataNormGradOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
......
...@@ -98,7 +98,7 @@ class PriorBoxOp : public framework::OperatorWithKernel { ...@@ -98,7 +98,7 @@ class PriorBoxOp : public framework::OperatorWithKernel {
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
auto input_image_type = ctx.Input<framework::Tensor>("Image")->type(); auto input_image_type = ctx.Input<framework::Tensor>("Image")->type();
......
...@@ -163,7 +163,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { ...@@ -163,7 +163,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) { if (this->CanMKLDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(), return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN); framework::LibraryType::kMKLDNN);
......
...@@ -33,7 +33,7 @@ class ElementwiseMulOp : public ElementwiseOp { ...@@ -33,7 +33,7 @@ class ElementwiseMulOp : public ElementwiseOp {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) { if (this->CanMKLDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(), return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN); framework::LibraryType::kMKLDNN);
......
...@@ -108,7 +108,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -108,7 +108,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) { if (this->CanMKLDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(), return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN); framework::LibraryType::kMKLDNN);
...@@ -265,8 +265,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -265,8 +265,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
return (ctx.Input<Tensor>("X")->dims() == ctx.Input<Tensor>("Y")->dims()); return (ctx.Input<Tensor>("X")->dims() == ctx.Input<Tensor>("Y")->dims());
}; };
if (platform::CanMKLDNNBeUsed(ctx) && if (this->CanMKLDNNBeUsed(ctx) && (ctx.Type() != "elementwise_add_grad" ||
(ctx.Type() != "elementwise_add_grad" ||
CanMKLDNNElementwiseAddGradBeUsed())) { CanMKLDNNElementwiseAddGradBeUsed())) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(), return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
...@@ -304,7 +303,7 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { ...@@ -304,7 +303,7 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) { if (this->CanMKLDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(), return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN); framework::LibraryType::kMKLDNN);
...@@ -343,7 +342,7 @@ class ElementwiseOpDoubleGradWithoutDXDY ...@@ -343,7 +342,7 @@ class ElementwiseOpDoubleGradWithoutDXDY
} }
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) { if (this->CanMKLDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(), return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN); framework::LibraryType::kMKLDNN);
......
...@@ -133,7 +133,7 @@ framework::OpKernelType FusionGRUOp::GetExpectedKernelType( ...@@ -133,7 +133,7 @@ framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
framework::LibraryType library = framework::LibraryType::kPlain; framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) { if (this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
......
...@@ -115,7 +115,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -115,7 +115,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
......
...@@ -49,7 +49,7 @@ class GeluOp : public framework::OperatorWithKernel { ...@@ -49,7 +49,7 @@ class GeluOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
auto it = this->Attrs().find("use_mkldnn"); auto it = this->Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
it != this->Attrs().end() && platform::CanMKLDNNBeUsed(ctx)) { it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
...@@ -89,7 +89,7 @@ class GeluGradOp : public framework::OperatorWithKernel { ...@@ -89,7 +89,7 @@ class GeluGradOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
auto it = this->Attrs().find("use_mkldnn"); auto it = this->Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
it != this->Attrs().end() && platform::CanMKLDNNBeUsed(ctx)) { it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
......
...@@ -104,7 +104,7 @@ class LayerNormOp : public framework::OperatorWithKernel { ...@@ -104,7 +104,7 @@ class LayerNormOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
......
...@@ -201,7 +201,7 @@ class LRNOp : public framework::OperatorWithKernel { ...@@ -201,7 +201,7 @@ class LRNOp : public framework::OperatorWithKernel {
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
...@@ -341,7 +341,7 @@ class LRNOpGrad : public framework::OperatorWithKernel { ...@@ -341,7 +341,7 @@ class LRNOpGrad : public framework::OperatorWithKernel {
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
......
...@@ -659,7 +659,7 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -659,7 +659,7 @@ class MatMulOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
using mkldnn::memory; using mkldnn::memory;
if (platform::CanMKLDNNBeUsed(ctx)) { if (this->CanMKLDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(), return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN); framework::LibraryType::kMKLDNN);
......
...@@ -106,7 +106,7 @@ class MulOp : public framework::OperatorWithKernel { ...@@ -106,7 +106,7 @@ class MulOp : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
......
...@@ -157,7 +157,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( ...@@ -157,7 +157,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
...@@ -213,7 +213,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( ...@@ -213,7 +213,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
......
...@@ -72,7 +72,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -72,7 +72,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
...@@ -196,7 +196,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -196,7 +196,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
......
...@@ -147,7 +147,7 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -147,7 +147,7 @@ class SumOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx) && this->CanMKLDNNBeUsed(ctx) &&
(static_cast<framework::proto::VarType::Type>(dtype) == (static_cast<framework::proto::VarType::Type>(dtype) ==
framework::proto::VarType::FP32 || framework::proto::VarType::FP32 ||
static_cast<framework::proto::VarType::Type>(dtype) == static_cast<framework::proto::VarType::Type>(dtype) ==
......
...@@ -88,7 +88,7 @@ class TransposeOp : public framework::OperatorWithKernel { ...@@ -88,7 +88,7 @@ class TransposeOp : public framework::OperatorWithKernel {
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
...@@ -186,7 +186,7 @@ class TransposeOpGrad : public framework::OperatorWithKernel { ...@@ -186,7 +186,7 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
...@@ -233,7 +233,7 @@ class Transpose2Op : public TransposeOp { ...@@ -233,7 +233,7 @@ class Transpose2Op : public TransposeOp {
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
using framework::proto::VarType; using framework::proto::VarType;
...@@ -298,7 +298,7 @@ class Transpose2OpGrad : public framework::OperatorWithKernel { ...@@ -298,7 +298,7 @@ class Transpose2OpGrad : public framework::OperatorWithKernel {
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
......
...@@ -134,11 +134,6 @@ inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims, ...@@ -134,11 +134,6 @@ inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
return mkldnn::memory::desc({dims}, data_type, format); return mkldnn::memory::desc({dims}, data_type, format);
} }
inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_mkldnn = ctx.Attr<bool>("use_mkldnn");
return use_mkldnn && platform::is_cpu_place(ctx.GetPlace());
}
inline void ClearMKLDNNCache(const platform::Place& place) { inline void ClearMKLDNNCache(const platform::Place& place) {
// Clear mkl-dnn cache, // Clear mkl-dnn cache,
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册