提交 681d3553 编写于 作者: L Leo Zhao 提交者: Tao Luo

Fix potential mkldnn concat/pool/conv kernel issues (#18393)

1. some key generation method is not aligned with PR#17965
2. enlarge ptr lifetime to avoid memory release if SetBlob fails
   otherwise it will get core dump.

test=develop
上级 052b0448
...@@ -81,6 +81,13 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx, ...@@ -81,6 +81,13 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt)); platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt));
platform::MKLDNNHandler::AppendKey(&key, platform::MKLDNNHandler::AppendKey(&key,
std::to_string(multi_input[0]->format())); std::to_string(multi_input[0]->format()));
if (platform::get_cur_thread_id() != -1) {
auto tid = std::this_thread::get_id();
std::stringstream ss;
ss << tid;
platform::MKLDNNHandler::AppendKey(&key, "-t:");
platform::MKLDNNHandler::AppendKey(&key, ss.str());
}
return key; return key;
} }
......
...@@ -220,7 +220,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -220,7 +220,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline, is_test); user_weights_memory_p, pipeline, is_test);
std::shared_ptr<mkldnn::memory> dst_memory_p; std::shared_ptr<mkldnn::memory> dst_memory_p, user_residual_memory_p;
if (fuse_residual_conn) { if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData"); auto residual_param = ctx.Input<Tensor>("ResidualData");
...@@ -243,7 +243,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -243,7 +243,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto user_residual_md = platform::MKLDNNMemDesc( auto user_residual_md = platform::MKLDNNMemDesc(
residual_data_tz, residual_data_type, residual_param->format()); residual_data_tz, residual_data_type, residual_param->format());
auto user_residual_memory_p = handler.AcquireResidualDataMemory( user_residual_memory_p = handler.AcquireResidualDataMemory(
user_residual_md, to_void_cast<T>(residual_param_data)); user_residual_md, to_void_cast<T>(residual_param_data));
dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory( dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory(
...@@ -263,14 +263,15 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -263,14 +263,15 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// create convolution op primitive // create convolution op primitive
std::shared_ptr<mkldnn::convolution_forward> conv_p; std::shared_ptr<mkldnn::convolution_forward> conv_p;
std::shared_ptr<mkldnn::memory> user_bias_memory_p, bias_memory_p;
if (bias) { if (bias) {
const T* bias_data = bias->data<T>(); const T* bias_data = bias->data<T>();
auto user_bias_md = platform::MKLDNNMemDesc( auto user_bias_md = platform::MKLDNNMemDesc(
{bias_tz}, platform::MKLDNNGetDataType<T>(), memory::format::x); {bias_tz}, platform::MKLDNNGetDataType<T>(), memory::format::x);
auto user_bias_memory_p = user_bias_memory_p =
handler.AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data)); handler.AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data));
auto bias_memory_p = bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline); handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p, conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
bias_memory_p, dst_memory_p); bias_memory_p, dst_memory_p);
......
...@@ -48,6 +48,13 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx, ...@@ -48,6 +48,13 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt)); platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt)); platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt));
platform::MKLDNNHandler::AppendKey(&key, suffix); platform::MKLDNNHandler::AppendKey(&key, suffix);
if (platform::get_cur_thread_id() != -1) {
auto tid = std::this_thread::get_id();
std::stringstream ss;
ss << tid;
platform::MKLDNNHandler::AppendKey(&key, "-t:");
platform::MKLDNNHandler::AppendKey(&key, ss.str());
}
return key; return key;
} }
...@@ -128,6 +135,10 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -128,6 +135,10 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::string key_pool_workspace_memory = const std::string key_pool_workspace_memory =
key + "@pool_workspace_memory"; key + "@pool_workspace_memory";
std::shared_ptr<mkldnn::memory> src_memory, dst_memory;
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd;
std::shared_ptr<mkldnn::memory> pool_src_memory_p, pool_dst_memory_p;
auto pool_p = auto pool_p =
std::static_pointer_cast<pooling_forward>(dev_ctx.GetBlob(key_pool_p)); std::static_pointer_cast<pooling_forward>(dev_ctx.GetBlob(key_pool_p));
if (pool_p == nullptr) { if (pool_p == nullptr) {
...@@ -158,9 +169,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -158,9 +169,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// save pool_pd into global device context to be referred in backward path // save pool_pd into global device context to be referred in backward path
if (!is_test) dev_ctx.SetBlob(key_pool_pd, pool_pd); if (!is_test) dev_ctx.SetBlob(key_pool_pd, pool_pd);
auto src_memory = std::make_shared<memory>(pool_pd->src_primitive_desc(), src_memory = std::make_shared<memory>(pool_pd->src_primitive_desc(),
to_void_cast<T>(input_data)); to_void_cast<T>(input_data));
auto dst_memory = dst_memory =
std::make_shared<memory>(pool_pd->dst_primitive_desc(), output_data); std::make_shared<memory>(pool_pd->dst_primitive_desc(), output_data);
dev_ctx.SetBlob(key_pool_src_mem_p, src_memory); dev_ctx.SetBlob(key_pool_src_mem_p, src_memory);
...@@ -186,11 +197,11 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -186,11 +197,11 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
(memory::format)dst_memory->get_primitive_desc().desc().data.format; (memory::format)dst_memory->get_primitive_desc().desc().data.format;
} else { } else {
// Primitives already exist // Primitives already exist
auto pool_src_memory_p = pool_src_memory_p =
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_src_mem_p)); std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_src_mem_p));
PADDLE_ENFORCE(pool_src_memory_p != nullptr, PADDLE_ENFORCE(pool_src_memory_p != nullptr,
"Fail to find pooling src mem_p in device context"); "Fail to find pooling src mem_p in device context");
auto pool_dst_memory_p = pool_dst_memory_p =
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_dst_mem_p)); std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_dst_mem_p));
PADDLE_ENFORCE(pool_dst_memory_p != nullptr, PADDLE_ENFORCE(pool_dst_memory_p != nullptr,
"Fail to find pooling dst mem_p in device context"); "Fail to find pooling dst mem_p in device context");
......
...@@ -38,6 +38,9 @@ class MKLDNNHandler { ...@@ -38,6 +38,9 @@ class MKLDNNHandler {
std::stringstream ss; std::stringstream ss;
ss << tid; ss << tid;
key_ = key_common_ + "-t:" + ss.str(); key_ = key_common_ + "-t:" + ss.str();
if (platform::get_cur_thread_id() == -1) {
key_ = key_common_;
}
} }
std::shared_ptr<mkldnn::memory> AcquireSrcMemory( std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册