diff --git a/paddle/fluid/operators/pool_mkldnn_op.cc b/paddle/fluid/operators/pool_mkldnn_op.cc index 63eaaedcd5fc3df17902511dc02b25bf43ccd241..60e936298defe7c6ce8a33bdc7de05b52eb950e7 100644 --- a/paddle/fluid/operators/pool_mkldnn_op.cc +++ b/paddle/fluid/operators/pool_mkldnn_op.cc @@ -18,6 +18,26 @@ limitations under the License. */ namespace paddle { namespace operators { +using mkldnn::memory; // Note: paddle has also "memory" namespace +using mkldnn::pooling_forward; +using mkldnn::pooling_backward; + +// Generate keys for storing/retriving primitives for this operator +// TODO(jczaja): Make hashing function more optimial +static std::string gethash(memory::dims& input_dims, std::string& pooling_type, + std::vector& ksize, std::vector& strides, + std::vector& paddings, std::string suffix) { + auto dims2str = [](memory::dims& operand_dims) { + std::string dstr = ""; + for (size_t i = 0; i < operand_dims.size(); ++i) { + dstr += std::to_string(operand_dims[i]) + "-"; + } + return dstr; + }; + return dims2str(input_dims) + dims2str(ksize) + dims2str(strides) + + dims2str(paddings) + pooling_type + suffix; +} + template class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { public: @@ -34,10 +54,6 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { // Get an unique name from "argument" name of "Out" variable // This name will be used as key when saving info into device context - const std::string key = ctx.op().Output("Out"); - const std::string key_pool_pd = key + "@pool_pd"; - const std::string key_pool_workspace_memory = - key + "@pool_workspace_memory"; std::string pooling_type = ctx.Attr("pooling_type"); std::vector ksize = ctx.Attr>("ksize"); @@ -63,37 +79,71 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector src_tz = paddle::framework::vectorize2int(input->dims()); std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); - // TODO(pzelazko-intel): support more formats - auto src_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, - mkldnn::memory::format::nchw); - auto dst_md = platform::MKLDNNMemDesc(dst_tz, mkldnn::memory::f32, - mkldnn::memory::format::nchw); - - std::shared_ptr pool_pd = - CreatePrimitiveDesc(src_md, dst_md, strides, paddings, ksize, - pooling_type, mkldnn_engine); - - // save pool_pd into global device context to be referred in backward path - dev_ctx.SetBlob(key_pool_pd, pool_pd); - - std::shared_ptr workspace_memory = - CreateWorkspaceMemory(pool_pd, pooling_type, mkldnn_engine); - - // save pool_workspace_memory to be referred in backward path - dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory); - - auto src_memory = - mkldnn::memory({src_md, mkldnn_engine}, - static_cast(const_cast(input_data))); - auto dst_memory = - mkldnn::memory({dst_md, mkldnn_engine}, - static_cast(const_cast(output_data))); + const std::string key = gethash(src_tz, pooling_type, ksize, strides, + paddings, ctx.op().Output("Out")); + const std::string key_pool_p = key + "@pool_p"; + const std::string key_pool_pd = key + "@pool_pd"; + const std::string key_pool_src_mem_p = key + "@pool_src_mem_p"; + const std::string key_pool_dst_mem_p = key + "@pool_dst_mem_p"; + const std::string key_pool_workspace_memory = + key + "@pool_workspace_memory"; - auto pool_prim = mkldnn::pooling_forward(*pool_pd, src_memory, dst_memory, - *workspace_memory); + auto pool_p = + std::static_pointer_cast(dev_ctx.GetBlob(key_pool_p)); + if (pool_p == nullptr) { + // TODO(pzelazko-intel): support more formats + + auto src_md = + platform::MKLDNNMemDesc(src_tz, platform::MKLDNNGetDataType(), + mkldnn::memory::format::nchw); + auto dst_md = + platform::MKLDNNMemDesc(dst_tz, platform::MKLDNNGetDataType(), + mkldnn::memory::format::nchw); + + std::shared_ptr pool_pd = + CreatePrimitiveDesc(src_md, dst_md, strides, paddings, ksize, + pooling_type, mkldnn_engine); + + // save pool_pd into global device context to be referred in backward path + dev_ctx.SetBlob(key_pool_pd, pool_pd); + + std::shared_ptr workspace_memory = + CreateWorkspaceMemory(pool_pd, pooling_type, mkldnn_engine); + + // save pool_workspace_memory to be referred in backward path + dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory); + + auto pool_src_memory_p = std::make_shared( + memory::primitive_desc{src_md, mkldnn_engine}, + static_cast(const_cast(input_data))); + dev_ctx.SetBlob(key_pool_src_mem_p, pool_src_memory_p); + + auto pool_dst_memory_p = std::make_shared( + memory::primitive_desc{dst_md, mkldnn_engine}, + static_cast(output_data)); + dev_ctx.SetBlob(key_pool_dst_mem_p, pool_dst_memory_p); + + pool_p = std::make_shared( + *pool_pd, *(pool_src_memory_p.get()), *(pool_dst_memory_p.get()), + *workspace_memory); + dev_ctx.SetBlob(key_pool_p, pool_p); + } else { + // Primitives already exist + auto pool_src_memory_p = + std::static_pointer_cast(dev_ctx.GetBlob(key_pool_src_mem_p)); + PADDLE_ENFORCE(pool_src_memory_p != nullptr, + "Fail to find pooling src mem_p in device context"); + auto pool_dst_memory_p = + std::static_pointer_cast(dev_ctx.GetBlob(key_pool_dst_mem_p)); + PADDLE_ENFORCE(pool_dst_memory_p != nullptr, + "Fail to find pooling dst mem_p in device context"); + pool_src_memory_p->set_data_handle( + reinterpret_cast(const_cast(input_data))); + pool_dst_memory_p->set_data_handle(output_data); + } // push primitive to stream and wait until it's executed - std::vector pipeline{pool_prim}; + std::vector pipeline{*(pool_p.get())}; mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); } @@ -120,9 +170,10 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { mkldnn::memory::primitive_desc workspace_md = pooling_type == "max" ? pool_pd->workspace_primitive_desc() - : mkldnn::memory::primitive_desc( - {{}, mkldnn::memory::f32, mkldnn::memory::format::nchw}, - engine); + : mkldnn::memory::primitive_desc({{}, + platform::MKLDNNGetDataType(), + mkldnn::memory::format::nchw}, + engine); auto p_workspace_memory = new mkldnn::memory(workspace_md); return std::unique_ptr(p_workspace_memory); @@ -140,13 +191,6 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel { const Tensor* out_grad = ctx.Input(framework::GradVarName("Out")); Tensor* in_x_grad = ctx.Output(framework::GradVarName("X")); - // Get an unique name from "argument" name of "Out" variable - // This name will be used as key when referring info from device context - const std::string key = ctx.op().Input("Out"); - const std::string key_pool_pd = key + "@pool_pd"; - const std::string key_pool_workspace_memory = - key + "@pool_workspace_memory"; - std::string pooling_type = ctx.Attr("pooling_type"); std::vector ksize = ctx.Attr>("ksize"); std::vector strides = ctx.Attr>("strides"); @@ -171,43 +215,76 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel { std::vector diff_dst_tz = paddle::framework::vectorize2int(out_grad->dims()); - auto diff_src_md = platform::MKLDNNMemDesc(diff_src_tz, mkldnn::memory::f32, - mkldnn::memory::format::nchw); - auto diff_dst_md = platform::MKLDNNMemDesc(diff_dst_tz, mkldnn::memory::f32, - mkldnn::memory::format::nchw); - - // Retrieve pool_pd/pool_workspace_memory from device context - auto pool_pd = - std::static_pointer_cast( - dev_ctx.GetBlob(key_pool_pd)); - PADDLE_ENFORCE(pool_pd != nullptr, - "Fail to find pool_pd in device context"); - - auto workspace_memory = std::static_pointer_cast( - dev_ctx.GetBlob(key_pool_workspace_memory)); - PADDLE_ENFORCE(workspace_memory != nullptr, - "Fail to find workspace_memory in device context"); - - auto pool_bwd_desc = mkldnn::pooling_backward::desc( - pooling_type == "max" ? mkldnn::algorithm::pooling_max - : mkldnn::algorithm::pooling_avg, - diff_src_md, diff_dst_md, strides, ksize, paddings, paddings, - mkldnn::padding_kind::zero); - auto pool_bwd_pd = mkldnn::pooling_backward::primitive_desc( - pool_bwd_desc, mkldnn_engine, *pool_pd); - - auto diff_src_memory = - mkldnn::memory({diff_src_md, mkldnn_engine}, - static_cast(const_cast(in_x_grad_data))); - auto diff_dst_memory = - mkldnn::memory({diff_dst_md, mkldnn_engine}, - static_cast(const_cast(out_grad_data))); + // Get an unique name from "argument" name of "Out" variable + // This name will be used as key when referring info from device context + const std::string key = gethash(diff_src_tz, pooling_type, ksize, strides, + paddings, ctx.op().Input("Out")); + const std::string key_pool_bwd_p = key + "@pool_bwd_p"; + const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p"; + const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p"; + const std::string key_pool_pd = key + "@pool_pd"; + const std::string key_pool_workspace_memory = + key + "@pool_workspace_memory"; - auto bwd_prim = mkldnn::pooling_backward( - pool_bwd_pd, diff_dst_memory, *workspace_memory, diff_src_memory); + auto pool_bwd_p = std::static_pointer_cast( + dev_ctx.GetBlob(key_pool_bwd_p)); + if (pool_bwd_p == nullptr) { + auto diff_src_md = + platform::MKLDNNMemDesc(diff_src_tz, platform::MKLDNNGetDataType(), + mkldnn::memory::format::nchw); + auto diff_dst_md = + platform::MKLDNNMemDesc(diff_dst_tz, platform::MKLDNNGetDataType(), + mkldnn::memory::format::nchw); + // Retrieve pool_pd/pool_workspace_memory from device context + auto pool_pd = + std::static_pointer_cast( + dev_ctx.GetBlob(key_pool_pd)); + PADDLE_ENFORCE(pool_pd != nullptr, + "Fail to find pool_pd in device context"); + + auto workspace_memory = std::static_pointer_cast( + dev_ctx.GetBlob(key_pool_workspace_memory)); + PADDLE_ENFORCE(workspace_memory != nullptr, + "Fail to find workspace_memory in device context"); + + auto pool_diff_src_memory_p = std::make_shared(memory( + {diff_src_md, mkldnn_engine}, static_cast(in_x_grad_data))); + dev_ctx.SetBlob(key_pool_diff_src_mem_p, pool_diff_src_memory_p); + + auto pool_diff_dst_memory_p = std::make_shared( + memory({diff_dst_md, mkldnn_engine}, + static_cast(const_cast(out_grad_data)))); + dev_ctx.SetBlob(key_pool_diff_dst_mem_p, pool_diff_dst_memory_p); + + auto pool_bwd_desc = mkldnn::pooling_backward::desc( + pooling_type == "max" ? mkldnn::algorithm::pooling_max + : mkldnn::algorithm::pooling_avg, + diff_src_md, diff_dst_md, strides, ksize, paddings, paddings, + mkldnn::padding_kind::zero); + auto pool_bwd_pd = mkldnn::pooling_backward::primitive_desc( + pool_bwd_desc, mkldnn_engine, *pool_pd); + + pool_bwd_p = std::make_shared( + pool_bwd_pd, *(pool_diff_dst_memory_p.get()), *workspace_memory, + *(pool_diff_src_memory_p)); + dev_ctx.SetBlob(key_pool_bwd_p, pool_bwd_p); + } else { + // Primitives already exist + auto pool_diff_src_memory_p = std::static_pointer_cast( + dev_ctx.GetBlob(key_pool_diff_src_mem_p)); + PADDLE_ENFORCE(pool_diff_src_memory_p != nullptr, + "Fail to find pooling src mem_p in device context"); + auto pool_diff_dst_memory_p = std::static_pointer_cast( + dev_ctx.GetBlob(key_pool_diff_dst_mem_p)); + PADDLE_ENFORCE(pool_diff_dst_memory_p != nullptr, + "Fail to find pooling dst mem_p in device context"); + pool_diff_src_memory_p->set_data_handle( + reinterpret_cast(in_x_grad_data)); + pool_diff_dst_memory_p->set_data_handle(const_cast(out_grad_data)); + } // push primitive to stream and wait until it's executed - std::vector pipeline{bwd_prim}; + std::vector pipeline{*(pool_bwd_p.get())}; mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); } // Compute() }; diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 23f1d615daab91f0e4b353bc7d9a3ca7f5cec5ae..56ed5912a15437b72b769610912c7493d77e5964 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -71,5 +71,15 @@ inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) { return use_mkldnn && platform::is_cpu_place(ctx.GetPlace()); } +template +mkldnn::memory::data_type MKLDNNGetDataType() { + return mkldnn::memory::data_undef; +} + +template <> +inline mkldnn::memory::data_type MKLDNNGetDataType() { + return mkldnn::memory::f32; +} + } // namespace platform } // namespace paddle