提交 b2b9a1bb 编写于 作者: J jiahongyu 提交者: HongyuJia

refine mkldnn code

上级 db97773b
......@@ -24,14 +24,13 @@ namespace operators {
framework::OpKernelType DeQuantOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library_ = framework::LibraryType::kMKLDNN;
framework::DataLayout layout_ = framework::DataLayout::kMKLDNN;
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.GetPlace(),
layout_,
library_);
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "Input");
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
void DeQuantOpMaker::Make() {
......
......@@ -126,26 +126,21 @@ class FCOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Input");
if (ctx.Attr<bool>("use_mkldnn")) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
using framework::proto::VarType;
customized_type_value = (input_data_type == VarType::INT8 ||
input_data_type == VarType::UINT8)
? kFCMKLDNNINT8
: kFCMKLDNNFP32;
int customized_type_value = (input_data_type == VarType::INT8 ||
input_data_type == VarType::UINT8)
? kFCMKLDNNINT8
: kFCMKLDNNFP32;
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
customized_type_value);
}
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
layout,
library,
customized_type_value);
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -58,21 +58,19 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
auto data_type =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.device_context(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(
data_type, ctx.device_context(), layout, library);
return framework::OpKernelType(data_type, ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
......
......@@ -18,9 +18,6 @@ limitations under the License. */
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
......
......@@ -35,17 +35,16 @@ class GeluOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace());
}
};
......@@ -76,18 +75,17 @@ class GeluGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
auto it = this->Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain &&
it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
if (it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace());
}
};
......
......@@ -340,8 +340,6 @@ class InterpolateOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
framework::LibraryType library = framework::LibraryType::kPlain;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
......@@ -349,12 +347,14 @@ class InterpolateOp : public framework::OperatorWithKernel {
// TODO(danqing): support other interp_method
if (this->CanMKLDNNBeUsed(ctx, data_type) &&
(interp_method == "nearest" || interp_method == "bilinear")) {
layout = framework::DataLayout::kMKLDNN;
library = framework::LibraryType::kMKLDNN;
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
......
......@@ -444,8 +444,6 @@ class InterpolateV2Op : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
framework::LibraryType library = framework::LibraryType::kPlain;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
......@@ -453,12 +451,14 @@ class InterpolateV2Op : public framework::OperatorWithKernel {
// TODO(danqing): support other interp_method
if (this->CanMKLDNNBeUsed(ctx, data_type) &&
(interp_method == "nearest" || interp_method == "bilinear")) {
layout = framework::DataLayout::kMKLDNN;
library = framework::LibraryType::kMKLDNN;
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
......
......@@ -110,21 +110,19 @@ class LayerNormOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
int begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, input_data_type) &&
if (this->CanMKLDNNBeUsed(ctx, input_data_type) &&
begin_norm_axis == ctx.Input<Tensor>("X")->dims().size() - 1) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(
input_data_type, ctx.GetPlace(), layout, library);
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -225,19 +225,18 @@ class LRNOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, data_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(
data_type, ctx.GetPlace(), layout_, library_);
return framework::OpKernelType(data_type, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
......@@ -360,19 +359,18 @@ class LRNOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, data_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(
data_type, ctx.GetPlace(), layout_, library_);
return framework::OpKernelType(data_type, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册