Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
cd43c444
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
cd43c444
编写于
11月 29, 2019
作者:
J
Jacek Czaja
提交者:
Tao Luo
11月 29, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[MKL-DNN] LRN and Pool2d (FWD) NHWC support (#21375)
上级
add62acf
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
111 addition
and
39 deletion
+111
-39
paddle/fluid/framework/data_layout_transform.cc
paddle/fluid/framework/data_layout_transform.cc
+13
-3
paddle/fluid/framework/data_layout_transform.h
paddle/fluid/framework/data_layout_transform.h
+3
-4
paddle/fluid/framework/data_transform.cc
paddle/fluid/framework/data_transform.cc
+10
-1
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+1
-0
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+5
-0
paddle/fluid/operators/controlflow/fetch_op.cc
paddle/fluid/operators/controlflow/fetch_op.cc
+6
-2
paddle/fluid/operators/lrn_op.cc
paddle/fluid/operators/lrn_op.cc
+22
-6
paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc
paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc
+1
-0
paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc
paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc
+2
-0
paddle/fluid/operators/pool_op.cc
paddle/fluid/operators/pool_op.cc
+26
-7
paddle/fluid/operators/pool_op.h
paddle/fluid/operators/pool_op.h
+4
-0
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+12
-0
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+3
-0
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+1
-1
python/paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py
...paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py
+1
-6
python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py
...dle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py
+1
-9
未找到文件。
paddle/fluid/framework/data_layout_transform.cc
浏览文件 @
cd43c444
...
@@ -127,13 +127,17 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
...
@@ -127,13 +127,17 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
"non-MKLDNN"
);
"non-MKLDNN"
);
innerTransDataLayoutFromMKLDNN
(
in_layout
,
out_layout
,
in
,
out
,
place
);
#ifdef PADDLE_WITH_MKLDNN
innerTransDataLayoutFromMKLDNN
(
in_layout
,
paddle
::
platform
::
get_cur_paddle_data_layout
(),
in
,
out
,
place
);
#endif
}
}
#ifdef PADDLE_WITH_MKLDNN
void
innerTransDataLayoutFromMKLDNN
(
DataLayout
in_layout
,
DataLayout
out_layout
,
void
innerTransDataLayoutFromMKLDNN
(
DataLayout
in_layout
,
DataLayout
out_layout
,
const
Tensor
&
in
,
Tensor
*
out
,
const
Tensor
&
in
,
Tensor
*
out
,
platform
::
Place
place
)
{
platform
::
Place
place
)
{
#ifdef PADDLE_WITH_MKLDNN
PADDLE_ENFORCE_NE
(
in
.
format
(),
MKLDNNMemoryFormat
::
format_undef
,
PADDLE_ENFORCE_NE
(
in
.
format
(),
MKLDNNMemoryFormat
::
format_undef
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Input tensor format is invalid. Input tensor should "
"Input tensor format is invalid. Input tensor should "
...
@@ -185,11 +189,17 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
...
@@ -185,11 +189,17 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
}
else
{
}
else
{
out
->
ShareDataWith
(
in
);
out
->
ShareDataWith
(
in
);
}
}
// For exepected NHWC data format we need to reshape the Output tensor
// As MKL-DNN description was in NCHW and paddle is expecting NHWC
if
(
out_layout
==
DataLayout
::
kNHWC
)
{
std
::
rotate
(
out_tz
.
begin
()
+
1
,
out_tz
.
begin
()
+
2
,
out_tz
.
end
());
out
->
Resize
(
framework
::
make_ddim
(
out_tz
));
}
out
->
set_layout
(
out_layout
);
out
->
set_layout
(
out_layout
);
// reset format since the out tensor will be feed to non-MKLDNN OPkernel
// reset format since the out tensor will be feed to non-MKLDNN OPkernel
out
->
set_format
(
MKLDNNMemoryFormat
::
format_undef
);
out
->
set_format
(
MKLDNNMemoryFormat
::
format_undef
);
#endif
}
}
#endif
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/data_layout_transform.h
浏览文件 @
cd43c444
...
@@ -66,16 +66,15 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
...
@@ -66,16 +66,15 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
return
MKLDNNDataType
::
data_undef
;
return
MKLDNNDataType
::
data_undef
;
}
}
void
innerTransDataLayoutFromMKLDNN
(
DataLayout
in_layout
,
DataLayout
out_layout
,
const
Tensor
&
in
,
Tensor
*
out
,
platform
::
Place
place
);
#endif
#endif
void
TransDataLayoutFromMKLDNN
(
const
OpKernelType
&
kernel_type_for_var
,
void
TransDataLayoutFromMKLDNN
(
const
OpKernelType
&
kernel_type_for_var
,
const
OpKernelType
&
expected_kernel_type
,
const
OpKernelType
&
expected_kernel_type
,
const
Tensor
&
in
,
Tensor
*
out
);
const
Tensor
&
in
,
Tensor
*
out
);
void
innerTransDataLayoutFromMKLDNN
(
DataLayout
in_layout
,
DataLayout
out_layout
,
const
Tensor
&
in
,
Tensor
*
out
,
platform
::
Place
place
);
std
::
vector
<
int
>
GetAxis
(
const
DataLayout
&
from
,
const
DataLayout
&
to
);
std
::
vector
<
int
>
GetAxis
(
const
DataLayout
&
from
,
const
DataLayout
&
to
);
void
TransDataLayout
(
const
OpKernelType
&
kernel_type_for_var
,
void
TransDataLayout
(
const
OpKernelType
&
kernel_type_for_var
,
...
...
paddle/fluid/framework/data_transform.cc
浏览文件 @
cd43c444
...
@@ -19,6 +19,7 @@ limitations under the License. */
...
@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/data_type_transform.h"
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
#include <algorithm>
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#endif
...
@@ -54,8 +55,16 @@ void TransformData(const OpKernelType &expected_kernel_type,
...
@@ -54,8 +55,16 @@ void TransformData(const OpKernelType &expected_kernel_type,
auto
out_format
=
platform
::
MKLDNNFormatForSize
(
in
.
dims
().
size
(),
auto
out_format
=
platform
::
MKLDNNFormatForSize
(
in
.
dims
().
size
(),
ToMKLDNNFormat
(
lin
));
ToMKLDNNFormat
(
lin
));
out
.
ShareDataWith
(
input_tensor
);
out
.
ShareDataWith
(
input_tensor
);
// For NHWC data we need reshape of tensors as MKL-DNN
// is expecting NHWC dims description order
if
(
lin
==
DataLayout
::
kNHWC
)
{
auto
nchw_dims
=
paddle
::
framework
::
vectorize
<
int
>
(
out
.
dims
());
std
::
rotate
(
nchw_dims
.
begin
()
+
1
,
nchw_dims
.
end
()
-
1
,
nchw_dims
.
end
());
out
.
Resize
(
framework
::
make_ddim
(
nchw_dims
));
paddle
::
platform
::
set_cur_paddle_data_layout
(
lin
);
}
out
.
set_layout
(
DataLayout
::
kMKLDNN
);
out
.
set_layout
(
DataLayout
::
kMKLDNN
);
out
.
set_format
(
out_format
);
out
.
set_format
(
out_format
);
#endif
#endif
...
...
paddle/fluid/framework/executor.cc
浏览文件 @
cd43c444
...
@@ -103,6 +103,7 @@ Executor::~Executor() {
...
@@ -103,6 +103,7 @@ Executor::~Executor() {
platform
::
MKLDNNDeviceContext
*
dev_ctx
=
platform
::
MKLDNNDeviceContext
*
dev_ctx
=
(
platform
::
MKLDNNDeviceContext
*
)
pool
.
Get
(
place_
);
(
platform
::
MKLDNNDeviceContext
*
)
pool
.
Get
(
place_
);
dev_ctx
->
ResetBlobMap
();
dev_ctx
->
ResetBlobMap
();
platform
::
set_cur_paddle_data_layout
(
paddle
::
framework
::
DataLayout
::
kNCHW
);
}
}
#endif
#endif
}
}
...
...
paddle/fluid/framework/operator.h
浏览文件 @
cd43c444
...
@@ -470,6 +470,11 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -470,6 +470,11 @@ class OperatorWithKernel : public OperatorBase {
return
g_all_op_kernels
;
return
g_all_op_kernels
;
}
}
bool
IsMKLDNNType
()
const
{
return
((
this
->
kernel_type_
)
&&
(
this
->
kernel_type_
->
data_layout_
==
framework
::
DataLayout
::
kMKLDNN
));
}
bool
SupportGPU
()
const
override
{
bool
SupportGPU
()
const
override
{
auto
&
op_kernels
=
OperatorWithKernel
::
AllOpKernels
().
at
(
type_
);
auto
&
op_kernels
=
OperatorWithKernel
::
AllOpKernels
().
at
(
type_
);
return
std
::
any_of
(
op_kernels
.
begin
(),
op_kernels
.
end
(),
return
std
::
any_of
(
op_kernels
.
begin
(),
op_kernels
.
end
(),
...
...
paddle/fluid/operators/controlflow/fetch_op.cc
浏览文件 @
cd43c444
...
@@ -56,16 +56,20 @@ class FetchOp : public framework::OperatorBase {
...
@@ -56,16 +56,20 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate
// FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs?
// CPU outputs?
if
(
src_item
.
IsInitialized
()
&&
src_item
.
numel
()
>
0
)
{
if
(
src_item
.
IsInitialized
()
&&
src_item
.
numel
()
>
0
)
{
#ifdef PADDLE_WITH_MKLDNN
// Conversion from MKL-DNN to Paddle
// Conversion from MKL-DNN to Paddle
if
(
src_item
.
layout
()
==
framework
::
DataLayout
::
kMKLDNN
)
{
if
(
src_item
.
layout
()
==
framework
::
DataLayout
::
kMKLDNN
)
{
framework
::
Tensor
out
;
framework
::
Tensor
out
;
framework
::
innerTransDataLayoutFromMKLDNN
(
framework
::
innerTransDataLayoutFromMKLDNN
(
src_item
.
layout
(),
framework
::
DataLayout
::
kNCHW
,
src_item
,
&
out
,
src_item
.
layout
(),
paddle
::
platform
::
get_cur_paddle_data_layout
()
,
platform
::
CPUPlace
());
src_item
,
&
out
,
platform
::
CPUPlace
());
TensorCopySync
(
out
,
platform
::
CPUPlace
(),
&
dst_item
);
TensorCopySync
(
out
,
platform
::
CPUPlace
(),
&
dst_item
);
}
else
{
}
else
{
TensorCopySync
(
src_item
,
platform
::
CPUPlace
(),
&
dst_item
);
TensorCopySync
(
src_item
,
platform
::
CPUPlace
(),
&
dst_item
);
}
}
#else
TensorCopySync
(
src_item
,
platform
::
CPUPlace
(),
&
dst_item
);
#endif
}
else
{
}
else
{
// Not copy, if the src tensor is empty.
// Not copy, if the src tensor is empty.
dst_item
.
clear
();
dst_item
.
clear
();
...
...
paddle/fluid/operators/lrn_op.cc
浏览文件 @
cd43c444
...
@@ -193,12 +193,6 @@ class LRNOp : public framework::OperatorWithKernel {
...
@@ -193,12 +193,6 @@ class LRNOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
// TODO(jczaja): Add support for NHWC
const
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
PADDLE_ENFORCE_NE
(
data_format
,
"NHWC"
,
platform
::
errors
::
Unimplemented
(
"LRN MKLDNN does not support NHWC data format yet"
));
library_
=
framework
::
LibraryType
::
kMKLDNN
;
library_
=
framework
::
LibraryType
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
}
}
...
@@ -207,6 +201,28 @@ class LRNOp : public framework::OperatorWithKernel {
...
@@ -207,6 +201,28 @@ class LRNOp : public framework::OperatorWithKernel {
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
(),
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
(),
layout_
,
library_
);
layout_
,
library_
);
}
}
framework
::
OpKernelType
GetKernelTypeForVar
(
const
std
::
string
&
var_name
,
const
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
override
{
#ifdef PADDLE_WITH_MKLDNN
if
((
expected_kernel_type
.
data_layout_
==
framework
::
DataLayout
::
kMKLDNN
)
&&
(
tensor
.
layout
()
!=
framework
::
DataLayout
::
kMKLDNN
))
{
auto
attrs
=
Attrs
();
auto
ar
=
paddle
::
framework
::
AttrReader
(
attrs
);
const
std
::
string
data_format
=
ar
.
Get
<
std
::
string
>
(
"data_format"
);
auto
dl
=
framework
::
StringToDataLayout
(
data_format
);
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if
(
dl
!=
framework
::
DataLayout
::
kAnyLayout
)
{
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
dl
);
}
}
#endif
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
tensor
.
layout
());
}
};
};
template
<
typename
T
>
template
<
typename
T
>
...
...
paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc
浏览文件 @
cd43c444
...
@@ -102,6 +102,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
...
@@ -102,6 +102,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
pipeline
.
push_back
(
*
reorder_p
);
pipeline
.
push_back
(
*
reorder_p
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
output
->
set_layout
(
DataLayout
::
kMKLDNN
);
output
->
set_format
(
GetMKLDNNFormat
(
*
dst_memory
));
output
->
set_format
(
GetMKLDNNFormat
(
*
dst_memory
));
}
}
};
};
...
...
paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc
浏览文件 @
cd43c444
...
@@ -62,6 +62,8 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -62,6 +62,8 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std
::
shared_ptr
<
mkldnn
::
lrn_forward
>
lrn_p
;
std
::
shared_ptr
<
mkldnn
::
lrn_forward
>
lrn_p
;
if
(
is_test
==
false
)
{
if
(
is_test
==
false
)
{
workspace_memory
=
handler
.
AcquireWorkspaceMemory
(
mid
);
workspace_memory
=
handler
.
AcquireWorkspaceMemory
(
mid
);
mid
->
set_layout
(
framework
::
DataLayout
::
kMKLDNN
);
mid
->
set_format
(
platform
::
GetMKLDNNFormat
(
*
workspace_memory
));
lrn_p
=
handler
.
AcquireForwardPrimitive
(
*
src_memory
,
*
workspace_memory
,
lrn_p
=
handler
.
AcquireForwardPrimitive
(
*
src_memory
,
*
workspace_memory
,
*
dst_memory
);
*
dst_memory
);
}
else
{
}
else
{
...
...
paddle/fluid/operators/pool_op.cc
浏览文件 @
cd43c444
...
@@ -88,7 +88,10 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -88,7 +88,10 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
ksize
.
size
(),
strides
.
size
(),
framework
::
make_ddim
(
ksize
),
ksize
.
size
(),
strides
.
size
(),
framework
::
make_ddim
(
ksize
),
framework
::
make_ddim
(
strides
));
framework
::
make_ddim
(
strides
));
const
bool
channel_last
=
(
data_format
==
"NHWC"
||
data_format
==
"NDHWC"
);
// MKL-DNN Kernels are using NCHW order of dims description
// so we ignore data_format consideration for MKL-DNN kernel
const
bool
channel_last
=
(
this
->
IsMKLDNNType
()
==
false
)
&&
(
data_format
==
"NHWC"
||
data_format
==
"NDHWC"
);
// update paddings if "SAME" or global_pooling
// update paddings if "SAME" or global_pooling
framework
::
DDim
data_dims
;
framework
::
DDim
data_dims
;
...
@@ -146,12 +149,6 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
...
@@ -146,12 +149,6 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
// TODO(jczaja): Add support for NHWC
const
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
PADDLE_ENFORCE_NE
(
data_format
,
"NHWC"
,
platform
::
errors
::
Unimplemented
(
"Pool MKLDNN grad does not support NHWC data format yet"
));
library_
=
framework
::
LibraryType
::
kMKLDNN
;
library_
=
framework
::
LibraryType
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
}
}
...
@@ -162,6 +159,28 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
...
@@ -162,6 +159,28 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
layout_
,
library_
);
layout_
,
library_
);
}
}
framework
::
OpKernelType
PoolOp
::
GetKernelTypeForVar
(
const
std
::
string
&
var_name
,
const
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
{
#ifdef PADDLE_WITH_MKLDNN
if
((
expected_kernel_type
.
data_layout_
==
framework
::
DataLayout
::
kMKLDNN
)
&&
(
tensor
.
layout
()
!=
framework
::
DataLayout
::
kMKLDNN
))
{
auto
attrs
=
Attrs
();
auto
ar
=
paddle
::
framework
::
AttrReader
(
attrs
);
const
std
::
string
data_format
=
ar
.
Get
<
std
::
string
>
(
"data_format"
);
auto
dl
=
framework
::
StringToDataLayout
(
data_format
);
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if
(
dl
!=
framework
::
DataLayout
::
kAnyLayout
)
{
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
dl
);
}
}
#endif
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
tensor
.
layout
());
}
void
PoolOpGrad
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
void
PoolOpGrad
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"X"
),
true
,
"Input(X) must not be null."
);
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"X"
),
true
,
"Input(X) must not be null."
);
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
true
,
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
true
,
...
...
paddle/fluid/operators/pool_op.h
浏览文件 @
cd43c444
...
@@ -35,6 +35,10 @@ class PoolOp : public framework::OperatorWithKernel {
...
@@ -35,6 +35,10 @@ class PoolOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
framework
::
OpKernelType
GetKernelTypeForVar
(
const
std
::
string
&
var_name
,
const
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
;
};
};
class
PoolOpGrad
:
public
framework
::
OperatorWithKernel
{
class
PoolOpGrad
:
public
framework
::
OperatorWithKernel
{
...
...
paddle/fluid/platform/device_context.cc
浏览文件 @
cd43c444
...
@@ -397,6 +397,10 @@ thread_local std::string cur_input_shape_str = "";
...
@@ -397,6 +397,10 @@ thread_local std::string cur_input_shape_str = "";
// the cache capacity of different input shapes for MKLDNN.
// the cache capacity of different input shapes for MKLDNN.
// Default 1 means fixed input shape, not dynamic shape.
// Default 1 means fixed input shape, not dynamic shape.
thread_local
int
cur_input_shape_cache_capacity
=
1
;
thread_local
int
cur_input_shape_cache_capacity
=
1
;
// Recently registered data_format. This is needed to
// know for converting MKL-DNN Tensor to non MKL-DNN
thread_local
paddle
::
framework
::
DataLayout
cur_paddle_data_layout
=
paddle
::
framework
::
DataLayout
::
kNCHW
;
}
// namespace
}
// namespace
void
set_cur_mkldnn_session_id
(
size_t
sid
)
{
cur_mkldnn_session_id
=
sid
;
}
void
set_cur_mkldnn_session_id
(
size_t
sid
)
{
cur_mkldnn_session_id
=
sid
;
}
...
@@ -408,6 +412,14 @@ void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity) {
...
@@ -408,6 +412,14 @@ void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity) {
cur_input_shape_cache_capacity
=
input_shape_cache_capacity
;
cur_input_shape_cache_capacity
=
input_shape_cache_capacity
;
}
}
void
set_cur_paddle_data_layout
(
framework
::
DataLayout
dl
)
{
cur_paddle_data_layout
=
dl
;
}
framework
::
DataLayout
get_cur_paddle_data_layout
(
void
)
{
return
cur_paddle_data_layout
;
}
void
MKLDNNDeviceContext
::
ResetBlobMap
()
const
{
p_blobmap_
->
clear
();
}
void
MKLDNNDeviceContext
::
ResetBlobMap
()
const
{
p_blobmap_
->
clear
();
}
size_t
MKLDNNDeviceContext
::
GetShapeBlobSize
()
const
{
size_t
MKLDNNDeviceContext
::
GetShapeBlobSize
()
const
{
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
cd43c444
...
@@ -30,6 +30,7 @@ limitations under the License. */
...
@@ -30,6 +30,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
#include "mkldnn.hpp"
#include "mkldnn.hpp"
#include "paddle/fluid/framework/data_layout.h"
#endif
#endif
#include <map>
#include <map>
...
@@ -290,6 +291,8 @@ void set_cur_mkldnn_session_id(size_t);
...
@@ -290,6 +291,8 @@ void set_cur_mkldnn_session_id(size_t);
size_t
get_cur_mkldnn_session_id
(
void
);
size_t
get_cur_mkldnn_session_id
(
void
);
void
set_cur_input_shape_str
(
std
::
string
input_shape_str
);
void
set_cur_input_shape_str
(
std
::
string
input_shape_str
);
void
set_cur_input_shape_cache_capacity
(
int
input_shape_cache_capacity
);
void
set_cur_input_shape_cache_capacity
(
int
input_shape_cache_capacity
);
void
set_cur_paddle_data_layout
(
framework
::
DataLayout
);
framework
::
DataLayout
get_cur_paddle_data_layout
(
void
);
class
MKLDNNDeviceContext
:
public
CPUDeviceContext
{
class
MKLDNNDeviceContext
:
public
CPUDeviceContext
{
public:
public:
...
...
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
cd43c444
...
@@ -502,7 +502,7 @@ class LRNMKLDNNHandler
...
@@ -502,7 +502,7 @@ class LRNMKLDNNHandler
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireWorkspaceMemory
(
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireWorkspaceMemory
(
framework
::
Tensor
*
workspace
)
{
framework
::
Tensor
*
workspace
)
{
T
*
ptr
=
workspace
->
mutable_data
<
T
>
(
T
*
ptr
=
workspace
->
mutable_data
<
T
>
(
this
->
place_
,
this
->
fwd_pd_
->
dst
_primitive_desc
().
get_size
());
this
->
place_
,
this
->
fwd_pd_
->
workspace
_primitive_desc
().
get_size
());
return
this
->
AcquireMemoryFromPrimitive
(
return
this
->
AcquireMemoryFromPrimitive
(
this
->
fwd_pd_
->
workspace_primitive_desc
(),
ptr
,
"@wrk_mem_p"
);
this
->
fwd_pd_
->
workspace_primitive_desc
(),
ptr
,
"@wrk_mem_p"
);
}
}
...
...
python/paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py
浏览文件 @
cd43c444
...
@@ -55,16 +55,11 @@ class TestLRNMKLDNNOpWithIsTest(TestLRNMKLDNNOp):
...
@@ -55,16 +55,11 @@ class TestLRNMKLDNNOpWithIsTest(TestLRNMKLDNNOp):
self
.
assertRaises
(
AttributeError
,
check_raise_is_test
)
self
.
assertRaises
(
AttributeError
,
check_raise_is_test
)
# TODO(jczaja): Once mkl-dnn integration support NHWC input
# then those tests should be changed to actual functional positive tests
class
TestLRNMKLDNNOpNHWC
(
TestLRNMKLDNNOp
):
class
TestLRNMKLDNNOpNHWC
(
TestLRNMKLDNNOp
):
def
init_test_case
(
self
):
def
init_test_case
(
self
):
self
.
data_format
=
'NHWC'
self
.
data_format
=
'NHWC'
def
test_check_output
(
self
):
#TODO(jczaja): Add grad support
pass
# Grad tests both FWD and BWD ops kernels creation
def
test_check_grad_normal
(
self
):
def
test_check_grad_normal
(
self
):
with
self
.
assertRaises
(
fluid
.
core_avx
.
EnforceNotMet
):
with
self
.
assertRaises
(
fluid
.
core_avx
.
EnforceNotMet
):
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.01
)
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.01
)
...
...
python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py
浏览文件 @
cd43c444
...
@@ -141,9 +141,6 @@ class TestAsymPadValid(TestAsymPad):
...
@@ -141,9 +141,6 @@ class TestAsymPadValid(TestAsymPad):
self
.
padding_algorithm
=
"VALID"
self
.
padding_algorithm
=
"VALID"
# Designed to Fail
# TODO(jczaja): Once mkl-dnn integration support NHWC input
# then those tests should be changed to actual functional positive tests
class
TestAsymPadValidNHWC
(
TestAsymPadValid
):
class
TestAsymPadValidNHWC
(
TestAsymPadValid
):
def
init_data_format
(
self
):
def
init_data_format
(
self
):
self
.
data_format
=
"NHWC"
self
.
data_format
=
"NHWC"
...
@@ -151,12 +148,7 @@ class TestAsymPadValidNHWC(TestAsymPadValid):
...
@@ -151,12 +148,7 @@ class TestAsymPadValidNHWC(TestAsymPadValid):
def
init_shape
(
self
):
def
init_shape
(
self
):
self
.
shape
=
[
2
,
7
,
7
,
3
]
self
.
shape
=
[
2
,
7
,
7
,
3
]
def
test_check_output
(
self
):
#TODO(jczaja): Add Grad NHWC support
pass
# Grad tests both FWD and BWD ops kernels creation
# GetExpectedKernelType should throw an exception on lack of support
# to NHWC inputs in pool mkldnn kernel
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
with
self
.
assertRaises
(
fluid
.
core_avx
.
EnforceNotMet
):
with
self
.
assertRaises
(
fluid
.
core_avx
.
EnforceNotMet
):
super
(
TestAsymPadValidNHWC
,
self
).
test_check_grad
()
super
(
TestAsymPadValidNHWC
,
self
).
test_check_grad
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录