未验证 提交 ea96172e 编写于 作者: H HongyuJia 提交者: GitHub

refine PADDLE_WITH_MKLDNN code (#46053)

* refine PADDLE_WITH_MKLDNN code

* fix data_norm_op

* polish addmm_op
上级 3671d114
...@@ -39,30 +39,24 @@ class AddMMOp : public framework::OperatorWithKernel { ...@@ -39,30 +39,24 @@ class AddMMOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
this->CanMKLDNNBeUsed(ctx, input_data_type)) { int customized_type_value =
library = framework::LibraryType::kMKLDNN; framework::OpKernelType::kDefaultCustomizedTypeValue;
layout = framework::DataLayout::kMKLDNN;
if (input_data_type == framework::DataTypeTrait<int8_t>::DataType() || if (input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType()) { input_data_type == framework::DataTypeTrait<uint8_t>::DataType()) {
customized_type_value = kMULMKLDNNINT8; customized_type_value = kMULMKLDNNINT8;
} }
}
#endif
return framework::OpKernelType(input_data_type, return framework::OpKernelType(input_data_type,
ctx.GetPlace(), ctx.GetPlace(),
layout, framework::DataLayout::kMKLDNN,
library, framework::LibraryType::kMKLDNN,
customized_type_value); customized_type_value);
} }
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class AddMMOpMaker : public framework::OpProtoAndCheckerMaker { class AddMMOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -16,9 +16,6 @@ ...@@ -16,9 +16,6 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
......
...@@ -195,18 +195,16 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType( ...@@ -195,18 +195,16 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType(
"Variance input should be of float type")); "Variance input should be of float type"));
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
this->CanMKLDNNBeUsed(ctx, input_data_type)) { return framework::OpKernelType(input_data_type,
library = framework::LibraryType::kMKLDNN; ctx.GetPlace(),
layout = framework::DataLayout::kMKLDNN; framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(input_data_type, ctx.GetPlace());
input_data_type, ctx.GetPlace(), layout, library);
} }
framework::OpKernelType BatchNormOp::GetKernelTypeForVar( framework::OpKernelType BatchNormOp::GetKernelTypeForVar(
...@@ -396,19 +394,18 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( ...@@ -396,19 +394,18 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
} }
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (this->CanMKLDNNBeUsed(ctx, data_type)) {
this->CanMKLDNNBeUsed(ctx, data_type)) { return framework::OpKernelType(data_type,
library = framework::LibraryType::kMKLDNN; ctx.GetPlace(),
layout = framework::DataLayout::kMKLDNN; framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
} }
#endif #endif
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar( framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar(
......
...@@ -80,9 +80,7 @@ framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar( ...@@ -80,9 +80,7 @@ framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar(
// op. Treat this as NCHW (default data_format value) // op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) { if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType( return framework::OpKernelType(
expected_kernel_type.data_type_, expected_kernel_type.data_type_, tensor.place(), dl);
tensor.place(),
framework::StringToDataLayout(data_format));
} }
} }
#endif #endif
......
...@@ -200,18 +200,16 @@ class DataNormOp : public framework::OperatorWithKernel { ...@@ -200,18 +200,16 @@ class DataNormOp : public framework::OperatorWithKernel {
"bias input should be of float type")); "bias input should be of float type"));
} }
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
this->CanMKLDNNBeUsed(ctx, input_data_type)) { return framework::OpKernelType(input_data_type,
library = framework::LibraryType::kMKLDNN; ctx.GetPlace(),
layout = framework::DataLayout::kMKLDNN; framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(input_data_type, ctx.GetPlace());
input_data_type, ctx.GetPlace(), layout, library);
} }
}; };
...@@ -511,19 +509,18 @@ class DataNormGradOp : public framework::OperatorWithKernel { ...@@ -511,19 +509,18 @@ class DataNormGradOp : public framework::OperatorWithKernel {
} }
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (this->CanMKLDNNBeUsed(ctx, data_type)) {
this->CanMKLDNNBeUsed(ctx, data_type)) { return framework::OpKernelType(data_type,
library = framework::LibraryType::kMKLDNN; ctx.GetPlace(),
layout = framework::DataLayout::kMKLDNN; framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
} }
#endif #endif
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册