提交 84bb45c0 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] Thread-Safety for MKL-DNN reusing Part 1 (#17965)

* - removed is_reusing_

* - Added TID to keys for reusing apart from softmax PD

* - compilation fix

* - Yet another compilation fix

* - Batch Norm and Conv adapted

* - Fix to softmax MT

* - Fixes to MT code of MKL-DNN

* - Lint fixes

test=develop
上级 da9143c1
...@@ -61,20 +61,25 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -61,20 +61,25 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
std::shared_ptr<batch_norm_fwd::primitive_desc> std::shared_ptr<batch_norm_fwd::primitive_desc>
AcquireBatchNormPrimitiveDescriptor(const batch_norm_fwd::desc &bn_fwd_desc, AcquireBatchNormPrimitiveDescriptor(const batch_norm_fwd::desc &bn_fwd_desc,
const mkldnn::engine &engine) { const mkldnn::engine &engine) {
const std::string key_batch_norm_fwd_pd = key_ + "@bn_fwd_pd"; // BatchNorm PD has to be passed to Grad op that
auto batch_norm_pd = // may be executed by diffrent thread, hence
std::static_pointer_cast<batch_norm_fwd::primitive_desc>( // for that one we use key that does not contain TID
const std::string key_batch_norm_fwd_pd = key_common_ + "@bn_fwd_pd";
batch_norm_pd_ = std::static_pointer_cast<batch_norm_fwd::primitive_desc>(
dev_ctx_.GetBlob(key_batch_norm_fwd_pd)); dev_ctx_.GetBlob(key_batch_norm_fwd_pd));
if (batch_norm_pd == nullptr) { if (batch_norm_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
batch_norm_pd_ = std::static_pointer_cast<batch_norm_fwd::primitive_desc>(
dev_ctx_.GetBlob(key_batch_norm_fwd_pd));
if (batch_norm_pd_ == nullptr) {
batch_norm_pd_.reset( batch_norm_pd_.reset(
new batch_norm_fwd::primitive_desc(bn_fwd_desc, engine)); new batch_norm_fwd::primitive_desc(bn_fwd_desc, engine));
dev_ctx_.SetBlob(key_batch_norm_fwd_pd, batch_norm_pd_); dev_ctx_.SetBlob(key_batch_norm_fwd_pd, batch_norm_pd_);
} else {
batch_norm_pd_ = batch_norm_pd;
is_reusing_ = true;
} }
}
return batch_norm_pd_; return batch_norm_pd_;
} }
...@@ -87,9 +92,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -87,9 +92,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
auto batch_norm_p = auto batch_norm_p =
std::static_pointer_cast<batch_norm_fwd>(dev_ctx_.GetBlob(prim_key)); std::static_pointer_cast<batch_norm_fwd>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((batch_norm_p != nullptr) || !is_reusing_,
"Fail to find batch norm primitive in device context");
if (batch_norm_p == nullptr) { if (batch_norm_p == nullptr) {
if (is_test) { if (is_test) {
batch_norm_p = std::make_shared<batch_norm_fwd>( batch_norm_p = std::make_shared<batch_norm_fwd>(
...@@ -104,8 +106,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -104,8 +106,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
} }
dev_ctx_.SetBlob(prim_key, batch_norm_p); dev_ctx_.SetBlob(prim_key, batch_norm_p);
} else {
is_reusing_ = true;
} }
return batch_norm_p; return batch_norm_p;
......
...@@ -54,18 +54,24 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -54,18 +54,24 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
std::shared_ptr<softmax_forward::primitive_desc> std::shared_ptr<softmax_forward::primitive_desc>
AcquireSoftmaxPrimitiveDescriptor(const softmax_forward::desc& softmax_desc, AcquireSoftmaxPrimitiveDescriptor(const softmax_forward::desc& softmax_desc,
const mkldnn::engine& engine) { const mkldnn::engine& engine) {
const std::string key_softmax_pd = key_ + "@softmax_pd"; // Softmax PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_softmax_pd = key_common_ + "@softmax_pd";
auto softmax_pd = std::static_pointer_cast<softmax_forward::primitive_desc>( softmax_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
dev_ctx_.GetBlob(key_softmax_pd)); dev_ctx_.GetBlob(key_softmax_pd));
if (softmax_pd_ == nullptr) {
if (softmax_pd == nullptr) { static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
softmax_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
dev_ctx_.GetBlob(key_softmax_pd));
if (softmax_pd_ == nullptr) {
softmax_pd_.reset( softmax_pd_.reset(
new softmax_forward::primitive_desc(softmax_desc, engine)); new softmax_forward::primitive_desc(softmax_desc, engine));
dev_ctx_.SetBlob(key_softmax_pd, softmax_pd_); dev_ctx_.SetBlob(key_softmax_pd, softmax_pd_);
} else { }
softmax_pd_ = softmax_pd;
is_reusing_ = true;
} }
return softmax_pd_; return softmax_pd_;
...@@ -79,15 +85,11 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -79,15 +85,11 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
auto softmax_p = std::static_pointer_cast<mkldnn::softmax_forward>( auto softmax_p = std::static_pointer_cast<mkldnn::softmax_forward>(
dev_ctx_.GetBlob(prim_key)); dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((softmax_p != nullptr) || (is_reusing_ == false),
"Fail to find softmax primitive in device context");
if (softmax_p == nullptr) { if (softmax_p == nullptr) {
softmax_p = std::make_shared<mkldnn::softmax_forward>( softmax_p = std::make_shared<mkldnn::softmax_forward>(
*softmax_pd_, *(static_cast<mkldnn::memory*>(src_memory_p.get())), *softmax_pd_, *(static_cast<mkldnn::memory*>(src_memory_p.get())),
*(static_cast<mkldnn::memory*>(dst_memory_p.get()))); *(static_cast<mkldnn::memory*>(dst_memory_p.get())));
dev_ctx_.SetBlob(prim_key, softmax_p); dev_ctx_.SetBlob(prim_key, softmax_p);
} else {
is_reusing_ = true;
} }
return softmax_p; return softmax_p;
...@@ -100,15 +102,11 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -100,15 +102,11 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
auto prim_key = key_ + "@softmax_bwd_p"; auto prim_key = key_ + "@softmax_bwd_p";
auto softmax_bwd_p = std::static_pointer_cast<mkldnn::softmax_backward>( auto softmax_bwd_p = std::static_pointer_cast<mkldnn::softmax_backward>(
dev_ctx_.GetBlob(prim_key)); dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((softmax_bwd_p != nullptr) || (is_reusing_ == false),
"Fail to find softmax backward primitive in device context");
if (softmax_bwd_p == nullptr) { if (softmax_bwd_p == nullptr) {
softmax_bwd_p = std::make_shared<mkldnn::softmax_backward>( softmax_bwd_p = std::make_shared<mkldnn::softmax_backward>(
*softmax_bwd_pd_, *dst_memory_p, *diff_dst_memory_p, *softmax_bwd_pd_, *dst_memory_p, *diff_dst_memory_p,
*diff_src_memory_p); *diff_src_memory_p);
dev_ctx_.SetBlob(prim_key, softmax_bwd_p); dev_ctx_.SetBlob(prim_key, softmax_bwd_p);
} else {
is_reusing_ = true;
} }
return softmax_bwd_p; return softmax_bwd_p;
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <memory> #include <memory>
#include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include "boost/optional.hpp" #include "boost/optional.hpp"
...@@ -31,10 +32,13 @@ class MKLDNNHandler { ...@@ -31,10 +32,13 @@ class MKLDNNHandler {
public: public:
MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key) const std::string& base_key)
: dev_ctx_(dev_ctx), : dev_ctx_(dev_ctx), engine_(engine), key_common_(base_key) {
engine_(engine), // TODO(jczaja): Make it faster
key_(base_key), auto tid = std::this_thread::get_id();
is_reusing_(false) {} std::stringstream ss;
ss << tid;
key_ = key_common_ + "-t:" + ss.str();
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory( std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const mkldnn::memory::desc& md, void* ptr) { const mkldnn::memory::desc& md, void* ptr) {
...@@ -73,16 +77,11 @@ class MKLDNNHandler { ...@@ -73,16 +77,11 @@ class MKLDNNHandler {
auto local_key = key_ + suffix; auto local_key = key_ + suffix;
auto mem_p = auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
"Fail to find mem primitive in device context");
if (mem_p == nullptr) { if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(mdp, ptr); mem_p = std::make_shared<mkldnn::memory>(mdp, ptr);
dev_ctx_.SetBlob(local_key, mem_p); dev_ctx_.SetBlob(local_key, mem_p);
} else { } else {
mem_p->set_data_handle(ptr); mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
} }
return mem_p; return mem_p;
} }
...@@ -96,8 +95,6 @@ class MKLDNNHandler { ...@@ -96,8 +95,6 @@ class MKLDNNHandler {
auto local_key = key_ + suffix; auto local_key = key_ + suffix;
auto mem_p = auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
"Fail to find mem primitive in device context");
if (mem_p == nullptr) { if (mem_p == nullptr) {
// Call custom reorder/preprocessing func if available // Call custom reorder/preprocessing func if available
if (custom_func) { if (custom_func) {
...@@ -111,9 +108,6 @@ class MKLDNNHandler { ...@@ -111,9 +108,6 @@ class MKLDNNHandler {
dev_ctx_.SetBlob(local_key, mem_p); dev_ctx_.SetBlob(local_key, mem_p);
} else { } else {
mem_p->set_data_handle(ptr); mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
} }
return mem_p; return mem_p;
} }
...@@ -155,8 +149,6 @@ class MKLDNNHandler { ...@@ -155,8 +149,6 @@ class MKLDNNHandler {
auto target_memory_p = auto target_memory_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((target_memory_p != nullptr) || (is_reusing_ == false),
"Fail to find mem primitive in device context");
if (target_memory_p == nullptr) { if (target_memory_p == nullptr) {
target_memory_p = user_memory_p; target_memory_p = user_memory_p;
std::shared_ptr<mkldnn::primitive> reorder_p; std::shared_ptr<mkldnn::primitive> reorder_p;
...@@ -187,7 +179,6 @@ class MKLDNNHandler { ...@@ -187,7 +179,6 @@ class MKLDNNHandler {
if (reorder_p != nullptr) { if (reorder_p != nullptr) {
pipeline.push_back(*reorder_p); pipeline.push_back(*reorder_p);
} }
is_reusing_ = true;
} }
return target_memory_p; return target_memory_p;
} }
...@@ -268,7 +259,7 @@ class MKLDNNHandler { ...@@ -268,7 +259,7 @@ class MKLDNNHandler {
const MKLDNNDeviceContext& dev_ctx_; const MKLDNNDeviceContext& dev_ctx_;
mkldnn::engine engine_; mkldnn::engine engine_;
std::string key_; std::string key_;
bool is_reusing_; std::string key_common_;
public: public:
static constexpr int MaxKeyLength = 256; static constexpr int MaxKeyLength = 256;
...@@ -290,8 +281,6 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { ...@@ -290,8 +281,6 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
auto local_key = key_ + "@user_src_mem_p"; auto local_key = key_ + "@user_src_mem_p";
auto mem_p = auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
" find mem primitive in device context");
if (mem_p == nullptr) { if (mem_p == nullptr) {
// Make memory descriptor using input format, unless it // Make memory descriptor using input format, unless it
// cannot be trusted (nchw) then make up memory fmt manually // cannot be trusted (nchw) then make up memory fmt manually
...@@ -307,9 +296,6 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { ...@@ -307,9 +296,6 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
dev_ctx_.SetBlob(local_key, mem_p); dev_ctx_.SetBlob(local_key, mem_p);
} else { } else {
mem_p->set_data_handle(ptr); mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
} }
return mem_p; return mem_p;
} }
...@@ -319,8 +305,6 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { ...@@ -319,8 +305,6 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
auto local_key = key_ + "@user_dst_mem_p"; auto local_key = key_ + "@user_dst_mem_p";
auto mem_p = auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
" find mem primitive in device context");
if (mem_p == nullptr) { if (mem_p == nullptr) {
auto dst_mdp = mkldnn::memory::primitive_desc{ auto dst_mdp = mkldnn::memory::primitive_desc{
Axis2MemoryDesc(dims_, axis_), engine_}; Axis2MemoryDesc(dims_, axis_), engine_};
...@@ -332,9 +316,6 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { ...@@ -332,9 +316,6 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
} else { } else {
auto dst_data = output->mutable_data<float>(place); auto dst_data = output->mutable_data<float>(place);
mem_p->set_data_handle(dst_data); mem_p->set_data_handle(dst_data);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
} }
return mem_p; return mem_p;
} }
...@@ -345,14 +326,10 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { ...@@ -345,14 +326,10 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
auto prim_key = key_ + "@transpose_p"; auto prim_key = key_ + "@transpose_p";
auto transpose_p = auto transpose_p =
std::static_pointer_cast<mkldnn::reorder>(dev_ctx_.GetBlob(prim_key)); std::static_pointer_cast<mkldnn::reorder>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((transpose_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution primitive in device context");
if (transpose_p == nullptr) { if (transpose_p == nullptr) {
transpose_p = transpose_p =
std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p)); std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p));
dev_ctx_.SetBlob(prim_key, transpose_p); dev_ctx_.SetBlob(prim_key, transpose_p);
} else {
is_reusing_ = true;
} }
return transpose_p; return transpose_p;
} }
...@@ -416,8 +393,6 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { ...@@ -416,8 +393,6 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
auto local_key = key_ + "@user_src_mem_p"; auto local_key = key_ + "@user_src_mem_p";
auto mem_p = auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
" find mem primitive in device context");
if (mem_p == nullptr) { if (mem_p == nullptr) {
auto src_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt); auto src_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt);
mem_p = std::make_shared<mkldnn::memory>( mem_p = std::make_shared<mkldnn::memory>(
...@@ -425,7 +400,6 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { ...@@ -425,7 +400,6 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
dev_ctx_.SetBlob(local_key, mem_p); dev_ctx_.SetBlob(local_key, mem_p);
} else { } else {
mem_p->set_data_handle(ptr); mem_p->set_data_handle(ptr);
is_reusing_ = true;
} }
return mem_p; return mem_p;
} }
...@@ -436,8 +410,6 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { ...@@ -436,8 +410,6 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
auto local_key = key_ + "@user_dst_mem_p"; auto local_key = key_ + "@user_dst_mem_p";
auto mem_p = auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
" find mem primitive in device context");
if (mem_p == nullptr) { if (mem_p == nullptr) {
auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt); auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt);
auto dst_mdp = mkldnn::memory::primitive_desc{dst_md, engine_}; auto dst_mdp = mkldnn::memory::primitive_desc{dst_md, engine_};
...@@ -449,7 +421,6 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { ...@@ -449,7 +421,6 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
} else { } else {
auto dst_data = output->mutable_data(place, vtype_); auto dst_data = output->mutable_data(place, vtype_);
mem_p->set_data_handle(dst_data); mem_p->set_data_handle(dst_data);
is_reusing_ = true;
} }
return mem_p; return mem_p;
} }
...@@ -460,14 +431,10 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { ...@@ -460,14 +431,10 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
auto prim_key = key_ + "@reorder_p"; auto prim_key = key_ + "@reorder_p";
auto reorder_p = auto reorder_p =
std::static_pointer_cast<mkldnn::reorder>(dev_ctx_.GetBlob(prim_key)); std::static_pointer_cast<mkldnn::reorder>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((reorder_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution primitive in device context");
if (reorder_p == nullptr) { if (reorder_p == nullptr) {
reorder_p = reorder_p =
std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p)); std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p));
dev_ctx_.SetBlob(prim_key, reorder_p); dev_ctx_.SetBlob(prim_key, reorder_p);
} else {
is_reusing_ = true;
} }
return reorder_p; return reorder_p;
} }
...@@ -695,35 +662,43 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -695,35 +662,43 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
const bool fuse_relu, const bool fuse_residual_conn, const bool fuse_relu, const bool fuse_residual_conn,
const bool fuse_brelu, const float fuse_brelu_threshold, const bool fuse_brelu, const float fuse_brelu_threshold,
mkldnn::prop_kind fwd_prop_kind) { mkldnn::prop_kind fwd_prop_kind) {
const std::string key_conv_pd = key_ + "@conv_pd"; // Conv PD has to be passed to Grad op that
// may be exxecuted by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_conv_pd = key_common_ + "@conv_pd";
auto conv_pd = std::static_pointer_cast<typename forward_t::primitive_desc>( conv_pd_ = std::static_pointer_cast<typename forward_t::primitive_desc>(
dev_ctx_.GetBlob(key_conv_pd)); dev_ctx_.GetBlob(key_conv_pd));
if (conv_pd == nullptr) { if (conv_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
conv_pd_ = std::static_pointer_cast<typename forward_t::primitive_desc>(
dev_ctx_.GetBlob(key_conv_pd));
if (conv_pd_ == nullptr) {
mkldnn::memory::dims stride_dims = strides; mkldnn::memory::dims stride_dims = strides;
mkldnn::memory::dims padding_dims = paddings; mkldnn::memory::dims padding_dims = paddings;
auto conv_desc = auto conv_desc =
bias ? typename forward_t::desc( bias ? typename forward_t::desc(
fwd_prop_kind, convolutional_algorithm<forward_t>::T, src, fwd_prop_kind, convolutional_algorithm<forward_t>::T,
weights, *bias, dst, stride_dims, padding_dims, src, weights, *bias, dst, stride_dims, padding_dims,
padding_dims, mkldnn::padding_kind::zero) padding_dims, mkldnn::padding_kind::zero)
: typename forward_t::desc( : typename forward_t::desc(
fwd_prop_kind, convolutional_algorithm<forward_t>::T, src, fwd_prop_kind, convolutional_algorithm<forward_t>::T,
weights, dst, stride_dims, padding_dims, padding_dims, src, weights, dst, stride_dims, padding_dims,
mkldnn::padding_kind::zero); padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = CreatePostOps( mkldnn::primitive_attr conv_attr = CreatePostOps(
fuse_relu, fuse_residual_conn, fuse_brelu, fuse_brelu_threshold); fuse_relu, fuse_residual_conn, fuse_brelu, fuse_brelu_threshold);
conv_pd_.reset( conv_pd_.reset(new typename forward_t::primitive_desc(
new typename forward_t::primitive_desc(conv_desc, conv_attr, engine)); conv_desc, conv_attr, engine));
// Save conv_pd/src_memory/weights_memory for backward pass // Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx_.SetBlob(key_conv_pd, conv_pd_); dev_ctx_.SetBlob(key_conv_pd, conv_pd_);
} else { }
conv_pd_ = conv_pd;
is_reusing_ = true;
} }
return conv_pd_; return conv_pd_;
...@@ -736,15 +711,11 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -736,15 +711,11 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
auto prim_key = key_ + "@conv_p"; auto prim_key = key_ + "@conv_p";
auto conv_p = auto conv_p =
std::static_pointer_cast<forward_t>(dev_ctx_.GetBlob(prim_key)); std::static_pointer_cast<forward_t>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution primitive in device context");
if (conv_p == nullptr) { if (conv_p == nullptr) {
conv_p = std::make_shared<forward_t>(*conv_pd_, *src_memory_p, conv_p = std::make_shared<forward_t>(*conv_pd_, *src_memory_p,
*weights_memory_p, *dst_memory_p); *weights_memory_p, *dst_memory_p);
dev_ctx_.SetBlob(prim_key, conv_p); dev_ctx_.SetBlob(prim_key, conv_p);
} else {
is_reusing_ = true;
} }
return conv_p; return conv_p;
} }
...@@ -757,16 +728,12 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -757,16 +728,12 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
auto prim_key = key_ + "@conv_p"; auto prim_key = key_ + "@conv_p";
auto conv_p = auto conv_p =
std::static_pointer_cast<forward_t>(dev_ctx_.GetBlob(prim_key)); std::static_pointer_cast<forward_t>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution primitive in device context");
if (conv_p == nullptr) { if (conv_p == nullptr) {
conv_p = std::make_shared<forward_t>(*conv_pd_, *src_memory_p, conv_p = std::make_shared<forward_t>(*conv_pd_, *src_memory_p,
*weights_memory_p, *bias_memory_p, *weights_memory_p, *bias_memory_p,
*dst_memory_p); *dst_memory_p);
dev_ctx_.SetBlob(prim_key, conv_p); dev_ctx_.SetBlob(prim_key, conv_p);
} else {
is_reusing_ = true;
} }
return conv_p; return conv_p;
} }
...@@ -778,17 +745,12 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -778,17 +745,12 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
auto prim_key = key_ + "@conv_bwd_weights_p"; auto prim_key = key_ + "@conv_bwd_weights_p";
auto conv_bwd_weights_p = std::static_pointer_cast<backward_weights_t>( auto conv_bwd_weights_p = std::static_pointer_cast<backward_weights_t>(
dev_ctx_.GetBlob(prim_key)); dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE(
(conv_bwd_weights_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution bwd weights primitive in device context");
if (conv_bwd_weights_p == nullptr) { if (conv_bwd_weights_p == nullptr) {
// create backward conv primitive for weights // create backward conv primitive for weights
conv_bwd_weights_p = std::make_shared<backward_weights_t>( conv_bwd_weights_p = std::make_shared<backward_weights_t>(
*conv_bwd_weights_pd_, *src_memory_p, *diff_dst_memory_p, *conv_bwd_weights_pd_, *src_memory_p, *diff_dst_memory_p,
*diff_weights_memory_p); *diff_weights_memory_p);
dev_ctx_.SetBlob(prim_key, conv_bwd_weights_p); dev_ctx_.SetBlob(prim_key, conv_bwd_weights_p);
} else {
is_reusing_ = true;
} }
return conv_bwd_weights_p; return conv_bwd_weights_p;
} }
...@@ -800,16 +762,11 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -800,16 +762,11 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
auto prim_key = key_ + "@conv_bwd_data_p"; auto prim_key = key_ + "@conv_bwd_data_p";
auto conv_bwd_data_p = auto conv_bwd_data_p =
std::static_pointer_cast<backward_data_t>(dev_ctx_.GetBlob(prim_key)); std::static_pointer_cast<backward_data_t>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE(
(conv_bwd_data_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution bwd data primitive in device context");
if (conv_bwd_data_p == nullptr) { if (conv_bwd_data_p == nullptr) {
conv_bwd_data_p = std::make_shared<backward_data_t>( conv_bwd_data_p = std::make_shared<backward_data_t>(
*conv_bwd_data_pd_, *diff_dst_memory_p, *weights_memory_p, *conv_bwd_data_pd_, *diff_dst_memory_p, *weights_memory_p,
*diff_src_memory_p); *diff_src_memory_p);
dev_ctx_.SetBlob(prim_key, conv_bwd_data_p); dev_ctx_.SetBlob(prim_key, conv_bwd_data_p);
} else {
is_reusing_ = true;
} }
return conv_bwd_data_p; return conv_bwd_data_p;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册