提交 a39f5452 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2307 fix stack overflow and memset use risk

Merge pull request !2307 from baihuawei/cpulstm
......@@ -81,11 +81,11 @@ void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
dnnl::lstm_forward::desc desc =
dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc,
formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc,
dst_desc, dst_h_desc, dst_c_desc);
prim_desc_ = dnnl::lstm_forward::primitive_desc(desc, eng);
auto desc = std::make_shared<dnnl::lstm_forward::desc>(dnnl::prop_kind::forward_training, direction, src_desc,
src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any),
formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc,
dst_h_desc, dst_c_desc);
prim_desc_ = dnnl::lstm_forward::primitive_desc(*desc, eng);
primitive_ = std::make_shared<dnnl::lstm_forward>(prim_desc_);
AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
AddArgument(DNNL_ARG_SRC_ITER, src_h_desc);
......@@ -117,7 +117,11 @@ bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
if (has_bias_) {
bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
} else {
std::memset(bias_memory.get_data_handle(), 0, prim_desc_.bias_desc().get_size());
auto ret =
memset_s(bias_memory.get_data_handle(), prim_desc_.bias_desc().get_size(), 0, prim_desc_.bias_desc().get_size());
if (ret != 0) {
MS_LOG(EXCEPTION) << "bias memset error";
}
}
// set handle
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr);
......
......@@ -79,17 +79,17 @@ void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
dnnl::lstm_forward::desc forward_desc =
dnnl::lstm_forward::desc(dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc,
formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc,
dst_desc, dst_h_desc, dst_c_desc);
auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(forward_desc, eng);
dnnl::lstm_backward::desc backward_desc = dnnl::lstm_backward::desc(
auto forward_desc = std::make_shared<dnnl::lstm_forward::desc>(
dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc,
formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc,
dst_c_desc);
auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(*forward_desc, eng);
auto backward_desc = std::make_shared<dnnl::lstm_backward::desc>(
dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any),
formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc,
src_c_desc, formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc,
dst_h_desc, dst_c_desc);
prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(backward_desc, eng, prim_forward_desc);
prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(*backward_desc, eng, prim_forward_desc);
primitive_ = std::make_shared<dnnl::lstm_backward>(prim_backward_desc_);
AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
......@@ -132,7 +132,10 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
if (has_bias_) {
bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
} else {
std::memset(bias_memory.get_data_handle(), 0, prim_backward_desc_.bias_desc().get_size());
if (memset_s(bias_memory.get_data_handle(), prim_backward_desc_.bias_desc().get_size(), 0,
prim_backward_desc_.bias_desc().get_size())) {
MS_LOG(EXCEPTION) << "bias memset error";
}
}
// construct bw memory
auto diff_weights_memory = dnnl::memory(prim_backward_desc_.diff_weights_layer_desc(), eng);
......@@ -142,14 +145,29 @@ bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
auto user_diff_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng);
user_diff_weights_memory.set_data_handle(outputs[3]->addr);
user_diff_weights_h_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_);
std::memset(user_diff_weights_memory.get_data_handle(), 0, user_diff_weights_memory.get_desc().get_size());
std::memset(user_diff_weights_h_memory.get_data_handle(), 0, user_diff_weights_h_memory.get_desc().get_size());
if (memset_s(user_diff_weights_memory.get_data_handle(), user_diff_weights_memory.get_desc().get_size(), 0,
user_diff_weights_memory.get_desc().get_size())) {
MS_LOG(EXCEPTION) << "user weights grad memset error";
}
if (memset_s(user_diff_weights_h_memory.get_data_handle(), user_diff_weights_h_memory.get_desc().get_size(), 0,
user_diff_weights_h_memory.get_desc().get_size())) {
MS_LOG(EXCEPTION) << "user weights iter grad memset error";
}
if (has_bias_) {
diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_);
}
std::memset(diff_bias_memory.get_data_handle(), 0, prim_backward_desc_.diff_bias_desc().get_size());
std::memset(diff_weights_memory.get_data_handle(), 0, diff_weights_memory.get_desc().get_size());
std::memset(diff_weights_h_memory.get_data_handle(), 0, diff_weights_h_memory.get_desc().get_size());
if (memset_s(diff_bias_memory.get_data_handle(), prim_backward_desc_.diff_bias_desc().get_size(), 0,
prim_backward_desc_.diff_bias_desc().get_size())) {
MS_LOG(EXCEPTION) << "bias grad memset error";
}
if (memset_s(diff_weights_memory.get_data_handle(), diff_weights_memory.get_desc().get_size(), 0,
diff_weights_memory.get_desc().get_size())) {
MS_LOG(EXCEPTION) << "weights grad memset error";
}
if (memset_s(diff_weights_h_memory.get_data_handle(), diff_weights_h_memory.get_desc().get_size(), 0,
diff_weights_h_memory.get_desc().get_size())) {
MS_LOG(EXCEPTION) << "weights iter grad memset error";
}
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr);
SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册