提交 9e4c9585 编写于 作者: J Jacek Czaja 提交者: Tao Luo

Refactoring activation mkldnn op (#19748)

test=develop

- fix to BWD

test=develop
上级 12542320
...@@ -83,13 +83,10 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -83,13 +83,10 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace."); "It must use CPUPlace.");
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Out"); auto *y = ctx.Output<Tensor>("Out");
const T *x_data = x->data<T>();
const T alpha = ctx.op().HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0; const T alpha = ctx.op().HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
const T beta = ctx.op().HasAttr("beta") ? ctx.Attr<T>("beta") : 0; const T beta = ctx.op().HasAttr("beta") ? ctx.Attr<T>("beta") : 0;
...@@ -103,23 +100,12 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -103,23 +100,12 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
bool is_test = ctx.Attr<bool>("is_test"); bool is_test = ctx.Attr<bool>("is_test");
std::string key = platform::ActivationMKLDNNHandler::GetHash( platform::ActivationMKLDNNHandler<T> handler(
src_tz, algorithm, src_format, alpha, beta, ctx.op().Input("X")); src_tz, algorithm, alpha, beta, src_format, is_test, dev_ctx,
ctx.GetPlace(), ctx.op().Input("X"));
platform::ActivationMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);
auto md = platform::MKLDNNMemDesc(src_tz, platform::MKLDNNGetDataType<T>(),
src_format);
auto activation_pd = handler.AcquireActivationPrimitiveDescriptor(
is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
algorithm, md, alpha, beta);
auto src_memory_p = handler.AcquireSrcMemory(md, to_void_cast<T>(x_data)); auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(y);
auto dst_memory_p =
handler.AcquireDstMemoryFromPrimitive<T>(y, ctx.GetPlace());
auto activation_p = handler.AcquireActivation(dst_memory_p, src_memory_p); auto activation_p = handler.AcquireActivation(dst_memory_p, src_memory_p);
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
...@@ -135,17 +121,11 @@ template <typename T> ...@@ -135,17 +121,11 @@ template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx, void eltwise_grad(const framework::ExecutionContext &ctx,
mkldnn::algorithm algorithm) { mkldnn::algorithm algorithm) {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
const T *x_data = x->data<T>();
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out")); const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X")); auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
const T *diff_y_data = diff_y->data<T>();
T *diff_x_data = diff_x->mutable_data<T>(ctx.GetPlace());
const T alpha = ctx.op().HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0; const T alpha = ctx.op().HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
const T beta = ctx.op().HasAttr("beta") ? ctx.Attr<T>("beta") : 0; const T beta = ctx.op().HasAttr("beta") ? ctx.Attr<T>("beta") : 0;
...@@ -158,32 +138,13 @@ void eltwise_grad(const framework::ExecutionContext &ctx, ...@@ -158,32 +138,13 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
auto diff_y_format = auto diff_y_format =
diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : diff_y->format(); diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : diff_y->format();
auto diff_dst_md = platform::MKLDNNMemDesc( platform::ActivationMKLDNNHandler<T> handler(
diff_dst_tz, platform::MKLDNNGetDataType<T>(), diff_y_format); diff_dst_tz, algorithm, alpha, beta, src_format, diff_y_format, dev_ctx,
ctx.GetPlace(), ctx.op().Input("X"));
std::string key = platform::ActivationMKLDNNHandler::GetHash(
diff_dst_tz, algorithm, src_format, alpha, beta, ctx.op().Input("X"));
const std::string key_src_data = key + "@eltwise_fwd_src_data";
auto src_md = platform::MKLDNNMemDesc(
diff_dst_tz, platform::MKLDNNGetDataType<T>(), src_format);
platform::ActivationMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);
auto src_memory_p = handler.AcquireSrcMemory(src_md, to_void_cast<T>(x_data));
auto diff_dst_memory_p =
handler.AcquireDiffDstMemory(diff_dst_md, to_void_cast<T>(diff_y_data));
auto activation_backward_pd =
handler.AcquireActivationBackwardPrimitiveDescriptor(
algorithm, diff_dst_md, src_memory_p->get_primitive_desc().desc(),
alpha, beta);
auto diff_src_memory_p =
handler.AcquireDiffSrcMemoryFromPrimitive(diff_x_data);
auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(diff_x);
auto activation_backward_p = handler.AcquireActivationBackward( auto activation_backward_p = handler.AcquireActivationBackward(
diff_src_memory_p, diff_dst_memory_p, src_memory_p); diff_src_memory_p, diff_dst_memory_p, src_memory_p);
......
...@@ -257,65 +257,94 @@ class SumMKLDNNHandler : public MKLDNNHandler { ...@@ -257,65 +257,94 @@ class SumMKLDNNHandler : public MKLDNNHandler {
std::shared_ptr<mkldnn::sum::primitive_desc> sum_pd_; std::shared_ptr<mkldnn::sum::primitive_desc> sum_pd_;
}; };
template <typename T>
class ActivationMKLDNNHandler : public MKLDNNHandler { class ActivationMKLDNNHandler : public MKLDNNHandler {
public: public:
ActivationMKLDNNHandler(const platform::MKLDNNDeviceContext& dev_ctx, ActivationMKLDNNHandler(const std::vector<int>& dims,
mkldnn::engine engine, const std::string& base_key) mkldnn::algorithm algorithm, float alpha, float beta,
: platform::MKLDNNHandler(dev_ctx, engine, base_key) {} const MKLDNNMemoryFormat fmt, bool is_test,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place,
const std::string& unique_name)
: platform::MKLDNNHandler(
dev_ctx, dev_ctx.GetEngine(),
platform::ActivationMKLDNNHandler<T>::GetHash(
dims, algorithm, fmt, alpha, beta, unique_name)),
place_(cpu_place),
fwd_pd_(nullptr),
bwd_pd_(nullptr) {
AcquireActivationPrimitiveDescriptor(
is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
algorithm, dims, fmt, alpha, beta);
}
ActivationMKLDNNHandler(const std::vector<int>& dims,
mkldnn::algorithm algorithm, float alpha, float beta,
const MKLDNNMemoryFormat fmt,
const MKLDNNMemoryFormat diff_fmt,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place,
const std::string& unique_name)
: platform::MKLDNNHandler(
dev_ctx, dev_ctx.GetEngine(),
platform::ActivationMKLDNNHandler<T>::GetHash(
dims, algorithm, fmt, alpha, beta, unique_name)),
place_(cpu_place),
fwd_pd_(nullptr),
bwd_pd_(nullptr) {
AcquireActivationPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
algorithm, dims, fmt, alpha, beta);
AcquireActivationBackwardPrimitiveDescriptor(algorithm, dims, fmt, diff_fmt,
alpha, beta);
}
// TODO(jczaja): Once fwd_pd_ are moved to MKLDNNHandler then this
// function
// should be moved as well eg. ActivationMKLDNNHandler ->
// MKLDNNHandler<activation_>
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(fwd_pd_->src_primitive_desc(),
to_void_cast<T>(input_data),
"@src_mem_p");
}
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> std::shared_ptr<mkldnn::memory> AcquireBackwardSrcMemory(
AcquireActivationPrimitiveDescriptor(mkldnn::prop_kind prop_kind, const framework::Tensor* input) {
mkldnn::algorithm algorithm, const T* input_data = input->data<T>();
const mkldnn::memory::desc& md, return this->AcquireMemoryFromPrimitive(bwd_pd_->src_primitive_desc(),
float alpha, float beta) { to_void_cast<T>(input_data),
// Activation PD has to be passed to Grad op that "@bwd-src_mem_p");
// may be executed by diffrent thread, hence }
// for that one we use key that does not contain TID
const std::string key_activation_pd = key_common_ + "@activation_pd";
fwd_pd_ = std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
dev_ctx_.GetBlob(key_activation_pd));
if (fwd_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
fwd_pd_ = // TODO(jczaja): Move to MKLDNNHandler as common code
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>( std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output) {
dev_ctx_.GetBlob(key_activation_pd)); T* ptr = output->mutable_data<T>(place_,
if (fwd_pd_ == nullptr) { fwd_pd_->dst_primitive_desc().get_size());
auto activation_desc = mkldnn::eltwise_forward::desc( return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr,
prop_kind, algorithm, md, alpha, beta); "@dst_mem_p");
}
fwd_pd_.reset(new mkldnn::eltwise_forward::primitive_desc( // TODO(jczaja): Move to MKLDNNHandler as common code
activation_desc, engine_)); std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory(
dev_ctx_.SetBlob(key_activation_pd, fwd_pd_); const framework::Tensor* diffdst) {
} const T* ptr = diffdst->data<T>();
} return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_dst_primitive_desc(),
return fwd_pd_; to_void_cast<T>(ptr),
"@diff_dst_mem_p");
} }
std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> // TODO(jczaja): Move to MKLDNNHandler as common code
AcquireActivationBackwardPrimitiveDescriptor( std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory(
mkldnn::algorithm algorithm, const mkldnn::memory::desc& diff_dst_md, framework::Tensor* diffsrc) {
const mkldnn::memory::desc& src_md, float alpha, float beta) { T* ptr = diffsrc->mutable_data<T>(
const std::string key_activation_pd = key_common_ + "@activation_pd"; place_, bwd_pd_->diff_src_primitive_desc().get_size());
const std::string key_activation_bwd_pd = key_ + "@activation_bwd_pd"; return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_primitive_desc(),
bwd_pd_ = ptr, "@diff_src_mem_p");
std::static_pointer_cast<mkldnn::eltwise_backward::primitive_desc>(
dev_ctx_.GetBlob(key_activation_bwd_pd));
if (bwd_pd_ == nullptr) {
fwd_pd_ =
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
dev_ctx_.GetBlob(key_activation_pd));
// PD from FWD op has to exist.
PADDLE_ENFORCE_NOT_NULL(fwd_pd_, "Eltwise MKL-DNN not found in cache!");
auto backward_desc = mkldnn::eltwise_backward::desc(
algorithm, diff_dst_md, src_md, alpha, beta);
bwd_pd_.reset(new mkldnn::eltwise_backward::primitive_desc(
backward_desc, engine_, *fwd_pd_));
dev_ctx_.SetBlob(key_activation_bwd_pd, bwd_pd_);
}
return bwd_pd_;
} }
std::shared_ptr<mkldnn::eltwise_forward> AcquireActivation( std::shared_ptr<mkldnn::eltwise_forward> AcquireActivation(
...@@ -335,20 +364,6 @@ class ActivationMKLDNNHandler : public MKLDNNHandler { ...@@ -335,20 +364,6 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
return eltwise_p; return eltwise_p;
} }
template <typename T>
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(
framework::Tensor* output, platform::Place place) {
T* ptr = output->mutable_data<T>(place,
fwd_pd_->dst_primitive_desc().get_size());
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr,
"@dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemoryFromPrimitive(void* ptr) {
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_primitive_desc(),
ptr, "@diff_src_mem_p");
}
std::shared_ptr<mkldnn::eltwise_backward> AcquireActivationBackward( std::shared_ptr<mkldnn::eltwise_backward> AcquireActivationBackward(
std::shared_ptr<mkldnn::memory> diff_src_memory_p, std::shared_ptr<mkldnn::memory> diff_src_memory_p,
std::shared_ptr<mkldnn::memory> diff_dst_memory_p, std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
...@@ -383,7 +398,70 @@ class ActivationMKLDNNHandler : public MKLDNNHandler { ...@@ -383,7 +398,70 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
return key; return key;
} }
protected:
void AcquireActivationPrimitiveDescriptor(mkldnn::prop_kind prop_kind,
mkldnn::algorithm algorithm,
const std::vector<int>& dims,
const MKLDNNMemoryFormat fmt,
float alpha, float beta) {
// Activation PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_activation_pd = key_common_ + "@activation_pd";
fwd_pd_ = std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
dev_ctx_.GetBlob(key_activation_pd));
if (fwd_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
fwd_pd_ =
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
dev_ctx_.GetBlob(key_activation_pd));
if (fwd_pd_ == nullptr) {
auto md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), fmt);
auto activation_desc = mkldnn::eltwise_forward::desc(
prop_kind, algorithm, md, alpha, beta);
fwd_pd_.reset(new mkldnn::eltwise_forward::primitive_desc(
activation_desc, engine_));
dev_ctx_.SetBlob(key_activation_pd, fwd_pd_);
}
}
}
void AcquireActivationBackwardPrimitiveDescriptor(
mkldnn::algorithm algorithm, const std::vector<int>& dims,
const MKLDNNMemoryFormat fmt, const MKLDNNMemoryFormat diff_fmt,
float alpha, float beta) {
const std::string key_activation_pd = key_common_ + "@activation_pd";
const std::string key_activation_bwd_pd = key_ + "@activation_bwd_pd";
bwd_pd_ =
std::static_pointer_cast<mkldnn::eltwise_backward::primitive_desc>(
dev_ctx_.GetBlob(key_activation_bwd_pd));
if (bwd_pd_ == nullptr) {
fwd_pd_ =
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
dev_ctx_.GetBlob(key_activation_pd));
// PD from FWD op has to exist.
PADDLE_ENFORCE_NOT_NULL(fwd_pd_, "Eltwise MKL-DNN not found in cache!");
auto diff_dst_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md =
platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto backward_desc = mkldnn::eltwise_backward::desc(
algorithm, diff_dst_md, src_md, alpha, beta);
bwd_pd_.reset(new mkldnn::eltwise_backward::primitive_desc(
backward_desc, engine_, *fwd_pd_));
dev_ctx_.SetBlob(key_activation_bwd_pd, bwd_pd_);
}
}
private: private:
platform::Place place_;
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> fwd_pd_; std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> fwd_pd_;
std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> bwd_pd_; std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> bwd_pd_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册