提交 8869d7f7 编写于 作者: J Jacek Czaja 提交者: Tao Luo

Activations MKLDNN ops refactoring (#18191)

上级 b6d5c74f
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -99,20 +99,21 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -99,20 +99,21 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
auto src_format = auto src_format =
src_tz.size() == 2 ? mkldnn::memory::format::nc : x->format(); src_tz.size() == 2 ? mkldnn::memory::format::nc : x->format();
const std::string key = gethash(src_tz, algorithm);
const std::string key_src_data =
key + ctx.op().Output("Out") + "@eltwise_fwd_src_data";
const std::string key_src_layout =
key + ctx.op().Output("Out") + "@eltwise_fwd_src_layout";
const std::string key_with_layout = key + std::to_string(src_format);
const std::string key_src_mem = key_with_layout + "@eltwise_fwd_src_mem";
const std::string key_dst_mem = key_with_layout + "@eltwise_fwd_dst_mem";
const std::string key_fwd = key_with_layout + "@eltwise_fwd";
const std::string key_fwd_pd = key_with_layout + "@eltwise_fwd_pd";
bool is_test = ctx.Attr<bool>("is_test"); bool is_test = ctx.Attr<bool>("is_test");
// TODO(jczaja): When adding leaky-relu , swish , elu make sure to extend key
// with alpha, beta
std::string key = platform::MKLDNNHandler::GetHash(
src_tz, std::to_string(algorithm) + ctx.op().Output("Out"));
// TODO(jczaja): Make it Thread safe
// save input data and layout to be referred in backward path // save input data and layout to be referred in backward path
const std::string key_src_data = key + "@eltwise_fwd_src_data";
const std::string key_src_layout = key + "@eltwise_fwd_src_layout";
// Just in case some int8 models are run interchangebly
// with float models then format maybe diffrent
key += std::to_string(src_format);
const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
auto p_src_data = std::make_shared<const T *>(x_data); auto p_src_data = std::make_shared<const T *>(x_data);
auto p_src_layout = std::make_shared<memory::format>(src_format); auto p_src_layout = std::make_shared<memory::format>(src_format);
if (!is_test) { if (!is_test) {
...@@ -120,65 +121,34 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -120,65 +121,34 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
dev_ctx.SetBlob(key_src_layout, p_src_layout); dev_ctx.SetBlob(key_src_layout, p_src_layout);
} }
auto p_fwd = std::static_pointer_cast<mkldnn::eltwise_forward>( platform::ActivationMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);
dev_ctx.GetBlob(key_fwd));
auto md = platform::MKLDNNMemDesc(src_tz, platform::MKLDNNGetDataType<T>(),
std::shared_ptr<memory> dst_memory; src_format);
if (p_fwd == nullptr) { auto activation_pd = handler.AcquireActivationPrimitiveDescriptor(
// create mkldnn memory for input X is_test ? mkldnn::prop_kind::forward_inference
auto src_md = platform::MKLDNNMemDesc( : mkldnn::prop_kind::forward_training,
src_tz, platform::MKLDNNGetDataType<T>(), src_format); algorithm, md, alpha, beta);
auto src_memory = std::shared_ptr<memory>(
new memory({src_md, mkldnn_engine}, to_void_cast(x_data))); auto src_memory_p = handler.AcquireSrcMemory(md, to_void_cast<T>(x_data));
// save src_memory to be referred in backward path // jczaja: Workaround, src_memory_p is needed in BWD so it has
dev_ctx.SetBlob(key_src_mem, src_memory); // to be accessible under key not dependant on TID
if (!is_test) {
// create primitive descriptor for activation forward and save it dev_ctx.SetBlob(key_src_mem, src_memory_p);
auto mkldnn_forward_prop_kind = is_test
? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training;
auto forward_desc = mkldnn::eltwise_forward::desc(
mkldnn_forward_prop_kind, algorithm,
src_memory->get_primitive_desc().desc(), alpha, beta);
auto forward_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>(
forward_desc, mkldnn_engine);
// save prim desc into global device context to be referred in backward path
if (!is_test) dev_ctx.SetBlob(key_fwd_pd, forward_pd);
// create mkldnn memory for output y
dst_memory =
std::make_shared<memory>(forward_pd->dst_primitive_desc(), y_data);
dev_ctx.SetBlob(key_dst_mem, dst_memory);
// create activation primitive
p_fwd = std::make_shared<mkldnn::eltwise_forward>(*forward_pd, *src_memory,
*dst_memory);
dev_ctx.SetBlob(key_fwd, p_fwd);
} else {
// primitives already exist
auto src_memory =
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem));
PADDLE_ENFORCE(src_memory != nullptr,
"Fail to find eltwise src_memory in device context.");
dst_memory =
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_dst_mem));
PADDLE_ENFORCE(dst_memory != nullptr,
"Fail to find eltwise dst_memory in device context.");
src_memory->set_data_handle(platform::to_void_cast(x_data));
dst_memory->set_data_handle(y_data);
} }
auto dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(y_data));
auto activation_p = handler.AcquireActivation(dst_memory_p, src_memory_p);
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
pipeline.push_back(*p_fwd); pipeline.push_back(*activation_p);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
y->set_layout(DataLayout::kMKLDNN); y->set_layout(DataLayout::kMKLDNN);
y->set_format(GetMKLDNNFormat(*dst_memory)); y->set_format(GetMKLDNNFormat(*dst_memory_p));
} }
template <typename T> template <typename T>
...@@ -199,90 +169,51 @@ void eltwise_grad(const framework::ExecutionContext &ctx, ...@@ -199,90 +169,51 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
auto diff_y_format = auto diff_y_format =
diff_dst_tz.size() == 2 ? mkldnn::memory::format::nc : diff_y->format(); diff_dst_tz.size() == 2 ? mkldnn::memory::format::nc : diff_y->format();
const std::string key = gethash(diff_dst_tz, algorithm); auto diff_dst_md = platform::MKLDNNMemDesc(
const std::string key_src_data = diff_dst_tz, platform::MKLDNNGetDataType<T>(), diff_y_format);
key + ctx.op().Input("Out") + "@eltwise_fwd_src_data";
const std::string key_src_layout = std::string key = platform::MKLDNNHandler::GetHash(
key + ctx.op().Input("Out") + "@eltwise_fwd_src_layout"; diff_dst_tz, std::to_string(algorithm) + ctx.op().Input("Out"));
const std::string key_src_data = key + "@eltwise_fwd_src_data";
const std::string key_src_layout = key + "@eltwise_fwd_src_layout";
// Get Data from FWD op
const auto p_src_layout = const auto p_src_layout =
std::static_pointer_cast<memory::format>(dev_ctx.GetBlob(key_src_layout)); std::static_pointer_cast<memory::format>(dev_ctx.GetBlob(key_src_layout));
const std::string key_src_mem =
key + std::to_string(*p_src_layout) + "@eltwise_fwd_src_mem";
const std::string key_fwd_pd =
key + std::to_string(*p_src_layout) + "@eltwise_fwd_pd";
const std::string key_with_layouts =
key + std::to_string(*p_src_layout) + "-" + std::to_string(diff_y_format);
const std::string key_diff_src_mem =
key_with_layouts + "@eltwise_diff_src_mem";
const std::string key_diff_dst_mem =
key_with_layouts + "@eltwise_diff_dst_mem";
const std::string key_grad = key_with_layouts + "@eltwise_grad";
const auto p_src_data = const auto p_src_data =
std::static_pointer_cast<T *>(dev_ctx.GetBlob(key_src_data)); std::static_pointer_cast<T *>(dev_ctx.GetBlob(key_src_data));
key += std::to_string(*p_src_layout);
const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
auto src_memory = auto src_memory =
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem)); std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem));
PADDLE_ENFORCE(src_memory != nullptr, PADDLE_ENFORCE(src_memory != nullptr,
"Fail to find src_memory in device context"); "Fail to find src_memory in device context");
src_memory->set_data_handle(*p_src_data); src_memory->set_data_handle(*p_src_data);
std::shared_ptr<memory> diff_src_memory; platform::ActivationMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);
auto p_grad = std::static_pointer_cast<mkldnn::eltwise_backward>( auto diff_dst_memory_p =
dev_ctx.GetBlob(key_grad)); handler.AcquireDiffDstMemory(diff_dst_md, to_void_cast<T>(diff_y_data));
if (p_grad == nullptr) { auto activation_backward_pd =
// create mkldnn memory for input diff_y handler.AcquireActivationBackwardPrimitiveDescriptor(
auto diff_dst_md = platform::MKLDNNMemDesc( algorithm, diff_dst_md, src_memory->get_primitive_desc().desc(),
diff_dst_tz, platform::MKLDNNGetDataType<T>(), diff_y_format); alpha, beta);
auto diff_dst_memory = std::shared_ptr<memory>(
new memory({diff_dst_md, mkldnn_engine}, to_void_cast(diff_y_data))); auto diff_src_memory_p =
dev_ctx.SetBlob(key_diff_dst_mem, diff_dst_memory); handler.AcquireDiffSrcMemoryFromPrimitive(diff_x_data);
// retrieve eltwise primitive desc from device context auto activation_backward_p = handler.AcquireActivationBackward(
auto forward_pd = diff_src_memory_p, diff_dst_memory_p, src_memory);
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
dev_ctx.GetBlob(key_fwd_pd));
PADDLE_ENFORCE(forward_pd != nullptr,
"Fail to find eltwise_fwd_pd in device context");
// ceate primitive descriptor for activation backward
auto backward_desc = mkldnn::eltwise_backward::desc(
algorithm, diff_dst_memory->get_primitive_desc().desc(),
src_memory->get_primitive_desc().desc(), alpha, beta);
auto backward_pd = mkldnn::eltwise_backward::primitive_desc(
backward_desc, mkldnn_engine, *forward_pd);
// create mkldnn memory for output diff_src
diff_src_memory = std::make_shared<memory>(
backward_pd.diff_src_primitive_desc(), diff_x_data);
dev_ctx.SetBlob(key_diff_src_mem, diff_src_memory);
// create activation backward primitive
p_grad = std::make_shared<mkldnn::eltwise_backward>(
backward_pd, *src_memory, *diff_dst_memory, *diff_src_memory);
dev_ctx.SetBlob(key_grad, p_grad);
} else {
// primitives already exist
diff_src_memory = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_diff_src_mem));
auto diff_dst_memory = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_diff_dst_mem));
diff_src_memory->set_data_handle(
platform::to_void_reinterpret_cast(diff_x_data));
diff_dst_memory->set_data_handle(
platform::to_void_reinterpret_cast(diff_y_data));
}
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
pipeline.push_back(*p_grad); pipeline.push_back(*activation_backward_p);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
diff_x->set_layout(DataLayout::kMKLDNN); diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format(GetMKLDNNFormat(*diff_src_memory)); diff_x->set_format(GetMKLDNNFormat(*diff_src_memory_p));
} }
template <typename T, mkldnn::algorithm algorithm> template <typename T, mkldnn::algorithm algorithm>
......
...@@ -309,6 +309,121 @@ class SumMKLDNNHandler : public MKLDNNHandler { ...@@ -309,6 +309,121 @@ class SumMKLDNNHandler : public MKLDNNHandler {
std::shared_ptr<mkldnn::sum::primitive_desc> sum_pd_; std::shared_ptr<mkldnn::sum::primitive_desc> sum_pd_;
}; };
class ActivationMKLDNNHandler : public MKLDNNHandler {
public:
ActivationMKLDNNHandler(const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key) {}
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc>
AcquireActivationPrimitiveDescriptor(mkldnn::prop_kind prop_kind,
mkldnn::algorithm algorithm,
const mkldnn::memory::desc& md,
float alpha, float beta) {
// Activation PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_activation_pd = key_common_ + "@activation_pd";
activation_pd_ =
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
dev_ctx_.GetBlob(key_activation_pd));
if (activation_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
activation_pd_ =
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
dev_ctx_.GetBlob(key_activation_pd));
if (activation_pd_ == nullptr) {
auto activation_desc = mkldnn::eltwise_forward::desc(
prop_kind, algorithm, md, alpha, beta);
activation_pd_.reset(new mkldnn::eltwise_forward::primitive_desc(
activation_desc, engine_));
dev_ctx_.SetBlob(key_activation_pd, activation_pd_);
}
}
return activation_pd_;
}
std::shared_ptr<mkldnn::eltwise_backward::primitive_desc>
AcquireActivationBackwardPrimitiveDescriptor(
mkldnn::algorithm algorithm, const mkldnn::memory::desc& diff_dst_md,
const mkldnn::memory::desc& src_md, float alpha, float beta) {
const std::string key_activation_pd = key_common_ + "@activation_pd";
const std::string key_activation_bwd_pd = key_ + "@activation_bwd_pd";
activation_bwd_pd_ =
std::static_pointer_cast<mkldnn::eltwise_backward::primitive_desc>(
dev_ctx_.GetBlob(key_activation_bwd_pd));
if (activation_bwd_pd_ == nullptr) {
activation_pd_ =
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
dev_ctx_.GetBlob(key_activation_pd));
// PD from FWD op has to exist.
PADDLE_ENFORCE(activation_pd_ != nullptr,
"Eltwise MKL-DNN not found in cache!");
auto backward_desc = mkldnn::eltwise_backward::desc(
algorithm, diff_dst_md, src_md, alpha, beta);
activation_bwd_pd_.reset(new mkldnn::eltwise_backward::primitive_desc(
backward_desc, engine_, *activation_pd_));
dev_ctx_.SetBlob(key_activation_bwd_pd, activation_bwd_pd_);
}
return activation_bwd_pd_;
}
std::shared_ptr<mkldnn::eltwise_forward> AcquireActivation(
std::shared_ptr<mkldnn::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> src_memory_p) {
/*Generate key*/
auto prim_key = key_ + "@eltwise_p";
auto eltwise_p = std::static_pointer_cast<mkldnn::eltwise_forward>(
dev_ctx_.GetBlob(prim_key));
if (eltwise_p == nullptr) {
eltwise_p = std::make_shared<mkldnn::eltwise_forward>(
*activation_pd_, *(src_memory_p), *(dst_memory_p));
dev_ctx_.SetBlob(prim_key, eltwise_p);
}
return eltwise_p;
}
// TODO(jczaja): Merge all AcquireDstMemoryFromPrimitive into one
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(void* ptr) {
return this->AcquireMemoryFromPrimitive(
activation_pd_->dst_primitive_desc(), ptr, "@dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemoryFromPrimitive(void* ptr) {
return this->AcquireMemoryFromPrimitive(
activation_bwd_pd_->diff_src_primitive_desc(), ptr, "@diff_src_mem_p");
}
std::shared_ptr<mkldnn::eltwise_backward> AcquireActivationBackward(
std::shared_ptr<mkldnn::memory> diff_src_memory_p,
std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
std::shared_ptr<mkldnn::memory> src_memory_p) {
/*Generate key*/
auto prim_key = key_ + "@eltwise_bwd_p";
auto eltwise_bwd_p = std::static_pointer_cast<mkldnn::eltwise_backward>(
dev_ctx_.GetBlob(prim_key));
if (eltwise_bwd_p == nullptr) {
eltwise_bwd_p = std::make_shared<mkldnn::eltwise_backward>(
*activation_bwd_pd_, *(src_memory_p), *(diff_dst_memory_p),
*(diff_src_memory_p));
dev_ctx_.SetBlob(prim_key, eltwise_bwd_p);
}
return eltwise_bwd_p;
}
private:
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> activation_pd_;
std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> activation_bwd_pd_;
};
class TransposeMKLDNNHandler : public MKLDNNHandler { class TransposeMKLDNNHandler : public MKLDNNHandler {
public: public:
TransposeMKLDNNHandler(std::vector<int>& dims, // NOLINT TransposeMKLDNNHandler(std::vector<int>& dims, // NOLINT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册