提交 c981222b 编写于 作者: J Jacek Czaja

- Conv MKLDNN grad op reuse of mkldnn primitives

上级 f0cd493c
......@@ -18,9 +18,6 @@
namespace paddle {
namespace operators {
using conv_bwd_data = mkldnn::convolution_backward_data;
using conv_bwd_weights = mkldnn::convolution_backward_weights;
using conv_fwd = mkldnn::convolution_forward;
using framework::DataLayout;
using mkldnn::memory;
using mkldnn::primitive;
......@@ -39,6 +36,72 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
conv_pd_ = conv_pd;
}
ConvMKLDNNHandler(
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd,
std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc>
conv_bwd_data_pd,
std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc>
conv_bwd_weights_pd,
const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
conv_pd_(conv_pd),
conv_bwd_weights_pd_(conv_bwd_weights_pd),
conv_bwd_data_pd_(conv_bwd_data_pd) {
// If we are in Grad operatgor then update a key with BWD suffix to
// distinguish from FWD memory primitives
key_ += "-BWD";
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryFromWeightsPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) {
auto src_pd = conv_bwd_weights_pd_->src_primitive_desc();
auto user_pd = user_memory_p->get_primitive_desc();
return this->AcquireMemory(src_pd, user_pd, user_memory_p,
"@weights-src_mem_p", pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemoryFromWeightsPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) {
auto diff_dst_pd = conv_bwd_weights_pd_->diff_dst_primitive_desc();
auto user_pd = user_memory_p->get_primitive_desc();
return this->AcquireMemory(diff_dst_pd, user_pd, user_memory_p,
"@weights-diff_dst_mem_p", pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireDiffWeightsMemoryFromWeightsPrimitive(
void* ptr) {
return this->AcquireMemoryFromPrimitive(
conv_bwd_weights_pd_->diff_weights_primitive_desc(), ptr,
"@diff_weights_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemoryFromDataPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) {
auto diff_dst_pd = conv_bwd_data_pd_->diff_dst_primitive_desc();
auto user_pd = user_memory_p->get_primitive_desc();
return this->AcquireMemory(diff_dst_pd, user_pd, user_memory_p,
"@data-diff_dst_mem_p", pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromDataPrimitive(
const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
std::vector<mkldnn::primitive>& pipeline) {
auto weights_pd = conv_bwd_data_pd_->weights_primitive_desc();
auto user_pd = user_weights_memory_p->get_primitive_desc();
return this->AcquireMemory(weights_pd, user_pd, user_weights_memory_p,
"@data-weights_mem_p", pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemoryFromDataPrimitive(
void* ptr) {
return this->AcquireMemoryFromPrimitive(
conv_bwd_data_pd_->diff_src_primitive_desc(), ptr, "@diff_src_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(void* ptr) {
return this->AcquireMemoryFromPrimitive(conv_pd_->dst_primitive_desc(), ptr,
"@dst_mem_p");
......@@ -68,7 +131,6 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
std::shared_ptr<mkldnn::memory> weights_memory_p,
std::shared_ptr<mkldnn::memory> dst_memory_p) {
auto prim_key = key_ + "@conv_p";
auto prim_desc_key = key_ + "@conv_pd";
auto conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false),
......@@ -85,6 +147,54 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
return conv_p;
}
std::shared_ptr<mkldnn::convolution_backward_weights>
AcquireConvolutionBackwardWeights(
std::shared_ptr<mkldnn::memory> src_memory_p,
std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
std::shared_ptr<mkldnn::memory> diff_weights_memory_p) {
auto prim_key = key_ + "@conv_bwd_weights_p";
auto conv_bwd_weights_p =
std::static_pointer_cast<mkldnn::convolution_backward_weights>(
dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE(
(conv_bwd_weights_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution bwd weights primitive in device context");
if (conv_bwd_weights_p == nullptr) {
// create backward conv primitive for weights
conv_bwd_weights_p =
std::make_shared<mkldnn::convolution_backward_weights>(
*conv_bwd_weights_pd_, *src_memory_p, *diff_dst_memory_p,
*diff_weights_memory_p);
dev_ctx_.SetBlob(prim_key, conv_bwd_weights_p);
} else {
is_reusing_ = true;
}
return conv_bwd_weights_p;
}
std::shared_ptr<mkldnn::convolution_backward_data>
AcquireConvolutionBackwardData(
std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
std::shared_ptr<mkldnn::memory> weights_memory_p,
std::shared_ptr<mkldnn::memory> diff_src_memory_p) {
auto prim_key = key_ + "@conv_bwd_data_p";
auto conv_bwd_data_p =
std::static_pointer_cast<mkldnn::convolution_backward_data>(
dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE(
(conv_bwd_data_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution bwd data primitive in device context");
if (conv_bwd_data_p == nullptr) {
conv_bwd_data_p = std::make_shared<mkldnn::convolution_backward_data>(
*conv_bwd_data_pd_, *diff_dst_memory_p, *weights_memory_p,
*diff_src_memory_p);
dev_ctx_.SetBlob(prim_key, conv_bwd_data_p);
} else {
is_reusing_ = true;
}
return conv_bwd_data_p;
}
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Make hashing function more optimial
static std::string GetHash(memory::dims& input_dims,
......@@ -100,6 +210,10 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
private:
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd_;
std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc>
conv_bwd_weights_pd_;
std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc>
conv_bwd_data_pd_;
};
template <typename T>
......@@ -174,8 +288,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dst_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
// create a conv primitive descriptor and save it for usage in backward
std::shared_ptr<conv_fwd::primitive_desc> conv_pd = ConvFwdPrimitiveDesc(
src_md, weights_md, dst_md, strides, paddings, mkldnn_engine);
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine);
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd);
......@@ -208,21 +323,24 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
}
private:
std::unique_ptr<conv_fwd::primitive_desc> ConvFwdPrimitiveDesc(
const memory::desc& src, const memory::desc& weights,
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings, const mkldnn::engine& engine) const {
const std::vector<int>& paddings,
const mkldnn::engine& engine) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
auto conv_desc =
conv_fwd::desc(mkldnn::prop_kind::forward, mkldnn::convolution_direct,
src, weights, dst, stride_dims, padding_dims,
padding_dims, mkldnn::padding_kind::zero);
auto conv_desc = mkldnn::convolution_forward::desc(
mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights,
dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
auto p_conv_pd = new conv_fwd::primitive_desc(conv_desc, engine);
auto p_conv_pd =
new mkldnn::convolution_forward::primitive_desc(conv_desc, engine);
return std::unique_ptr<conv_fwd::primitive_desc>(p_conv_pd);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd);
}
};
......@@ -290,147 +408,108 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
dilations, groups, ctx.op().Input("Output"));
const std::string key_conv_pd = key + "@conv_pd";
std::vector<primitive> pipeline;
// create mkldnn memory from input tensors (input/weights/output_grad)
auto user_src_memory = memory(
{{{src_tz}, memory::data_type::f32, input->format()}, mkldnn_engine},
to_void_cast(input_data));
auto user_weights_memory =
memory({{{weights_tz}, memory::data_type::f32, filter->format()},
mkldnn_engine},
to_void_cast(filter_data));
auto user_diff_dst_memory =
memory({{{dst_tz}, memory::data_type::f32, output_grad->format()},
mkldnn_engine},
to_void_cast(output_grad_data));
// Create user memory descriptors
auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format());
auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<T>(), filter->format());
auto user_diff_dst_md = platform::MKLDNNMemDesc(
{dst_tz}, platform::MKLDNNGetDataType<T>(), output_grad->format());
/* create memory descriptor for conv backward without specified format
* ('any') which lets a primitive (conv backward in this case) choose
* the memory format preferred for best performance
*/
auto src_md = platform::MKLDNNMemDesc(src_tz, memory::data_type::f32,
memory::format::any);
auto diff_src_md = platform::MKLDNNMemDesc(src_tz, memory::data_type::f32,
memory::format::any);
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
auto diff_src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, memory::data_type::f32, memory::format::any);
weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
auto diff_weights_md = platform::MKLDNNMemDesc(
weights_tz, memory::data_type::f32, memory::format::any);
auto diff_dst_md = platform::MKLDNNMemDesc(dst_tz, memory::data_type::f32,
memory::format::any);
weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
auto diff_dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
// Retrieve conv_pd from device context
auto conv_pd = std::static_pointer_cast<conv_fwd::primitive_desc>(
auto conv_pd =
std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
dev_ctx.GetBlob(key_conv_pd));
PADDLE_ENFORCE(conv_pd != nullptr,
"Fail to find conv_pd in device context");
// create backward conv primitive for weights
if (filter_grad) {
// create backward convolution primitive descriptor
auto conv_bwd_weights_desc = conv_bwd_weights::desc(
// create backward convolution weights primitive descriptor
auto conv_bwd_weights_desc = mkldnn::convolution_backward_weights::desc(
mkldnn::convolution_direct, src_md, diff_weights_md, diff_dst_md,
strides, paddings, paddings, mkldnn::padding_kind::zero);
auto conv_bwd_weights_pd = conv_bwd_weights::primitive_desc(
auto conv_bwd_weights_pd =
std::make_shared<mkldnn::convolution_backward_weights::primitive_desc>(
conv_bwd_weights_desc, mkldnn_engine, *conv_pd);
// create reorder primitive if the input format is not the preferred one
auto src_memory = user_src_memory;
primitive reorder_src;
bool is_src_reordered = false;
if (memory::primitive_desc(conv_bwd_weights_pd.src_primitive_desc()) !=
user_src_memory.get_primitive_desc()) {
src_memory = memory(conv_bwd_weights_pd.src_primitive_desc());
reorder_src = reorder(user_src_memory, src_memory);
is_src_reordered = true;
}
// create backward convolution data primitive descriptor
auto conv_bwd_data_desc = mkldnn::convolution_backward_data::desc(
mkldnn::convolution_direct, diff_src_md, weights_md, diff_dst_md,
strides, paddings, paddings, mkldnn::padding_kind::zero);
auto conv_bwd_data_pd =
std::make_shared<mkldnn::convolution_backward_data::primitive_desc>(
conv_bwd_data_desc, mkldnn_engine, *conv_pd);
auto diff_dst_memory_4filter = user_diff_dst_memory;
primitive reorder_diff_dst_4filter;
bool is_diff_dst_reordered_4filter = false;
if (memory::primitive_desc(
conv_bwd_weights_pd.diff_dst_primitive_desc()) !=
user_diff_dst_memory.get_primitive_desc()) {
diff_dst_memory_4filter =
memory(conv_bwd_weights_pd.diff_dst_primitive_desc());
reorder_diff_dst_4filter =
reorder(user_diff_dst_memory, diff_dst_memory_4filter);
is_diff_dst_reordered_4filter = true;
}
ConvMKLDNNHandler handler(conv_pd, conv_bwd_data_pd, conv_bwd_weights_pd,
dev_ctx, mkldnn_engine, key);
// create mkldnn memory for output (i.e. diff weights)
auto diff_weights_memory =
memory(conv_bwd_weights_pd.diff_weights_primitive_desc(),
reinterpret_cast<void*>(filter_grad_data));
// create mkldnn memory from input tensors (data/weights)
auto user_src_memory_p =
handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
auto user_weights_memory_p = handler.AcquireWeightsMemory(
user_weights_md, to_void_cast<T>(filter_data));
auto user_diff_dst_memory_p = handler.AcquireDiffDstMemory(
user_diff_dst_md, to_void_cast<T>(output_grad_data));
// create backward conv primitive for weights
auto conv_bwd_weights_prim =
conv_bwd_weights(conv_bwd_weights_pd, src_memory,
diff_dst_memory_4filter, diff_weights_memory);
if (filter_grad) {
auto src_memory_p = handler.AcquireSrcMemoryFromWeightsPrimitive(
user_src_memory_p, pipeline);
// push primitive and execute it
std::vector<primitive> pipeline;
if (is_src_reordered) pipeline.push_back(reorder_src);
if (is_diff_dst_reordered_4filter)
pipeline.push_back(reorder_diff_dst_4filter);
pipeline.push_back(conv_bwd_weights_prim);
stream(stream::kind::eager).submit(pipeline).wait();
auto diff_dst_memory_4filter_p =
handler.AcquireDiffDstMemoryFromWeightsPrimitive(
user_diff_dst_memory_p, pipeline);
auto diff_weights_memory_p =
handler.AcquireDiffWeightsMemoryFromWeightsPrimitive(
reinterpret_cast<void*>(filter_grad_data));
auto conv_bwd_weights_p = handler.AcquireConvolutionBackwardWeights(
src_memory_p, diff_dst_memory_4filter_p, diff_weights_memory_p);
// push primitive to stream and wait until it's executed
pipeline.push_back(*conv_bwd_weights_p);
filter_grad->set_layout(DataLayout::kMKLDNN);
filter_grad->set_format(GetMKLDNNFormat(diff_weights_memory));
filter_grad->set_format(GetMKLDNNFormat(*diff_weights_memory_p));
}
if (input_grad) {
// create backward convolution primitive descriptor
auto conv_bwd_data_desc = conv_bwd_data::desc(
mkldnn::convolution_direct, diff_src_md, weights_md, diff_dst_md,
strides, paddings, paddings, mkldnn::padding_kind::zero);
auto conv_bwd_data_pd = conv_bwd_data::primitive_desc(
conv_bwd_data_desc, mkldnn_engine, *conv_pd);
// create reorder primitive if the input format is not the preferred one
auto weights_memory = user_weights_memory;
primitive reorder_weights;
bool is_weights_reordered = false;
if (memory::primitive_desc(conv_bwd_data_pd.weights_primitive_desc()) !=
user_weights_memory.get_primitive_desc()) {
weights_memory = memory(conv_bwd_data_pd.weights_primitive_desc());
reorder_weights = reorder(user_weights_memory, weights_memory);
is_weights_reordered = true;
}
auto weights_memory_p = handler.AcquireWeightsMemoryFromDataPrimitive(
user_weights_memory_p, pipeline);
auto diff_dst_memory_4data = user_diff_dst_memory;
primitive reorder_diff_dst_4data;
bool is_diff_dst_reordered_4data = false;
if (memory::primitive_desc(conv_bwd_data_pd.diff_dst_primitive_desc()) !=
user_diff_dst_memory.get_primitive_desc()) {
diff_dst_memory_4data =
memory(conv_bwd_data_pd.diff_dst_primitive_desc());
reorder_diff_dst_4data =
reorder(user_diff_dst_memory, diff_dst_memory_4data);
is_diff_dst_reordered_4data = true;
}
auto diff_dst_memory_4data_p =
handler.AcquireDiffDstMemoryFromDataPrimitive(user_diff_dst_memory_p,
pipeline);
// create mkldnn memory for output (i.e. diff src)
auto diff_src_memory = memory(conv_bwd_data_pd.diff_src_primitive_desc(),
auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive(
reinterpret_cast<void*>(input_grad_data));
// create backward conv primitive for data
auto conv_bwd_data_prim =
conv_bwd_data(conv_bwd_data_pd, diff_dst_memory_4data, weights_memory,
diff_src_memory);
auto conv_bwd_data_p = handler.AcquireConvolutionBackwardData(
diff_dst_memory_4data_p, weights_memory_p, diff_src_memory_p);
// push primitive and execute it
std::vector<primitive> pipeline;
if (is_weights_reordered) pipeline.push_back(reorder_weights);
if (is_diff_dst_reordered_4data)
pipeline.push_back(reorder_diff_dst_4data);
pipeline.push_back(conv_bwd_data_prim);
stream(stream::kind::eager).submit(pipeline).wait();
pipeline.push_back(*conv_bwd_data_p);
input_grad->set_layout(DataLayout::kMKLDNN);
input_grad->set_format(GetMKLDNNFormat(diff_src_memory));
input_grad->set_format(GetMKLDNNFormat(*diff_src_memory_p));
}
stream(stream::kind::eager).submit(pipeline).wait();
} // Compute()
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部