提交 cef95ee3 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] Refactoring Softmax (#19312)

* - First set of modifications

- Compilation fixes

- compilation fix

- Another compilation fix

- Moved AcquireSoftmaxPrimitiveDescriptor call into handler

- MKL-DNN Softmax PD refactor

test=develop

- Compilation fix

test=develop

- another compilation fix

- cosmetcis

test=develop

- Compilation fix

- Fix to crash when softmax backward is created

* - Fixes after review of softmax refactoring

test=develop
上级 0a73f720
...@@ -32,49 +32,58 @@ using mkldnn::softmax_forward; ...@@ -32,49 +32,58 @@ using mkldnn::softmax_forward;
using mkldnn::stream; using mkldnn::stream;
using platform::to_void_cast; using platform::to_void_cast;
template <typename T>
class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
public: public:
SoftmaxMKLDNNHandler(const platform::MKLDNNDeviceContext& dev_ctx, SoftmaxMKLDNNHandler(const std::vector<int>& dims,
const mkldnn::memory::format fmt,
const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key) mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key) {} : platform::MKLDNNHandler(dev_ctx, engine, base_key),
dims_(dims),
fmt_(fmt) {}
SoftmaxMKLDNNHandler( SoftmaxMKLDNNHandler(const std::vector<int>& dims,
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> softmax_pd, const mkldnn::memory::format fmt,
std::shared_ptr<mkldnn::softmax_backward::primitive_desc> softmax_bwd_pd, const mkldnn::memory::format diff_fmt,
const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, const platform::MKLDNNDeviceContext& dev_ctx,
const std::string& base_key) mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key), : platform::MKLDNNHandler(dev_ctx, engine, base_key),
softmax_pd_(softmax_pd), dims_(dims),
softmax_bwd_pd_(softmax_bwd_pd) { fmt_(fmt),
diff_fmt_(diff_fmt) {
// If we are in Grad operatgor then update a key with BWD suffix to // If we are in Grad operatgor then update a key with BWD suffix to
// distinguish from FWD memory primitives // distinguish from FWD memory primitives
// Key_common will allow to access FWD_PD from cache
key_ += "-BWD"; key_ += "-BWD";
} }
std::shared_ptr<softmax_forward::primitive_desc> // TODO(jczaja): Once fwd_pd_ are moved to MKLDNNHandler then this function
AcquireSoftmaxPrimitiveDescriptor(const softmax_forward::desc& softmax_desc, // should be moved as well eg. SoftmaxMKLDNNHandler -> MKLDNNHandler<softmax_>
const mkldnn::engine& engine) { std::shared_ptr<mkldnn::memory> AcquireSrcMemory(void* ptr) {
// Softmax PD has to be passed to Grad op that return this->AcquireMemory(dims_, platform::MKLDNNGetDataType<T>(), fmt_,
// may be executed by diffrent thread, hence ptr, "@user_src_mem_p");
// for that one we use key that does not contain TID }
const std::string key_softmax_pd = key_common_ + "@softmax_pd";
softmax_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>( std::shared_ptr<mkldnn::memory> AcquireDstMemory(void* ptr) {
dev_ctx_.GetBlob(key_softmax_pd)); return this->AcquireMemory(dims_, platform::MKLDNNGetDataType<T>(), fmt_,
if (softmax_pd_ == nullptr) { ptr, "@user_dst_mem_p");
static std::mutex acquire_barrier; }
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier); std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory(void* ptr) {
softmax_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>( return this->AcquireMemory(dims_, platform::MKLDNNGetDataType<T>(),
dev_ctx_.GetBlob(key_softmax_pd)); diff_fmt_, ptr, "@user_diff_dst_mem_p");
if (softmax_pd_ == nullptr) { }
softmax_pd_.reset(
new softmax_forward::primitive_desc(softmax_desc, engine)); std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory(void* ptr) {
dev_ctx_.SetBlob(key_softmax_pd, softmax_pd_); return this->AcquireMemory(dims_, platform::MKLDNNGetDataType<T>(),
} diff_fmt_, ptr, "@user_diff_src_mem_p");
} }
return softmax_pd_; std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(void* ptr) {
this->AcquireSoftmaxPrimitiveDescriptor();
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr,
"@dst_mem_p");
} }
std::shared_ptr<mkldnn::softmax_forward> AcquireSoftmax( std::shared_ptr<mkldnn::softmax_forward> AcquireSoftmax(
...@@ -86,8 +95,9 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -86,8 +95,9 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
auto softmax_p = std::static_pointer_cast<mkldnn::softmax_forward>( auto softmax_p = std::static_pointer_cast<mkldnn::softmax_forward>(
dev_ctx_.GetBlob(prim_key)); dev_ctx_.GetBlob(prim_key));
if (softmax_p == nullptr) { if (softmax_p == nullptr) {
this->AcquireSoftmaxPrimitiveDescriptor();
softmax_p = std::make_shared<mkldnn::softmax_forward>( softmax_p = std::make_shared<mkldnn::softmax_forward>(
*softmax_pd_, *(static_cast<mkldnn::memory*>(src_memory_p.get())), *fwd_pd_, *(static_cast<mkldnn::memory*>(src_memory_p.get())),
*(static_cast<mkldnn::memory*>(dst_memory_p.get()))); *(static_cast<mkldnn::memory*>(dst_memory_p.get())));
dev_ctx_.SetBlob(prim_key, softmax_p); dev_ctx_.SetBlob(prim_key, softmax_p);
} }
...@@ -103,8 +113,19 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -103,8 +113,19 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
auto softmax_bwd_p = std::static_pointer_cast<mkldnn::softmax_backward>( auto softmax_bwd_p = std::static_pointer_cast<mkldnn::softmax_backward>(
dev_ctx_.GetBlob(prim_key)); dev_ctx_.GetBlob(prim_key));
if (softmax_bwd_p == nullptr) { if (softmax_bwd_p == nullptr) {
auto data_softmax_md =
mkldnn::memory::desc(dims_, platform::MKLDNNGetDataType<T>(), fmt_);
auto diff_softmax_md = mkldnn::memory::desc(
dims_, platform::MKLDNNGetDataType<T>(), diff_fmt_);
// TODO(jczaja): Add support for other axes
auto softmax_bwd_desc = softmax_backward::desc(
diff_softmax_md, data_softmax_md, 1 /* dim: C*/);
this->AcquireSoftmaxPrimitiveDescriptor();
auto softmax_bwd_pd = mkldnn::softmax_backward::primitive_desc(
softmax_bwd_desc, engine_, *fwd_pd_);
softmax_bwd_p = std::make_shared<mkldnn::softmax_backward>( softmax_bwd_p = std::make_shared<mkldnn::softmax_backward>(
*softmax_bwd_pd_, *dst_memory_p, *diff_dst_memory_p, softmax_bwd_pd, *dst_memory_p, *diff_dst_memory_p,
*diff_src_memory_p); *diff_src_memory_p);
dev_ctx_.SetBlob(prim_key, softmax_bwd_p); dev_ctx_.SetBlob(prim_key, softmax_bwd_p);
} }
...@@ -112,9 +133,41 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -112,9 +133,41 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
return softmax_bwd_p; return softmax_bwd_p;
} }
protected:
void AcquireSoftmaxPrimitiveDescriptor(void) {
// Softmax PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_softmax_pd = key_common_ + "@softmax_pd";
fwd_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
dev_ctx_.GetBlob(key_softmax_pd));
if (fwd_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
fwd_pd_ = std::static_pointer_cast<softmax_forward::primitive_desc>(
dev_ctx_.GetBlob(key_softmax_pd));
if (fwd_pd_ == nullptr) {
// TODO(jczaja): Make it working along chosen axis and for
// forward_training
// Normalization is made after innermost dimension eg. C out of NC
auto md =
mkldnn::memory::desc(dims_, platform::MKLDNNGetDataType<T>(), fmt_);
auto softmax_desc =
softmax_forward::desc(prop_kind::forward_scoring, md, 1 /*dim: C*/);
fwd_pd_.reset(
new softmax_forward::primitive_desc(softmax_desc, engine_));
dev_ctx_.SetBlob(key_softmax_pd, fwd_pd_);
}
}
}
private: private:
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> softmax_pd_; std::vector<int> dims_;
std::shared_ptr<mkldnn::softmax_backward::primitive_desc> softmax_bwd_pd_; mkldnn::memory::format fmt_;
mkldnn::memory::format diff_fmt_;
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> fwd_pd_;
}; };
template <typename T> template <typename T>
...@@ -154,21 +207,14 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -154,21 +207,14 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
const std::string key = const std::string key =
platform::MKLDNNHandler::GetHash(softmax_tz, ctx.op().Output("Out")); platform::MKLDNNHandler::GetHash(softmax_tz, ctx.op().Output("Out"));
SoftmaxMKLDNNHandler handler(dev_ctx, mkldnn_engine, key); SoftmaxMKLDNNHandler<T> handler(softmax_tz, mkldnn::memory::format::nc,
// Currently only NC data format is supported dev_ctx, mkldnn_engine, key);
auto softmax_md = MKLDNNMemDesc(
{softmax_tz}, platform::MKLDNNGetDataType<T>(), memory::format::nc);
// Normalization is made after innermost dimension eg. C out of NC
auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring,
softmax_md, 1 /*dim: C*/);
auto softmax_pd =
handler.AcquireSoftmaxPrimitiveDescriptor(softmax_desc, mkldnn_engine);
// Currently only NC data format is supported
auto softmax_src_memory_p = auto softmax_src_memory_p =
handler.AcquireSrcMemory(softmax_md, to_void_cast<T>(input_data)); handler.AcquireSrcMemory(to_void_cast<T>(input_data));
auto softmax_dst_memory_p = auto softmax_dst_memory_p =
handler.AcquireDstMemory(softmax_md, to_void_cast<T>(output_data)); handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
auto softmax_p = auto softmax_p =
handler.AcquireSoftmax(softmax_dst_memory_p, softmax_src_memory_p); handler.AcquireSoftmax(softmax_dst_memory_p, softmax_src_memory_p);
...@@ -241,25 +287,16 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> { ...@@ -241,25 +287,16 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
// TODO(jczaja): Add layouts support when there is a need to do so // TODO(jczaja): Add layouts support when there is a need to do so
// Two dimensional softmax does support NC format // Two dimensional softmax does support NC format
auto data_softmax_md = MKLDNNMemDesc(
{softmax_tz}, platform::MKLDNNGetDataType<T>(), memory::format::nc);
auto diff_softmax_md = MKLDNNMemDesc(
{softmax_tz}, platform::MKLDNNGetDataType<T>(), memory::format::nc);
// Normalization is made after innermost dimension eg. C out of NC // Normalization is made after innermost dimension eg. C out of NC
auto softmax_bwd_desc = SoftmaxMKLDNNHandler<T> handler(softmax_tz, mkldnn::memory::format::nc,
softmax_backward::desc(diff_softmax_md, data_softmax_md, 1 /* dim: C*/); mkldnn::memory::format::nc, dev_ctx,
auto softmax_bwd_pd = mkldnn_engine, key);
std::make_shared<mkldnn::softmax_backward::primitive_desc>(
softmax_bwd_desc, mkldnn_engine, *softmax_pd); auto dst_memory_p = handler.AcquireDstMemory(to_void_cast<T>(dst_data));
auto diff_dst_memory_p =
SoftmaxMKLDNNHandler handler(softmax_pd, softmax_bwd_pd, dev_ctx, handler.AcquireDiffDstMemory(to_void_cast<T>(diff_dst_ptr));
mkldnn_engine, key); auto diff_src_memory_p =
auto dst_memory_p = handler.AcquireDiffSrcMemory(to_void_cast<T>(diff_src_ptr));
handler.AcquireDstMemory(data_softmax_md, to_void_cast<T>(dst_data));
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(
diff_softmax_md, to_void_cast<T>(diff_dst_ptr));
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(
diff_softmax_md, to_void_cast<T>(diff_src_ptr));
// Get primitve from device context // Get primitve from device context
auto softmax_bwd_p = handler.AcquireSoftmaxBackward( auto softmax_bwd_p = handler.AcquireSoftmaxBackward(
......
...@@ -119,6 +119,25 @@ class MKLDNNHandler { ...@@ -119,6 +119,25 @@ class MKLDNNHandler {
return mem_p; return mem_p;
} }
std::shared_ptr<mkldnn::memory> AcquireMemory(
const std::vector<int>& dims, const mkldnn::memory::data_type dtype,
const mkldnn::memory::format& fmt, void* ptr, const std::string& suffix) {
/*Generate key*/
auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
auto md = mkldnn::memory::desc(dims, dtype, fmt);
mem_p = std::make_shared<mkldnn::memory>(
mkldnn::memory::primitive_desc{md, engine_}, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemory( std::shared_ptr<mkldnn::memory> AcquireMemory(
const mkldnn::memory::primitive_desc& mpd, const std::string& suffix) { const mkldnn::memory::primitive_desc& mpd, const std::string& suffix) {
auto local_key = key_ + suffix; auto local_key = key_ + suffix;
...@@ -949,18 +968,7 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { ...@@ -949,18 +968,7 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
std::shared_ptr<mkldnn::memory> AcquireSrcMemory( std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const mkldnn::memory::format& fmt, void* ptr) { const mkldnn::memory::format& fmt, void* ptr) {
auto local_key = key_ + "@user_src_mem_p"; return this->AcquireMemory(dims_, dtype_, fmt, ptr, "@user_src_mem_p");
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
auto src_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt);
mem_p = std::make_shared<mkldnn::memory>(
mkldnn::memory::primitive_desc{src_md, engine_}, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
}
return mem_p;
} }
std::shared_ptr<mkldnn::memory> AcquireDstMemory( std::shared_ptr<mkldnn::memory> AcquireDstMemory(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册