Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
944ea436
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
944ea436
编写于
1月 17, 2022
作者:
J
jakpiase
提交者:
GitHub
1月 17, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix for conv2D training error (#38938)
上级
05c98ec7
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
42 addition
and
2 deletion
+42
-2
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
+16
-2
paddle/fluid/platform/mkldnn_reuse.h
paddle/fluid/platform/mkldnn_reuse.h
+8
-0
python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_bf16_mkldnn_op.py
...luid/tests/unittests/mkldnn/test_conv2d_bf16_mkldnn_op.py
+18
-0
未找到文件。
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
浏览文件 @
944ea436
...
...
@@ -613,7 +613,7 @@ class ConvMKLDNNHandlerT
auto
weights_mem_p
=
this
->
AcquireMemory
(
"@weights_mem_p_target"
);
if
(
is_test
&&
weights_mem_p
)
{
return
weights_mem_p
;
}
else
{
}
else
if
(
is_test
)
{
const
K
*
filter_data
=
filter
->
data
<
K
>
();
auto
weights_tz
=
framework
::
vectorize
(
filter
->
dims
());
platform
::
GetGroupConvWeightsTz
(
weights_tz
,
groups
);
...
...
@@ -626,6 +626,19 @@ class ConvMKLDNNHandlerT
user_src_md
,
this
->
fwd_pd_
->
weights_desc
(),
platform
::
to_void_cast
<
K
>
(
filter_data
),
"@weights_mem_p"
,
is_test
,
{},
scale_data
,
mask
);
}
else
{
const
T
*
filter_data
=
filter
->
data
<
T
>
();
auto
weights_tz
=
framework
::
vectorize
(
filter
->
dims
());
platform
::
GetGroupConvWeightsTz
(
weights_tz
,
groups
);
auto
user_src_md
=
platform
::
MKLDNNMemDesc
(
weights_tz
,
platform
::
MKLDNNGetDataType
<
T
>
(),
GetWeightsFormat
(
filter
->
format
(),
groups
,
is_conv3d
));
return
this
->
AcquireMemoryWithReorder
(
user_src_md
,
this
->
fwd_pd_
->
weights_desc
(),
platform
::
to_void_cast
<
T
>
(
filter_data
),
"@weights_mem_p"
,
is_test
,
{},
scale_data
,
mask
);
}
}
...
...
@@ -1027,7 +1040,8 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN,
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE
(
conv2d_grad
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
BF16
,
ops
::
kConvMKLDNNFP32
,
ops
::
ConvMKLDNNGradOpKernel
<
paddle
::
platform
::
bfloat16
,
float
>
);
ops
::
ConvMKLDNNGradOpKernel
<
paddle
::
platform
::
bfloat16
,
paddle
::
platform
::
bfloat16
>
);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE
(
depthwise_conv2d
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
FP32
,
...
...
paddle/fluid/platform/mkldnn_reuse.h
浏览文件 @
944ea436
...
...
@@ -377,6 +377,14 @@ class MKLDNNHandlerT {
if
(
bwd_pd_
==
nullptr
)
{
return
false
;
}
else
{
if
(
std
::
is_same
<
TBackward_params
,
mkldnn_dummy_primitive
>::
value
==
false
)
{
const
std
::
string
key_bw_w_pd
=
key_
+
"@bwd_w_pd"
;
bwd_w_pd_
=
std
::
static_pointer_cast
<
typename
TBackward_params
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_bw_w_pd
));
}
// When BWD is cached then still we need to Get FWD PD
const
std
::
string
key_fpd
=
key_
+
"@fwd_pd"
;
fwd_pd_
=
std
::
static_pointer_cast
<
typename
TForward
::
primitive_desc
>
(
...
...
python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_bf16_mkldnn_op.py
浏览文件 @
944ea436
...
...
@@ -50,6 +50,7 @@ class TestConv2DBF16Op(TestConv2DOp):
self
.
init_fuse_residual
()
self
.
init_data_type
()
self
.
init_force_fp32_output
()
self
.
init_infer_or_train
()
self
.
conv2d_param
=
{
'stride'
:
self
.
stride
,
...
...
@@ -83,6 +84,9 @@ class TestConv2DBF16Op(TestConv2DOp):
if
self
.
input_type
is
not
np
.
float32
:
self
.
input
=
convert_float_to_uint16
(
self
.
input
)
if
self
.
weight_type
is
not
np
.
float32
:
self
.
filter
=
convert_float_to_uint16
(
self
.
filter
)
self
.
inputs
=
{
'Input'
:
self
.
input
,
'Filter'
:
OpTest
.
np_dtype_to_fluid_dtype
(
...
...
@@ -105,6 +109,8 @@ class TestConv2DBF16Op(TestConv2DOp):
'fuse_residual_connection'
:
self
.
fuse_residual
}
self
.
init_additional_attrs
()
def
test_check_output
(
self
):
self
.
check_output_with_place
(
core
.
CPUPlace
())
...
...
@@ -141,6 +147,12 @@ class TestConv2DBF16Op(TestConv2DOp):
def
init_fuse_residual
(
self
):
self
.
fuse_residual
=
True
def
init_infer_or_train
(
self
):
self
.
weight_type
=
np
.
float32
def
init_additional_attrs
(
self
):
self
.
attrs
[
'is_test'
]
=
True
@
OpTestTool
.
skip_if_not_cpu_bf16
()
class
TestConv2DWithGradBF16Op
(
TestConv2DBF16Op
):
...
...
@@ -150,6 +162,12 @@ class TestConv2DWithGradBF16Op(TestConv2DBF16Op):
def
init_fuse_residual
(
self
):
self
.
fuse_residual
=
None
def
init_additional_attrs
(
self
):
self
.
attrs
[
'is_test'
]
=
False
def
init_infer_or_train
(
self
):
self
.
weight_type
=
np
.
uint16
def
test_check_grad
(
self
):
dout
=
self
.
conv_output_float
x
=
self
.
inputs_fp32
[
'Input'
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录