提交 47f670d5 编写于 作者: J Jacek Czaja 提交者: Tao Luo

- Softmax mkl-dnn refactoring (#19615)

test=develop

- Cosmetic fixes

test=develop
上级 a65c728e
......@@ -38,52 +38,69 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
SoftmaxMKLDNNHandler(const std::vector<int>& dims,
const MKLDNNMemoryFormat fmt,
const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
dims_(dims),
fmt_(fmt) {}
platform::Place cpu_place, const std::string& uniq_name)
: platform::MKLDNNHandler(dev_ctx, dev_ctx.GetEngine(),
platform::GetHash(dims, uniq_name)),
place_(cpu_place),
fwd_pd_(nullptr),
bwd_pd_(nullptr) {
this->AcquireSoftmaxPrimitiveDescriptor(dims, fmt);
}
SoftmaxMKLDNNHandler(const std::vector<int>& dims,
const MKLDNNMemoryFormat fmt,
const MKLDNNMemoryFormat diff_fmt,
const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
dims_(dims),
fmt_(fmt),
diff_fmt_(diff_fmt) {
platform::Place cpu_place, const std::string& uniq_name)
: platform::MKLDNNHandler(dev_ctx, dev_ctx.GetEngine(),
platform::GetHash(dims, uniq_name)),
place_(cpu_place),
fwd_pd_(nullptr),
bwd_pd_(nullptr) {
// If we are in Grad operatgor then update a key with BWD suffix to
// distinguish from FWD memory primitives
// Key_common will allow to access FWD_PD from cache
key_ += "-BWD";
this->AcquireSoftmaxPrimitiveDescriptor(dims, fmt);
this->AcquireSoftmaxBackwardPrimitiveDescriptor(dims, fmt, diff_fmt);
}
// TODO(jczaja): Once fwd_pd_ are moved to MKLDNNHandler then this function
// should be moved as well eg. SoftmaxMKLDNNHandler -> MKLDNNHandler<softmax_>
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(void* ptr) {
return this->AcquireMemory(dims_, platform::MKLDNNGetDataType<T>(), fmt_,
ptr, "@user_src_mem_p");
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(const Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(fwd_pd_->src_primitive_desc(),
to_void_cast<T>(input_data),
"@src_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDstMemory(void* ptr) {
return this->AcquireMemory(dims_, platform::MKLDNNGetDataType<T>(), fmt_,
ptr, "@user_dst_mem_p");
// TODO(jczaja): Move to MKLDNNHandler as common code
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output) {
T* ptr = output->mutable_data<T>(place_,
fwd_pd_->dst_primitive_desc().get_size());
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr,
"@dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory(void* ptr) {
return this->AcquireMemory(dims_, platform::MKLDNNGetDataType<T>(),
diff_fmt_, ptr, "@user_diff_dst_mem_p");
std::shared_ptr<mkldnn::memory> AcquireDstMemory(const Tensor* output) {
const T* output_data = output->data<T>();
return this->AcquireMemoryFromPrimitive(bwd_pd_->dst_primitive_desc(),
to_void_cast<T>(output_data),
"@bwd-dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory(void* ptr) {
return this->AcquireMemory(dims_, platform::MKLDNNGetDataType<T>(),
diff_fmt_, ptr, "@user_diff_src_mem_p");
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory(const Tensor* diffdst) {
const T* ptr = diffdst->data<T>();
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_dst_primitive_desc(),
to_void_cast<T>(ptr),
"@diff_dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(void* ptr) {
this->AcquireSoftmaxPrimitiveDescriptor();
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr,
"@dst_mem_p");
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory(
framework::Tensor* diffsrc) {
T* ptr = diffsrc->mutable_data<T>(
place_, bwd_pd_->diff_src_primitive_desc().get_size());
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_primitive_desc(),
ptr, "@diff_src_mem_p");
}
std::shared_ptr<mkldnn::softmax_forward> AcquireSoftmax(
......@@ -95,7 +112,6 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
auto softmax_p = std::static_pointer_cast<mkldnn::softmax_forward>(
dev_ctx_.GetBlob(prim_key));
if (softmax_p == nullptr) {
this->AcquireSoftmaxPrimitiveDescriptor();
softmax_p = std::make_shared<mkldnn::softmax_forward>(
*fwd_pd_, *(static_cast<mkldnn::memory*>(src_memory_p.get())),
*(static_cast<mkldnn::memory*>(dst_memory_p.get())));
......@@ -113,20 +129,8 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
auto softmax_bwd_p = std::static_pointer_cast<mkldnn::softmax_backward>(
dev_ctx_.GetBlob(prim_key));
if (softmax_bwd_p == nullptr) {
auto data_softmax_md =
mkldnn::memory::desc(dims_, platform::MKLDNNGetDataType<T>(), fmt_);
auto diff_softmax_md = mkldnn::memory::desc(
dims_, platform::MKLDNNGetDataType<T>(), diff_fmt_);
// TODO(jczaja): Add support for other axes
auto softmax_bwd_desc = softmax_backward::desc(
diff_softmax_md, data_softmax_md, 1 /* dim: C*/);
this->AcquireSoftmaxPrimitiveDescriptor();
auto softmax_bwd_pd = mkldnn::softmax_backward::primitive_desc(
softmax_bwd_desc, engine_, *fwd_pd_);
softmax_bwd_p = std::make_shared<mkldnn::softmax_backward>(
softmax_bwd_pd, *dst_memory_p, *diff_dst_memory_p,
*diff_src_memory_p);
*bwd_pd_, *dst_memory_p, *diff_dst_memory_p, *diff_src_memory_p);
dev_ctx_.SetBlob(prim_key, softmax_bwd_p);
}
......@@ -134,7 +138,8 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
}
protected:
void AcquireSoftmaxPrimitiveDescriptor(void) {
void AcquireSoftmaxPrimitiveDescriptor(const std::vector<int>& dims,
const mkldnn::memory::format fmt) {
// Softmax PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
......@@ -153,7 +158,7 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
// forward_training
// Normalization is made after innermost dimension eg. C out of NC
auto md =
mkldnn::memory::desc(dims_, platform::MKLDNNGetDataType<T>(), fmt_);
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto softmax_desc =
softmax_forward::desc(prop_kind::forward_scoring, md, 1 /*dim: C*/);
fwd_pd_.reset(
......@@ -163,11 +168,33 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
}
}
void AcquireSoftmaxBackwardPrimitiveDescriptor(
const std::vector<int>& dims, const mkldnn::memory::format fmt,
const mkldnn::memory::format diff_fmt) {
// Fwd_PD_ has to exists when to create BWD_PD_
PADDLE_ENFORCE_NOT_NULL(fwd_pd_);
const std::string key_bwd_pd = key_ + "@softmax_bwd_pd";
bwd_pd_ =
std::static_pointer_cast<mkldnn::softmax_backward::primitive_desc>(
dev_ctx_.GetBlob(key_bwd_pd));
if (bwd_pd_ == nullptr) {
auto data_softmax_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto diff_softmax_md = mkldnn::memory::desc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
// TODO(jczaja): Add support for other axes
auto backward_desc = softmax_backward::desc(
diff_softmax_md, data_softmax_md, 1 /* dim: C*/);
bwd_pd_.reset(new mkldnn::softmax_backward::primitive_desc(
backward_desc, engine_, *fwd_pd_));
dev_ctx_.SetBlob(key_bwd_pd, bwd_pd_);
}
}
private:
std::vector<int> dims_;
MKLDNNMemoryFormat fmt_;
MKLDNNMemoryFormat diff_fmt_;
platform::Place place_;
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> fwd_pd_;
std::shared_ptr<mkldnn::softmax_backward::primitive_desc> bwd_pd_;
};
template <typename T>
......@@ -177,44 +204,25 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
const Tensor* input = ctx.Input<Tensor>("X");
Tensor* output = ctx.Output<Tensor>("Out");
PADDLE_ENFORCE_EQ(
input->dims(), output->dims(),
"The shape of softmax's input and output must be identical.");
// make sure 'output' holds memory, which will be shared by
// 'flattened_output' later.
output->mutable_data<T>(ctx.GetPlace());
// flatten input and output to 2-D matrixs
auto dims = input->dims(); // input and output share the same shape
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
framework::Tensor flattened_input;
framework::Tensor flattened_output;
flattened_input.ShareDataWith(*input).Resize(flattened_dims);
flattened_output.ShareDataWith(*output).Resize(flattened_dims);
const T* input_data = flattened_input.data<T>();
T* output_data = flattened_output.mutable_data<T>(ctx.GetPlace());
auto src_tz = paddle::framework::vectorize<int>(flattened_dims);
auto dst_tz = src_tz;
// Same memory descriptor to be used for input and output
memory::dims softmax_tz = {src_tz[0], src_tz[1]};
// Generate keys for storing/retriving primitives for this operator
const std::string key =
platform::GetHash(softmax_tz, ctx.op().Output("Out"));
SoftmaxMKLDNNHandler<T> handler(softmax_tz, MKLDNNMemoryFormat::nc, dev_ctx,
mkldnn_engine, key);
ctx.GetPlace(), ctx.op().Output("Out"));
// Currently only NC data format is supported
auto softmax_src_memory_p =
handler.AcquireSrcMemory(to_void_cast<T>(input_data));
auto softmax_dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
auto softmax_src_memory_p = handler.AcquireSrcMemory(input);
auto softmax_dst_memory_p = handler.AcquireDstMemory(output);
auto softmax_p =
handler.AcquireSoftmax(softmax_dst_memory_p, softmax_src_memory_p);
......@@ -222,6 +230,7 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
*(static_cast<softmax_forward::primitive*>(softmax_p.get()))};
stream(stream::kind::eager).submit(pipeline).wait();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
const bool is_test = ctx.Attr<bool>("is_test");
if (!is_test) {
T threshold = exp(-64);
......@@ -230,6 +239,10 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
output_data[i] < threshold ? threshold : output_data[i];
}
}
output->set_layout(framework::DataLayout::kMKLDNN);
// Softmax output format is the same as input one
output->set_format(input->format());
}
};
......@@ -241,7 +254,6 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
"It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
const Tensor* output = ctx.Input<Tensor>("Out");
auto* dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
auto* dx =
......@@ -251,52 +263,25 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
dout->dims(), dx->dims(),
"The shape of softmax_grad's input and output must be identical.");
// make sure 'dx' holds memory, which will be shared by 'flattened_dx'
// later.
dx->template mutable_data<T>(ctx.GetPlace());
auto dims = dout->dims(); // input and output share the same shape
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
framework::Tensor flattened_output;
framework::Tensor flattened_dout;
framework::Tensor flattened_dx;
flattened_output.ShareDataWith(*output).Resize(flattened_dims);
flattened_dout.ShareDataWith(*dout).Resize(flattened_dims);
flattened_dx.ShareDataWith(*dx).Resize(flattened_dims);
const T* dst_data = flattened_output.data<T>();
const T* diff_dst_ptr = flattened_dout.template data<T>();
T* diff_src_ptr = flattened_dx.template mutable_data<T>(ctx.GetPlace());
auto dst_tz = paddle::framework::vectorize<int>(flattened_dims);
auto src_tz(dst_tz);
std::vector<int> dst_tz = paddle::framework::vectorize<int>(flattened_dims);
std::vector<int> src_tz(dst_tz);
// Same memory descriptor to be used for input and output
memory::dims softmax_tz = {src_tz[0], src_tz[1]};
// Currently only supports NC data format
// retrieve eltwise primitive desc from device context
const std::string key =
platform::GetHash(softmax_tz, ctx.op().Input("Out"));
const std::string key_softmax_pd = key + "@softmax_pd";
auto softmax_pd =
std::static_pointer_cast<mkldnn::softmax_forward::primitive_desc>(
dev_ctx.GetBlob(key_softmax_pd));
PADDLE_ENFORCE(softmax_pd != nullptr,
"Fail to find softmax_pd in device context");
// TODO(jczaja): Add layouts support when there is a need to do so
// Two dimensional softmax does support NC format
// Normalization is made after innermost dimension eg. C out of NC
SoftmaxMKLDNNHandler<T> handler(softmax_tz, MKLDNNMemoryFormat::nc,
MKLDNNMemoryFormat::nc, dev_ctx,
mkldnn_engine, key);
ctx.GetPlace(), ctx.op().Input("Out"));
auto dst_memory_p = handler.AcquireDstMemory(to_void_cast<T>(dst_data));
auto diff_dst_memory_p =
handler.AcquireDiffDstMemory(to_void_cast<T>(diff_dst_ptr));
auto diff_src_memory_p =
handler.AcquireDiffSrcMemory(to_void_cast<T>(diff_src_ptr));
auto dst_memory_p = handler.AcquireDstMemory(output);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
// Get primitve from device context
auto softmax_bwd_p = handler.AcquireSoftmaxBackward(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册