提交 47f670d5 编写于 作者: J Jacek Czaja 提交者: Tao Luo

- Softmax mkl-dnn refactoring (#19615)

test=develop

- Cosmetic fixes

test=develop
上级 a65c728e
...@@ -38,52 +38,69 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -38,52 +38,69 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
SoftmaxMKLDNNHandler(const std::vector<int>& dims, SoftmaxMKLDNNHandler(const std::vector<int>& dims,
const MKLDNNMemoryFormat fmt, const MKLDNNMemoryFormat fmt,
const platform::MKLDNNDeviceContext& dev_ctx, const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key) platform::Place cpu_place, const std::string& uniq_name)
: platform::MKLDNNHandler(dev_ctx, engine, base_key), : platform::MKLDNNHandler(dev_ctx, dev_ctx.GetEngine(),
dims_(dims), platform::GetHash(dims, uniq_name)),
fmt_(fmt) {} place_(cpu_place),
fwd_pd_(nullptr),
bwd_pd_(nullptr) {
this->AcquireSoftmaxPrimitiveDescriptor(dims, fmt);
}
SoftmaxMKLDNNHandler(const std::vector<int>& dims, SoftmaxMKLDNNHandler(const std::vector<int>& dims,
const MKLDNNMemoryFormat fmt, const MKLDNNMemoryFormat fmt,
const MKLDNNMemoryFormat diff_fmt, const MKLDNNMemoryFormat diff_fmt,
const platform::MKLDNNDeviceContext& dev_ctx, const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key) platform::Place cpu_place, const std::string& uniq_name)
: platform::MKLDNNHandler(dev_ctx, engine, base_key), : platform::MKLDNNHandler(dev_ctx, dev_ctx.GetEngine(),
dims_(dims), platform::GetHash(dims, uniq_name)),
fmt_(fmt), place_(cpu_place),
diff_fmt_(diff_fmt) { fwd_pd_(nullptr),
bwd_pd_(nullptr) {
// If we are in Grad operatgor then update a key with BWD suffix to // If we are in Grad operatgor then update a key with BWD suffix to
// distinguish from FWD memory primitives // distinguish from FWD memory primitives
// Key_common will allow to access FWD_PD from cache // Key_common will allow to access FWD_PD from cache
key_ += "-BWD"; this->AcquireSoftmaxPrimitiveDescriptor(dims, fmt);
this->AcquireSoftmaxBackwardPrimitiveDescriptor(dims, fmt, diff_fmt);
} }
// TODO(jczaja): Once fwd_pd_ are moved to MKLDNNHandler then this function // TODO(jczaja): Once fwd_pd_ are moved to MKLDNNHandler then this function
// should be moved as well eg. SoftmaxMKLDNNHandler -> MKLDNNHandler<softmax_> // should be moved as well eg. SoftmaxMKLDNNHandler -> MKLDNNHandler<softmax_>
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(void* ptr) { std::shared_ptr<mkldnn::memory> AcquireSrcMemory(const Tensor* input) {
return this->AcquireMemory(dims_, platform::MKLDNNGetDataType<T>(), fmt_, const T* input_data = input->data<T>();
ptr, "@user_src_mem_p"); return this->AcquireMemoryFromPrimitive(fwd_pd_->src_primitive_desc(),
to_void_cast<T>(input_data),
"@src_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireDstMemory(void* ptr) { // TODO(jczaja): Move to MKLDNNHandler as common code
return this->AcquireMemory(dims_, platform::MKLDNNGetDataType<T>(), fmt_, std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output) {
ptr, "@user_dst_mem_p"); T* ptr = output->mutable_data<T>(place_,
fwd_pd_->dst_primitive_desc().get_size());
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr,
"@dst_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory(void* ptr) { std::shared_ptr<mkldnn::memory> AcquireDstMemory(const Tensor* output) {
return this->AcquireMemory(dims_, platform::MKLDNNGetDataType<T>(), const T* output_data = output->data<T>();
diff_fmt_, ptr, "@user_diff_dst_mem_p"); return this->AcquireMemoryFromPrimitive(bwd_pd_->dst_primitive_desc(),
to_void_cast<T>(output_data),
"@bwd-dst_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory(void* ptr) { std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory(const Tensor* diffdst) {
return this->AcquireMemory(dims_, platform::MKLDNNGetDataType<T>(), const T* ptr = diffdst->data<T>();
diff_fmt_, ptr, "@user_diff_src_mem_p"); return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_dst_primitive_desc(),
to_void_cast<T>(ptr),
"@diff_dst_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(void* ptr) { std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory(
this->AcquireSoftmaxPrimitiveDescriptor(); framework::Tensor* diffsrc) {
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr, T* ptr = diffsrc->mutable_data<T>(
"@dst_mem_p"); place_, bwd_pd_->diff_src_primitive_desc().get_size());
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_primitive_desc(),
ptr, "@diff_src_mem_p");
} }
std::shared_ptr<mkldnn::softmax_forward> AcquireSoftmax( std::shared_ptr<mkldnn::softmax_forward> AcquireSoftmax(
...@@ -95,7 +112,6 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -95,7 +112,6 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
auto softmax_p = std::static_pointer_cast<mkldnn::softmax_forward>( auto softmax_p = std::static_pointer_cast<mkldnn::softmax_forward>(
dev_ctx_.GetBlob(prim_key)); dev_ctx_.GetBlob(prim_key));
if (softmax_p == nullptr) { if (softmax_p == nullptr) {
this->AcquireSoftmaxPrimitiveDescriptor();
softmax_p = std::make_shared<mkldnn::softmax_forward>( softmax_p = std::make_shared<mkldnn::softmax_forward>(
*fwd_pd_, *(static_cast<mkldnn::memory*>(src_memory_p.get())), *fwd_pd_, *(static_cast<mkldnn::memory*>(src_memory_p.get())),
*(static_cast<mkldnn::memory*>(dst_memory_p.get()))); *(static_cast<mkldnn::memory*>(dst_memory_p.get())));
...@@ -113,20 +129,8 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -113,20 +129,8 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
auto softmax_bwd_p = std::static_pointer_cast<mkldnn::softmax_backward>( auto softmax_bwd_p = std::static_pointer_cast<mkldnn::softmax_backward>(
dev_ctx_.GetBlob(prim_key)); dev_ctx_.GetBlob(prim_key));
if (softmax_bwd_p == nullptr) { if (softmax_bwd_p == nullptr) {
auto data_softmax_md =
mkldnn::memory::desc(dims_, platform::MKLDNNGetDataType<T>(), fmt_);
auto diff_softmax_md = mkldnn::memory::desc(
dims_, platform::MKLDNNGetDataType<T>(), diff_fmt_);
// TODO(jczaja): Add support for other axes
auto softmax_bwd_desc = softmax_backward::desc(
diff_softmax_md, data_softmax_md, 1 /* dim: C*/);
this->AcquireSoftmaxPrimitiveDescriptor();
auto softmax_bwd_pd = mkldnn::softmax_backward::primitive_desc(
softmax_bwd_desc, engine_, *fwd_pd_);
softmax_bwd_p = std::make_shared<mkldnn::softmax_backward>( softmax_bwd_p = std::make_shared<mkldnn::softmax_backward>(
softmax_bwd_pd, *dst_memory_p, *diff_dst_memory_p, *bwd_pd_, *dst_memory_p, *diff_dst_memory_p, *diff_src_memory_p);
*diff_src_memory_p);
dev_ctx_.SetBlob(prim_key, softmax_bwd_p); dev_ctx_.SetBlob(prim_key, softmax_bwd_p);
} }
...@@ -134,7 +138,8 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -134,7 +138,8 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
} }
protected: protected:
void AcquireSoftmaxPrimitiveDescriptor(void) { void AcquireSoftmaxPrimitiveDescriptor(const std::vector<int>& dims,
const mkldnn::memory::format fmt) {
// Softmax PD has to be passed to Grad op that // Softmax PD has to be passed to Grad op that
// may be executed by diffrent thread, hence // may be executed by diffrent thread, hence
// for that one we use key that does not contain TID // for that one we use key that does not contain TID
...@@ -153,7 +158,7 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -153,7 +158,7 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
// forward_training // forward_training
// Normalization is made after innermost dimension eg. C out of NC // Normalization is made after innermost dimension eg. C out of NC
auto md = auto md =
mkldnn::memory::desc(dims_, platform::MKLDNNGetDataType<T>(), fmt_); mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto softmax_desc = auto softmax_desc =
softmax_forward::desc(prop_kind::forward_scoring, md, 1 /*dim: C*/); softmax_forward::desc(prop_kind::forward_scoring, md, 1 /*dim: C*/);
fwd_pd_.reset( fwd_pd_.reset(
...@@ -163,11 +168,33 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -163,11 +168,33 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
} }
} }
void AcquireSoftmaxBackwardPrimitiveDescriptor(
const std::vector<int>& dims, const mkldnn::memory::format fmt,
const mkldnn::memory::format diff_fmt) {
// Fwd_PD_ has to exists when to create BWD_PD_
PADDLE_ENFORCE_NOT_NULL(fwd_pd_);
const std::string key_bwd_pd = key_ + "@softmax_bwd_pd";
bwd_pd_ =
std::static_pointer_cast<mkldnn::softmax_backward::primitive_desc>(
dev_ctx_.GetBlob(key_bwd_pd));
if (bwd_pd_ == nullptr) {
auto data_softmax_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto diff_softmax_md = mkldnn::memory::desc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
// TODO(jczaja): Add support for other axes
auto backward_desc = softmax_backward::desc(
diff_softmax_md, data_softmax_md, 1 /* dim: C*/);
bwd_pd_.reset(new mkldnn::softmax_backward::primitive_desc(
backward_desc, engine_, *fwd_pd_));
dev_ctx_.SetBlob(key_bwd_pd, bwd_pd_);
}
}
private: private:
std::vector<int> dims_; platform::Place place_;
MKLDNNMemoryFormat fmt_;
MKLDNNMemoryFormat diff_fmt_;
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> fwd_pd_; std::shared_ptr<mkldnn::softmax_forward::primitive_desc> fwd_pd_;
std::shared_ptr<mkldnn::softmax_backward::primitive_desc> bwd_pd_;
}; };
template <typename T> template <typename T>
...@@ -177,44 +204,25 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -177,44 +204,25 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace."); "It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
const Tensor* input = ctx.Input<Tensor>("X"); const Tensor* input = ctx.Input<Tensor>("X");
Tensor* output = ctx.Output<Tensor>("Out"); Tensor* output = ctx.Output<Tensor>("Out");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input->dims(), output->dims(), input->dims(), output->dims(),
"The shape of softmax's input and output must be identical."); "The shape of softmax's input and output must be identical.");
// make sure 'output' holds memory, which will be shared by
// 'flattened_output' later.
output->mutable_data<T>(ctx.GetPlace());
// flatten input and output to 2-D matrixs // flatten input and output to 2-D matrixs
auto dims = input->dims(); // input and output share the same shape auto dims = input->dims(); // input and output share the same shape
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
framework::Tensor flattened_input;
framework::Tensor flattened_output;
flattened_input.ShareDataWith(*input).Resize(flattened_dims);
flattened_output.ShareDataWith(*output).Resize(flattened_dims);
const T* input_data = flattened_input.data<T>();
T* output_data = flattened_output.mutable_data<T>(ctx.GetPlace());
auto src_tz = paddle::framework::vectorize<int>(flattened_dims); auto src_tz = paddle::framework::vectorize<int>(flattened_dims);
auto dst_tz = src_tz; auto dst_tz = src_tz;
// Same memory descriptor to be used for input and output // Same memory descriptor to be used for input and output
memory::dims softmax_tz = {src_tz[0], src_tz[1]}; memory::dims softmax_tz = {src_tz[0], src_tz[1]};
// Generate keys for storing/retriving primitives for this operator
const std::string key =
platform::GetHash(softmax_tz, ctx.op().Output("Out"));
SoftmaxMKLDNNHandler<T> handler(softmax_tz, MKLDNNMemoryFormat::nc, dev_ctx, SoftmaxMKLDNNHandler<T> handler(softmax_tz, MKLDNNMemoryFormat::nc, dev_ctx,
mkldnn_engine, key); ctx.GetPlace(), ctx.op().Output("Out"));
// Currently only NC data format is supported // Currently only NC data format is supported
auto softmax_src_memory_p = auto softmax_src_memory_p = handler.AcquireSrcMemory(input);
handler.AcquireSrcMemory(to_void_cast<T>(input_data)); auto softmax_dst_memory_p = handler.AcquireDstMemory(output);
auto softmax_dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
auto softmax_p = auto softmax_p =
handler.AcquireSoftmax(softmax_dst_memory_p, softmax_src_memory_p); handler.AcquireSoftmax(softmax_dst_memory_p, softmax_src_memory_p);
...@@ -222,6 +230,7 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -222,6 +230,7 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
*(static_cast<softmax_forward::primitive*>(softmax_p.get()))}; *(static_cast<softmax_forward::primitive*>(softmax_p.get()))};
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
if (!is_test) { if (!is_test) {
T threshold = exp(-64); T threshold = exp(-64);
...@@ -230,6 +239,10 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -230,6 +239,10 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
output_data[i] < threshold ? threshold : output_data[i]; output_data[i] < threshold ? threshold : output_data[i];
} }
} }
output->set_layout(framework::DataLayout::kMKLDNN);
// Softmax output format is the same as input one
output->set_format(input->format());
} }
}; };
...@@ -241,7 +254,6 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> { ...@@ -241,7 +254,6 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
"It must use CPUPlace."); "It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
const Tensor* output = ctx.Input<Tensor>("Out"); const Tensor* output = ctx.Input<Tensor>("Out");
auto* dout = ctx.template Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
auto* dx = auto* dx =
...@@ -251,52 +263,25 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> { ...@@ -251,52 +263,25 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
dout->dims(), dx->dims(), dout->dims(), dx->dims(),
"The shape of softmax_grad's input and output must be identical."); "The shape of softmax_grad's input and output must be identical.");
// make sure 'dx' holds memory, which will be shared by 'flattened_dx'
// later.
dx->template mutable_data<T>(ctx.GetPlace());
auto dims = dout->dims(); // input and output share the same shape auto dims = dout->dims(); // input and output share the same shape
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
framework::Tensor flattened_output;
framework::Tensor flattened_dout;
framework::Tensor flattened_dx;
flattened_output.ShareDataWith(*output).Resize(flattened_dims);
flattened_dout.ShareDataWith(*dout).Resize(flattened_dims);
flattened_dx.ShareDataWith(*dx).Resize(flattened_dims);
const T* dst_data = flattened_output.data<T>();
const T* diff_dst_ptr = flattened_dout.template data<T>();
T* diff_src_ptr = flattened_dx.template mutable_data<T>(ctx.GetPlace());
auto dst_tz = paddle::framework::vectorize<int>(flattened_dims); std::vector<int> dst_tz = paddle::framework::vectorize<int>(flattened_dims);
auto src_tz(dst_tz); std::vector<int> src_tz(dst_tz);
// Same memory descriptor to be used for input and output // Same memory descriptor to be used for input and output
memory::dims softmax_tz = {src_tz[0], src_tz[1]}; memory::dims softmax_tz = {src_tz[0], src_tz[1]};
// Currently only supports NC data format
// retrieve eltwise primitive desc from device context
const std::string key =
platform::GetHash(softmax_tz, ctx.op().Input("Out"));
const std::string key_softmax_pd = key + "@softmax_pd";
auto softmax_pd =
std::static_pointer_cast<mkldnn::softmax_forward::primitive_desc>(
dev_ctx.GetBlob(key_softmax_pd));
PADDLE_ENFORCE(softmax_pd != nullptr,
"Fail to find softmax_pd in device context");
// TODO(jczaja): Add layouts support when there is a need to do so // TODO(jczaja): Add layouts support when there is a need to do so
// Two dimensional softmax does support NC format // Two dimensional softmax does support NC format
// Normalization is made after innermost dimension eg. C out of NC // Normalization is made after innermost dimension eg. C out of NC
SoftmaxMKLDNNHandler<T> handler(softmax_tz, MKLDNNMemoryFormat::nc, SoftmaxMKLDNNHandler<T> handler(softmax_tz, MKLDNNMemoryFormat::nc,
MKLDNNMemoryFormat::nc, dev_ctx, MKLDNNMemoryFormat::nc, dev_ctx,
mkldnn_engine, key); ctx.GetPlace(), ctx.op().Input("Out"));
auto dst_memory_p = handler.AcquireDstMemory(to_void_cast<T>(dst_data)); auto dst_memory_p = handler.AcquireDstMemory(output);
auto diff_dst_memory_p = auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
handler.AcquireDiffDstMemory(to_void_cast<T>(diff_dst_ptr)); auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
auto diff_src_memory_p =
handler.AcquireDiffSrcMemory(to_void_cast<T>(diff_src_ptr));
// Get primitve from device context // Get primitve from device context
auto softmax_bwd_p = handler.AcquireSoftmaxBackward( auto softmax_bwd_p = handler.AcquireSoftmaxBackward(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册