Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
a39f5452
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a39f5452
编写于
6月 19, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 19, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2307 fix stack overflow and memset use risk
Merge pull request !2307 from baihuawei/cpulstm
上级
2d6d3cc8
576a73cd
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
41 addition
and
19 deletion
+41
-19
mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc
mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc
+10
-6
mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc
mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc
+31
-13
未找到文件。
mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc
浏览文件 @
a39f5452
...
...
@@ -81,11 +81,11 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl
::
memory
::
desc
dst_desc
=
formatted_md
(
dst_dims
,
tag
::
tnc
);
dnnl
::
memory
::
desc
dst_h_desc
=
formatted_md
(
dst_h_dims
,
tag
::
ldnc
);
dnnl
::
memory
::
desc
dst_c_desc
=
formatted_md
(
dst_c_dims
,
tag
::
ldnc
);
dnnl
::
lstm_forward
::
desc
desc
=
dnnl
::
lstm_forward
::
desc
(
dnnl
::
prop_kind
::
forward_training
,
direction
,
src_desc
,
src_h_desc
,
src_c_desc
,
formatted_md
(
weights_dims_
,
tag
::
any
),
formatted_md
(
weights_h_dims_
,
tag
::
any
),
bias
_desc
,
dst_desc
,
dst_h_desc
,
dst_c_desc
);
prim_desc_
=
dnnl
::
lstm_forward
::
primitive_desc
(
desc
,
eng
);
auto
desc
=
std
::
make_shared
<
dnnl
::
lstm_forward
::
desc
>
(
dnnl
::
prop_kind
::
forward_training
,
direction
,
src_desc
,
src_h_desc
,
src_c_desc
,
formatted_md
(
weights_dims_
,
tag
::
any
)
,
formatted_md
(
weights_h_dims_
,
tag
::
any
),
bias_desc
,
dst
_desc
,
dst_h_desc
,
dst_c_desc
);
prim_desc_
=
dnnl
::
lstm_forward
::
primitive_desc
(
*
desc
,
eng
);
primitive_
=
std
::
make_shared
<
dnnl
::
lstm_forward
>
(
prim_desc_
);
AddArgument
(
DNNL_ARG_SRC_LAYER
,
src_desc
);
AddArgument
(
DNNL_ARG_SRC_ITER
,
src_h_desc
);
...
...
@@ -117,7 +117,11 @@ bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
if
(
has_bias_
)
{
bias_memory
.
set_data_handle
(
reinterpret_cast
<
float
*>
(
inputs
[
3
]
->
addr
)
+
weight_size_
+
weight_h_size_
);
}
else
{
std
::
memset
(
bias_memory
.
get_data_handle
(),
0
,
prim_desc_
.
bias_desc
().
get_size
());
auto
ret
=
memset_s
(
bias_memory
.
get_data_handle
(),
prim_desc_
.
bias_desc
().
get_size
(),
0
,
prim_desc_
.
bias_desc
().
get_size
());
if
(
ret
!=
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"bias memset error"
;
}
}
// set handle
SetArgumentHandle
(
DNNL_ARG_SRC_LAYER
,
inputs
[
0
]
->
addr
);
...
...
mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc
浏览文件 @
a39f5452
...
...
@@ -79,17 +79,17 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl
::
memory
::
desc
dst_desc
=
formatted_md
(
dst_dims
,
tag
::
tnc
);
dnnl
::
memory
::
desc
dst_h_desc
=
formatted_md
(
dst_h_dims
,
tag
::
ldnc
);
dnnl
::
memory
::
desc
dst_c_desc
=
formatted_md
(
dst_c_dims
,
tag
::
ldnc
);
dnnl
::
lstm_forward
::
desc
forward_desc
=
dnnl
::
lstm_forward
::
desc
(
dnnl
::
prop_kind
::
forward_training
,
direction
,
src_desc
,
src_h_desc
,
src_c_desc
,
formatted_md
(
weights_dims_
,
tag
::
any
),
formatted_md
(
weights_h_dims_
,
tag
::
any
),
bias
_desc
,
dst_desc
,
dst_h_desc
,
dst_c_desc
);
auto
prim_forward_desc
=
dnnl
::
lstm_forward
::
primitive_desc
(
forward_desc
,
eng
);
dnnl
::
lstm_backward
::
desc
backward_desc
=
dnnl
::
lstm_backward
::
desc
(
auto
forward_desc
=
std
::
make_shared
<
dnnl
::
lstm_forward
::
desc
>
(
dnnl
::
prop_kind
::
forward_training
,
direction
,
src_desc
,
src_h_desc
,
src_c_desc
,
formatted_md
(
weights_dims_
,
tag
::
any
),
formatted_md
(
weights_h_dims_
,
tag
::
any
),
bias_desc
,
dst_desc
,
dst_h
_desc
,
dst_c_desc
);
auto
prim_forward_desc
=
dnnl
::
lstm_forward
::
primitive_desc
(
*
forward_desc
,
eng
);
auto
backward_desc
=
std
::
make_shared
<
dnnl
::
lstm_backward
::
desc
>
(
dnnl
::
prop_kind
::
backward
,
direction
,
src_desc
,
src_h_desc
,
src_c_desc
,
formatted_md
(
weights_dims_
,
tag
::
any
),
formatted_md
(
weights_h_dims_
,
tag
::
any
),
bias_desc
,
dst_desc
,
dst_h_desc
,
dst_c_desc
,
src_desc
,
src_h_desc
,
src_c_desc
,
formatted_md
(
weights_dims_
,
tag
::
any
),
formatted_md
(
weights_h_dims_
,
tag
::
any
),
bias_desc
,
dst_desc
,
dst_h_desc
,
dst_c_desc
);
prim_backward_desc_
=
dnnl
::
lstm_backward
::
primitive_desc
(
backward_desc
,
eng
,
prim_forward_desc
);
prim_backward_desc_
=
dnnl
::
lstm_backward
::
primitive_desc
(
*
backward_desc
,
eng
,
prim_forward_desc
);
primitive_
=
std
::
make_shared
<
dnnl
::
lstm_backward
>
(
prim_backward_desc_
);
AddArgument
(
DNNL_ARG_SRC_LAYER
,
src_desc
);
...
...
@@ -132,7 +132,10 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
if
(
has_bias_
)
{
bias_memory
.
set_data_handle
(
reinterpret_cast
<
float
*>
(
inputs
[
3
]
->
addr
)
+
weight_size_
+
weight_h_size_
);
}
else
{
std
::
memset
(
bias_memory
.
get_data_handle
(),
0
,
prim_backward_desc_
.
bias_desc
().
get_size
());
if
(
memset_s
(
bias_memory
.
get_data_handle
(),
prim_backward_desc_
.
bias_desc
().
get_size
(),
0
,
prim_backward_desc_
.
bias_desc
().
get_size
()))
{
MS_LOG
(
EXCEPTION
)
<<
"bias memset error"
;
}
}
// construct bw memory
auto
diff_weights_memory
=
dnnl
::
memory
(
prim_backward_desc_
.
diff_weights_layer_desc
(),
eng
);
...
...
@@ -142,14 +145,29 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
auto
user_diff_weights_h_memory
=
dnnl
::
memory
(
dnnl
::
memory
::
desc
{{
weights_h_dims_
},
dt
::
f32
,
tag
::
ldgoi
},
eng
);
user_diff_weights_memory
.
set_data_handle
(
outputs
[
3
]
->
addr
);
user_diff_weights_h_memory
.
set_data_handle
(
reinterpret_cast
<
float
*>
(
outputs
[
3
]
->
addr
)
+
weight_size_
);
std
::
memset
(
user_diff_weights_memory
.
get_data_handle
(),
0
,
user_diff_weights_memory
.
get_desc
().
get_size
());
std
::
memset
(
user_diff_weights_h_memory
.
get_data_handle
(),
0
,
user_diff_weights_h_memory
.
get_desc
().
get_size
());
if
(
memset_s
(
user_diff_weights_memory
.
get_data_handle
(),
user_diff_weights_memory
.
get_desc
().
get_size
(),
0
,
user_diff_weights_memory
.
get_desc
().
get_size
()))
{
MS_LOG
(
EXCEPTION
)
<<
"user weights grad memset error"
;
}
if
(
memset_s
(
user_diff_weights_h_memory
.
get_data_handle
(),
user_diff_weights_h_memory
.
get_desc
().
get_size
(),
0
,
user_diff_weights_h_memory
.
get_desc
().
get_size
()))
{
MS_LOG
(
EXCEPTION
)
<<
"user weights iter grad memset error"
;
}
if
(
has_bias_
)
{
diff_bias_memory
.
set_data_handle
(
reinterpret_cast
<
float
*>
(
outputs
[
3
]
->
addr
)
+
weight_size_
+
weight_h_size_
);
}
std
::
memset
(
diff_bias_memory
.
get_data_handle
(),
0
,
prim_backward_desc_
.
diff_bias_desc
().
get_size
());
std
::
memset
(
diff_weights_memory
.
get_data_handle
(),
0
,
diff_weights_memory
.
get_desc
().
get_size
());
std
::
memset
(
diff_weights_h_memory
.
get_data_handle
(),
0
,
diff_weights_h_memory
.
get_desc
().
get_size
());
if
(
memset_s
(
diff_bias_memory
.
get_data_handle
(),
prim_backward_desc_
.
diff_bias_desc
().
get_size
(),
0
,
prim_backward_desc_
.
diff_bias_desc
().
get_size
()))
{
MS_LOG
(
EXCEPTION
)
<<
"bias grad memset error"
;
}
if
(
memset_s
(
diff_weights_memory
.
get_data_handle
(),
diff_weights_memory
.
get_desc
().
get_size
(),
0
,
diff_weights_memory
.
get_desc
().
get_size
()))
{
MS_LOG
(
EXCEPTION
)
<<
"weights grad memset error"
;
}
if
(
memset_s
(
diff_weights_h_memory
.
get_data_handle
(),
diff_weights_h_memory
.
get_desc
().
get_size
(),
0
,
diff_weights_h_memory
.
get_desc
().
get_size
()))
{
MS_LOG
(
EXCEPTION
)
<<
"weights iter grad memset error"
;
}
SetArgumentHandle
(
DNNL_ARG_SRC_LAYER
,
inputs
[
0
]
->
addr
);
SetArgumentHandle
(
DNNL_ARG_SRC_ITER
,
inputs
[
1
]
->
addr
);
SetArgumentHandle
(
DNNL_ARG_SRC_ITER_C
,
inputs
[
2
]
->
addr
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录