提交 84bb45c0 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] Thread-Safety for MKL-DNN reusing Part 1 (#17965)

* - removed is_reusing_

* - Added TID to keys for reusing apart from softmax PD

* - compilation fix

* - Yet another compilation fix

* - Batch Norm and Conv adapted

* - Fix to softmax MT

* - Fixes to MT code of MKL-DNN

* - Lint fixes

test=develop
上级 da9143c1
......@@ -61,20 +61,25 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
std::shared_ptr<batch_norm_fwd::primitive_desc>
AcquireBatchNormPrimitiveDescriptor(const batch_norm_fwd::desc &bn_fwd_desc,
const mkldnn::engine &engine) {
const std::string key_batch_norm_fwd_pd = key_ + "@bn_fwd_pd";
auto batch_norm_pd =
std::static_pointer_cast<batch_norm_fwd::primitive_desc>(
dev_ctx_.GetBlob(key_batch_norm_fwd_pd));
if (batch_norm_pd == nullptr) {
batch_norm_pd_.reset(
new batch_norm_fwd::primitive_desc(bn_fwd_desc, engine));
dev_ctx_.SetBlob(key_batch_norm_fwd_pd, batch_norm_pd_);
} else {
batch_norm_pd_ = batch_norm_pd;
is_reusing_ = true;
// BatchNorm PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_batch_norm_fwd_pd = key_common_ + "@bn_fwd_pd";
batch_norm_pd_ = std::static_pointer_cast<batch_norm_fwd::primitive_desc>(
dev_ctx_.GetBlob(key_batch_norm_fwd_pd));
if (batch_norm_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
batch_norm_pd_ = std::static_pointer_cast<batch_norm_fwd::primitive_desc>(
dev_ctx_.GetBlob(key_batch_norm_fwd_pd));
if (batch_norm_pd_ == nullptr) {
batch_norm_pd_.reset(
new batch_norm_fwd::primitive_desc(bn_fwd_desc, engine));
dev_ctx_.SetBlob(key_batch_norm_fwd_pd, batch_norm_pd_);
}
}
return batch_norm_pd_;
}
......@@ -87,9 +92,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
auto batch_norm_p =
std::static_pointer_cast<batch_norm_fwd>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((batch_norm_p != nullptr) || !is_reusing_,
"Fail to find batch norm primitive in device context");
if (batch_norm_p == nullptr) {
if (is_test) {
batch_norm_p = std::make_shared<batch_norm_fwd>(
......@@ -104,8 +106,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
}
dev_ctx_.SetBlob(prim_key, batch_norm_p);
} else {
is_reusing_ = true;
}
return batch_norm_p;
......
......@@ -54,18 +54,24 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
std::shared_ptr<softmax_forward::primitive_desc>
AcquireSoftmaxPrimitiveDescriptor(const softmax_forward::desc& softmax_desc,
const mkldnn::engine& engine) {
const std::string key_softmax_pd = key_ + "@softmax_pd";
// Softmax PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_softmax_pd = key_common_ + "@softmax_pd";
auto softmax_pd = std::static_pointer_cast<softmax_forward::primitive_desc>(
softmax_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
dev_ctx_.GetBlob(key_softmax_pd));
if (softmax_pd == nullptr) {
softmax_pd_.reset(
new softmax_forward::primitive_desc(softmax_desc, engine));
dev_ctx_.SetBlob(key_softmax_pd, softmax_pd_);
} else {
softmax_pd_ = softmax_pd;
is_reusing_ = true;
if (softmax_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
softmax_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
dev_ctx_.GetBlob(key_softmax_pd));
if (softmax_pd_ == nullptr) {
softmax_pd_.reset(
new softmax_forward::primitive_desc(softmax_desc, engine));
dev_ctx_.SetBlob(key_softmax_pd, softmax_pd_);
}
}
return softmax_pd_;
......@@ -79,15 +85,11 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
auto softmax_p = std::static_pointer_cast<mkldnn::softmax_forward>(
dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((softmax_p != nullptr) || (is_reusing_ == false),
"Fail to find softmax primitive in device context");
if (softmax_p == nullptr) {
softmax_p = std::make_shared<mkldnn::softmax_forward>(
*softmax_pd_, *(static_cast<mkldnn::memory*>(src_memory_p.get())),
*(static_cast<mkldnn::memory*>(dst_memory_p.get())));
dev_ctx_.SetBlob(prim_key, softmax_p);
} else {
is_reusing_ = true;
}
return softmax_p;
......@@ -100,15 +102,11 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
auto prim_key = key_ + "@softmax_bwd_p";
auto softmax_bwd_p = std::static_pointer_cast<mkldnn::softmax_backward>(
dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((softmax_bwd_p != nullptr) || (is_reusing_ == false),
"Fail to find softmax backward primitive in device context");
if (softmax_bwd_p == nullptr) {
softmax_bwd_p = std::make_shared<mkldnn::softmax_backward>(
*softmax_bwd_pd_, *dst_memory_p, *diff_dst_memory_p,
*diff_src_memory_p);
dev_ctx_.SetBlob(prim_key, softmax_bwd_p);
} else {
is_reusing_ = true;
}
return softmax_bwd_p;
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "boost/optional.hpp"
......@@ -31,10 +32,13 @@ class MKLDNNHandler {
public:
MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key)
: dev_ctx_(dev_ctx),
engine_(engine),
key_(base_key),
is_reusing_(false) {}
: dev_ctx_(dev_ctx), engine_(engine), key_common_(base_key) {
// TODO(jczaja): Make it faster
auto tid = std::this_thread::get_id();
std::stringstream ss;
ss << tid;
key_ = key_common_ + "-t:" + ss.str();
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const mkldnn::memory::desc& md, void* ptr) {
......@@ -73,16 +77,11 @@ class MKLDNNHandler {
auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
"Fail to find mem primitive in device context");
if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(mdp, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
}
return mem_p;
}
......@@ -96,8 +95,6 @@ class MKLDNNHandler {
auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
"Fail to find mem primitive in device context");
if (mem_p == nullptr) {
// Call custom reorder/preprocessing func if available
if (custom_func) {
......@@ -111,9 +108,6 @@ class MKLDNNHandler {
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
}
return mem_p;
}
......@@ -155,8 +149,6 @@ class MKLDNNHandler {
auto target_memory_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((target_memory_p != nullptr) || (is_reusing_ == false),
"Fail to find mem primitive in device context");
if (target_memory_p == nullptr) {
target_memory_p = user_memory_p;
std::shared_ptr<mkldnn::primitive> reorder_p;
......@@ -187,7 +179,6 @@ class MKLDNNHandler {
if (reorder_p != nullptr) {
pipeline.push_back(*reorder_p);
}
is_reusing_ = true;
}
return target_memory_p;
}
......@@ -268,7 +259,7 @@ class MKLDNNHandler {
const MKLDNNDeviceContext& dev_ctx_;
mkldnn::engine engine_;
std::string key_;
bool is_reusing_;
std::string key_common_;
public:
static constexpr int MaxKeyLength = 256;
......@@ -290,8 +281,6 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
auto local_key = key_ + "@user_src_mem_p";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
" find mem primitive in device context");
if (mem_p == nullptr) {
// Make memory descriptor using input format, unless it
// cannot be trusted (nchw) then make up memory fmt manually
......@@ -307,9 +296,6 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
}
return mem_p;
}
......@@ -319,8 +305,6 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
auto local_key = key_ + "@user_dst_mem_p";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
" find mem primitive in device context");
if (mem_p == nullptr) {
auto dst_mdp = mkldnn::memory::primitive_desc{
Axis2MemoryDesc(dims_, axis_), engine_};
......@@ -332,9 +316,6 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
} else {
auto dst_data = output->mutable_data<float>(place);
mem_p->set_data_handle(dst_data);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
}
return mem_p;
}
......@@ -345,14 +326,10 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
auto prim_key = key_ + "@transpose_p";
auto transpose_p =
std::static_pointer_cast<mkldnn::reorder>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((transpose_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution primitive in device context");
if (transpose_p == nullptr) {
transpose_p =
std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p));
dev_ctx_.SetBlob(prim_key, transpose_p);
} else {
is_reusing_ = true;
}
return transpose_p;
}
......@@ -416,8 +393,6 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
auto local_key = key_ + "@user_src_mem_p";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
" find mem primitive in device context");
if (mem_p == nullptr) {
auto src_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt);
mem_p = std::make_shared<mkldnn::memory>(
......@@ -425,7 +400,6 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
is_reusing_ = true;
}
return mem_p;
}
......@@ -436,8 +410,6 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
auto local_key = key_ + "@user_dst_mem_p";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
" find mem primitive in device context");
if (mem_p == nullptr) {
auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt);
auto dst_mdp = mkldnn::memory::primitive_desc{dst_md, engine_};
......@@ -449,7 +421,6 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
} else {
auto dst_data = output->mutable_data(place, vtype_);
mem_p->set_data_handle(dst_data);
is_reusing_ = true;
}
return mem_p;
}
......@@ -460,14 +431,10 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
auto prim_key = key_ + "@reorder_p";
auto reorder_p =
std::static_pointer_cast<mkldnn::reorder>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((reorder_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution primitive in device context");
if (reorder_p == nullptr) {
reorder_p =
std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p));
dev_ctx_.SetBlob(prim_key, reorder_p);
} else {
is_reusing_ = true;
}
return reorder_p;
}
......@@ -695,35 +662,43 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
const bool fuse_relu, const bool fuse_residual_conn,
const bool fuse_brelu, const float fuse_brelu_threshold,
mkldnn::prop_kind fwd_prop_kind) {
const std::string key_conv_pd = key_ + "@conv_pd";
// Conv PD has to be passed to Grad op that
// may be exxecuted by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_conv_pd = key_common_ + "@conv_pd";
auto conv_pd = std::static_pointer_cast<typename forward_t::primitive_desc>(
conv_pd_ = std::static_pointer_cast<typename forward_t::primitive_desc>(
dev_ctx_.GetBlob(key_conv_pd));
if (conv_pd == nullptr) {
mkldnn::memory::dims stride_dims = strides;
mkldnn::memory::dims padding_dims = paddings;
auto conv_desc =
bias ? typename forward_t::desc(
fwd_prop_kind, convolutional_algorithm<forward_t>::T, src,
weights, *bias, dst, stride_dims, padding_dims,
padding_dims, mkldnn::padding_kind::zero)
: typename forward_t::desc(
fwd_prop_kind, convolutional_algorithm<forward_t>::T, src,
weights, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps(
fuse_relu, fuse_residual_conn, fuse_brelu, fuse_brelu_threshold);
conv_pd_.reset(
new typename forward_t::primitive_desc(conv_desc, conv_attr, engine));
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx_.SetBlob(key_conv_pd, conv_pd_);
} else {
conv_pd_ = conv_pd;
is_reusing_ = true;
if (conv_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
conv_pd_ = std::static_pointer_cast<typename forward_t::primitive_desc>(
dev_ctx_.GetBlob(key_conv_pd));
if (conv_pd_ == nullptr) {
mkldnn::memory::dims stride_dims = strides;
mkldnn::memory::dims padding_dims = paddings;
auto conv_desc =
bias ? typename forward_t::desc(
fwd_prop_kind, convolutional_algorithm<forward_t>::T,
src, weights, *bias, dst, stride_dims, padding_dims,
padding_dims, mkldnn::padding_kind::zero)
: typename forward_t::desc(
fwd_prop_kind, convolutional_algorithm<forward_t>::T,
src, weights, dst, stride_dims, padding_dims,
padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps(
fuse_relu, fuse_residual_conn, fuse_brelu, fuse_brelu_threshold);
conv_pd_.reset(new typename forward_t::primitive_desc(
conv_desc, conv_attr, engine));
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx_.SetBlob(key_conv_pd, conv_pd_);
}
}
return conv_pd_;
......@@ -736,15 +711,11 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
auto prim_key = key_ + "@conv_p";
auto conv_p =
std::static_pointer_cast<forward_t>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution primitive in device context");
if (conv_p == nullptr) {
conv_p = std::make_shared<forward_t>(*conv_pd_, *src_memory_p,
*weights_memory_p, *dst_memory_p);
dev_ctx_.SetBlob(prim_key, conv_p);
} else {
is_reusing_ = true;
}
return conv_p;
}
......@@ -757,16 +728,12 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
auto prim_key = key_ + "@conv_p";
auto conv_p =
std::static_pointer_cast<forward_t>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution primitive in device context");
if (conv_p == nullptr) {
conv_p = std::make_shared<forward_t>(*conv_pd_, *src_memory_p,
*weights_memory_p, *bias_memory_p,
*dst_memory_p);
dev_ctx_.SetBlob(prim_key, conv_p);
} else {
is_reusing_ = true;
}
return conv_p;
}
......@@ -778,17 +745,12 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
auto prim_key = key_ + "@conv_bwd_weights_p";
auto conv_bwd_weights_p = std::static_pointer_cast<backward_weights_t>(
dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE(
(conv_bwd_weights_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution bwd weights primitive in device context");
if (conv_bwd_weights_p == nullptr) {
// create backward conv primitive for weights
conv_bwd_weights_p = std::make_shared<backward_weights_t>(
*conv_bwd_weights_pd_, *src_memory_p, *diff_dst_memory_p,
*diff_weights_memory_p);
dev_ctx_.SetBlob(prim_key, conv_bwd_weights_p);
} else {
is_reusing_ = true;
}
return conv_bwd_weights_p;
}
......@@ -800,16 +762,11 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
auto prim_key = key_ + "@conv_bwd_data_p";
auto conv_bwd_data_p =
std::static_pointer_cast<backward_data_t>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE(
(conv_bwd_data_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution bwd data primitive in device context");
if (conv_bwd_data_p == nullptr) {
conv_bwd_data_p = std::make_shared<backward_data_t>(
*conv_bwd_data_pd_, *diff_dst_memory_p, *weights_memory_p,
*diff_src_memory_p);
dev_ctx_.SetBlob(prim_key, conv_bwd_data_p);
} else {
is_reusing_ = true;
}
return conv_bwd_data_p;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册