提交 dfdd73cb 编写于 作者: A Adam 提交者: Tao Luo

Add MKLDNNhandlerT templatized class (#19801)

test=develop
上级 cabb9501
......@@ -33,17 +33,18 @@ using mkldnn::stream;
using platform::to_void_cast;
template <typename T>
class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
class SoftmaxMKLDNNHandler
: public platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward> {
public:
SoftmaxMKLDNNHandler(const std::vector<int>& dims,
const MKLDNNMemoryFormat fmt,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const std::string& uniq_name)
: platform::MKLDNNHandler(dev_ctx, dev_ctx.GetEngine(),
platform::CreateKey(dims, uniq_name)),
place_(cpu_place),
fwd_pd_(nullptr),
bwd_pd_(nullptr) {
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, uniq_name)) {
this->AcquireSoftmaxPrimitiveDescriptor(dims, fmt);
}
......@@ -52,11 +53,10 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
const MKLDNNMemoryFormat diff_fmt,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const std::string& uniq_name)
: platform::MKLDNNHandler(dev_ctx, dev_ctx.GetEngine(),
platform::CreateKey(dims, uniq_name)),
place_(cpu_place),
fwd_pd_(nullptr),
bwd_pd_(nullptr) {
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, uniq_name)) {
// If we are in Grad operatgor then update a key with BWD suffix to
// distinguish from FWD memory primitives
// Key_common will allow to access FWD_PD from cache
......@@ -64,58 +64,19 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
this->AcquireSoftmaxBackwardPrimitiveDescriptor(dims, fmt, diff_fmt);
}
// TODO(jczaja): Once fwd_pd_ are moved to MKLDNNHandler then this function
// should be moved as well eg. SoftmaxMKLDNNHandler -> MKLDNNHandler<softmax_>
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(const Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(fwd_pd_->src_primitive_desc(),
to_void_cast<T>(input_data),
"@src_mem_p");
}
// TODO(jczaja): Move to MKLDNNHandler as common code
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output) {
T* ptr = output->mutable_data<T>(place_,
fwd_pd_->dst_primitive_desc().get_size());
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr,
"@dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDstMemory(const Tensor* output) {
const T* output_data = output->data<T>();
return this->AcquireMemoryFromPrimitive(bwd_pd_->dst_primitive_desc(),
to_void_cast<T>(output_data),
"@bwd-dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory(const Tensor* diffdst) {
const T* ptr = diffdst->data<T>();
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_dst_primitive_desc(),
to_void_cast<T>(ptr),
"@diff_dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory(
framework::Tensor* diffsrc) {
T* ptr = diffsrc->mutable_data<T>(
place_, bwd_pd_->diff_src_primitive_desc().get_size());
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_primitive_desc(),
ptr, "@diff_src_mem_p");
}
std::shared_ptr<mkldnn::softmax_forward> AcquireSoftmax(
std::shared_ptr<mkldnn::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> src_memory_p) {
/*Generate key*/
auto prim_key = key_ + "@softmax_p";
auto prim_key = this->key_ + "@softmax_p";
auto softmax_p = std::static_pointer_cast<mkldnn::softmax_forward>(
dev_ctx_.GetBlob(prim_key));
this->dev_ctx_.GetBlob(prim_key));
if (softmax_p == nullptr) {
softmax_p = std::make_shared<mkldnn::softmax_forward>(
*fwd_pd_, *(static_cast<mkldnn::memory*>(src_memory_p.get())),
*this->fwd_pd_, *(static_cast<mkldnn::memory*>(src_memory_p.get())),
*(static_cast<mkldnn::memory*>(dst_memory_p.get())));
dev_ctx_.SetBlob(prim_key, softmax_p);
this->dev_ctx_.SetBlob(prim_key, softmax_p);
}
return softmax_p;
......@@ -125,13 +86,14 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
std::shared_ptr<mkldnn::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
std::shared_ptr<mkldnn::memory> diff_src_memory_p) {
auto prim_key = key_ + "@softmax_bwd_p";
auto prim_key = this->key_ + "@softmax_bwd_p";
auto softmax_bwd_p = std::static_pointer_cast<mkldnn::softmax_backward>(
dev_ctx_.GetBlob(prim_key));
this->dev_ctx_.GetBlob(prim_key));
if (softmax_bwd_p == nullptr) {
softmax_bwd_p = std::make_shared<mkldnn::softmax_backward>(
*bwd_pd_, *dst_memory_p, *diff_dst_memory_p, *diff_src_memory_p);
dev_ctx_.SetBlob(prim_key, softmax_bwd_p);
*this->bwd_pd_, *dst_memory_p, *diff_dst_memory_p,
*diff_src_memory_p);
this->dev_ctx_.SetBlob(prim_key, softmax_bwd_p);
}
return softmax_bwd_p;
......@@ -143,17 +105,17 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
// Softmax PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_softmax_pd = key_common_ + "@softmax_pd";
const std::string key_softmax_pd = this->key_common_ + "@softmax_pd";
fwd_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
dev_ctx_.GetBlob(key_softmax_pd));
if (fwd_pd_ == nullptr) {
this->fwd_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
this->dev_ctx_.GetBlob(key_softmax_pd));
if (this->fwd_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
fwd_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
dev_ctx_.GetBlob(key_softmax_pd));
if (fwd_pd_ == nullptr) {
this->fwd_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
this->dev_ctx_.GetBlob(key_softmax_pd));
if (this->fwd_pd_ == nullptr) {
// TODO(jczaja): Make it working along chosen axis and for
// forward_training
// Normalization is made after innermost dimension eg. C out of NC
......@@ -161,9 +123,9 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto softmax_desc =
softmax_forward::desc(prop_kind::forward_scoring, md, 1 /*dim: C*/);
fwd_pd_.reset(
new softmax_forward::primitive_desc(softmax_desc, engine_));
dev_ctx_.SetBlob(key_softmax_pd, fwd_pd_);
this->fwd_pd_.reset(
new softmax_forward::primitive_desc(softmax_desc, this->engine_));
this->dev_ctx_.SetBlob(key_softmax_pd, this->fwd_pd_);
}
}
}
......@@ -172,12 +134,12 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
const std::vector<int>& dims, const mkldnn::memory::format fmt,
const mkldnn::memory::format diff_fmt) {
// Fwd_PD_ has to exists when to create BWD_PD_
PADDLE_ENFORCE_NOT_NULL(fwd_pd_);
const std::string key_bwd_pd = key_ + "@softmax_bwd_pd";
bwd_pd_ =
PADDLE_ENFORCE_NOT_NULL(this->fwd_pd_);
const std::string key_bwd_pd = this->key_ + "@softmax_bwd_pd";
this->bwd_pd_ =
std::static_pointer_cast<mkldnn::softmax_backward::primitive_desc>(
dev_ctx_.GetBlob(key_bwd_pd));
if (bwd_pd_ == nullptr) {
this->dev_ctx_.GetBlob(key_bwd_pd));
if (this->bwd_pd_ == nullptr) {
auto data_softmax_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto diff_softmax_md = mkldnn::memory::desc(
......@@ -185,16 +147,11 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
// TODO(jczaja): Add support for other axes
auto backward_desc = softmax_backward::desc(
diff_softmax_md, data_softmax_md, 1 /* dim: C*/);
bwd_pd_.reset(new mkldnn::softmax_backward::primitive_desc(
backward_desc, engine_, *fwd_pd_));
dev_ctx_.SetBlob(key_bwd_pd, bwd_pd_);
this->bwd_pd_.reset(new mkldnn::softmax_backward::primitive_desc(
backward_desc, this->engine_, *this->fwd_pd_));
this->dev_ctx_.SetBlob(key_bwd_pd, this->bwd_pd_);
}
}
private:
platform::Place place_;
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> fwd_pd_;
std::shared_ptr<mkldnn::softmax_backward::primitive_desc> bwd_pd_;
};
template <typename T>
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/place.h"
......@@ -206,7 +207,7 @@ inline std::string CreateKey(ArgTypes&&... args) {
std::string key;
key.reserve(256);
using expand_type = int[];
expand_type{0, (AppendKey(&key, args), 0)...};
expand_type{0, (AppendKey(&key, std::forward<ArgTypes>(args)), 0)...};
return key;
}
......
......@@ -29,6 +29,90 @@ namespace platform {
using user_function = std::function<std::shared_ptr<float>(const float*)>;
using memory = mkldnn::memory;
template <typename T, typename TForward, typename TBackward>
class MKLDNNHandlerT {
public:
MKLDNNHandlerT(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
platform::Place cpu_place, const std::string& base_key)
: dev_ctx_(dev_ctx),
engine_(engine),
place_(cpu_place),
key_common_(base_key),
fwd_pd_(nullptr),
bwd_pd_(nullptr) {
if (platform::get_cur_mkldnn_session_id() !=
platform::kMKLDNNSessionID_Default) {
key_ = key_common_;
} else {
key_ = key_common_ + "-t:" + ThreadIDasStr();
}
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(fwd_pd_->src_primitive_desc(),
to_void_cast<T>(input_data),
"@src_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output) {
T* ptr = output->mutable_data<T>(place_,
fwd_pd_->dst_primitive_desc().get_size());
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr,
"@dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDstMemory(
const framework::Tensor* output) {
const T* output_data = output->data<T>();
return this->AcquireMemoryFromPrimitive(bwd_pd_->dst_primitive_desc(),
to_void_cast<T>(output_data),
"@bwd-dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory(
const framework::Tensor* diffdst) {
const T* ptr = diffdst->data<T>();
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_dst_primitive_desc(),
to_void_cast<T>(ptr),
"@diff_dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory(
framework::Tensor* diffsrc) {
T* ptr = diffsrc->mutable_data<T>(
place_, bwd_pd_->diff_src_primitive_desc().get_size());
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_primitive_desc(),
ptr, "@diff_src_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::primitive_desc mdp, void* ptr,
const std::string& suffix) {
auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(mdp, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
}
return mem_p;
}
protected:
const MKLDNNDeviceContext& dev_ctx_;
mkldnn::engine engine_;
platform::Place place_;
std::string key_;
std::string key_common_;
std::shared_ptr<typename TForward::primitive_desc> fwd_pd_;
std::shared_ptr<typename TBackward::primitive_desc> bwd_pd_;
};
// TODO(grygielski) this class will be deleted later.
class MKLDNNHandler {
public:
MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
......@@ -255,7 +339,9 @@ class SumMKLDNNHandler : public MKLDNNHandler {
};
template <typename T>
class ActivationMKLDNNHandler : public MKLDNNHandler {
class ActivationMKLDNNHandler
: public MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward> {
public:
ActivationMKLDNNHandler(const std::vector<int>& dims,
mkldnn::algorithm algorithm, float alpha, float beta,
......@@ -264,12 +350,11 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
platform::Place cpu_place,
const std::string& unique_name)
: platform::MKLDNNHandler(dev_ctx, dev_ctx.GetEngine(),
platform::CreateKey(dims, algorithm, fmt, alpha,
beta, unique_name)),
place_(cpu_place),
fwd_pd_(nullptr),
bwd_pd_(nullptr) {
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, algorithm, fmt, alpha, beta,
unique_name)) {
AcquireActivationPrimitiveDescriptor(
is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
......@@ -284,76 +369,37 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
platform::Place cpu_place,
const std::string& unique_name)
: platform::MKLDNNHandler(dev_ctx, dev_ctx.GetEngine(),
platform::CreateKey(dims, algorithm, fmt, alpha,
beta, unique_name)),
place_(cpu_place),
fwd_pd_(nullptr),
bwd_pd_(nullptr) {
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, algorithm, fmt, alpha, beta,
unique_name)) {
AcquireActivationPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
algorithm, dims, fmt, alpha, beta);
AcquireActivationBackwardPrimitiveDescriptor(algorithm, dims, fmt, diff_fmt,
alpha, beta);
}
// TODO(jczaja): Once fwd_pd_ are moved to MKLDNNHandler then this
// function
// should be moved as well eg. ActivationMKLDNNHandler ->
// MKLDNNHandler<activation_>
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(fwd_pd_->src_primitive_desc(),
to_void_cast<T>(input_data),
"@src_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireBackwardSrcMemory(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(bwd_pd_->src_primitive_desc(),
return this->AcquireMemoryFromPrimitive(this->bwd_pd_->src_primitive_desc(),
to_void_cast<T>(input_data),
"@bwd-src_mem_p");
}
// TODO(jczaja): Move to MKLDNNHandler as common code
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output) {
T* ptr = output->mutable_data<T>(place_,
fwd_pd_->dst_primitive_desc().get_size());
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr,
"@dst_mem_p");
}
// TODO(jczaja): Move to MKLDNNHandler as common code
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory(
const framework::Tensor* diffdst) {
const T* ptr = diffdst->data<T>();
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_dst_primitive_desc(),
to_void_cast<T>(ptr),
"@diff_dst_mem_p");
}
// TODO(jczaja): Move to MKLDNNHandler as common code
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory(
framework::Tensor* diffsrc) {
T* ptr = diffsrc->mutable_data<T>(
place_, bwd_pd_->diff_src_primitive_desc().get_size());
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_primitive_desc(),
ptr, "@diff_src_mem_p");
}
std::shared_ptr<mkldnn::eltwise_forward> AcquireActivation(
std::shared_ptr<mkldnn::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> src_memory_p) {
/*Generate key*/
auto prim_key = key_ + "@eltwise_p";
auto prim_key = this->key_ + "@eltwise_p";
auto eltwise_p = std::static_pointer_cast<mkldnn::eltwise_forward>(
dev_ctx_.GetBlob(prim_key));
this->dev_ctx_.GetBlob(prim_key));
if (eltwise_p == nullptr) {
eltwise_p = std::make_shared<mkldnn::eltwise_forward>(
*fwd_pd_, *(src_memory_p), *(dst_memory_p));
dev_ctx_.SetBlob(prim_key, eltwise_p);
*this->fwd_pd_, *(src_memory_p), *(dst_memory_p));
this->dev_ctx_.SetBlob(prim_key, eltwise_p);
}
return eltwise_p;
......@@ -364,15 +410,15 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
std::shared_ptr<mkldnn::memory> src_memory_p) {
/*Generate key*/
auto prim_key = key_ + "@eltwise_bwd_p";
auto prim_key = this->key_ + "@eltwise_bwd_p";
auto eltwise_bwd_p = std::static_pointer_cast<mkldnn::eltwise_backward>(
dev_ctx_.GetBlob(prim_key));
this->dev_ctx_.GetBlob(prim_key));
if (eltwise_bwd_p == nullptr) {
eltwise_bwd_p = std::make_shared<mkldnn::eltwise_backward>(
*bwd_pd_, *(src_memory_p), *(diff_dst_memory_p),
*this->bwd_pd_, *(src_memory_p), *(diff_dst_memory_p),
*(diff_src_memory_p));
dev_ctx_.SetBlob(prim_key, eltwise_bwd_p);
this->dev_ctx_.SetBlob(prim_key, eltwise_bwd_p);
}
return eltwise_bwd_p;
......@@ -387,26 +433,27 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
// Activation PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_activation_pd = key_common_ + "@activation_pd";
fwd_pd_ = std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
dev_ctx_.GetBlob(key_activation_pd));
if (fwd_pd_ == nullptr) {
const std::string key_activation_pd = this->key_common_ + "@activation_pd";
this->fwd_pd_ =
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
this->dev_ctx_.GetBlob(key_activation_pd));
if (this->fwd_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
fwd_pd_ =
this->fwd_pd_ =
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
dev_ctx_.GetBlob(key_activation_pd));
if (fwd_pd_ == nullptr) {
this->dev_ctx_.GetBlob(key_activation_pd));
if (this->fwd_pd_ == nullptr) {
auto md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), fmt);
auto activation_desc = mkldnn::eltwise_forward::desc(
prop_kind, algorithm, md, alpha, beta);
fwd_pd_.reset(new mkldnn::eltwise_forward::primitive_desc(
activation_desc, engine_));
dev_ctx_.SetBlob(key_activation_pd, fwd_pd_);
this->fwd_pd_.reset(new mkldnn::eltwise_forward::primitive_desc(
activation_desc, this->engine_));
this->dev_ctx_.SetBlob(key_activation_pd, this->fwd_pd_);
}
}
}
......@@ -415,17 +462,18 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
mkldnn::algorithm algorithm, const std::vector<int>& dims,
const MKLDNNMemoryFormat fmt, const MKLDNNMemoryFormat diff_fmt,
float alpha, float beta) {
const std::string key_activation_pd = key_common_ + "@activation_pd";
const std::string key_activation_bwd_pd = key_ + "@activation_bwd_pd";
bwd_pd_ =
const std::string key_activation_pd = this->key_common_ + "@activation_pd";
const std::string key_activation_bwd_pd = this->key_ + "@activation_bwd_pd";
this->bwd_pd_ =
std::static_pointer_cast<mkldnn::eltwise_backward::primitive_desc>(
dev_ctx_.GetBlob(key_activation_bwd_pd));
if (bwd_pd_ == nullptr) {
fwd_pd_ =
this->dev_ctx_.GetBlob(key_activation_bwd_pd));
if (this->bwd_pd_ == nullptr) {
this->fwd_pd_ =
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
dev_ctx_.GetBlob(key_activation_pd));
this->dev_ctx_.GetBlob(key_activation_pd));
// PD from FWD op has to exist.
PADDLE_ENFORCE_NOT_NULL(fwd_pd_, "Eltwise MKL-DNN not found in cache!");
PADDLE_ENFORCE_NOT_NULL(this->fwd_pd_,
"Eltwise MKL-DNN not found in cache!");
auto diff_dst_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
......@@ -434,16 +482,11 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
auto backward_desc = mkldnn::eltwise_backward::desc(
algorithm, diff_dst_md, src_md, alpha, beta);
bwd_pd_.reset(new mkldnn::eltwise_backward::primitive_desc(
backward_desc, engine_, *fwd_pd_));
dev_ctx_.SetBlob(key_activation_bwd_pd, bwd_pd_);
this->bwd_pd_.reset(new mkldnn::eltwise_backward::primitive_desc(
backward_desc, this->engine_, *this->fwd_pd_));
this->dev_ctx_.SetBlob(key_activation_bwd_pd, this->bwd_pd_);
}
}
private:
platform::Place place_;
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> fwd_pd_;
std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> bwd_pd_;
};
class LRNMKLDNNHandler : public MKLDNNHandler {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册