未验证 提交 5bf25d1e 编写于 作者: A arlesniak 提交者: GitHub

More precise mkldnn kernel rules in GetExpectedKernelType (#29840)

* More precise mkldnn kernel choice in GetExpectedKernelType

* Fixes after review

* Refresh develop for CI

* CI experiment

* get back from CI exper
上级 a28a2026
......@@ -1040,21 +1040,23 @@ static void CheckTensorNANOrInf(const std::string& op_type,
op_type, name));
}
bool OperatorWithKernel::SupportsMKLDNN() const {
bool OperatorWithKernel::SupportsMKLDNN(
const proto::VarType::Type data_type) const {
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(op_kernels.begin(), op_kernels.end(),
[](OpKernelMap::const_reference kern_pair) {
[data_type](OpKernelMap::const_reference kern_pair) {
return platform::is_cpu_place(kern_pair.first.place_) &&
kern_pair.first.library_type_ ==
LibraryType::kMKLDNN;
LibraryType::kMKLDNN &&
kern_pair.first.data_type_ == data_type;
});
}
bool OperatorWithKernel::CanMKLDNNBeUsed(
const framework::ExecutionContext& ctx) const {
bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const {
bool use_mkldnn_ctx =
ctx.Attr<bool>("use_mkldnn") && platform::is_cpu_place(ctx.GetPlace());
return use_mkldnn_ctx && this->SupportsMKLDNN();
return use_mkldnn_ctx && this->SupportsMKLDNN(data_type);
}
void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
......
......@@ -156,8 +156,6 @@ class OperatorBase {
virtual bool SupportGPU() const { return false; }
virtual bool SupportsMKLDNN() const { return false; }
const std::string& Type() const { return type_; }
bool HasAttr(const std::string& name) const { return attrs_.count(name); }
......@@ -492,9 +490,10 @@ class OperatorWithKernel : public OperatorBase {
return platform::is_gpu_place(kern_pair.first.place_);
});
}
bool SupportsMKLDNN() const override;
bool SupportsMKLDNN(proto::VarType::Type data_type) const;
bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) const;
bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const;
virtual void InferShape(InferShapeContext* ctx) const = 0;
......
......@@ -93,6 +93,7 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const std::string& name) {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = oper.IndicateVarDataType(ctx, name);
// FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind
// https://github.com/PaddlePaddle/Paddle/issues/16096
......@@ -106,13 +107,12 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
#ifdef PADDLE_WITH_MKLDNN
auto it = oper.Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
oper.CanMKLDNNBeUsed(ctx)) {
oper.CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(oper.IndicateVarDataType(ctx, name),
ctx.GetPlace(), layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}
class ActivationOp : public framework::OperatorWithKernel {
......
......@@ -119,7 +119,7 @@ class AddMMOp : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
......
......@@ -157,7 +157,8 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType(
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) {
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
......@@ -524,17 +525,17 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) {
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout,
library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}
framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar(
......
......@@ -83,7 +83,7 @@ class ConcatOp : public framework::OperatorWithKernel {
"All Inputs of Concat OP are Empty!"));
}
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
......
......@@ -155,7 +155,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) {
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
customized_type_value =
......@@ -556,6 +557,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
std::string data_format = "AnyLayout";
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
......@@ -564,7 +566,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, data_type)) {
const std::string data_format = ctx.Attr<std::string>("data_format");
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
......@@ -572,9 +574,8 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
}
#endif
auto type = framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_, customized_type_value);
auto type = framework::OpKernelType(data_type, ctx.GetPlace(), layout_,
library_, customized_type_value);
return type;
}
......
......@@ -182,6 +182,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
......@@ -193,15 +194,13 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, data_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, library_);
}
framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar(
......
......@@ -184,7 +184,7 @@ class DataNormOp : public framework::OperatorWithKernel {
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
......@@ -483,18 +483,17 @@ class DataNormGradOp : public framework::OperatorWithKernel {
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}
};
......
......@@ -98,7 +98,7 @@ class PriorBoxOp : public framework::OperatorWithKernel {
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, input_input_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
auto input_image_type = ctx.Input<framework::Tensor>("Image")->type();
......
......@@ -207,7 +207,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Out");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
......
......@@ -34,7 +34,7 @@ class ElementwiseMulOp : public ElementwiseOp {
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
......
......@@ -110,7 +110,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
......@@ -280,7 +280,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
return (ctx.Input<Tensor>("X")->dims() == ctx.Input<Tensor>("Y")->dims());
};
if (this->CanMKLDNNBeUsed(ctx) && (ctx.Type() != "elementwise_add_grad" ||
if (this->CanMKLDNNBeUsed(ctx, input_data_type) &&
(ctx.Type() != "elementwise_add_grad" ||
CanMKLDNNElementwiseAddGradBeUsed())) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
......@@ -331,7 +332,7 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
......@@ -384,7 +385,7 @@ class ElementwiseOpDoubleGradWithoutDXDY
}
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
......
......@@ -133,15 +133,14 @@ framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout,
library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}
void FusionGRUOpMaker::Make() {
......
......@@ -112,18 +112,19 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
auto data_type =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.device_context(), layout, library);
return framework::OpKernelType(data_type, ctx.device_context(), layout,
library);
}
framework::OpKernelType GetKernelTypeForVar(
......
......@@ -46,17 +46,16 @@ class GeluOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
auto it = this->Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain &&
it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx)) {
it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}
};
......@@ -86,17 +85,16 @@ class GeluGradOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
auto it = this->Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain &&
it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx)) {
it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}
};
......
......@@ -322,20 +322,19 @@ class InterpolateOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
framework::LibraryType library = framework::LibraryType::kPlain;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
auto interp_method = ctx.Attr<std::string>("interp_method");
// TODO(danqing): support other interp_method
if (this->CanMKLDNNBeUsed(ctx) &&
if (this->CanMKLDNNBeUsed(ctx, data_type) &&
(interp_method == "nearest" || interp_method == "bilinear")) {
layout = framework::DataLayout::kMKLDNN;
library = framework::LibraryType::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}
framework::OpKernelType GetKernelTypeForVar(
......
......@@ -124,7 +124,7 @@ class LayerNormOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
......
......@@ -199,16 +199,16 @@ class LRNOp : public framework::OperatorWithKernel {
framework::LibraryType library_{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, data_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout_, library_);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_,
library_);
}
framework::OpKernelType GetKernelTypeForVar(
......@@ -339,16 +339,16 @@ class LRNOpGrad : public framework::OperatorWithKernel {
framework::LibraryType library_{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, data_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout_, library_);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_,
library_);
}
framework::OpKernelType GetKernelTypeForVar(
......
......@@ -661,7 +661,7 @@ class MatMulOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
using mkldnn::memory;
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
......
......@@ -106,7 +106,7 @@ class MulOp : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
......
......@@ -149,6 +149,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = "AnyLayout";
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
......@@ -157,15 +158,13 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, data_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout_, library_);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, library_);
}
framework::OpKernelType PoolOp::GetKernelTypeForVar(
......@@ -205,6 +204,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = "AnyLayout";
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
......@@ -213,14 +213,12 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
library_);
}
......
......@@ -64,6 +64,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
......@@ -72,13 +73,12 @@ class SoftmaxOp : public framework::OperatorWithKernel {
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument(
......@@ -188,7 +188,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
......@@ -196,13 +197,11 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument(
......
......@@ -145,29 +145,26 @@ class SumOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument(
"Sum operator should have at least one tensor"));
auto data_type = static_cast<framework::proto::VarType::Type>(dtype);
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx) &&
(static_cast<framework::proto::VarType::Type>(dtype) ==
framework::proto::VarType::FP32 ||
static_cast<framework::proto::VarType::Type>(dtype) ==
framework::proto::VarType::BF16) &&
this->CanMKLDNNBeUsed(ctx, data_type) &&
(data_type == framework::proto::VarType::FP32 ||
data_type == framework::proto::VarType::BF16) &&
ctx.OutputVar("Out")->IsType<framework::LoDTensor>()) {
if (std::all_of(x_vars.begin(), x_vars.end(),
[](const framework::Variable* v) {
return v->IsType<framework::LoDTensor>();
})) {
return framework::OpKernelType(
static_cast<framework::proto::VarType::Type>(dtype),
ctx.GetPlace(), framework::DataLayout::kMKLDNN,
return framework::OpKernelType(data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
}
#endif
return framework::OpKernelType(
static_cast<framework::proto::VarType::Type>(dtype), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout,
library);
} else if (x_vars[0]->IsType<framework::SelectedRows>()) {
for (auto& var : x_vars) {
auto& value = var->Get<framework::SelectedRows>().value();
......
......@@ -86,16 +86,16 @@ class TransposeOp : public framework::OperatorWithKernel {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, data_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout_, library_);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_,
library_);
}
};
......@@ -184,16 +184,17 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, data_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace(), layout_, library_);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_,
library_);
}
};
......@@ -231,9 +232,11 @@ class Transpose2Op : public TransposeOp {
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
framework::proto::VarType::Type data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, data_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
using framework::proto::VarType;
......@@ -244,9 +247,8 @@ class Transpose2Op : public TransposeOp {
: kTransposeMKLDNNFP32;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout_, library_, customized_type_value);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, library_,
customized_type_value);
}
};
......@@ -310,16 +312,18 @@ class Transpose2OpGrad : public framework::OperatorWithKernel {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
framework::proto::VarType::Type data_type =
OperatorWithKernel::IndicateVarDataType(ctx,
framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, data_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace(), layout_, library_);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_,
library_);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册