提交 c7e68892 编写于 作者: A Adam 提交者: Tao Luo

Add template functions for Acquire primitive/primitive_desc (#19867)

* Add template functions for Acquire primitive/primitive_desc
test=develop

* Move acquire primitive descriptor to protected section
test=develop
上级 fe18cfdb
......@@ -92,7 +92,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(y);
auto activation_p = handler.AcquireActivation(dst_memory_p, src_memory_p);
auto activation_p =
handler.AcquireForwardPrimitive(*src_memory_p, *dst_memory_p);
// push primitive to stream and wait until it's executed
std::vector<primitive> pipeline;
......@@ -131,8 +132,8 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(diff_x);
auto activation_backward_p = handler.AcquireActivationBackward(
diff_src_memory_p, diff_dst_memory_p, src_memory_p);
auto activation_backward_p = handler.AcquireBackwardPrimitive(
*src_memory_p, *diff_dst_memory_p, *diff_src_memory_p);
// push primitive to stream and wait until it's executed
std::vector<primitive> pipeline;
......
......@@ -45,7 +45,10 @@ class SoftmaxMKLDNNHandler
mkldnn::softmax_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, uniq_name)) {
this->AcquireSoftmaxPrimitiveDescriptor(dims, fmt);
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md,
1 /*dim: C*/);
}
SoftmaxMKLDNNHandler(const std::vector<int>& dims,
......@@ -57,100 +60,15 @@ class SoftmaxMKLDNNHandler
mkldnn::softmax_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, uniq_name)) {
// If we are in Grad operatgor then update a key with BWD suffix to
// distinguish from FWD memory primitives
// Key_common will allow to access FWD_PD from cache
this->AcquireSoftmaxPrimitiveDescriptor(dims, fmt);
this->AcquireSoftmaxBackwardPrimitiveDescriptor(dims, fmt, diff_fmt);
}
std::shared_ptr<mkldnn::softmax_forward> AcquireSoftmax(
std::shared_ptr<mkldnn::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> src_memory_p) {
/*Generate key*/
auto prim_key = this->key_ + "@softmax_p";
auto softmax_p = std::static_pointer_cast<mkldnn::softmax_forward>(
this->dev_ctx_.GetBlob(prim_key));
if (softmax_p == nullptr) {
softmax_p = std::make_shared<mkldnn::softmax_forward>(
*this->fwd_pd_, *(static_cast<mkldnn::memory*>(src_memory_p.get())),
*(static_cast<mkldnn::memory*>(dst_memory_p.get())));
this->dev_ctx_.SetBlob(prim_key, softmax_p);
}
return softmax_p;
}
std::shared_ptr<mkldnn::softmax_backward> AcquireSoftmaxBackward(
std::shared_ptr<mkldnn::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
std::shared_ptr<mkldnn::memory> diff_src_memory_p) {
auto prim_key = this->key_ + "@softmax_bwd_p";
auto softmax_bwd_p = std::static_pointer_cast<mkldnn::softmax_backward>(
this->dev_ctx_.GetBlob(prim_key));
if (softmax_bwd_p == nullptr) {
softmax_bwd_p = std::make_shared<mkldnn::softmax_backward>(
*this->bwd_pd_, *dst_memory_p, *diff_dst_memory_p,
*diff_src_memory_p);
this->dev_ctx_.SetBlob(prim_key, softmax_bwd_p);
}
return softmax_bwd_p;
}
protected:
void AcquireSoftmaxPrimitiveDescriptor(const std::vector<int>& dims,
const mkldnn::memory::format fmt) {
// Softmax PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_softmax_pd = this->key_common_ + "@softmax_pd";
this->fwd_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
this->dev_ctx_.GetBlob(key_softmax_pd));
if (this->fwd_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
this->fwd_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
this->dev_ctx_.GetBlob(key_softmax_pd));
if (this->fwd_pd_ == nullptr) {
// TODO(jczaja): Make it working along chosen axis and for
// forward_training
// Normalization is made after innermost dimension eg. C out of NC
auto md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto softmax_desc =
softmax_forward::desc(prop_kind::forward_scoring, md, 1 /*dim: C*/);
this->fwd_pd_.reset(
new softmax_forward::primitive_desc(softmax_desc, this->engine_));
this->dev_ctx_.SetBlob(key_softmax_pd, this->fwd_pd_);
}
}
}
void AcquireSoftmaxBackwardPrimitiveDescriptor(
const std::vector<int>& dims, const mkldnn::memory::format fmt,
const mkldnn::memory::format diff_fmt) {
// Fwd_PD_ has to exists when to create BWD_PD_
PADDLE_ENFORCE_NOT_NULL(this->fwd_pd_);
const std::string key_bwd_pd = this->key_ + "@softmax_bwd_pd";
this->bwd_pd_ =
std::static_pointer_cast<mkldnn::softmax_backward::primitive_desc>(
this->dev_ctx_.GetBlob(key_bwd_pd));
if (this->bwd_pd_ == nullptr) {
auto data_softmax_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto diff_softmax_md = mkldnn::memory::desc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
// TODO(jczaja): Add support for other axes
auto backward_desc = softmax_backward::desc(
diff_softmax_md, data_softmax_md, 1 /* dim: C*/);
this->bwd_pd_.reset(new mkldnn::softmax_backward::primitive_desc(
backward_desc, this->engine_, *this->fwd_pd_));
this->dev_ctx_.SetBlob(key_bwd_pd, this->bwd_pd_);
}
auto data_softmax_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto diff_softmax_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring,
data_softmax_md, 1 /*dim: C*/);
this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md,
1 /* dim: C*/);
}
};
......@@ -181,11 +99,10 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
// Currently only NC data format is supported
auto softmax_src_memory_p = handler.AcquireSrcMemory(input);
auto softmax_dst_memory_p = handler.AcquireDstMemory(output);
auto softmax_p =
handler.AcquireSoftmax(softmax_dst_memory_p, softmax_src_memory_p);
auto softmax_p = handler.AcquireForwardPrimitive(*softmax_src_memory_p,
*softmax_dst_memory_p);
std::vector<primitive> pipeline{
*(static_cast<softmax_forward::primitive*>(softmax_p.get()))};
std::vector<primitive> pipeline{*softmax_p};
stream(stream::kind::eager).submit(pipeline).wait();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
......@@ -242,8 +159,8 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
// Get primitve from device context
auto softmax_bwd_p = handler.AcquireSoftmaxBackward(
dst_memory_p, diff_dst_memory_p, diff_src_memory_p);
auto softmax_bwd_p = handler.AcquireBackwardPrimitive(
*dst_memory_p, *diff_dst_memory_p, *diff_src_memory_p);
std::vector<primitive> pipeline{*softmax_bwd_p};
stream(stream::kind::eager).submit(pipeline).wait();
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "boost/optional.hpp"
#include "paddle/fluid/framework/data_layout_transform.h"
......@@ -48,6 +49,32 @@ class MKLDNNHandlerT {
}
}
template <typename... Args>
std::shared_ptr<TForward> AcquireForwardPrimitive(Args&&... args) {
const std::string key_p = key_ + "@forward_p";
auto forward_p =
std::static_pointer_cast<TForward>(dev_ctx_.GetBlob(key_p));
if (forward_p == nullptr) {
forward_p =
std::make_shared<TForward>(*fwd_pd_, std::forward<Args>(args)...);
dev_ctx_.SetBlob(key_p, forward_p);
}
return forward_p;
}
template <typename... Args>
std::shared_ptr<TBackward> AcquireBackwardPrimitive(Args&&... args) {
const std::string key_p = key_ + "@backward_p";
auto backward_p =
std::static_pointer_cast<TBackward>(dev_ctx_.GetBlob(key_p));
if (backward_p == nullptr) {
backward_p =
std::make_shared<TBackward>(*bwd_pd_, std::forward<Args>(args)...);
dev_ctx_.SetBlob(key_p, backward_p);
}
return backward_p;
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
......@@ -87,6 +114,44 @@ class MKLDNNHandlerT {
ptr, "@diff_src_mem_p");
}
protected:
template <typename... Args>
void AcquireForwardPrimitiveDescriptor(Args&&... args) {
// Forward PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_pd = key_common_ + "@forward_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (fwd_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (fwd_pd_ == nullptr) {
auto fwd_desc = typename TForward::desc(std::forward<Args>(args)...);
fwd_pd_ = std::make_shared<typename TForward::primitive_desc>(fwd_desc,
engine_);
dev_ctx_.SetBlob(key_pd, fwd_pd_);
}
}
}
template <typename... Args>
void AcquireBackwardPrimitiveDescriptor(Args&&... args) {
PADDLE_ENFORCE_NOT_NULL(fwd_pd_);
const std::string key_pd = key_ + "@backward_pd";
bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (bwd_pd_ == nullptr) {
auto bwd_desc = typename TBackward::desc(std::forward<Args>(args)...);
bwd_pd_ = std::make_shared<typename TBackward::primitive_desc>(
bwd_desc, engine_, *fwd_pd_);
dev_ctx_.SetBlob(key_pd, bwd_pd_);
}
}
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::primitive_desc mdp, void* ptr,
const std::string& suffix) {
......@@ -102,7 +167,6 @@ class MKLDNNHandlerT {
return mem_p;
}
protected:
const MKLDNNDeviceContext& dev_ctx_;
mkldnn::engine engine_;
platform::Place place_;
......@@ -355,10 +419,12 @@ class ActivationMKLDNNHandler
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, algorithm, fmt, alpha, beta,
unique_name)) {
AcquireActivationPrimitiveDescriptor(
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
this->AcquireForwardPrimitiveDescriptor(
is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
algorithm, dims, fmt, alpha, beta);
algorithm, md, alpha, beta);
}
ActivationMKLDNNHandler(const std::vector<int>& dims,
......@@ -374,10 +440,15 @@ class ActivationMKLDNNHandler
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, algorithm, fmt, alpha, beta,
unique_name)) {
AcquireActivationPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
algorithm, dims, fmt, alpha, beta);
AcquireActivationBackwardPrimitiveDescriptor(algorithm, dims, fmt, diff_fmt,
alpha, beta);
auto diff_dst_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md =
platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType<T>(), fmt);
this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
algorithm, src_md, alpha, beta);
this->AcquireBackwardPrimitiveDescriptor(algorithm, diff_dst_md, src_md,
alpha, beta);
}
std::shared_ptr<mkldnn::memory> AcquireBackwardSrcMemory(
......@@ -387,106 +458,6 @@ class ActivationMKLDNNHandler
to_void_cast<T>(input_data),
"@bwd-src_mem_p");
}
std::shared_ptr<mkldnn::eltwise_forward> AcquireActivation(
std::shared_ptr<mkldnn::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> src_memory_p) {
/*Generate key*/
auto prim_key = this->key_ + "@eltwise_p";
auto eltwise_p = std::static_pointer_cast<mkldnn::eltwise_forward>(
this->dev_ctx_.GetBlob(prim_key));
if (eltwise_p == nullptr) {
eltwise_p = std::make_shared<mkldnn::eltwise_forward>(
*this->fwd_pd_, *(src_memory_p), *(dst_memory_p));
this->dev_ctx_.SetBlob(prim_key, eltwise_p);
}
return eltwise_p;
}
std::shared_ptr<mkldnn::eltwise_backward> AcquireActivationBackward(
std::shared_ptr<mkldnn::memory> diff_src_memory_p,
std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
std::shared_ptr<mkldnn::memory> src_memory_p) {
/*Generate key*/
auto prim_key = this->key_ + "@eltwise_bwd_p";
auto eltwise_bwd_p = std::static_pointer_cast<mkldnn::eltwise_backward>(
this->dev_ctx_.GetBlob(prim_key));
if (eltwise_bwd_p == nullptr) {
eltwise_bwd_p = std::make_shared<mkldnn::eltwise_backward>(
*this->bwd_pd_, *(src_memory_p), *(diff_dst_memory_p),
*(diff_src_memory_p));
this->dev_ctx_.SetBlob(prim_key, eltwise_bwd_p);
}
return eltwise_bwd_p;
}
protected:
void AcquireActivationPrimitiveDescriptor(mkldnn::prop_kind prop_kind,
mkldnn::algorithm algorithm,
const std::vector<int>& dims,
const MKLDNNMemoryFormat fmt,
float alpha, float beta) {
// Activation PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_activation_pd = this->key_common_ + "@activation_pd";
this->fwd_pd_ =
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
this->dev_ctx_.GetBlob(key_activation_pd));
if (this->fwd_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
this->fwd_pd_ =
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
this->dev_ctx_.GetBlob(key_activation_pd));
if (this->fwd_pd_ == nullptr) {
auto md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), fmt);
auto activation_desc = mkldnn::eltwise_forward::desc(
prop_kind, algorithm, md, alpha, beta);
this->fwd_pd_.reset(new mkldnn::eltwise_forward::primitive_desc(
activation_desc, this->engine_));
this->dev_ctx_.SetBlob(key_activation_pd, this->fwd_pd_);
}
}
}
void AcquireActivationBackwardPrimitiveDescriptor(
mkldnn::algorithm algorithm, const std::vector<int>& dims,
const MKLDNNMemoryFormat fmt, const MKLDNNMemoryFormat diff_fmt,
float alpha, float beta) {
const std::string key_activation_pd = this->key_common_ + "@activation_pd";
const std::string key_activation_bwd_pd = this->key_ + "@activation_bwd_pd";
this->bwd_pd_ =
std::static_pointer_cast<mkldnn::eltwise_backward::primitive_desc>(
this->dev_ctx_.GetBlob(key_activation_bwd_pd));
if (this->bwd_pd_ == nullptr) {
this->fwd_pd_ =
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
this->dev_ctx_.GetBlob(key_activation_pd));
// PD from FWD op has to exist.
PADDLE_ENFORCE_NOT_NULL(this->fwd_pd_,
"Eltwise MKL-DNN not found in cache!");
auto diff_dst_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md =
platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto backward_desc = mkldnn::eltwise_backward::desc(
algorithm, diff_dst_md, src_md, alpha, beta);
this->bwd_pd_.reset(new mkldnn::eltwise_backward::primitive_desc(
backward_desc, this->engine_, *this->fwd_pd_));
this->dev_ctx_.SetBlob(key_activation_bwd_pd, this->bwd_pd_);
}
}
};
class LRNMKLDNNHandler : public MKLDNNHandler {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册