提交 cef95ee3 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] Refactoring Softmax (#19312)

* - First set of modifications

- Compilation fixes

- compilation fix

- Another compilation fix

- Moved AcquireSoftmaxPrimitiveDescriptor call into handler

- MKL-DNN Softmax PD refactor

test=develop

- Compilation fix

test=develop

- another compilation fix

- cosmetcis

test=develop

- Compilation fix

- Fix to crash when softmax backward is created

* - Fixes after review of softmax refactoring

test=develop
上级 0a73f720
......@@ -32,49 +32,58 @@ using mkldnn::softmax_forward;
using mkldnn::stream;
using platform::to_void_cast;
template <typename T>
class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
public:
SoftmaxMKLDNNHandler(const platform::MKLDNNDeviceContext& dev_ctx,
SoftmaxMKLDNNHandler(const std::vector<int>& dims,
const mkldnn::memory::format fmt,
const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key) {}
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
dims_(dims),
fmt_(fmt) {}
SoftmaxMKLDNNHandler(
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> softmax_pd,
std::shared_ptr<mkldnn::softmax_backward::primitive_desc> softmax_bwd_pd,
const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key)
SoftmaxMKLDNNHandler(const std::vector<int>& dims,
const mkldnn::memory::format fmt,
const mkldnn::memory::format diff_fmt,
const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
softmax_pd_(softmax_pd),
softmax_bwd_pd_(softmax_bwd_pd) {
dims_(dims),
fmt_(fmt),
diff_fmt_(diff_fmt) {
// If we are in Grad operatgor then update a key with BWD suffix to
// distinguish from FWD memory primitives
// Key_common will allow to access FWD_PD from cache
key_ += "-BWD";
}
std::shared_ptr<softmax_forward::primitive_desc>
AcquireSoftmaxPrimitiveDescriptor(const softmax_forward::desc& softmax_desc,
const mkldnn::engine& engine) {
// Softmax PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_softmax_pd = key_common_ + "@softmax_pd";
// TODO(jczaja): Once fwd_pd_ are moved to MKLDNNHandler then this function
// should be moved as well eg. SoftmaxMKLDNNHandler -> MKLDNNHandler<softmax_>
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(void* ptr) {
return this->AcquireMemory(dims_, platform::MKLDNNGetDataType<T>(), fmt_,
ptr, "@user_src_mem_p");
}
softmax_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
dev_ctx_.GetBlob(key_softmax_pd));
if (softmax_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
softmax_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
dev_ctx_.GetBlob(key_softmax_pd));
if (softmax_pd_ == nullptr) {
softmax_pd_.reset(
new softmax_forward::primitive_desc(softmax_desc, engine));
dev_ctx_.SetBlob(key_softmax_pd, softmax_pd_);
std::shared_ptr<mkldnn::memory> AcquireDstMemory(void* ptr) {
return this->AcquireMemory(dims_, platform::MKLDNNGetDataType<T>(), fmt_,
ptr, "@user_dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory(void* ptr) {
return this->AcquireMemory(dims_, platform::MKLDNNGetDataType<T>(),
diff_fmt_, ptr, "@user_diff_dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory(void* ptr) {
return this->AcquireMemory(dims_, platform::MKLDNNGetDataType<T>(),
diff_fmt_, ptr, "@user_diff_src_mem_p");
}
return softmax_pd_;
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(void* ptr) {
this->AcquireSoftmaxPrimitiveDescriptor();
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr,
"@dst_mem_p");
}
std::shared_ptr<mkldnn::softmax_forward> AcquireSoftmax(
......@@ -86,8 +95,9 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
auto softmax_p = std::static_pointer_cast<mkldnn::softmax_forward>(
dev_ctx_.GetBlob(prim_key));
if (softmax_p == nullptr) {
this->AcquireSoftmaxPrimitiveDescriptor();
softmax_p = std::make_shared<mkldnn::softmax_forward>(
*softmax_pd_, *(static_cast<mkldnn::memory*>(src_memory_p.get())),
*fwd_pd_, *(static_cast<mkldnn::memory*>(src_memory_p.get())),
*(static_cast<mkldnn::memory*>(dst_memory_p.get())));
dev_ctx_.SetBlob(prim_key, softmax_p);
}
......@@ -103,8 +113,19 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
auto softmax_bwd_p = std::static_pointer_cast<mkldnn::softmax_backward>(
dev_ctx_.GetBlob(prim_key));
if (softmax_bwd_p == nullptr) {
auto data_softmax_md =
mkldnn::memory::desc(dims_, platform::MKLDNNGetDataType<T>(), fmt_);
auto diff_softmax_md = mkldnn::memory::desc(
dims_, platform::MKLDNNGetDataType<T>(), diff_fmt_);
// TODO(jczaja): Add support for other axes
auto softmax_bwd_desc = softmax_backward::desc(
diff_softmax_md, data_softmax_md, 1 /* dim: C*/);
this->AcquireSoftmaxPrimitiveDescriptor();
auto softmax_bwd_pd = mkldnn::softmax_backward::primitive_desc(
softmax_bwd_desc, engine_, *fwd_pd_);
softmax_bwd_p = std::make_shared<mkldnn::softmax_backward>(
*softmax_bwd_pd_, *dst_memory_p, *diff_dst_memory_p,
softmax_bwd_pd, *dst_memory_p, *diff_dst_memory_p,
*diff_src_memory_p);
dev_ctx_.SetBlob(prim_key, softmax_bwd_p);
}
......@@ -112,9 +133,41 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
return softmax_bwd_p;
}
protected:
void AcquireSoftmaxPrimitiveDescriptor(void) {
// Softmax PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_softmax_pd = key_common_ + "@softmax_pd";
fwd_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
dev_ctx_.GetBlob(key_softmax_pd));
if (fwd_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
fwd_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
dev_ctx_.GetBlob(key_softmax_pd));
if (fwd_pd_ == nullptr) {
// TODO(jczaja): Make it working along chosen axis and for
// forward_training
// Normalization is made after innermost dimension eg. C out of NC
auto md =
mkldnn::memory::desc(dims_, platform::MKLDNNGetDataType<T>(), fmt_);
auto softmax_desc =
softmax_forward::desc(prop_kind::forward_scoring, md, 1 /*dim: C*/);
fwd_pd_.reset(
new softmax_forward::primitive_desc(softmax_desc, engine_));
dev_ctx_.SetBlob(key_softmax_pd, fwd_pd_);
}
}
}
private:
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> softmax_pd_;
std::shared_ptr<mkldnn::softmax_backward::primitive_desc> softmax_bwd_pd_;
std::vector<int> dims_;
mkldnn::memory::format fmt_;
mkldnn::memory::format diff_fmt_;
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> fwd_pd_;
};
template <typename T>
......@@ -154,21 +207,14 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
const std::string key =
platform::MKLDNNHandler::GetHash(softmax_tz, ctx.op().Output("Out"));
SoftmaxMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);
// Currently only NC data format is supported
auto softmax_md = MKLDNNMemDesc(
{softmax_tz}, platform::MKLDNNGetDataType<T>(), memory::format::nc);
// Normalization is made after innermost dimension eg. C out of NC
auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring,
softmax_md, 1 /*dim: C*/);
auto softmax_pd =
handler.AcquireSoftmaxPrimitiveDescriptor(softmax_desc, mkldnn_engine);
SoftmaxMKLDNNHandler<T> handler(softmax_tz, mkldnn::memory::format::nc,
dev_ctx, mkldnn_engine, key);
// Currently only NC data format is supported
auto softmax_src_memory_p =
handler.AcquireSrcMemory(softmax_md, to_void_cast<T>(input_data));
handler.AcquireSrcMemory(to_void_cast<T>(input_data));
auto softmax_dst_memory_p =
handler.AcquireDstMemory(softmax_md, to_void_cast<T>(output_data));
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
auto softmax_p =
handler.AcquireSoftmax(softmax_dst_memory_p, softmax_src_memory_p);
......@@ -241,25 +287,16 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
// TODO(jczaja): Add layouts support when there is a need to do so
// Two dimensional softmax does support NC format
auto data_softmax_md = MKLDNNMemDesc(
{softmax_tz}, platform::MKLDNNGetDataType<T>(), memory::format::nc);
auto diff_softmax_md = MKLDNNMemDesc(
{softmax_tz}, platform::MKLDNNGetDataType<T>(), memory::format::nc);
// Normalization is made after innermost dimension eg. C out of NC
auto softmax_bwd_desc =
softmax_backward::desc(diff_softmax_md, data_softmax_md, 1 /* dim: C*/);
auto softmax_bwd_pd =
std::make_shared<mkldnn::softmax_backward::primitive_desc>(
softmax_bwd_desc, mkldnn_engine, *softmax_pd);
SoftmaxMKLDNNHandler handler(softmax_pd, softmax_bwd_pd, dev_ctx,
SoftmaxMKLDNNHandler<T> handler(softmax_tz, mkldnn::memory::format::nc,
mkldnn::memory::format::nc, dev_ctx,
mkldnn_engine, key);
auto dst_memory_p =
handler.AcquireDstMemory(data_softmax_md, to_void_cast<T>(dst_data));
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(
diff_softmax_md, to_void_cast<T>(diff_dst_ptr));
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(
diff_softmax_md, to_void_cast<T>(diff_src_ptr));
auto dst_memory_p = handler.AcquireDstMemory(to_void_cast<T>(dst_data));
auto diff_dst_memory_p =
handler.AcquireDiffDstMemory(to_void_cast<T>(diff_dst_ptr));
auto diff_src_memory_p =
handler.AcquireDiffSrcMemory(to_void_cast<T>(diff_src_ptr));
// Get primitve from device context
auto softmax_bwd_p = handler.AcquireSoftmaxBackward(
......
......@@ -119,6 +119,25 @@ class MKLDNNHandler {
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemory(
const std::vector<int>& dims, const mkldnn::memory::data_type dtype,
const mkldnn::memory::format& fmt, void* ptr, const std::string& suffix) {
/*Generate key*/
auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
auto md = mkldnn::memory::desc(dims, dtype, fmt);
mem_p = std::make_shared<mkldnn::memory>(
mkldnn::memory::primitive_desc{md, engine_}, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemory(
const mkldnn::memory::primitive_desc& mpd, const std::string& suffix) {
auto local_key = key_ + suffix;
......@@ -949,18 +968,7 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const mkldnn::memory::format& fmt, void* ptr) {
auto local_key = key_ + "@user_src_mem_p";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
auto src_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt);
mem_p = std::make_shared<mkldnn::memory>(
mkldnn::memory::primitive_desc{src_md, engine_}, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
}
return mem_p;
return this->AcquireMemory(dims_, dtype_, fmt, ptr, "@user_src_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDstMemory(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册