提交 681d3553 编写于 作者: L Leo Zhao 提交者: Tao Luo

Fix potential mkldnn concat/pool/conv kernel issues (#18393)

1. some key generation method is not aligned with PR#17965
2. enlarge ptr lifetime to avoid memory release if SetBlob fails
   otherwise it will get core dump.

test=develop
上级 052b0448
......@@ -81,6 +81,13 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt));
platform::MKLDNNHandler::AppendKey(&key,
std::to_string(multi_input[0]->format()));
if (platform::get_cur_thread_id() != -1) {
auto tid = std::this_thread::get_id();
std::stringstream ss;
ss << tid;
platform::MKLDNNHandler::AppendKey(&key, "-t:");
platform::MKLDNNHandler::AppendKey(&key, ss.str());
}
return key;
}
......
......@@ -220,7 +220,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline, is_test);
std::shared_ptr<mkldnn::memory> dst_memory_p;
std::shared_ptr<mkldnn::memory> dst_memory_p, user_residual_memory_p;
if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
......@@ -243,7 +243,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto user_residual_md = platform::MKLDNNMemDesc(
residual_data_tz, residual_data_type, residual_param->format());
auto user_residual_memory_p = handler.AcquireResidualDataMemory(
user_residual_memory_p = handler.AcquireResidualDataMemory(
user_residual_md, to_void_cast<T>(residual_param_data));
dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory(
......@@ -263,14 +263,15 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// create convolution op primitive
std::shared_ptr<mkldnn::convolution_forward> conv_p;
std::shared_ptr<mkldnn::memory> user_bias_memory_p, bias_memory_p;
if (bias) {
const T* bias_data = bias->data<T>();
auto user_bias_md = platform::MKLDNNMemDesc(
{bias_tz}, platform::MKLDNNGetDataType<T>(), memory::format::x);
auto user_bias_memory_p =
user_bias_memory_p =
handler.AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data));
auto bias_memory_p =
bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
bias_memory_p, dst_memory_p);
......
......@@ -48,6 +48,13 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt));
platform::MKLDNNHandler::AppendKey(&key, suffix);
if (platform::get_cur_thread_id() != -1) {
auto tid = std::this_thread::get_id();
std::stringstream ss;
ss << tid;
platform::MKLDNNHandler::AppendKey(&key, "-t:");
platform::MKLDNNHandler::AppendKey(&key, ss.str());
}
return key;
}
......@@ -128,6 +135,10 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::string key_pool_workspace_memory =
key + "@pool_workspace_memory";
std::shared_ptr<mkldnn::memory> src_memory, dst_memory;
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd;
std::shared_ptr<mkldnn::memory> pool_src_memory_p, pool_dst_memory_p;
auto pool_p =
std::static_pointer_cast<pooling_forward>(dev_ctx.GetBlob(key_pool_p));
if (pool_p == nullptr) {
......@@ -158,9 +169,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// save pool_pd into global device context to be referred in backward path
if (!is_test) dev_ctx.SetBlob(key_pool_pd, pool_pd);
auto src_memory = std::make_shared<memory>(pool_pd->src_primitive_desc(),
src_memory = std::make_shared<memory>(pool_pd->src_primitive_desc(),
to_void_cast<T>(input_data));
auto dst_memory =
dst_memory =
std::make_shared<memory>(pool_pd->dst_primitive_desc(), output_data);
dev_ctx.SetBlob(key_pool_src_mem_p, src_memory);
......@@ -186,11 +197,11 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
(memory::format)dst_memory->get_primitive_desc().desc().data.format;
} else {
// Primitives already exist
auto pool_src_memory_p =
pool_src_memory_p =
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_src_mem_p));
PADDLE_ENFORCE(pool_src_memory_p != nullptr,
"Fail to find pooling src mem_p in device context");
auto pool_dst_memory_p =
pool_dst_memory_p =
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_dst_mem_p));
PADDLE_ENFORCE(pool_dst_memory_p != nullptr,
"Fail to find pooling dst mem_p in device context");
......
......@@ -38,6 +38,9 @@ class MKLDNNHandler {
std::stringstream ss;
ss << tid;
key_ = key_common_ + "-t:" + ss.str();
if (platform::get_cur_thread_id() == -1) {
key_ = key_common_;
}
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册