提交 4b8d4ade 编写于 作者: J jiahongyu 提交者: HongyuJia

refine mkldnn code

上级 db0ca7a5
......@@ -35,13 +35,8 @@ class PriorBoxOp : public framework::OperatorWithKernel {
auto input_input_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Input");
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, input_input_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
if (this->CanMKLDNNBeUsed(ctx, input_input_type)) {
auto input_image_type = framework::TransToProtoVarType(
ctx.Input<framework::Tensor>("Image")->dtype());
int customized_type_value =
......@@ -54,13 +49,12 @@ class PriorBoxOp : public framework::OperatorWithKernel {
}
return framework::OpKernelType(input_input_type,
ctx.GetPlace(),
layout_,
library_,
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
customized_type_value);
}
#endif
return framework::OpKernelType(
input_input_type, ctx.GetPlace(), layout_, library_);
return framework::OpKernelType(input_input_type, ctx.GetPlace());
}
};
......
......@@ -152,16 +152,16 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace());
}
void FusionGRUOpMaker::Make() {
......
......@@ -175,16 +175,16 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace());
}
void FusionLSTMOpMaker::Make() {
......
......@@ -143,14 +143,11 @@ void MultiGRUOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType MultiGRUOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kMKLDNN;
framework::DataLayout layout = framework::DataLayout::kMKLDNN;
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace(),
layout,
library);
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
void MultiGRUOpMaker::Make() {
......
......@@ -700,7 +700,6 @@ class MatMulOp : public framework::OperatorWithKernel {
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
#ifdef PADDLE_WITH_MKLDNN
using dnnl::memory;
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
......
......@@ -19,10 +19,6 @@
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/phi/kernels/funcs/compare_functors.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
using DDim = framework::DDim;
......
......@@ -41,17 +41,12 @@ class MulOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
if (input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType()) {
customized_type_value = kMULMKLDNNINT8;
......@@ -62,15 +57,16 @@ class MulOp : public framework::OperatorWithKernel {
framework::DataTypeTrait<float>::DataType()) {
customized_type_value = kMULMKLDNNFP32;
}
}
#endif
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
layout,
library,
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
customized_type_value);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class MulOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -140,17 +136,12 @@ class MulGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
if (input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType()) {
customized_type_value = kMULMKLDNNINT8;
......@@ -161,15 +152,16 @@ class MulGradOp : public framework::OperatorWithKernel {
framework::DataTypeTrait<float>::DataType()) {
customized_type_value = kMULMKLDNNFP32;
}
}
#endif
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
layout,
library,
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
customized_type_value);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
template <typename T>
......
......@@ -42,8 +42,7 @@ bool CanMKLDNNSupportPool(const framework::ExecutionContext& ctx) {
framework::OpKernelType PoolOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = "AnyLayout";
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......@@ -88,8 +87,7 @@ framework::OpKernelType PoolOp::GetKernelTypeForVar(
framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = "AnyLayout";
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......
......@@ -23,26 +23,6 @@ namespace operators {
using Tensor = framework::Tensor;
framework::OpKernelType innerGetKernelTypeForVar(
const Tensor &tensor, const framework::OpKernelType &expected_kernel_type) {
#ifdef PADDLE_WITH_MKLDNN
auto isOneDNNKernelChosen =
(expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN);
auto isNotOneDNNTensor = (tensor.layout() != framework::DataLayout::kMKLDNN);
auto isModelNHWC =
(paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == framework::DataLayout::kNHWC);
// All inputs (including alpha) need shape rotating
if (isOneDNNKernelChosen && isNotOneDNNTensor && isModelNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(),
framework::DataLayout::kNHWC);
}
#endif
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
class PReluOp : public framework::OperatorWithKernel {
public:
PReluOp(const std::string &type,
......@@ -72,7 +52,19 @@ class PReluOp : public framework::OperatorWithKernel {
const std::string &var_name,
const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return innerGetKernelTypeForVar(tensor, expected_kernel_type);
#ifdef PADDLE_WITH_MKLDNN
// All inputs (including alpha) need shape rotating
if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN) &&
paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == framework::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(),
framework::DataLayout::kNHWC);
}
#endif
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
};
......@@ -151,7 +143,19 @@ class PReluGradOp : public framework::OperatorWithKernel {
const std::string &var_name,
const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return innerGetKernelTypeForVar(tensor, expected_kernel_type);
#ifdef PADDLE_WITH_MKLDNN
// All inputs (including alpha) need shape rotating
if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN) &&
paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == framework::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(),
framework::DataLayout::kNHWC);
}
#endif
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
};
......
......@@ -17,12 +17,9 @@
#include <unordered_map>
#include <vector>
#include "paddle/phi/core/ddim.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
......
......@@ -24,14 +24,11 @@ namespace operators {
framework::OpKernelType QuantOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library_ = framework::LibraryType::kMKLDNN;
framework::DataLayout layout_ = framework::DataLayout::kMKLDNN;
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.GetPlace(),
layout_,
library_);
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
void QuantOpMaker::Make() {
......
......@@ -24,14 +24,11 @@ namespace operators {
framework::OpKernelType ReQuantOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library_ = framework::LibraryType::kMKLDNN;
framework::DataLayout layout_ = framework::DataLayout::kMKLDNN;
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.GetPlace(),
layout_,
library_);
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
void ReQuantOpMaker::Make() {
......
......@@ -21,9 +21,6 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/infermeta/unary.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
......
......@@ -99,19 +99,18 @@ class TransposeOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
auto &data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, data_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(
data_type, ctx.GetPlace(), layout_, library_);
auto &data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_);
}
};
......@@ -203,20 +202,19 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, data_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(
data_type, ctx.GetPlace(), layout_, library_);
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_);
}
};
......@@ -249,29 +247,27 @@ class Transpose2Op : public TransposeOp {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
framework::proto::VarType::Type data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, data_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
using framework::proto::VarType;
auto input_data_type =
framework::TransToProtoVarType(ctx.Input<Tensor>("X")->dtype());
customized_type_value = (input_data_type == VarType::INT8 ||
int customized_type_value = (input_data_type == VarType::INT8 ||
input_data_type == VarType::UINT8)
? kTransposeMKLDNNINT8
: kTransposeMKLDNNFP32;
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
customized_type_value);
}
#endif
return framework::OpKernelType(
data_type, ctx.GetPlace(), layout_, library_, customized_type_value);
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_);
}
};
......@@ -371,21 +367,20 @@ class Transpose2OpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
framework::proto::VarType::Type data_type =
OperatorWithKernel::IndicateVarDataType(ctx,
framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, data_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(
data_type, ctx.GetPlace(), layout_, library_);
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册