提交 4b8d4ade 编写于 作者: J jiahongyu 提交者: HongyuJia

refine mkldnn code

上级 db0ca7a5
...@@ -35,13 +35,8 @@ class PriorBoxOp : public framework::OperatorWithKernel { ...@@ -35,13 +35,8 @@ class PriorBoxOp : public framework::OperatorWithKernel {
auto input_input_type = auto input_input_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Input"); OperatorWithKernel::IndicateVarDataType(ctx, "Input");
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (this->CanMKLDNNBeUsed(ctx, input_input_type)) {
this->CanMKLDNNBeUsed(ctx, input_input_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
auto input_image_type = framework::TransToProtoVarType( auto input_image_type = framework::TransToProtoVarType(
ctx.Input<framework::Tensor>("Image")->dtype()); ctx.Input<framework::Tensor>("Image")->dtype());
int customized_type_value = int customized_type_value =
...@@ -54,13 +49,12 @@ class PriorBoxOp : public framework::OperatorWithKernel { ...@@ -54,13 +49,12 @@ class PriorBoxOp : public framework::OperatorWithKernel {
} }
return framework::OpKernelType(input_input_type, return framework::OpKernelType(input_input_type,
ctx.GetPlace(), ctx.GetPlace(),
layout_, framework::DataLayout::kMKLDNN,
library_, framework::LibraryType::kMKLDNN,
customized_type_value); customized_type_value);
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(input_input_type, ctx.GetPlace());
input_input_type, ctx.GetPlace(), layout_, library_);
} }
}; };
......
...@@ -152,16 +152,16 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -152,16 +152,16 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionGRUOp::GetExpectedKernelType( framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) { if (this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN; return framework::OpKernelType(data_type,
layout = framework::DataLayout::kMKLDNN; ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
} }
#endif #endif
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
void FusionGRUOpMaker::Make() { void FusionGRUOpMaker::Make() {
......
...@@ -175,16 +175,16 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -175,16 +175,16 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionLSTMOp::GetExpectedKernelType( framework::OpKernelType FusionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) { if (this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN; return framework::OpKernelType(data_type,
layout = framework::DataLayout::kMKLDNN; ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
} }
#endif #endif
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
void FusionLSTMOpMaker::Make() { void FusionLSTMOpMaker::Make() {
......
...@@ -143,14 +143,11 @@ void MultiGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -143,14 +143,11 @@ void MultiGRUOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType MultiGRUOp::GetExpectedKernelType( framework::OpKernelType MultiGRUOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kMKLDNN;
framework::DataLayout layout = framework::DataLayout::kMKLDNN;
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace(), ctx.GetPlace(),
layout, framework::DataLayout::kMKLDNN,
library); framework::LibraryType::kMKLDNN);
} }
void MultiGRUOpMaker::Make() { void MultiGRUOpMaker::Make() {
......
...@@ -700,7 +700,6 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -700,7 +700,6 @@ class MatMulOp : public framework::OperatorWithKernel {
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
using dnnl::memory;
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, return framework::OpKernelType(input_data_type,
ctx.GetPlace(), ctx.GetPlace(),
......
...@@ -19,10 +19,6 @@ ...@@ -19,10 +19,6 @@
#include "paddle/fluid/operators/svd_helper.h" #include "paddle/fluid/operators/svd_helper.h"
#include "paddle/phi/kernels/funcs/compare_functors.h" #include "paddle/phi/kernels/funcs/compare_functors.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using DDim = framework::DDim; using DDim = framework::DDim;
......
...@@ -41,17 +41,12 @@ class MulOp : public framework::OperatorWithKernel { ...@@ -41,17 +41,12 @@ class MulOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
if (input_data_type == framework::DataTypeTrait<int8_t>::DataType() || if (input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType()) { input_data_type == framework::DataTypeTrait<uint8_t>::DataType()) {
customized_type_value = kMULMKLDNNINT8; customized_type_value = kMULMKLDNNINT8;
...@@ -62,15 +57,16 @@ class MulOp : public framework::OperatorWithKernel { ...@@ -62,15 +57,16 @@ class MulOp : public framework::OperatorWithKernel {
framework::DataTypeTrait<float>::DataType()) { framework::DataTypeTrait<float>::DataType()) {
customized_type_value = kMULMKLDNNFP32; customized_type_value = kMULMKLDNNFP32;
} }
}
#endif
return framework::OpKernelType(input_data_type, return framework::OpKernelType(input_data_type,
ctx.GetPlace(), ctx.GetPlace(),
layout, framework::DataLayout::kMKLDNN,
library, framework::LibraryType::kMKLDNN,
customized_type_value); customized_type_value);
} }
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class MulOpMaker : public framework::OpProtoAndCheckerMaker { class MulOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -140,17 +136,12 @@ class MulGradOp : public framework::OperatorWithKernel { ...@@ -140,17 +136,12 @@ class MulGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
if (input_data_type == framework::DataTypeTrait<int8_t>::DataType() || if (input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType()) { input_data_type == framework::DataTypeTrait<uint8_t>::DataType()) {
customized_type_value = kMULMKLDNNINT8; customized_type_value = kMULMKLDNNINT8;
...@@ -161,15 +152,16 @@ class MulGradOp : public framework::OperatorWithKernel { ...@@ -161,15 +152,16 @@ class MulGradOp : public framework::OperatorWithKernel {
framework::DataTypeTrait<float>::DataType()) { framework::DataTypeTrait<float>::DataType()) {
customized_type_value = kMULMKLDNNFP32; customized_type_value = kMULMKLDNNFP32;
} }
}
#endif
return framework::OpKernelType(input_data_type, return framework::OpKernelType(input_data_type,
ctx.GetPlace(), ctx.GetPlace(),
layout, framework::DataLayout::kMKLDNN,
library, framework::LibraryType::kMKLDNN,
customized_type_value); customized_type_value);
} }
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
template <typename T> template <typename T>
......
...@@ -42,8 +42,7 @@ bool CanMKLDNNSupportPool(const framework::ExecutionContext& ctx) { ...@@ -42,8 +42,7 @@ bool CanMKLDNNSupportPool(const framework::ExecutionContext& ctx) {
framework::OpKernelType PoolOp::GetExpectedKernelType( framework::OpKernelType PoolOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = "AnyLayout"; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
...@@ -88,8 +87,7 @@ framework::OpKernelType PoolOp::GetKernelTypeForVar( ...@@ -88,8 +87,7 @@ framework::OpKernelType PoolOp::GetKernelTypeForVar(
framework::OpKernelType PoolOpGrad::GetExpectedKernelType( framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = "AnyLayout"; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......
...@@ -23,26 +23,6 @@ namespace operators { ...@@ -23,26 +23,6 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
framework::OpKernelType innerGetKernelTypeForVar(
const Tensor &tensor, const framework::OpKernelType &expected_kernel_type) {
#ifdef PADDLE_WITH_MKLDNN
auto isOneDNNKernelChosen =
(expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN);
auto isNotOneDNNTensor = (tensor.layout() != framework::DataLayout::kMKLDNN);
auto isModelNHWC =
(paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == framework::DataLayout::kNHWC);
// All inputs (including alpha) need shape rotating
if (isOneDNNKernelChosen && isNotOneDNNTensor && isModelNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(),
framework::DataLayout::kNHWC);
}
#endif
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
class PReluOp : public framework::OperatorWithKernel { class PReluOp : public framework::OperatorWithKernel {
public: public:
PReluOp(const std::string &type, PReluOp(const std::string &type,
...@@ -72,7 +52,19 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -72,7 +52,19 @@ class PReluOp : public framework::OperatorWithKernel {
const std::string &var_name, const std::string &var_name,
const Tensor &tensor, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override { const framework::OpKernelType &expected_kernel_type) const override {
return innerGetKernelTypeForVar(tensor, expected_kernel_type); #ifdef PADDLE_WITH_MKLDNN
// All inputs (including alpha) need shape rotating
if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN) &&
paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == framework::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(),
framework::DataLayout::kNHWC);
}
#endif
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
} }
}; };
...@@ -151,7 +143,19 @@ class PReluGradOp : public framework::OperatorWithKernel { ...@@ -151,7 +143,19 @@ class PReluGradOp : public framework::OperatorWithKernel {
const std::string &var_name, const std::string &var_name,
const Tensor &tensor, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override { const framework::OpKernelType &expected_kernel_type) const override {
return innerGetKernelTypeForVar(tensor, expected_kernel_type); #ifdef PADDLE_WITH_MKLDNN
// All inputs (including alpha) need shape rotating
if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN) &&
paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == framework::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(),
framework::DataLayout::kNHWC);
}
#endif
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
} }
}; };
......
...@@ -17,12 +17,9 @@ ...@@ -17,12 +17,9 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/phi/core/ddim.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
......
...@@ -24,14 +24,11 @@ namespace operators { ...@@ -24,14 +24,11 @@ namespace operators {
framework::OpKernelType QuantOp::GetExpectedKernelType( framework::OpKernelType QuantOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library_ = framework::LibraryType::kMKLDNN;
framework::DataLayout layout_ = framework::DataLayout::kMKLDNN;
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.GetPlace(), ctx.GetPlace(),
layout_, framework::DataLayout::kMKLDNN,
library_); framework::LibraryType::kMKLDNN);
} }
void QuantOpMaker::Make() { void QuantOpMaker::Make() {
......
...@@ -24,14 +24,11 @@ namespace operators { ...@@ -24,14 +24,11 @@ namespace operators {
framework::OpKernelType ReQuantOp::GetExpectedKernelType( framework::OpKernelType ReQuantOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library_ = framework::LibraryType::kMKLDNN;
framework::DataLayout layout_ = framework::DataLayout::kMKLDNN;
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.GetPlace(), ctx.GetPlace(),
layout_, framework::DataLayout::kMKLDNN,
library_); framework::LibraryType::kMKLDNN);
} }
void ReQuantOpMaker::Make() { void ReQuantOpMaker::Make() {
......
...@@ -21,9 +21,6 @@ ...@@ -21,9 +21,6 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -99,19 +99,18 @@ class TransposeOp : public framework::OperatorWithKernel { ...@@ -99,19 +99,18 @@ class TransposeOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
auto &data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (this->CanMKLDNNBeUsed(ctx, data_type)) {
this->CanMKLDNNBeUsed(ctx, data_type)) { return framework::OpKernelType(data_type,
library_ = framework::LibraryType::kMKLDNN; ctx.GetPlace(),
layout_ = framework::DataLayout::kMKLDNN; framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
} }
#endif #endif
return framework::OpKernelType( auto &data_format = ctx.Attr<std::string>("data_format");
data_type, ctx.GetPlace(), layout_, library_); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_);
} }
}; };
...@@ -203,20 +202,19 @@ class TransposeOpGrad : public framework::OperatorWithKernel { ...@@ -203,20 +202,19 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
auto data_type = OperatorWithKernel::IndicateVarDataType( auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (this->CanMKLDNNBeUsed(ctx, data_type)) {
this->CanMKLDNNBeUsed(ctx, data_type)) { return framework::OpKernelType(data_type,
library_ = framework::LibraryType::kMKLDNN; ctx.GetPlace(),
layout_ = framework::DataLayout::kMKLDNN; framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
} }
#endif #endif
return framework::OpKernelType( std::string data_format = ctx.Attr<std::string>("data_format");
data_type, ctx.GetPlace(), layout_, library_); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_);
} }
}; };
...@@ -249,29 +247,27 @@ class Transpose2Op : public TransposeOp { ...@@ -249,29 +247,27 @@ class Transpose2Op : public TransposeOp {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
framework::proto::VarType::Type data_type = framework::proto::VarType::Type data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "X"); OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (this->CanMKLDNNBeUsed(ctx, data_type)) {
this->CanMKLDNNBeUsed(ctx, data_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
using framework::proto::VarType; using framework::proto::VarType;
auto input_data_type = auto input_data_type =
framework::TransToProtoVarType(ctx.Input<Tensor>("X")->dtype()); framework::TransToProtoVarType(ctx.Input<Tensor>("X")->dtype());
customized_type_value = (input_data_type == VarType::INT8 || int customized_type_value = (input_data_type == VarType::INT8 ||
input_data_type == VarType::UINT8) input_data_type == VarType::UINT8)
? kTransposeMKLDNNINT8 ? kTransposeMKLDNNINT8
: kTransposeMKLDNNFP32; : kTransposeMKLDNNFP32;
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
customized_type_value);
} }
#endif #endif
return framework::OpKernelType( std::string data_format = ctx.Attr<std::string>("data_format");
data_type, ctx.GetPlace(), layout_, library_, customized_type_value); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_);
} }
}; };
...@@ -371,21 +367,20 @@ class Transpose2OpGrad : public framework::OperatorWithKernel { ...@@ -371,21 +367,20 @@ class Transpose2OpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
framework::proto::VarType::Type data_type = framework::proto::VarType::Type data_type =
OperatorWithKernel::IndicateVarDataType(ctx, OperatorWithKernel::IndicateVarDataType(ctx,
framework::GradVarName("Out")); framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (this->CanMKLDNNBeUsed(ctx, data_type)) {
this->CanMKLDNNBeUsed(ctx, data_type)) { return framework::OpKernelType(data_type,
library_ = framework::LibraryType::kMKLDNN; ctx.GetPlace(),
layout_ = framework::DataLayout::kMKLDNN; framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
} }
#endif #endif
return framework::OpKernelType( std::string data_format = ctx.Attr<std::string>("data_format");
data_type, ctx.GetPlace(), layout_, library_); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册