提交 9908d3cf 编写于 作者: M mozga-intel

MKLDNN layout: Support for convolution operator

上级 b7c683b8
...@@ -18,6 +18,17 @@ ...@@ -18,6 +18,17 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using conv_bwd_data = mkldnn::convolution_backward_data;
using conv_bwd_weights = mkldnn::convolution_backward_weights;
using conv_fwd = mkldnn::convolution_forward;
using framework::DataLayout;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::reorder;
using mkldnn::stream;
using platform::to_void_cast;
using platform::GetMKLDNNFormat;
template <typename T> template <typename T>
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public: public:
...@@ -25,6 +36,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -25,6 +36,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace."); "It must use CPUPlace.");
// Get unique name for index
const std::string key = ctx.op().Output("Output");
const std::string key_conv_pd = key + "@conv_pd";
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>(); ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
...@@ -33,10 +48,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -33,10 +48,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto* filter = ctx.Input<Tensor>("Filter"); auto* filter = ctx.Input<Tensor>("Filter");
auto* output = ctx.Output<Tensor>("Output"); auto* output = ctx.Output<Tensor>("Output");
// Get an unique name from "argument" name of "Output" variable PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN &&
// This name will be used as key when saving info into device context input->format() != memory::format::format_undef,
const std::string key = ctx.op().Output("Output"); "Wrong layout/format set for Input tensor");
const std::string key_conv_pd = key + "@conv_pd"; PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN &&
filter->format() != memory::format::format_undef,
"Wrong layout/format set for Filter tensor");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
...@@ -63,60 +80,86 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -63,60 +80,86 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
paddle::framework::vectorize2int(filter->dims()); paddle::framework::vectorize2int(filter->dims());
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
// TODO(pzelazko-intel): support more formats // create mkldnn memory from input tensors (data/weights)
auto src_md = platform::MKLDNNMemDesc( auto user_src_memory = memory(
src_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); {{{src_tz}, memory::data_type::f32, input->format()}, mkldnn_engine},
auto weights_md = to_void_cast(input_data));
platform::MKLDNNMemDesc(weights_tz, mkldnn::memory::data_type::f32, auto user_weights_memory =
mkldnn::memory::format::oihw); memory({{{weights_tz}, memory::data_type::f32, filter->format()},
auto dst_md = platform::MKLDNNMemDesc( mkldnn_engine},
dst_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); to_void_cast(filter_data));
auto src_memory = /* create memory descriptor for convolution without specified format
mkldnn::memory({src_md, mkldnn_engine}, * ('any') which lets a primitive (convolution in this case) choose
reinterpret_cast<void*>(const_cast<T*>(input_data))); * the memory format preferred for best performance
auto weights_memory = */
mkldnn::memory({weights_md, mkldnn_engine}, auto src_md = platform::MKLDNNMemDesc(src_tz, memory::data_type::f32,
reinterpret_cast<void*>(const_cast<T*>(filter_data))); memory::format::any);
auto dst_memory = mkldnn::memory({dst_md, mkldnn_engine}, output_data); auto weights_md = platform::MKLDNNMemDesc(
weights_tz, memory::data_type::f32, memory::format::any);
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd = auto dst_md = platform::MKLDNNMemDesc(dst_tz, memory::data_type::f32,
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, memory::format::any);
mkldnn_engine);
// create a conv primitive descriptor and save it for usage in backward
// save conv_pd into global device context to be referred in backward path std::shared_ptr<conv_fwd::primitive_desc> conv_pd = ConvFwdPrimitiveDesc(
dev_ctx.SetBlob(key_conv_pd, conv_pd); src_md, weights_md, dst_md, strides, paddings, mkldnn_engine);
// create reorder primitive if the input format is not the preferred one
auto src_memory = user_src_memory;
primitive reorder_src;
bool is_src_reordered = false;
if (memory::primitive_desc(conv_pd->src_primitive_desc()) !=
user_src_memory.get_primitive_desc()) {
src_memory = memory(conv_pd->src_primitive_desc());
reorder_src = reorder(user_src_memory, src_memory);
is_src_reordered = true;
}
auto weights_memory = user_weights_memory;
primitive reorder_weights;
bool is_weights_reordered = false;
if (memory::primitive_desc(conv_pd->weights_primitive_desc()) !=
user_weights_memory.get_primitive_desc()) {
weights_memory = memory(conv_pd->weights_primitive_desc());
reorder_weights = reorder(user_weights_memory, weights_memory);
is_weights_reordered = true;
}
// create memory primitive for conv dst
auto dst_memory = memory(conv_pd->dst_primitive_desc(), output_data);
// create convolution op primitive // create convolution op primitive
auto conv_prim = mkldnn::convolution_forward(*conv_pd, src_memory, auto conv_prim = conv_fwd(*conv_pd, src_memory, weights_memory, dst_memory);
weights_memory, dst_memory);
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline{conv_prim}; std::vector<primitive> pipeline;
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); if (is_src_reordered) pipeline.push_back(reorder_src);
if (is_weights_reordered) pipeline.push_back(reorder_weights);
pipeline.push_back(conv_prim);
stream(stream::kind::eager).submit(pipeline).wait();
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd);
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(dst_memory));
} }
private: private:
std::unique_ptr<mkldnn::convolution_forward::primitive_desc> std::unique_ptr<conv_fwd::primitive_desc> ConvFwdPrimitiveDesc(
ConvFwdPrimitiveDesc(const mkldnn::memory::desc& src, const memory::desc& src, const memory::desc& weights,
const mkldnn::memory::desc& weights, const memory::desc& dst, const std::vector<int>& strides,
const mkldnn::memory::desc& dst, const std::vector<int>& paddings, const mkldnn::engine& engine) const {
const std::vector<int>& strides, memory::dims stride_dims = {strides[0], strides[1]};
const std::vector<int>& paddings, memory::dims padding_dims = {paddings[0], paddings[1]};
const mkldnn::engine& engine) const {
mkldnn::memory::dims stride_dims = {strides[0], strides[1]}; auto conv_desc =
mkldnn::memory::dims padding_dims = {paddings[0], paddings[1]}; conv_fwd::desc(mkldnn::prop_kind::forward, mkldnn::convolution_direct,
src, weights, dst, stride_dims, padding_dims,
auto conv_desc = mkldnn::convolution_forward::desc( padding_dims, mkldnn::padding_kind::zero);
mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights,
dst, stride_dims, padding_dims, padding_dims, auto p_conv_pd = new conv_fwd::primitive_desc(conv_desc, engine);
mkldnn::padding_kind::zero);
return std::unique_ptr<conv_fwd::primitive_desc>(p_conv_pd);
auto p_conv_pd =
new mkldnn::convolution_forward::primitive_desc(conv_desc, engine);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd);
} }
}; };
...@@ -139,6 +182,19 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -139,6 +182,19 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input")); Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter")); Tensor* filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));
PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN &&
input->format() != memory::format::format_undef,
"Wrong layout/format set for Input tensor");
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN &&
filter->format() != memory::format::format_undef,
"Wrong layout/format set for Filter tensor");
PADDLE_ENFORCE(output->layout() == DataLayout::kMKLDNN &&
output->format() != memory::format::format_undef,
"Wrong layout/format set for Output tensor");
PADDLE_ENFORCE(output_grad->layout() == DataLayout::kMKLDNN &&
output_grad->format() != memory::format::format_undef,
"Wrong layout/format set for output_grad tensor");
if (!input_grad && !filter_grad) return; if (!input_grad && !filter_grad) return;
// Get an unique name from "argument" name of "Output" variable // Get an unique name from "argument" name of "Output" variable
...@@ -167,108 +223,147 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -167,108 +223,147 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
paddle::framework::vectorize2int(filter->dims()); paddle::framework::vectorize2int(filter->dims());
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
// TODO(pzelazko-intel): support more formats // create mkldnn memory from input tensors (input/weights/output_grad)
auto src_md = platform::MKLDNNMemDesc( auto user_src_memory = memory(
src_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); {{{src_tz}, memory::data_type::f32, input->format()}, mkldnn_engine},
auto diff_src_md = platform::MKLDNNMemDesc( to_void_cast(input_data));
src_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); auto user_weights_memory =
auto weights_md = memory({{{weights_tz}, memory::data_type::f32, filter->format()},
platform::MKLDNNMemDesc(weights_tz, mkldnn::memory::data_type::f32, mkldnn_engine},
mkldnn::memory::format::oihw); to_void_cast(filter_data));
auto diff_weights_md = auto user_diff_dst_memory =
platform::MKLDNNMemDesc(weights_tz, mkldnn::memory::data_type::f32, memory({{{dst_tz}, memory::data_type::f32, output_grad->format()},
mkldnn::memory::format::oihw); mkldnn_engine},
auto diff_dst_md = platform::MKLDNNMemDesc( to_void_cast(output_grad_data));
dst_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
/* create memory descriptor for conv backward without specified format
// create memory * ('any') which lets a primitive (conv backward in this case) choose
auto diff_dst_memory = mkldnn::memory( * the memory format preferred for best performance
{diff_weights_md, mkldnn_engine}, */
reinterpret_cast<void*>(const_cast<T*>(output_grad_data))); auto src_md = platform::MKLDNNMemDesc(src_tz, memory::data_type::f32,
memory::format::any);
auto diff_src_md = platform::MKLDNNMemDesc(src_tz, memory::data_type::f32,
memory::format::any);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, memory::data_type::f32, memory::format::any);
auto diff_weights_md = platform::MKLDNNMemDesc(
weights_tz, memory::data_type::f32, memory::format::any);
auto diff_dst_md = platform::MKLDNNMemDesc(dst_tz, memory::data_type::f32,
memory::format::any);
// Retrieve conv_pd from device context // Retrieve conv_pd from device context
auto conv_pd = auto conv_pd = std::static_pointer_cast<conv_fwd::primitive_desc>(
std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>( dev_ctx.GetBlob(key_conv_pd));
dev_ctx.GetBlob(key_conv_pd));
PADDLE_ENFORCE(conv_pd != nullptr, PADDLE_ENFORCE(conv_pd != nullptr,
"Fail to find conv_pd in device context"); "Fail to find conv_pd in device context");
// create backward conv primitive for weights // create backward conv primitive for weights
if (filter_grad) { if (filter_grad) {
// create primitive descriptor // create backward convolution primitive descriptor
mkldnn::convolution_backward_weights::primitive_desc conv_bwd_weights_pd = auto conv_bwd_weights_desc = conv_bwd_weights::desc(
ConvBwdWeightsPrimitiveDesc(src_md, diff_weights_md, diff_dst_md, mkldnn::convolution_direct, src_md, diff_weights_md, diff_dst_md,
strides, paddings, *conv_pd, strides, paddings, paddings, mkldnn::padding_kind::zero);
mkldnn_engine); auto conv_bwd_weights_pd = conv_bwd_weights::primitive_desc(
conv_bwd_weights_desc, mkldnn_engine, *conv_pd);
// create memory
// create reorder primitive if the input format is not the preferred one
auto src_memory = user_src_memory;
primitive reorder_src;
bool is_src_reordered = false;
if (memory::primitive_desc(conv_bwd_weights_pd.src_primitive_desc()) !=
user_src_memory.get_primitive_desc()) {
src_memory = memory(conv_bwd_weights_pd.src_primitive_desc());
reorder_src = reorder(user_src_memory, src_memory);
is_src_reordered = true;
}
auto diff_dst_memory_4filter = user_diff_dst_memory;
primitive reorder_diff_dst_4filter;
bool is_diff_dst_reordered_4filter = false;
if (memory::primitive_desc(
conv_bwd_weights_pd.diff_dst_primitive_desc()) !=
user_diff_dst_memory.get_primitive_desc()) {
diff_dst_memory_4filter =
memory(conv_bwd_weights_pd.diff_dst_primitive_desc());
reorder_diff_dst_4filter =
reorder(user_diff_dst_memory, diff_dst_memory_4filter);
is_diff_dst_reordered_4filter = true;
}
// create mkldnn memory for output (i.e. diff weights)
auto diff_weights_memory = auto diff_weights_memory =
mkldnn::memory({diff_weights_md, mkldnn_engine}, memory(conv_bwd_weights_pd.diff_weights_primitive_desc(),
reinterpret_cast<void*>(filter_grad_data)); reinterpret_cast<void*>(filter_grad_data));
auto src_memory =
mkldnn::memory({src_md, mkldnn_engine},
reinterpret_cast<void*>(const_cast<T*>(input_data)));
// create backward conv primitive for weights // create backward conv primitive for weights
auto conv_bwd_weights_prim = mkldnn::convolution_backward_weights( auto conv_bwd_weights_prim =
conv_bwd_weights_pd, src_memory, diff_dst_memory, conv_bwd_weights(conv_bwd_weights_pd, src_memory,
diff_weights_memory); diff_dst_memory_4filter, diff_weights_memory);
// push primitive and execute it // push primitive and execute it
std::vector<mkldnn::primitive> pipeline{conv_bwd_weights_prim}; std::vector<primitive> pipeline;
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); if (is_src_reordered) pipeline.push_back(reorder_src);
if (is_diff_dst_reordered_4filter)
pipeline.push_back(reorder_diff_dst_4filter);
pipeline.push_back(conv_bwd_weights_prim);
stream(stream::kind::eager).submit(pipeline).wait();
filter_grad->set_layout(DataLayout::kMKLDNN);
filter_grad->set_format(GetMKLDNNFormat(diff_weights_memory));
} }
if (input_grad) { if (input_grad) {
// create primitive descriptor // create backward convolution primitive descriptor
mkldnn::convolution_backward_data::primitive_desc conv_bwd_data_pd = auto conv_bwd_data_desc = conv_bwd_data::desc(
ConvBwdDataPrimitiveDesc(diff_src_md, weights_md, diff_dst_md, mkldnn::convolution_direct, diff_src_md, weights_md, diff_dst_md,
strides, paddings, *conv_pd, mkldnn_engine); strides, paddings, paddings, mkldnn::padding_kind::zero);
auto conv_bwd_data_pd = conv_bwd_data::primitive_desc(
// create memory conv_bwd_data_desc, mkldnn_engine, *conv_pd);
auto diff_src_memory = mkldnn::memory(
{diff_src_md, mkldnn_engine}, // create reorder primitive if the input format is not the preferred one
reinterpret_cast<void*>(const_cast<T*>(input_grad_data))); auto weights_memory = user_weights_memory;
auto weights_memory = primitive reorder_weights;
mkldnn::memory({weights_md, mkldnn_engine}, bool is_weights_reordered = false;
reinterpret_cast<void*>(const_cast<T*>(filter_data))); if (memory::primitive_desc(conv_bwd_data_pd.weights_primitive_desc()) !=
user_weights_memory.get_primitive_desc()) {
weights_memory = memory(conv_bwd_data_pd.weights_primitive_desc());
reorder_weights = reorder(user_weights_memory, weights_memory);
is_weights_reordered = true;
}
auto diff_dst_memory_4data = user_diff_dst_memory;
primitive reorder_diff_dst_4data;
bool is_diff_dst_reordered_4data = false;
if (memory::primitive_desc(conv_bwd_data_pd.diff_dst_primitive_desc()) !=
user_diff_dst_memory.get_primitive_desc()) {
diff_dst_memory_4data =
memory(conv_bwd_data_pd.diff_dst_primitive_desc());
reorder_diff_dst_4data =
reorder(user_diff_dst_memory, diff_dst_memory_4data);
is_diff_dst_reordered_4data = true;
}
// create mkldnn memory for output (i.e. diff src)
auto diff_src_memory = memory(conv_bwd_data_pd.diff_src_primitive_desc(),
reinterpret_cast<void*>(input_grad_data));
// create backward conv primitive for data // create backward conv primitive for data
auto conv_bwd_data_prim = mkldnn::convolution_backward_data( auto conv_bwd_data_prim =
conv_bwd_data_pd, diff_dst_memory, weights_memory, diff_src_memory); conv_bwd_data(conv_bwd_data_pd, diff_dst_memory_4data, weights_memory,
diff_src_memory);
// push primitive to stream and wait until it's executed // push primitive and execute it
std::vector<mkldnn::primitive> pipeline{conv_bwd_data_prim}; std::vector<primitive> pipeline;
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); if (is_weights_reordered) pipeline.push_back(reorder_weights);
if (is_diff_dst_reordered_4data)
pipeline.push_back(reorder_diff_dst_4data);
pipeline.push_back(conv_bwd_data_prim);
stream(stream::kind::eager).submit(pipeline).wait();
input_grad->set_layout(DataLayout::kMKLDNN);
input_grad->set_format(GetMKLDNNFormat(diff_src_memory));
} }
} // Compute() } // Compute()
private:
mkldnn::convolution_backward_weights::primitive_desc
ConvBwdWeightsPrimitiveDesc(
const mkldnn::memory::desc& src, const mkldnn::memory::desc& diff_weights,
const mkldnn::memory::desc& diff_dst, const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::convolution_forward::primitive_desc& conv_pd,
const mkldnn::engine& engine) const {
auto conv_bwd_weights_desc = mkldnn::convolution_backward_weights::desc(
mkldnn::convolution_direct, src, diff_weights, diff_dst, strides,
paddings, paddings, mkldnn::padding_kind::zero);
return mkldnn::convolution_backward_weights::primitive_desc(
conv_bwd_weights_desc, engine, conv_pd);
}
mkldnn::convolution_backward_data::primitive_desc ConvBwdDataPrimitiveDesc(
const mkldnn::memory::desc& diff_src, const mkldnn::memory::desc& weights,
const mkldnn::memory::desc& diff_dst, const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::convolution_forward::primitive_desc& conv_pd,
const mkldnn::engine& engine) const {
auto conv_bwd_data_desc = mkldnn::convolution_backward_data::desc(
mkldnn::convolution_direct, diff_src, weights, diff_dst, strides,
paddings, paddings, mkldnn::padding_kind::zero);
return mkldnn::convolution_backward_data::primitive_desc(conv_bwd_data_desc,
engine, conv_pd);
}
}; };
} // namespace operators } // namespace operators
......
...@@ -75,9 +75,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -75,9 +75,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOp::GetExpectedKernelType( framework::OpKernelType ConvOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library{framework::LibraryType::kPlain}; framework::LibraryType library{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout = framework::StringToDataLayout(data_format); framework::DataLayout layout = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册