Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ad8a9cb8
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ad8a9cb8
编写于
1月 05, 2020
作者:
J
Jacek Czaja
提交者:
Tao Luo
1月 05, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[MKL-DNN] Pool & LRN Grad Ops NHWC support (#21747)
上级
e1d666fb
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
124 addition
and
46 deletion
+124
-46
paddle/fluid/framework/data_layout_transform.cc
paddle/fluid/framework/data_layout_transform.cc
+2
-4
paddle/fluid/framework/data_transform.cc
paddle/fluid/framework/data_transform.cc
+2
-7
paddle/fluid/framework/data_transform.h
paddle/fluid/framework/data_transform.h
+0
-1
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+52
-11
paddle/fluid/operators/lrn_op.cc
paddle/fluid/operators/lrn_op.cc
+22
-6
paddle/fluid/operators/pool_op.cc
paddle/fluid/operators/pool_op.cc
+18
-6
paddle/fluid/operators/pool_op.h
paddle/fluid/operators/pool_op.h
+5
-1
paddle/fluid/platform/mkldnn_helper.h
paddle/fluid/platform/mkldnn_helper.h
+23
-0
python/paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py
...paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py
+0
-5
python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py
...dle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py
+0
-5
未找到文件。
paddle/fluid/framework/data_layout_transform.cc
浏览文件 @
ad8a9cb8
...
@@ -185,10 +185,8 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
...
@@ -185,10 +185,8 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
}
}
// For exepected NHWC data format we need to reshape the Output tensor
// For exepected NHWC data format we need to reshape the Output tensor
// As MKL-DNN description was in NCHW and paddle is expecting NHWC
// As MKL-DNN description was in NCHW and paddle is expecting NHWC
if
(
out_layout
==
DataLayout
::
kNHWC
)
{
platform
::
MatchShapeToLayout
(
out
,
in_layout
,
out_layout
);
std
::
rotate
(
out_tz
.
begin
()
+
1
,
out_tz
.
begin
()
+
2
,
out_tz
.
end
());
out
->
Resize
(
framework
::
make_ddim
(
out_tz
));
}
out
->
set_layout
(
out_layout
);
out
->
set_layout
(
out_layout
);
// reset format since the out tensor will be feed to non-MKLDNN OPkernel
// reset format since the out tensor will be feed to non-MKLDNN OPkernel
out
->
set_format
(
MKLDNNMemoryFormat
::
undef
);
out
->
set_format
(
MKLDNNMemoryFormat
::
undef
);
...
...
paddle/fluid/framework/data_transform.cc
浏览文件 @
ad8a9cb8
...
@@ -58,13 +58,8 @@ void TransformData(const OpKernelType &expected_kernel_type,
...
@@ -58,13 +58,8 @@ void TransformData(const OpKernelType &expected_kernel_type,
out
.
ShareDataWith
(
input_tensor
);
out
.
ShareDataWith
(
input_tensor
);
// For NHWC data we need reshape of tensors as MKL-DNN
// For NHWC data we need reshape of tensors as MKL-DNN
// is expecting NHWC dims description order
// is expecting NHWC dims description order
if
(
lin
==
DataLayout
::
kNHWC
)
{
platform
::
MatchShapeToLayout
(
&
out
,
lin
,
lout
);
auto
nchw_dims
=
paddle
::
framework
::
vectorize
<
int
>
(
out
.
dims
());
std
::
rotate
(
nchw_dims
.
begin
()
+
1
,
nchw_dims
.
end
()
-
1
,
nchw_dims
.
end
());
out
.
Resize
(
framework
::
make_ddim
(
nchw_dims
));
paddle
::
platform
::
set_cur_paddle_data_layout
(
lin
);
paddle
::
platform
::
set_cur_paddle_data_layout
(
lin
);
}
out
.
set_layout
(
DataLayout
::
kMKLDNN
);
out
.
set_layout
(
DataLayout
::
kMKLDNN
);
out
.
set_format
(
out_format
);
out
.
set_format
(
out_format
);
}
else
{
}
else
{
...
...
paddle/fluid/framework/data_transform.h
浏览文件 @
ad8a9cb8
...
@@ -39,6 +39,5 @@ void TransformData(const OpKernelType &expected_kernel_type,
...
@@ -39,6 +39,5 @@ void TransformData(const OpKernelType &expected_kernel_type,
*/
*/
void
SetTensorToVariable
(
const
Variable
&
in_var
,
const
Tensor
&
tensor
,
void
SetTensorToVariable
(
const
Variable
&
in_var
,
const
Tensor
&
tensor
,
Variable
*
out_var
);
Variable
*
out_var
);
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/operator.cc
浏览文件 @
ad8a9cb8
...
@@ -33,6 +33,10 @@ limitations under the License. */
...
@@ -33,6 +33,10 @@ limitations under the License. */
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
DECLARE_bool
(
benchmark
);
DECLARE_bool
(
benchmark
);
DECLARE_bool
(
check_nan_inf
);
DECLARE_bool
(
check_nan_inf
);
DECLARE_bool
(
enable_unused_var_check
);
DECLARE_bool
(
enable_unused_var_check
);
...
@@ -1102,11 +1106,8 @@ Scope* OperatorWithKernel::PrepareData(
...
@@ -1102,11 +1106,8 @@ Scope* OperatorWithKernel::PrepareData(
}
}
for
(
auto
&
var_name_item
:
Inputs
())
{
for
(
auto
&
var_name_item
:
Inputs
())
{
if
(
no_buffer_ins
&&
no_buffer_ins
->
count
(
var_name_item
.
first
)
>
0
)
{
bool
should_skip_input
=
VLOG
(
7
)
<<
"Skip scanning input "
<<
var_name_item
.
first
no_buffer_ins
&&
no_buffer_ins
->
count
(
var_name_item
.
first
)
>
0
;
<<
" in Operator "
<<
type_
;
continue
;
}
std
::
vector
<
Variable
*>&
input_vars
=
ctx
->
inputs
[
var_name_item
.
first
];
std
::
vector
<
Variable
*>&
input_vars
=
ctx
->
inputs
[
var_name_item
.
first
];
...
@@ -1120,6 +1121,44 @@ Scope* OperatorWithKernel::PrepareData(
...
@@ -1120,6 +1121,44 @@ Scope* OperatorWithKernel::PrepareData(
}
}
auto
*
tensor_in
=
GetLoDTensorOrSelectedRowsValueFromVar
(
*
var
);
auto
*
tensor_in
=
GetLoDTensorOrSelectedRowsValueFromVar
(
*
var
);
// When no_buffer_ins then checking of Tensor::holder_ is
// not a thread safe. And for infershape scenario checks
// to be omitted are not really needed
if
(
should_skip_input
==
true
)
{
#ifdef PADDLE_WITH_MKLDNN
// Var without buffer may be needed
// for some situation like InferShape().
// In this situation We cannot skip Var analysis, as
// MKL-DNN shape of Var may differ from kNHWC Var
// In such situation corressponding resized Var
// has to be created and registered
if
((
tensor_in
->
layout
()
==
DataLayout
::
kMKLDNN
)
&&
(
var
->
IsType
<
LoDTensor
>
()
==
true
)
&&
(
expected_kernel_key
.
data_layout_
!=
DataLayout
::
kMKLDNN
)
&&
(
paddle
::
platform
::
get_cur_paddle_data_layout
()
==
DataLayout
::
kNHWC
))
{
// Mixed execution : MKL-DNN and GPU is not supported!
if
(
!
new_scope
)
{
new_scope
=
&
scope
.
NewScope
();
}
auto
*
trans_var
=
new_scope
->
Var
(
var_name
);
input_vars
[
i
]
=
trans_var
;
auto
out
=
trans_var
->
GetMutable
<
LoDTensor
>
();
out
->
Resize
(
tensor_in
->
dims
());
platform
::
MatchShapeToLayout
(
out
,
tensor_in
->
layout
(),
DataLayout
::
kNHWC
);
VLOG
(
7
)
<<
"Created reshaped dummy input based on MKL-DNN Tensor , "
"but kNHWC layout"
<<
var_name_item
.
first
<<
" in Operator "
<<
type_
;
}
else
{
VLOG
(
7
)
<<
"Skip scanning input "
<<
var_name_item
.
first
<<
" in Operator "
<<
type_
;
}
#endif
continue
;
}
if
(
!
tensor_in
->
IsInitialized
())
{
if
(
!
tensor_in
->
IsInitialized
())
{
continue
;
continue
;
}
}
...
@@ -1143,14 +1182,17 @@ Scope* OperatorWithKernel::PrepareData(
...
@@ -1143,14 +1182,17 @@ Scope* OperatorWithKernel::PrepareData(
// In the inference scenerio, the scopes will be reused across the
// In the inference scenerio, the scopes will be reused across the
// batches, so the `new_scope` here will result in GPU memroy explosion
// batches, so the `new_scope` here will result in GPU memroy explosion
// over the running of operators.
// over the running of operators.
// We use a thread_local cache to fix that issue, the key in the cache is
// We use a thread_local cache to fix that issue, the key in the cache
// is
// the combination of the `scope` argument, from_kernel_type,
// the combination of the `scope` argument, from_kernel_type,
// target_kernel_type.
// target_kernel_type.
// Have a discussion with @Superjomn or the inference developers if some
// Have a discussion with @Superjomn or the inference developers if some
// changes on this logic for this macro might not tested on the other
// changes on this logic for this macro might not tested on the other
// scenerios.
// scenerios.
// If this op is not called by an Executor or ParallelExecutor, it should
// If this op is not called by an Executor or ParallelExecutor, it
// called by a NaiveExecutor, the NaiveExecutor will cache the scopes and
// should
// called by a NaiveExecutor, the NaiveExecutor will cache the scopes
// and
// variables, that behavior a lot different.
// variables, that behavior a lot different.
//
//
// To solve issue #15032, have a discussion with @Luotao for cpu
// To solve issue #15032, have a discussion with @Luotao for cpu
...
@@ -1174,15 +1216,14 @@ Scope* OperatorWithKernel::PrepareData(
...
@@ -1174,15 +1216,14 @@ Scope* OperatorWithKernel::PrepareData(
// we will create a new cpu tensor in new scope.
// we will create a new cpu tensor in new scope.
// However, if enable_cache_runtime_context_, we get the cpu tensor each
// However, if enable_cache_runtime_context_, we get the cpu tensor each
// time, not the gpu tensor.
// time, not the gpu tensor.
// Thus, we set pre_scope_ = nullptr to trigger `new RuntimeContext()` in
// Thus, we set pre_scope_ = nullptr to trigger `new RuntimeContext()`
// in
// RunImpl().
// RunImpl().
if
(
enable_cache_runtime_context_
)
{
if
(
enable_cache_runtime_context_
)
{
pre_scope_
=
nullptr
;
pre_scope_
=
nullptr
;
}
}
auto
*
trans_var
=
new_scope
->
Var
(
var_name
);
auto
*
trans_var
=
new_scope
->
Var
(
var_name
);
input_vars
[
i
]
=
trans_var
;
input_vars
[
i
]
=
trans_var
;
Tensor
out
;
Tensor
out
;
TransformData
(
expected_kernel_key
,
kernel_type_for_var
,
*
tensor_in
,
&
out
);
TransformData
(
expected_kernel_key
,
kernel_type_for_var
,
*
tensor_in
,
&
out
);
SetTensorToVariable
(
*
var
,
out
,
trans_var
);
SetTensorToVariable
(
*
var
,
out
,
trans_var
);
...
...
paddle/fluid/operators/lrn_op.cc
浏览文件 @
ad8a9cb8
...
@@ -334,12 +334,6 @@ class LRNOpGrad : public framework::OperatorWithKernel {
...
@@ -334,12 +334,6 @@ class LRNOpGrad : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
// TODO(jczaja): Add support for NHWC
const
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
PADDLE_ENFORCE_NE
(
data_format
,
"NHWC"
,
platform
::
errors
::
Unimplemented
(
"LRN MKLDNN grad does not support NHWC data format yet"
));
library_
=
framework
::
LibraryType
::
kMKLDNN
;
library_
=
framework
::
LibraryType
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
}
}
...
@@ -348,6 +342,28 @@ class LRNOpGrad : public framework::OperatorWithKernel {
...
@@ -348,6 +342,28 @@ class LRNOpGrad : public framework::OperatorWithKernel {
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
(),
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
(),
layout_
,
library_
);
layout_
,
library_
);
}
}
framework
::
OpKernelType
GetKernelTypeForVar
(
const
std
::
string
&
var_name
,
const
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
override
{
#ifdef PADDLE_WITH_MKLDNN
if
((
expected_kernel_type
.
data_layout_
==
framework
::
DataLayout
::
kMKLDNN
)
&&
(
tensor
.
layout
()
!=
framework
::
DataLayout
::
kMKLDNN
))
{
auto
attrs
=
Attrs
();
auto
ar
=
paddle
::
framework
::
AttrReader
(
attrs
);
const
std
::
string
data_format
=
ar
.
Get
<
std
::
string
>
(
"data_format"
);
auto
dl
=
framework
::
StringToDataLayout
(
data_format
);
// Some models may have intentionally set "AnyLayout" for lrn
// op. Treat this as NCHW (default data_format value)
if
(
dl
!=
framework
::
DataLayout
::
kAnyLayout
)
{
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
dl
);
}
}
#endif
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
tensor
.
layout
());
}
};
};
template
<
typename
T
>
template
<
typename
T
>
...
...
paddle/fluid/operators/pool_op.cc
浏览文件 @
ad8a9cb8
...
@@ -202,12 +202,6 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
...
@@ -202,12 +202,6 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
if
(
library_
==
framework
::
LibraryType
::
kPlain
&&
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
// TODO(jczaja): Add support for NHWC
const
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
PADDLE_ENFORCE_NE
(
data_format
,
"NHWC"
,
platform
::
errors
::
Unimplemented
(
"Pool MKLDNN grad does not support NHWC data format yet"
));
library_
=
framework
::
LibraryType
::
kMKLDNN
;
library_
=
framework
::
LibraryType
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
}
}
...
@@ -222,6 +216,24 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
...
@@ -222,6 +216,24 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
library_
);
library_
);
}
}
framework
::
OpKernelType
PoolOpGrad
::
GetKernelTypeForVar
(
const
std
::
string
&
var_name
,
const
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
{
#ifdef PADDLE_WITH_MKLDNN
if
((
expected_kernel_type
.
data_layout_
==
framework
::
DataLayout
::
kMKLDNN
)
&&
(
tensor
.
layout
()
!=
framework
::
DataLayout
::
kMKLDNN
))
{
auto
attrs
=
Attrs
();
auto
ar
=
paddle
::
framework
::
AttrReader
(
attrs
);
const
std
::
string
data_format
=
ar
.
Get
<
std
::
string
>
(
"data_format"
);
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
framework
::
StringToDataLayout
(
data_format
));
}
#endif
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
tensor
.
layout
());
}
void
Pool2dOpMaker
::
Make
()
{
void
Pool2dOpMaker
::
Make
()
{
AddInput
(
AddInput
(
"X"
,
"X"
,
...
...
paddle/fluid/operators/pool_op.h
浏览文件 @
ad8a9cb8
...
@@ -38,7 +38,7 @@ class PoolOp : public framework::OperatorWithKernel {
...
@@ -38,7 +38,7 @@ class PoolOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetKernelTypeForVar
(
framework
::
OpKernelType
GetKernelTypeForVar
(
const
std
::
string
&
var_name
,
const
Tensor
&
tensor
,
const
std
::
string
&
var_name
,
const
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
;
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
override
;
};
};
class
PoolOpGrad
:
public
framework
::
OperatorWithKernel
{
class
PoolOpGrad
:
public
framework
::
OperatorWithKernel
{
...
@@ -50,6 +50,10 @@ class PoolOpGrad : public framework::OperatorWithKernel {
...
@@ -50,6 +50,10 @@ class PoolOpGrad : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
framework
::
OpKernelType
GetKernelTypeForVar
(
const
std
::
string
&
var_name
,
const
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
override
;
};
};
class
Pool2dOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
Pool2dOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
paddle/fluid/platform/mkldnn_helper.h
浏览文件 @
ad8a9cb8
...
@@ -71,6 +71,29 @@ tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p,
...
@@ -71,6 +71,29 @@ tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p,
return
tf_pd
<
Type
>
(
desc
,
e
,
p
);
return
tf_pd
<
Type
>
(
desc
,
e
,
p
);
}
}
inline
void
MatchShapeToLayout
(
framework
::
Tensor
*
tensor_in
,
framework
::
DataLayout
from
,
framework
::
DataLayout
to
)
{
switch
(
from
)
{
case
framework
:
:
DataLayout
::
kMKLDNN
:
if
(
to
==
framework
::
DataLayout
::
kNHWC
)
{
auto
dims
=
framework
::
vectorize
<
int
>
(
tensor_in
->
dims
());
std
::
rotate
(
dims
.
begin
()
+
1
,
dims
.
begin
()
+
2
,
dims
.
end
());
tensor_in
->
Resize
(
framework
::
make_ddim
(
dims
));
}
break
;
case
framework
:
:
DataLayout
::
kNHWC
:
if
(
to
==
framework
::
DataLayout
::
kMKLDNN
)
{
auto
dims
=
framework
::
vectorize
<
int
>
(
tensor_in
->
dims
());
std
::
rotate
(
dims
.
begin
()
+
1
,
dims
.
end
()
-
1
,
dims
.
end
());
tensor_in
->
Resize
(
framework
::
make_ddim
(
dims
));
}
break
;
default:
break
;
}
}
inline
mkldnn
::
memory
::
desc
MKLDNNMemDesc
(
const
std
::
vector
<
int64_t
>&
dims
,
inline
mkldnn
::
memory
::
desc
MKLDNNMemDesc
(
const
std
::
vector
<
int64_t
>&
dims
,
mkldnn
::
memory
::
data_type
data_type
,
mkldnn
::
memory
::
data_type
data_type
,
MKLDNNMemoryFormat
format
)
{
MKLDNNMemoryFormat
format
)
{
...
...
python/paddle/fluid/tests/unittests/mkldnn/test_lrn_mkldnn_op.py
浏览文件 @
ad8a9cb8
...
@@ -59,11 +59,6 @@ class TestLRNMKLDNNOpNHWC(TestLRNMKLDNNOp):
...
@@ -59,11 +59,6 @@ class TestLRNMKLDNNOpNHWC(TestLRNMKLDNNOp):
def
init_test_case
(
self
):
def
init_test_case
(
self
):
self
.
data_format
=
'NHWC'
self
.
data_format
=
'NHWC'
#TODO(jczaja): Add grad support
def
test_check_grad_normal
(
self
):
with
self
.
assertRaises
(
fluid
.
core_avx
.
EnforceNotMet
):
self
.
check_grad
([
'X'
],
'Out'
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/mkldnn/test_pool2d_mkldnn_op.py
浏览文件 @
ad8a9cb8
...
@@ -157,11 +157,6 @@ class TestAsymPadValidNHWC(TestAsymPadValid):
...
@@ -157,11 +157,6 @@ class TestAsymPadValidNHWC(TestAsymPadValid):
def
init_shape
(
self
):
def
init_shape
(
self
):
self
.
shape
=
[
2
,
7
,
7
,
3
]
self
.
shape
=
[
2
,
7
,
7
,
3
]
#TODO(jczaja): Add Grad NHWC support
def
test_check_grad
(
self
):
with
self
.
assertRaises
(
fluid
.
core_avx
.
EnforceNotMet
):
super
(
TestAsymPadValidNHWC
,
self
).
test_check_grad
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录