提交 b2b9a1bb 编写于 作者: J jiahongyu 提交者: HongyuJia

refine mkldnn code

上级 db97773b
...@@ -24,14 +24,13 @@ namespace operators { ...@@ -24,14 +24,13 @@ namespace operators {
framework::OpKernelType DeQuantOp::GetExpectedKernelType( framework::OpKernelType DeQuantOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library_ = framework::LibraryType::kMKLDNN; auto input_data_type =
framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; framework::OperatorWithKernel::IndicateVarDataType(ctx, "Input");
return framework::OpKernelType( return framework::OpKernelType(input_data_type,
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
ctx.GetPlace(), framework::DataLayout::kMKLDNN,
layout_, framework::LibraryType::kMKLDNN);
library_);
} }
void DeQuantOpMaker::Make() { void DeQuantOpMaker::Make() {
......
...@@ -126,26 +126,21 @@ class FCOp : public framework::OperatorWithKernel { ...@@ -126,26 +126,21 @@ class FCOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type = auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Input"); OperatorWithKernel::IndicateVarDataType(ctx, "Input");
if (ctx.Attr<bool>("use_mkldnn")) { if (ctx.Attr<bool>("use_mkldnn")) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
using framework::proto::VarType; using framework::proto::VarType;
customized_type_value = (input_data_type == VarType::INT8 || int customized_type_value = (input_data_type == VarType::INT8 ||
input_data_type == VarType::UINT8) input_data_type == VarType::UINT8)
? kFCMKLDNNINT8 ? kFCMKLDNNINT8
: kFCMKLDNNFP32; : kFCMKLDNNFP32;
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
customized_type_value);
} }
return framework::OpKernelType(input_data_type, return framework::OpKernelType(input_data_type, ctx.GetPlace());
ctx.GetPlace(),
layout,
library,
customized_type_value);
} }
}; };
......
...@@ -58,21 +58,19 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -58,21 +58,19 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
auto data_type = auto data_type =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")); static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (this->CanMKLDNNBeUsed(ctx, data_type)) {
this->CanMKLDNNBeUsed(ctx, data_type)) { return framework::OpKernelType(data_type,
library = framework::LibraryType::kMKLDNN; ctx.device_context(),
layout = framework::DataLayout::kMKLDNN; framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(data_type, ctx.device_context());
data_type, ctx.device_context(), layout, library);
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
......
...@@ -18,9 +18,6 @@ limitations under the License. */ ...@@ -18,9 +18,6 @@ limitations under the License. */
#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -35,17 +35,16 @@ class GeluOp : public framework::OperatorWithKernel { ...@@ -35,17 +35,16 @@ class GeluOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (this->CanMKLDNNBeUsed(ctx, data_type)) {
this->CanMKLDNNBeUsed(ctx, data_type)) { return framework::OpKernelType(data_type,
library = framework::LibraryType::kMKLDNN; ctx.GetPlace(),
layout = framework::DataLayout::kMKLDNN; framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
} }
#endif #endif
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
}; };
...@@ -76,18 +75,17 @@ class GeluGradOp : public framework::OperatorWithKernel { ...@@ -76,18 +75,17 @@ class GeluGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
auto it = this->Attrs().find("use_mkldnn"); auto it = this->Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain && if (it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) {
it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) { return framework::OpKernelType(data_type,
library = framework::LibraryType::kMKLDNN; ctx.GetPlace(),
layout = framework::DataLayout::kMKLDNN; framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
} }
#endif #endif
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
}; };
......
...@@ -340,8 +340,6 @@ class InterpolateOp : public framework::OperatorWithKernel { ...@@ -340,8 +340,6 @@ class InterpolateOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
framework::LibraryType library = framework::LibraryType::kPlain;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
...@@ -349,12 +347,14 @@ class InterpolateOp : public framework::OperatorWithKernel { ...@@ -349,12 +347,14 @@ class InterpolateOp : public framework::OperatorWithKernel {
// TODO(danqing): support other interp_method // TODO(danqing): support other interp_method
if (this->CanMKLDNNBeUsed(ctx, data_type) && if (this->CanMKLDNNBeUsed(ctx, data_type) &&
(interp_method == "nearest" || interp_method == "bilinear")) { (interp_method == "nearest" || interp_method == "bilinear")) {
layout = framework::DataLayout::kMKLDNN; return framework::OpKernelType(data_type,
library = framework::LibraryType::kMKLDNN; ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
} }
#endif #endif
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
......
...@@ -444,8 +444,6 @@ class InterpolateV2Op : public framework::OperatorWithKernel { ...@@ -444,8 +444,6 @@ class InterpolateV2Op : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
framework::LibraryType library = framework::LibraryType::kPlain;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
...@@ -453,12 +451,14 @@ class InterpolateV2Op : public framework::OperatorWithKernel { ...@@ -453,12 +451,14 @@ class InterpolateV2Op : public framework::OperatorWithKernel {
// TODO(danqing): support other interp_method // TODO(danqing): support other interp_method
if (this->CanMKLDNNBeUsed(ctx, data_type) && if (this->CanMKLDNNBeUsed(ctx, data_type) &&
(interp_method == "nearest" || interp_method == "bilinear")) { (interp_method == "nearest" || interp_method == "bilinear")) {
layout = framework::DataLayout::kMKLDNN; return framework::OpKernelType(data_type,
library = framework::LibraryType::kMKLDNN; ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
} }
#endif #endif
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
......
...@@ -110,21 +110,19 @@ class LayerNormOp : public framework::OperatorWithKernel { ...@@ -110,21 +110,19 @@ class LayerNormOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
int begin_norm_axis = ctx.Attr<int>("begin_norm_axis"); int begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
if (library == framework::LibraryType::kPlain && if (this->CanMKLDNNBeUsed(ctx, input_data_type) &&
this->CanMKLDNNBeUsed(ctx, input_data_type) &&
begin_norm_axis == ctx.Input<Tensor>("X")->dims().size() - 1) { begin_norm_axis == ctx.Input<Tensor>("X")->dims().size() - 1) {
library = framework::LibraryType::kMKLDNN; return framework::OpKernelType(input_data_type,
layout = framework::DataLayout::kMKLDNN; ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(input_data_type, ctx.GetPlace());
input_data_type, ctx.GetPlace(), layout, library);
} }
}; };
......
...@@ -225,19 +225,18 @@ class LRNOp : public framework::OperatorWithKernel { ...@@ -225,19 +225,18 @@ class LRNOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (this->CanMKLDNNBeUsed(ctx, data_type)) {
this->CanMKLDNNBeUsed(ctx, data_type)) { return framework::OpKernelType(data_type,
library_ = framework::LibraryType::kMKLDNN; ctx.GetPlace(),
layout_ = framework::DataLayout::kMKLDNN; framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(data_type, ctx.GetPlace());
data_type, ctx.GetPlace(), layout_, library_);
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
...@@ -360,19 +359,18 @@ class LRNOpGrad : public framework::OperatorWithKernel { ...@@ -360,19 +359,18 @@ class LRNOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (this->CanMKLDNNBeUsed(ctx, data_type)) {
this->CanMKLDNNBeUsed(ctx, data_type)) { return framework::OpKernelType(data_type,
library_ = framework::LibraryType::kMKLDNN; ctx.GetPlace(),
layout_ = framework::DataLayout::kMKLDNN; framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(data_type, ctx.GetPlace());
data_type, ctx.GetPlace(), layout_, library_);
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册