提交 36031cb5 编写于 作者: M mozga-intel

MKLDNN layout: Support for pool operator

上级 b7c683b8
...@@ -18,9 +18,14 @@ limitations under the License. */ ...@@ -18,9 +18,14 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using mkldnn::memory; // Note: paddle has also "memory" namespace using framework::DataLayout;
using mkldnn::pooling_forward; using mkldnn::memory;
using mkldnn::pooling_backward; using mkldnn::pooling_backward;
using mkldnn::pooling_forward;
using mkldnn::primitive;
using mkldnn::reorder;
using mkldnn::stream;
using platform::to_void_cast;
// Generate keys for storing/retriving primitives for this operator // Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Make hashing function more optimial // TODO(jczaja): Make hashing function more optimial
...@@ -55,8 +60,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -55,8 +60,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const Tensor* input = ctx.Input<Tensor>("X"); const Tensor* input = ctx.Input<Tensor>("X");
Tensor* output = ctx.Output<Tensor>("Out"); Tensor* output = ctx.Output<Tensor>("Out");
// Get an unique name from "argument" name of "Out" variable PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN &&
// This name will be used as key when saving info into device context input->format() != memory::format::format_undef,
"Wrong layout/format set for Input tensor");
std::string pooling_type = ctx.Attr<std::string>("pooling_type"); std::string pooling_type = ctx.Attr<std::string>("pooling_type");
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
...@@ -82,6 +88,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -82,6 +88,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims()); std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
auto input_format = input->format();
memory::format output_format{memory::format::format_undef};
const std::string key = gethash(src_tz, pooling_type, ksize, strides, const std::string key = gethash(src_tz, pooling_type, ksize, strides,
paddings, ctx.op().Output("Out")); paddings, ctx.op().Output("Out"));
const std::string key_pool_p = key + "@pool_p"; const std::string key_pool_p = key + "@pool_p";
...@@ -94,16 +103,17 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -94,16 +103,17 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto pool_p = auto pool_p =
std::static_pointer_cast<pooling_forward>(dev_ctx.GetBlob(key_pool_p)); std::static_pointer_cast<pooling_forward>(dev_ctx.GetBlob(key_pool_p));
if (pool_p == nullptr) { if (pool_p == nullptr) {
// TODO(pzelazko-intel): support more formats auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), input_format);
auto src_md = /* create memory descriptor for pooling without specified format
platform::MKLDNNMemDesc(src_tz, platform::MKLDNNGetDataType<T>(), * ('any') which lets a primitive (pooling in this case) choose
mkldnn::memory::format::nchw); * the memory format preferred for best performance
auto dst_md = */
platform::MKLDNNMemDesc(dst_tz, platform::MKLDNNGetDataType<T>(), auto dst_md = platform::MKLDNNMemDesc(dst_tz, mkldnn::memory::f32,
mkldnn::memory::format::nchw); mkldnn::memory::format::any);
std::shared_ptr<pooling_forward::primitive_desc> pool_pd = std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd =
CreatePrimitiveDesc(src_md, dst_md, strides, paddings, ksize, CreatePrimitiveDesc(src_md, dst_md, strides, paddings, ksize,
pooling_type, mkldnn_engine); pooling_type, mkldnn_engine);
...@@ -116,20 +126,22 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -116,20 +126,22 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// save pool_workspace_memory to be referred in backward path // save pool_workspace_memory to be referred in backward path
dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory); dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory);
auto pool_src_memory_p = std::make_shared<memory>( auto src_memory = std::make_shared<memory>(pool_pd->src_primitive_desc(),
memory::primitive_desc{src_md, mkldnn_engine}, to_void_cast<T>(input_data));
static_cast<void*>(const_cast<T*>(input_data))); auto dst_memory =
dev_ctx.SetBlob(key_pool_src_mem_p, pool_src_memory_p); std::make_shared<memory>(pool_pd->dst_primitive_desc(), output_data);
auto pool_dst_memory_p = std::make_shared<memory>( dev_ctx.SetBlob(key_pool_src_mem_p, src_memory);
memory::primitive_desc{dst_md, mkldnn_engine}, dev_ctx.SetBlob(key_pool_dst_mem_p, dst_memory);
static_cast<void*>(output_data));
dev_ctx.SetBlob(key_pool_dst_mem_p, pool_dst_memory_p); pool_p = std::make_shared<pooling_forward>(*pool_pd, *(src_memory.get()),
*(dst_memory.get()),
*workspace_memory);
pool_p = std::make_shared<pooling_forward>(
*pool_pd, *(pool_src_memory_p.get()), *(pool_dst_memory_p.get()),
*workspace_memory);
dev_ctx.SetBlob(key_pool_p, pool_p); dev_ctx.SetBlob(key_pool_p, pool_p);
output_format =
(memory::format)dst_memory->get_primitive_desc().desc().data.format;
} else { } else {
// Primitives already exist // Primitives already exist
auto pool_src_memory_p = auto pool_src_memory_p =
...@@ -140,14 +152,20 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -140,14 +152,20 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_dst_mem_p)); std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_dst_mem_p));
PADDLE_ENFORCE(pool_dst_memory_p != nullptr, PADDLE_ENFORCE(pool_dst_memory_p != nullptr,
"Fail to find pooling dst mem_p in device context"); "Fail to find pooling dst mem_p in device context");
pool_src_memory_p->set_data_handle( pool_src_memory_p->set_data_handle(to_void_cast<T>(input_data));
reinterpret_cast<void*>(const_cast<T*>(input_data)));
pool_dst_memory_p->set_data_handle(output_data); pool_dst_memory_p->set_data_handle(output_data);
output_format = (memory::format)pool_dst_memory_p->get_primitive_desc()
.desc()
.data.format;
} }
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline{*(pool_p.get())}; std::vector<mkldnn::primitive> pipeline{*(pool_p.get())};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN);
output->set_format(output_format);
} }
private: private:
...@@ -194,6 +212,13 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -194,6 +212,13 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); Tensor* in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
PADDLE_ENFORCE(in_x->layout() == DataLayout::kMKLDNN &&
in_x->format() != memory::format::format_undef,
"Wrong layout/format set for Input X tensor");
PADDLE_ENFORCE(out_grad->layout() == DataLayout::kMKLDNN &&
out_grad->format() != memory::format::format_undef,
"Wrong layout/format set for Input output_grad tensor");
std::string pooling_type = ctx.Attr<std::string>("pooling_type"); std::string pooling_type = ctx.Attr<std::string>("pooling_type");
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
...@@ -212,6 +237,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -212,6 +237,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const T* out_grad_data = out_grad->data<T>(); const T* out_grad_data = out_grad->data<T>();
T* in_x_grad_data = in_x_grad->mutable_data<T>(ctx.GetPlace()); T* in_x_grad_data = in_x_grad->mutable_data<T>(ctx.GetPlace());
memory::format in_x_grad_format{memory::format::format_undef};
std::vector<int> diff_src_tz = std::vector<int> diff_src_tz =
paddle::framework::vectorize2int(in_x_grad->dims()); paddle::framework::vectorize2int(in_x_grad->dims());
...@@ -225,39 +251,48 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -225,39 +251,48 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const std::string key_pool_bwd_p = key + "@pool_bwd_p"; const std::string key_pool_bwd_p = key + "@pool_bwd_p";
const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p"; const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p";
const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p"; const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p";
const std::string key_pool_src_mem_p = key + "@pool_src_mem_p";
const std::string key_pool_dst_mem_p = key + "@pool_dst_mem_p";
const std::string key_pool_pd = key + "@pool_pd"; const std::string key_pool_pd = key + "@pool_pd";
const std::string key_pool_workspace_memory = const std::string key_pool_workspace_memory =
key + "@pool_workspace_memory"; key + "@pool_workspace_memory";
auto user_diff_dst_memory =
memory({{{diff_dst_tz}, memory::data_type::f32, out_grad->format()},
mkldnn_engine},
to_void_cast<T>(out_grad_data));
std::shared_ptr<memory> diff_src_memory;
std::shared_ptr<memory> diff_dst_memory;
auto dst_memory =
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_dst_mem_p));
PADDLE_ENFORCE(dst_memory != nullptr,
"Fail to find dst_memory in device context");
primitive reorder_diff_dst;
bool is_diff_dst_reordered = false;
auto pool_bwd_p = std::static_pointer_cast<pooling_backward>( auto pool_bwd_p = std::static_pointer_cast<pooling_backward>(
dev_ctx.GetBlob(key_pool_bwd_p)); dev_ctx.GetBlob(key_pool_bwd_p));
if (pool_bwd_p == nullptr) { if (pool_bwd_p == nullptr) {
auto diff_src_md = // Retrieve src_memory/dst_memory saved in forward pass
platform::MKLDNNMemDesc(diff_src_tz, platform::MKLDNNGetDataType<T>(), auto src_memory =
mkldnn::memory::format::nchw); std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_src_mem_p));
auto diff_dst_md = PADDLE_ENFORCE(src_memory != nullptr,
platform::MKLDNNMemDesc(diff_dst_tz, platform::MKLDNNGetDataType<T>(), "Fail to find src_memory in device context");
mkldnn::memory::format::nchw);
// Retrieve pool_pd/pool_workspace_memory from device context // Retrieve pool_pd/pool_workspace_memory from device context
auto pool_pd = auto pool_pd =
std::static_pointer_cast<mkldnn::pooling_forward::primitive_desc>( std::static_pointer_cast<mkldnn::pooling_forward::primitive_desc>(
dev_ctx.GetBlob(key_pool_pd)); dev_ctx.GetBlob(key_pool_pd));
PADDLE_ENFORCE(pool_pd != nullptr, PADDLE_ENFORCE(pool_pd != nullptr,
"Fail to find pool_pd in device context"); "Fail to find pool_pd in device context");
auto workspace_memory = std::static_pointer_cast<memory>(
auto workspace_memory = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_pool_workspace_memory)); dev_ctx.GetBlob(key_pool_workspace_memory));
PADDLE_ENFORCE(workspace_memory != nullptr, PADDLE_ENFORCE(workspace_memory != nullptr,
"Fail to find workspace_memory in device context"); "Fail to find workspace_memory in device context");
auto pool_diff_src_memory_p = std::make_shared<memory>(memory( // create memory descriptors for pooling
{diff_src_md, mkldnn_engine}, static_cast<void*>(in_x_grad_data))); auto diff_src_md = src_memory.get()->get_primitive_desc().desc();
dev_ctx.SetBlob(key_pool_diff_src_mem_p, pool_diff_src_memory_p); auto diff_dst_md = dst_memory.get()->get_primitive_desc().desc();
auto pool_diff_dst_memory_p = std::make_shared<memory>(
memory({diff_dst_md, mkldnn_engine},
static_cast<void*>(const_cast<T*>(out_grad_data))));
dev_ctx.SetBlob(key_pool_diff_dst_mem_p, pool_diff_dst_memory_p);
auto pool_bwd_desc = mkldnn::pooling_backward::desc( auto pool_bwd_desc = mkldnn::pooling_backward::desc(
pooling_type == "max" ? mkldnn::algorithm::pooling_max pooling_type == "max" ? mkldnn::algorithm::pooling_max
...@@ -267,35 +302,74 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -267,35 +302,74 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto pool_bwd_pd = mkldnn::pooling_backward::primitive_desc( auto pool_bwd_pd = mkldnn::pooling_backward::primitive_desc(
pool_bwd_desc, mkldnn_engine, *pool_pd); pool_bwd_desc, mkldnn_engine, *pool_pd);
// reorder between user_diff_dst and pool diff_dst if needed
diff_dst_memory = std::make_shared<memory>(user_diff_dst_memory);
if (memory::primitive_desc(dst_memory->get_primitive_desc()) !=
user_diff_dst_memory.get_primitive_desc()) {
diff_dst_memory =
std::make_shared<memory>(dst_memory.get()->get_primitive_desc());
reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory);
is_diff_dst_reordered = true;
}
diff_src_memory = std::make_shared<memory>(
pool_bwd_pd.diff_src_primitive_desc(), in_x_grad_data);
dev_ctx.SetBlob(key_pool_diff_src_mem_p, diff_src_memory);
dev_ctx.SetBlob(key_pool_diff_dst_mem_p, diff_dst_memory);
pool_bwd_p = std::make_shared<pooling_backward>( pool_bwd_p = std::make_shared<pooling_backward>(
pool_bwd_pd, *(pool_diff_dst_memory_p.get()), *workspace_memory, pool_bwd_pd, *(diff_dst_memory.get()), *workspace_memory,
*(pool_diff_src_memory_p)); *(diff_src_memory));
dev_ctx.SetBlob(key_pool_bwd_p, pool_bwd_p); dev_ctx.SetBlob(key_pool_bwd_p, pool_bwd_p);
} else { } else {
// Primitives already exist // Primitives already exist
auto pool_diff_src_memory_p = std::static_pointer_cast<memory>( diff_src_memory = std::static_pointer_cast<memory>(
dev_ctx.GetBlob(key_pool_diff_src_mem_p)); dev_ctx.GetBlob(key_pool_diff_src_mem_p));
PADDLE_ENFORCE(pool_diff_src_memory_p != nullptr, PADDLE_ENFORCE(diff_src_memory != nullptr,
"Fail to find pooling src mem_p in device context"); "Fail to find pooling src mem_p in device context");
auto pool_diff_dst_memory_p = std::static_pointer_cast<memory>( diff_dst_memory = std::static_pointer_cast<memory>(
dev_ctx.GetBlob(key_pool_diff_dst_mem_p)); dev_ctx.GetBlob(key_pool_diff_dst_mem_p));
PADDLE_ENFORCE(pool_diff_dst_memory_p != nullptr, PADDLE_ENFORCE(diff_dst_memory != nullptr,
"Fail to find pooling dst mem_p in device context"); "Fail to find pooling dst mem_p in device context");
pool_diff_src_memory_p->set_data_handle(
reinterpret_cast<void*>(in_x_grad_data)); diff_src_memory->set_data_handle(reinterpret_cast<void*>(in_x_grad_data));
pool_diff_dst_memory_p->set_data_handle(const_cast<T*>(out_grad_data)); diff_dst_memory->set_data_handle(const_cast<T*>(out_grad_data));
// reorder between user_diff_dst and pool diff_dst if needed
if (memory::primitive_desc(dst_memory->get_primitive_desc()) !=
user_diff_dst_memory.get_primitive_desc()) {
diff_dst_memory =
std::make_shared<memory>(dst_memory.get()->get_primitive_desc());
reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory);
is_diff_dst_reordered = true;
}
} }
in_x_grad_format = (memory::format)diff_src_memory->get_primitive_desc()
.desc()
.data.format;
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline{*(pool_bwd_p.get())}; std::vector<mkldnn::primitive> pipeline;
if (is_diff_dst_reordered) {
pipeline.push_back(reorder_diff_dst);
}
pipeline.push_back(*(pool_bwd_p.get()));
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
in_x_grad->set_layout(DataLayout::kMKLDNN);
in_x_grad->set_format(in_x_grad_format);
} // Compute() } // Compute()
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(pool2d, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(pool2d, MKLDNN, ::paddle::platform::CPUPlace,
paddle::operators::PoolMKLDNNOpKernel<float>); ops::PoolMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL(pool2d_grad, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(pool2d_grad, MKLDNN, ::paddle::platform::CPUPlace,
paddle::operators::PoolMKLDNNGradOpKernel<float>); ops::PoolMKLDNNGradOpKernel<float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册