未验证 提交 bc902044 编写于 作者: A arlesniak 提交者: GitHub

Fixes mkldnn dygraph learning rate scheduler crashes (#28988)

上级 0fca8cdf
......@@ -1007,6 +1007,23 @@ static void CheckTensorNANOrInf(const std::string& op_type,
op_type, name));
}
bool OperatorWithKernel::SupportsMKLDNN() const {
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(op_kernels.begin(), op_kernels.end(),
[](OpKernelMap::const_reference kern_pair) {
return platform::is_cpu_place(kern_pair.first.place_) &&
kern_pair.first.library_type_ ==
LibraryType::kMKLDNN;
});
}
bool OperatorWithKernel::CanMKLDNNBeUsed(
const framework::ExecutionContext& ctx) const {
bool use_mkldnn_ctx =
ctx.Attr<bool>("use_mkldnn") && platform::is_cpu_place(ctx.GetPlace());
return use_mkldnn_ctx && this->SupportsMKLDNN();
}
void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
const platform::Place& place,
const RuntimeContext& ctx) const {
......
......@@ -156,6 +156,8 @@ class OperatorBase {
virtual bool SupportGPU() const { return false; }
virtual bool SupportsMKLDNN() const { return false; }
const std::string& Type() const { return type_; }
bool HasAttr(const std::string& name) const { return attrs_.count(name); }
......@@ -490,6 +492,9 @@ class OperatorWithKernel : public OperatorBase {
return platform::is_gpu_place(kern_pair.first.place_);
});
}
bool SupportsMKLDNN() const override;
bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) const;
virtual void InferShape(InferShapeContext* ctx) const = 0;
......
......@@ -106,7 +106,7 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
#ifdef PADDLE_WITH_MKLDNN
auto it = oper.Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
platform::CanMKLDNNBeUsed(ctx)) {
oper.CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
......
......@@ -119,7 +119,7 @@ class AddMMOp : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
......
......@@ -157,8 +157,7 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType(
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
......@@ -527,8 +526,7 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
......
......@@ -83,7 +83,7 @@ class ConcatOp : public framework::OperatorWithKernel {
"All Inputs of Concat OP are Empty!"));
}
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
......
......@@ -155,8 +155,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
customized_type_value =
......@@ -565,7 +564,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx)) {
const std::string data_format = ctx.Attr<std::string>("data_format");
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
......
......@@ -193,7 +193,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
......
......@@ -184,7 +184,7 @@ class DataNormOp : public framework::OperatorWithKernel {
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
......@@ -486,7 +486,7 @@ class DataNormGradOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
......
......@@ -98,7 +98,7 @@ class PriorBoxOp : public framework::OperatorWithKernel {
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
auto input_image_type = ctx.Input<framework::Tensor>("Image")->type();
......
......@@ -163,7 +163,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX");
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
......
......@@ -33,7 +33,7 @@ class ElementwiseMulOp : public ElementwiseOp {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
......
......@@ -108,7 +108,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
......@@ -265,9 +265,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
return (ctx.Input<Tensor>("X")->dims() == ctx.Input<Tensor>("Y")->dims());
};
if (platform::CanMKLDNNBeUsed(ctx) &&
(ctx.Type() != "elementwise_add_grad" ||
CanMKLDNNElementwiseAddGradBeUsed())) {
if (this->CanMKLDNNBeUsed(ctx) && (ctx.Type() != "elementwise_add_grad" ||
CanMKLDNNElementwiseAddGradBeUsed())) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
......@@ -304,7 +303,7 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut");
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
......@@ -343,7 +342,7 @@ class ElementwiseOpDoubleGradWithoutDXDY
}
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
......
......@@ -133,7 +133,7 @@ framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
......
......@@ -115,7 +115,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
......
......@@ -49,7 +49,7 @@ class GeluOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
auto it = this->Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain &&
it != this->Attrs().end() && platform::CanMKLDNNBeUsed(ctx)) {
it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
......@@ -89,7 +89,7 @@ class GeluGradOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
auto it = this->Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain &&
it != this->Attrs().end() && platform::CanMKLDNNBeUsed(ctx)) {
it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
......
......@@ -104,7 +104,7 @@ class LayerNormOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
......
......@@ -201,7 +201,7 @@ class LRNOp : public framework::OperatorWithKernel {
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
......@@ -341,7 +341,7 @@ class LRNOpGrad : public framework::OperatorWithKernel {
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
......
......@@ -659,7 +659,7 @@ class MatMulOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
using mkldnn::memory;
if (platform::CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
......
......@@ -106,7 +106,7 @@ class MulOp : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
......
......@@ -157,7 +157,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
......@@ -213,7 +213,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
......
......@@ -72,7 +72,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
......@@ -196,7 +196,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
......
......@@ -147,7 +147,7 @@ class SumOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx) &&
this->CanMKLDNNBeUsed(ctx) &&
(static_cast<framework::proto::VarType::Type>(dtype) ==
framework::proto::VarType::FP32 ||
static_cast<framework::proto::VarType::Type>(dtype) ==
......
......@@ -88,7 +88,7 @@ class TransposeOp : public framework::OperatorWithKernel {
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
......@@ -186,7 +186,7 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
......@@ -233,7 +233,7 @@ class Transpose2Op : public TransposeOp {
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
using framework::proto::VarType;
......@@ -298,7 +298,7 @@ class Transpose2OpGrad : public framework::OperatorWithKernel {
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
......
......@@ -134,11 +134,6 @@ inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
return mkldnn::memory::desc({dims}, data_type, format);
}
inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_mkldnn = ctx.Attr<bool>("use_mkldnn");
return use_mkldnn && platform::is_cpu_place(ctx.GetPlace());
}
inline void ClearMKLDNNCache(const platform::Place& place) {
// Clear mkl-dnn cache,
if (platform::is_cpu_place(place)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册