Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
4b8d4ade
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4b8d4ade
编写于
9月 20, 2022
作者:
J
jiahongyu
提交者:
HongyuJia
9月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine mkldnn code
上级
db0ca7a5
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
106 addition
and
143 deletion
+106
-143
paddle/fluid/operators/detection/prior_box_op.cc
paddle/fluid/operators/detection/prior_box_op.cc
+4
-10
paddle/fluid/operators/fused/fusion_gru_op.cc
paddle/fluid/operators/fused/fusion_gru_op.cc
+5
-5
paddle/fluid/operators/fused/fusion_lstm_op.cc
paddle/fluid/operators/fused/fusion_lstm_op.cc
+5
-5
paddle/fluid/operators/fused/multi_gru_op.cc
paddle/fluid/operators/fused/multi_gru_op.cc
+2
-5
paddle/fluid/operators/matmul_op.cc
paddle/fluid/operators/matmul_op.cc
+0
-1
paddle/fluid/operators/matrix_rank_op.cc
paddle/fluid/operators/matrix_rank_op.cc
+0
-4
paddle/fluid/operators/mul_op.cc
paddle/fluid/operators/mul_op.cc
+20
-28
paddle/fluid/operators/pool_op.cc
paddle/fluid/operators/pool_op.cc
+2
-4
paddle/fluid/operators/prelu_op.cc
paddle/fluid/operators/prelu_op.cc
+26
-22
paddle/fluid/operators/qr_op.cc
paddle/fluid/operators/qr_op.cc
+1
-4
paddle/fluid/operators/quantize_op.cc
paddle/fluid/operators/quantize_op.cc
+2
-5
paddle/fluid/operators/requantize_op.cc
paddle/fluid/operators/requantize_op.cc
+2
-5
paddle/fluid/operators/svd_op.cc
paddle/fluid/operators/svd_op.cc
+0
-3
paddle/fluid/operators/transpose_op.cc
paddle/fluid/operators/transpose_op.cc
+37
-42
未找到文件。
paddle/fluid/operators/detection/prior_box_op.cc
浏览文件 @
4b8d4ade
...
...
@@ -35,13 +35,8 @@ class PriorBoxOp : public framework::OperatorWithKernel {
auto
input_input_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Input"
);
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
framework
::
DataLayout
layout_
=
framework
::
DataLayout
::
kAnyLayout
;
#ifdef PADDLE_WITH_MKLDNN
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
this
->
CanMKLDNNBeUsed
(
ctx
,
input_input_type
))
{
library_
=
framework
::
LibraryType
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
if
(
this
->
CanMKLDNNBeUsed
(
ctx
,
input_input_type
))
{
auto
input_image_type
=
framework
::
TransToProtoVarType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"Image"
)
->
dtype
());
int
customized_type_value
=
...
...
@@ -54,13 +49,12 @@ class PriorBoxOp : public framework::OperatorWithKernel {
}
return
framework
::
OpKernelType
(
input_input_type
,
ctx
.
GetPlace
(),
layout_
,
library_
,
framework
::
DataLayout
::
kMKLDNN
,
framework
::
LibraryType
::
kMKLDNN
,
customized_type_value
);
}
#endif
return
framework
::
OpKernelType
(
input_input_type
,
ctx
.
GetPlace
(),
layout_
,
library_
);
return
framework
::
OpKernelType
(
input_input_type
,
ctx
.
GetPlace
());
}
};
...
...
paddle/fluid/operators/fused/fusion_gru_op.cc
浏览文件 @
4b8d4ade
...
...
@@ -152,16 +152,16 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
framework
::
OpKernelType
FusionGRUOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
framework
::
LibraryType
library
=
framework
::
LibraryType
::
kPlain
;
framework
::
DataLayout
layout
=
framework
::
DataLayout
::
kAnyLayout
;
auto
data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
#ifdef PADDLE_WITH_MKLDNN
if
(
this
->
CanMKLDNNBeUsed
(
ctx
,
data_type
))
{
library
=
framework
::
LibraryType
::
kMKLDNN
;
layout
=
framework
::
DataLayout
::
kMKLDNN
;
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kMKLDNN
,
framework
::
LibraryType
::
kMKLDNN
);
}
#endif
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
()
,
layout
,
library
);
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
());
}
void
FusionGRUOpMaker
::
Make
()
{
...
...
paddle/fluid/operators/fused/fusion_lstm_op.cc
浏览文件 @
4b8d4ade
...
...
@@ -175,16 +175,16 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework
::
OpKernelType
FusionLSTMOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
framework
::
LibraryType
library
=
framework
::
LibraryType
::
kPlain
;
framework
::
DataLayout
layout
=
framework
::
DataLayout
::
kAnyLayout
;
auto
data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
#ifdef PADDLE_WITH_MKLDNN
if
(
this
->
CanMKLDNNBeUsed
(
ctx
,
data_type
))
{
library
=
framework
::
LibraryType
::
kMKLDNN
;
layout
=
framework
::
DataLayout
::
kMKLDNN
;
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kMKLDNN
,
framework
::
LibraryType
::
kMKLDNN
);
}
#endif
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
()
,
layout
,
library
);
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
());
}
void
FusionLSTMOpMaker
::
Make
()
{
...
...
paddle/fluid/operators/fused/multi_gru_op.cc
浏览文件 @
4b8d4ade
...
...
@@ -143,14 +143,11 @@ void MultiGRUOp::InferShape(framework::InferShapeContext* ctx) const {
framework
::
OpKernelType
MultiGRUOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
framework
::
LibraryType
library
=
framework
::
LibraryType
::
kMKLDNN
;
framework
::
DataLayout
layout
=
framework
::
DataLayout
::
kMKLDNN
;
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
(),
layout
,
library
);
framework
::
DataLayout
::
kMKLDNN
,
framework
::
LibraryType
::
kMKLDNN
);
}
void
MultiGRUOpMaker
::
Make
()
{
...
...
paddle/fluid/operators/matmul_op.cc
浏览文件 @
4b8d4ade
...
...
@@ -700,7 +700,6 @@ class MatMulOp : public framework::OperatorWithKernel {
OperatorWithKernel
::
IndicateOrPromoteVarDataTypes
(
ctx
,
"X"
,
"Y"
);
#ifdef PADDLE_WITH_MKLDNN
using
dnnl
::
memory
;
if
(
this
->
CanMKLDNNBeUsed
(
ctx
,
input_data_type
))
{
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
(),
...
...
paddle/fluid/operators/matrix_rank_op.cc
浏览文件 @
4b8d4ade
...
...
@@ -19,10 +19,6 @@
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/phi/kernels/funcs/compare_functors.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
using
DDim
=
framework
::
DDim
;
...
...
paddle/fluid/operators/mul_op.cc
浏览文件 @
4b8d4ade
...
...
@@ -41,17 +41,12 @@ class MulOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
framework
::
LibraryType
library
=
framework
::
LibraryType
::
kPlain
;
framework
::
DataLayout
layout
=
framework
::
DataLayout
::
kAnyLayout
;
int
customized_type_value
=
framework
::
OpKernelType
::
kDefaultCustomizedTypeValue
;
auto
input_data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
#ifdef PADDLE_WITH_MKLDNN
if
(
library
==
framework
::
LibraryType
::
kPlain
&&
this
->
CanMKLDNNBeUsed
(
ctx
,
input_data_type
))
{
library
=
framework
::
LibraryType
::
kMKLDNN
;
layout
=
framework
::
DataLayout
::
kMKLDNN
;
#ifdef PADDLE_WITH_MKLDNN
if
(
this
->
CanMKLDNNBeUsed
(
ctx
,
input_data_type
))
{
int
customized_type_value
=
framework
::
OpKernelType
::
kDefaultCustomizedTypeValue
;
if
(
input_data_type
==
framework
::
DataTypeTrait
<
int8_t
>::
DataType
()
||
input_data_type
==
framework
::
DataTypeTrait
<
uint8_t
>::
DataType
())
{
customized_type_value
=
kMULMKLDNNINT8
;
...
...
@@ -62,14 +57,15 @@ class MulOp : public framework::OperatorWithKernel {
framework
::
DataTypeTrait
<
float
>::
DataType
())
{
customized_type_value
=
kMULMKLDNNFP32
;
}
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kMKLDNN
,
framework
::
LibraryType
::
kMKLDNN
,
customized_type_value
);
}
#endif
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
(),
layout
,
library
,
customized_type_value
);
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
};
...
...
@@ -140,17 +136,12 @@ class MulGradOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
framework
::
LibraryType
library
=
framework
::
LibraryType
::
kPlain
;
framework
::
DataLayout
layout
=
framework
::
DataLayout
::
kAnyLayout
;
int
customized_type_value
=
framework
::
OpKernelType
::
kDefaultCustomizedTypeValue
;
auto
input_data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
#ifdef PADDLE_WITH_MKLDNN
if
(
library
==
framework
::
LibraryType
::
kPlain
&&
this
->
CanMKLDNNBeUsed
(
ctx
,
input_data_type
))
{
library
=
framework
::
LibraryType
::
kMKLDNN
;
layout
=
framework
::
DataLayout
::
kMKLDNN
;
#ifdef PADDLE_WITH_MKLDNN
if
(
this
->
CanMKLDNNBeUsed
(
ctx
,
input_data_type
))
{
int
customized_type_value
=
framework
::
OpKernelType
::
kDefaultCustomizedTypeValue
;
if
(
input_data_type
==
framework
::
DataTypeTrait
<
int8_t
>::
DataType
()
||
input_data_type
==
framework
::
DataTypeTrait
<
uint8_t
>::
DataType
())
{
customized_type_value
=
kMULMKLDNNINT8
;
...
...
@@ -161,14 +152,15 @@ class MulGradOp : public framework::OperatorWithKernel {
framework
::
DataTypeTrait
<
float
>::
DataType
())
{
customized_type_value
=
kMULMKLDNNFP32
;
}
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kMKLDNN
,
framework
::
LibraryType
::
kMKLDNN
,
customized_type_value
);
}
#endif
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
(),
layout
,
library
,
customized_type_value
);
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
};
...
...
paddle/fluid/operators/pool_op.cc
浏览文件 @
4b8d4ade
...
...
@@ -42,8 +42,7 @@ bool CanMKLDNNSupportPool(const framework::ExecutionContext& ctx) {
framework
::
OpKernelType
PoolOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
std
::
string
data_format
=
"AnyLayout"
;
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
framework
::
DataLayout
layout_
=
framework
::
DataLayout
::
kAnyLayout
;
auto
data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
...
...
@@ -88,8 +87,7 @@ framework::OpKernelType PoolOp::GetKernelTypeForVar(
framework
::
OpKernelType
PoolOpGrad
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
std
::
string
data_format
=
"AnyLayout"
;
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
framework
::
DataLayout
layout_
=
framework
::
DataLayout
::
kAnyLayout
;
auto
input_data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
...
...
paddle/fluid/operators/prelu_op.cc
浏览文件 @
4b8d4ade
...
...
@@ -23,26 +23,6 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
framework
::
OpKernelType
innerGetKernelTypeForVar
(
const
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
{
#ifdef PADDLE_WITH_MKLDNN
auto
isOneDNNKernelChosen
=
(
expected_kernel_type
.
data_layout_
==
framework
::
DataLayout
::
kMKLDNN
);
auto
isNotOneDNNTensor
=
(
tensor
.
layout
()
!=
framework
::
DataLayout
::
kMKLDNN
);
auto
isModelNHWC
=
(
paddle
::
platform
::
MKLDNNDeviceContext
::
tls
()
.
get_cur_paddle_data_layout
()
==
framework
::
DataLayout
::
kNHWC
);
// All inputs (including alpha) need shape rotating
if
(
isOneDNNKernelChosen
&&
isNotOneDNNTensor
&&
isModelNHWC
)
{
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
framework
::
DataLayout
::
kNHWC
);
}
#endif
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
tensor
.
layout
());
}
class
PReluOp
:
public
framework
::
OperatorWithKernel
{
public:
PReluOp
(
const
std
::
string
&
type
,
...
...
@@ -72,7 +52,19 @@ class PReluOp : public framework::OperatorWithKernel {
const
std
::
string
&
var_name
,
const
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
override
{
return
innerGetKernelTypeForVar
(
tensor
,
expected_kernel_type
);
#ifdef PADDLE_WITH_MKLDNN
// All inputs (including alpha) need shape rotating
if
((
expected_kernel_type
.
data_layout_
==
framework
::
DataLayout
::
kMKLDNN
)
&&
(
tensor
.
layout
()
!=
framework
::
DataLayout
::
kMKLDNN
)
&&
paddle
::
platform
::
MKLDNNDeviceContext
::
tls
()
.
get_cur_paddle_data_layout
()
==
framework
::
DataLayout
::
kNHWC
)
{
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
framework
::
DataLayout
::
kNHWC
);
}
#endif
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
tensor
.
layout
());
}
};
...
...
@@ -151,7 +143,19 @@ class PReluGradOp : public framework::OperatorWithKernel {
const
std
::
string
&
var_name
,
const
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
override
{
return
innerGetKernelTypeForVar
(
tensor
,
expected_kernel_type
);
#ifdef PADDLE_WITH_MKLDNN
// All inputs (including alpha) need shape rotating
if
((
expected_kernel_type
.
data_layout_
==
framework
::
DataLayout
::
kMKLDNN
)
&&
(
tensor
.
layout
()
!=
framework
::
DataLayout
::
kMKLDNN
)
&&
paddle
::
platform
::
MKLDNNDeviceContext
::
tls
()
.
get_cur_paddle_data_layout
()
==
framework
::
DataLayout
::
kNHWC
)
{
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
framework
::
DataLayout
::
kNHWC
);
}
#endif
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
tensor
.
layout
());
}
};
...
...
paddle/fluid/operators/qr_op.cc
浏览文件 @
4b8d4ade
...
...
@@ -17,12 +17,9 @@
#include <unordered_map>
#include <vector>
#include "paddle/phi/core/ddim.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
...
...
paddle/fluid/operators/quantize_op.cc
浏览文件 @
4b8d4ade
...
...
@@ -24,14 +24,11 @@ namespace operators {
framework
::
OpKernelType
QuantOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
framework
::
LibraryType
library_
=
framework
::
LibraryType
::
kMKLDNN
;
framework
::
DataLayout
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Input"
),
ctx
.
GetPlace
(),
layout_
,
library_
);
framework
::
DataLayout
::
kMKLDNN
,
framework
::
LibraryType
::
kMKLDNN
);
}
void
QuantOpMaker
::
Make
()
{
...
...
paddle/fluid/operators/requantize_op.cc
浏览文件 @
4b8d4ade
...
...
@@ -24,14 +24,11 @@ namespace operators {
framework
::
OpKernelType
ReQuantOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
framework
::
LibraryType
library_
=
framework
::
LibraryType
::
kMKLDNN
;
framework
::
DataLayout
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Input"
),
ctx
.
GetPlace
(),
layout_
,
library_
);
framework
::
DataLayout
::
kMKLDNN
,
framework
::
LibraryType
::
kMKLDNN
);
}
void
ReQuantOpMaker
::
Make
()
{
...
...
paddle/fluid/operators/svd_op.cc
浏览文件 @
4b8d4ade
...
...
@@ -21,9 +21,6 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/infermeta/unary.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
...
...
paddle/fluid/operators/transpose_op.cc
浏览文件 @
4b8d4ade
...
...
@@ -99,19 +99,18 @@ class TransposeOp : public framework::OperatorWithKernel {
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
auto
&
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
auto
data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
#ifdef PADDLE_WITH_MKLDNN
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
this
->
CanMKLDNNBeUsed
(
ctx
,
data_type
))
{
library_
=
framework
::
LibraryType
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
if
(
this
->
CanMKLDNNBeUsed
(
ctx
,
data_type
))
{
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kMKLDNN
,
framework
::
LibraryType
::
kMKLDNN
);
}
#endif
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
(),
layout_
,
library_
);
auto
&
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
(),
layout_
);
}
};
...
...
@@ -203,20 +202,19 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
auto
data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
framework
::
GradVarName
(
"Out"
));
#ifdef PADDLE_WITH_MKLDNN
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
this
->
CanMKLDNNBeUsed
(
ctx
,
data_type
))
{
library_
=
framework
::
LibraryType
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
if
(
this
->
CanMKLDNNBeUsed
(
ctx
,
data_type
))
{
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kMKLDNN
,
framework
::
LibraryType
::
kMKLDNN
);
}
#endif
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
(),
layout_
,
library_
);
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
(),
layout_
);
}
};
...
...
@@ -249,29 +247,27 @@ class Transpose2Op : public TransposeOp {
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
int
customized_type_value
=
framework
::
OpKernelType
::
kDefaultCustomizedTypeValue
;
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
framework
::
proto
::
VarType
::
Type
data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
#ifdef PADDLE_WITH_MKLDNN
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
this
->
CanMKLDNNBeUsed
(
ctx
,
data_type
))
{
library_
=
framework
::
LibraryType
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
if
(
this
->
CanMKLDNNBeUsed
(
ctx
,
data_type
))
{
using
framework
::
proto
::
VarType
;
auto
input_data_type
=
framework
::
TransToProtoVarType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dtype
());
customized_type_value
=
(
input_data_type
==
VarType
::
INT8
||
input_data_type
==
VarType
::
UINT8
)
?
kTransposeMKLDNNINT8
:
kTransposeMKLDNNFP32
;
int
customized_type_value
=
(
input_data_type
==
VarType
::
INT8
||
input_data_type
==
VarType
::
UINT8
)
?
kTransposeMKLDNNINT8
:
kTransposeMKLDNNFP32
;
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kMKLDNN
,
framework
::
LibraryType
::
kMKLDNN
,
customized_type_value
);
}
#endif
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
(),
layout_
,
library_
,
customized_type_value
);
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
(),
layout_
);
}
};
...
...
@@ -371,21 +367,20 @@ class Transpose2OpGrad : public framework::OperatorWithKernel {
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
framework
::
proto
::
VarType
::
Type
data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
framework
::
GradVarName
(
"Out"
));
#ifdef PADDLE_WITH_MKLDNN
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
this
->
CanMKLDNNBeUsed
(
ctx
,
data_type
))
{
library_
=
framework
::
LibraryType
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
if
(
this
->
CanMKLDNNBeUsed
(
ctx
,
data_type
))
{
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kMKLDNN
,
framework
::
LibraryType
::
kMKLDNN
);
}
#endif
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
(),
layout_
,
library_
);
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
(),
layout_
);
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录