未验证 提交 ea96172e 编写于 作者: H HongyuJia 提交者: GitHub

refine PADDLE_WITH_MKLDNN code (#46053)

* refine PADDLE_WITH_MKLDNN code

* fix data_norm_op

* polish addmm_op
上级 3671d114
......@@ -39,30 +39,24 @@ class AddMMOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
if (input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType()) {
customized_type_value = kMULMKLDNNINT8;
}
}
#endif
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
layout,
library,
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
customized_type_value);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class AddMMOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -16,9 +16,6 @@
#include <string>
#include <unordered_map>
#include <vector>
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
......
......@@ -195,18 +195,16 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType(
"Variance input should be of float type"));
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(
input_data_type, ctx.GetPlace(), layout, library);
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
framework::OpKernelType BatchNormOp::GetKernelTypeForVar(
......@@ -396,19 +394,18 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
}
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace());
}
framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar(
......
......@@ -80,9 +80,7 @@ framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar(
// op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType(
expected_kernel_type.data_type_,
tensor.place(),
framework::StringToDataLayout(data_format));
expected_kernel_type.data_type_, tensor.place(), dl);
}
}
#endif
......
......@@ -200,18 +200,16 @@ class DataNormOp : public framework::OperatorWithKernel {
"bias input should be of float type"));
}
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(
input_data_type, ctx.GetPlace(), layout, library);
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -511,19 +509,18 @@ class DataNormGradOp : public framework::OperatorWithKernel {
}
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace());
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册