Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
f4cf028a
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f4cf028a
编写于
11月 26, 2019
作者:
J
Jacek Czaja
提交者:
Tao Luo
11月 26, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[MKL-DNN] Error throwing for NHWC layout for MKL-DNN ops (#21207)
上级
ed9ceb9f
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
85 addition
and
8 deletion
+85
-8
paddle/fluid/operators/conv_op.cc
paddle/fluid/operators/conv_op.cc
+19
-8
paddle/fluid/operators/conv_transpose_op.cc
paddle/fluid/operators/conv_transpose_op.cc
+5
-0
paddle/fluid/operators/lrn_op.cc
paddle/fluid/operators/lrn_op.cc
+12
-0
paddle/fluid/operators/pool_op.cc
paddle/fluid/operators/pool_op.cc
+12
-0
python/paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py
...paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py
+16
-0
python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py
...dle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py
+21
-0
未找到文件。
paddle/fluid/operators/conv_op.cc
浏览文件 @
f4cf028a
...
...
@@ -151,6 +151,15 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN
if
(
library
==
framework
::
LibraryType
::
kPlain
&&
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
// TODO(jczaja): Add support for NHWC
const
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
PADDLE_ENFORCE_NE
(
data_format
,
"NHWC"
,
platform
::
errors
::
Unimplemented
(
"Conv MKLDNN does not support NHWC data format yet"
));
PADDLE_ENFORCE_NE
(
data_format
,
"NDHWC"
,
platform
::
errors
::
Unimplemented
(
"Conv MKLDNN does not support NDHWC data format yet"
));
library
=
framework
::
LibraryType
::
kMKLDNN
;
layout
=
framework
::
DataLayout
::
kMKLDNN
;
customized_type_value
=
...
...
@@ -524,6 +533,16 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
// TODO(jczaja): Add support for NHWC
const
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
PADDLE_ENFORCE_NE
(
data_format
,
"NHWC"
,
platform
::
errors
::
Unimplemented
(
"Conv MKLDNN grad does not support NHWC data format yet"
));
PADDLE_ENFORCE_NE
(
data_format
,
"NDHWC"
,
platform
::
errors
::
Unimplemented
(
"Conv MKLDNN Grad does not support NDHWC data format yet"
));
library_
=
framework
::
LibraryType
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
customized_type_value
=
kConvMKLDNNFP32
;
...
...
@@ -706,14 +725,6 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
if
(
platform
::
CanCUDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
library_
=
framework
::
LibraryType
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
customized_type_value
=
kConvMKLDNNFP32
;
}
#endif
auto
type
=
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Input"
),
ctx
.
GetPlace
(),
...
...
paddle/fluid/operators/conv_transpose_op.cc
浏览文件 @
f4cf028a
...
...
@@ -145,6 +145,11 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
// TODO(jczaja): Add support for NHWC
const
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
PADDLE_ENFORCE_NE
(
data_format
,
"NHWC"
,
"Conv Transpose MKLDNN does not support NHWC data format yet"
);
library_
=
framework
::
LibraryType
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
}
...
...
paddle/fluid/operators/lrn_op.cc
浏览文件 @
f4cf028a
...
...
@@ -193,6 +193,12 @@ class LRNOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
// TODO(jczaja): Add support for NHWC
const
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
PADDLE_ENFORCE_NE
(
data_format
,
"NHWC"
,
platform
::
errors
::
Unimplemented
(
"LRN MKLDNN does not support NHWC data format yet"
));
library_
=
framework
::
LibraryType
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
}
...
...
@@ -311,6 +317,12 @@ class LRNOpGrad : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
// TODO(jczaja): Add support for NHWC
const
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
PADDLE_ENFORCE_NE
(
data_format
,
"NHWC"
,
platform
::
errors
::
Unimplemented
(
"LRN MKLDNN grad does not support NHWC data format yet"
));
library_
=
framework
::
LibraryType
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
}
...
...
paddle/fluid/operators/pool_op.cc
浏览文件 @
f4cf028a
...
...
@@ -146,6 +146,12 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
// TODO(jczaja): Add support for NHWC
const
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
PADDLE_ENFORCE_NE
(
data_format
,
"NHWC"
,
platform
::
errors
::
Unimplemented
(
"Pool MKLDNN grad does not support NHWC data format yet"
));
library_
=
framework
::
LibraryType
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
}
...
...
@@ -177,6 +183,12 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
// TODO(jczaja): Add support for NHWC
const
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
PADDLE_ENFORCE_NE
(
data_format
,
"NHWC"
,
platform
::
errors
::
Unimplemented
(
"Pool MKLDNN grad does not support NHWC data format yet"
));
library_
=
framework
::
LibraryType
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
}
...
...
python/paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py
浏览文件 @
f4cf028a
...
...
@@ -16,6 +16,7 @@ from __future__ import print_function
import
unittest
from
paddle.fluid.tests.unittests.test_lrn_op
import
TestLRNOp
import
paddle.fluid
as
fluid
class
TestLRNMKLDNNOp
(
TestLRNOp
):
...
...
@@ -54,5 +55,20 @@ class TestLRNMKLDNNOpWithIsTest(TestLRNMKLDNNOp):
self
.
assertRaises
(
AttributeError
,
check_raise_is_test
)
# TODO(jczaja): Once mkl-dnn integration support NHWC input
# then those tests should be changed to actual functional positive tests
class
TestLRNMKLDNNOpNHWC
(
TestLRNMKLDNNOp
):
def
init_test_case
(
self
):
self
.
data_format
=
'NHWC'
def
test_check_output
(
self
):
pass
# Grad tests both FWD and BWD ops kernels creation
def
test_check_grad_normal
(
self
):
with
self
.
assertRaises
(
fluid
.
core_avx
.
EnforceNotMet
):
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.01
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py
浏览文件 @
f4cf028a
...
...
@@ -141,5 +141,26 @@ class TestAsymPadValid(TestAsymPad):
self
.
padding_algorithm
=
"VALID"
# Designed to Fail
# TODO(jczaja): Once mkl-dnn integration support NHWC input
# then those tests should be changed to actual functional positive tests
class
TestAsymPadValidNHWC
(
TestAsymPadValid
):
def
init_data_format
(
self
):
self
.
data_format
=
"NHWC"
def
init_shape
(
self
):
self
.
shape
=
[
2
,
7
,
7
,
3
]
def
test_check_output
(
self
):
pass
# Grad tests both FWD and BWD ops kernels creation
# GetExpectedKernelType should throw an exception on lack of support
# to NHWC inputs in pool mkldnn kernel
def
test_check_grad
(
self
):
with
self
.
assertRaises
(
fluid
.
core_avx
.
EnforceNotMet
):
super
(
TestAsymPadValidNHWC
,
self
).
test_check_grad
()
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录