/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/pool_op.h" #include "paddle/fluid/platform/mkldnn_helper.h" namespace paddle { namespace operators { using mkldnn::memory; // Note: paddle has also "memory" namespace using mkldnn::pooling_forward; using mkldnn::pooling_backward; // Generate keys for storing/retriving primitives for this operator // TODO(jczaja): Make hashing function more optimial static std::string gethash(memory::dims& input_dims, std::string& pooling_type, std::vector& ksize, std::vector& strides, std::vector& paddings, std::string suffix) { auto dims2str = [](memory::dims& operand_dims) { std::string dstr = ""; for (size_t i = 0; i < operand_dims.size(); ++i) { dstr += std::to_string(operand_dims[i]) + "-"; } return dstr; }; return dims2str(input_dims) + dims2str(ksize) + dims2str(strides) + dims2str(paddings) + pooling_type + suffix; } template class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { public: void Compute(const paddle::framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); const Tensor* input = ctx.Input("X"); Tensor* output = ctx.Output("Out"); // Get an unique name from "argument" name of "Out" variable // This name will be used as key when saving info into device context std::string pooling_type = ctx.Attr("pooling_type"); std::vector ksize = ctx.Attr>("ksize"); std::vector strides = ctx.Attr>("strides"); std::vector paddings = ctx.Attr>("paddings"); if (ctx.Attr("global_pooling")) { for (size_t i = 0; i < ksize.size(); ++i) { paddings[i] = 0; ksize[i] = static_cast(input->dims()[i + 2]); } } // Only 2D pooling is supported now PADDLE_ENFORCE(ksize.size() == 2, "ksize must be 2D, i.e. 2D pooling"); PADDLE_ENFORCE(pooling_type == "max" || pooling_type == "avg", "pooling_type must be 'max' or 'avg'"); PADDLE_ENFORCE(input->dims().size() == 4, "Input dim must be with 4, i.e. NCHW"); const T* input_data = input->data(); T* output_data = output->mutable_data(ctx.GetPlace()); std::vector src_tz = paddle::framework::vectorize2int(input->dims()); std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); const std::string key = gethash(src_tz, pooling_type, ksize, strides, paddings, ctx.op().Output("Out")); const std::string key_pool_p = key + "@pool_p"; const std::string key_pool_pd = key + "@pool_pd"; const std::string key_pool_src_mem_p = key + "@pool_src_mem_p"; const std::string key_pool_dst_mem_p = key + "@pool_dst_mem_p"; const std::string key_pool_workspace_memory = key + "@pool_workspace_memory"; auto pool_p = std::static_pointer_cast(dev_ctx.GetBlob(key_pool_p)); if (pool_p == nullptr) { // TODO(pzelazko-intel): support more formats auto src_md = platform::MKLDNNMemDesc(src_tz, platform::MKLDNNGetDataType(), mkldnn::memory::format::nchw); auto dst_md = platform::MKLDNNMemDesc(dst_tz, platform::MKLDNNGetDataType(), mkldnn::memory::format::nchw); std::shared_ptr pool_pd = CreatePrimitiveDesc(src_md, dst_md, strides, paddings, ksize, pooling_type, mkldnn_engine); // save pool_pd into global device context to be referred in backward path dev_ctx.SetBlob(key_pool_pd, pool_pd); std::shared_ptr workspace_memory = CreateWorkspaceMemory(pool_pd, pooling_type, mkldnn_engine); // save pool_workspace_memory to be referred in backward path dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory); auto pool_src_memory_p = std::make_shared( memory::primitive_desc{src_md, mkldnn_engine}, static_cast(const_cast(input_data))); dev_ctx.SetBlob(key_pool_src_mem_p, pool_src_memory_p); auto pool_dst_memory_p = std::make_shared( memory::primitive_desc{dst_md, mkldnn_engine}, static_cast(output_data)); dev_ctx.SetBlob(key_pool_dst_mem_p, pool_dst_memory_p); pool_p = std::make_shared( *pool_pd, *(pool_src_memory_p.get()), *(pool_dst_memory_p.get()), *workspace_memory); dev_ctx.SetBlob(key_pool_p, pool_p); } else { // Primitives already exist auto pool_src_memory_p = std::static_pointer_cast(dev_ctx.GetBlob(key_pool_src_mem_p)); PADDLE_ENFORCE(pool_src_memory_p != nullptr, "Fail to find pooling src mem_p in device context"); auto pool_dst_memory_p = std::static_pointer_cast(dev_ctx.GetBlob(key_pool_dst_mem_p)); PADDLE_ENFORCE(pool_dst_memory_p != nullptr, "Fail to find pooling dst mem_p in device context"); pool_src_memory_p->set_data_handle( reinterpret_cast(const_cast(input_data))); pool_dst_memory_p->set_data_handle(output_data); } // push primitive to stream and wait until it's executed std::vector pipeline{*(pool_p.get())}; mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); } private: std::unique_ptr CreatePrimitiveDesc( const mkldnn::memory::desc& src, const mkldnn::memory::desc& dst, const std::vector& stride, const std::vector& padding, const std::vector& kernel, const std::string& pooling_type, const mkldnn::engine& engine) const { auto pool_desc = mkldnn::pooling_forward::desc( mkldnn::prop_kind::forward, pooling_type == "max" ? mkldnn::algorithm::pooling_max : mkldnn::algorithm::pooling_avg, src, dst, stride, kernel, padding, padding, mkldnn::padding_kind::zero); auto p_pool_pd = new mkldnn::pooling_forward::primitive_desc(pool_desc, engine); return std::unique_ptr(p_pool_pd); } std::unique_ptr CreateWorkspaceMemory( std::shared_ptr pool_pd, const std::string& pooling_type, const mkldnn::engine& engine) const { mkldnn::memory::primitive_desc workspace_md = pooling_type == "max" ? pool_pd->workspace_primitive_desc() : mkldnn::memory::primitive_desc({{}, platform::MKLDNNGetDataType(), mkldnn::memory::format::nchw}, engine); auto p_workspace_memory = new mkldnn::memory(workspace_md); return std::unique_ptr(p_workspace_memory); } }; template class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel { public: void Compute(const paddle::framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); const Tensor* in_x = ctx.Input("X"); const Tensor* out_grad = ctx.Input(framework::GradVarName("Out")); Tensor* in_x_grad = ctx.Output(framework::GradVarName("X")); std::string pooling_type = ctx.Attr("pooling_type"); std::vector ksize = ctx.Attr>("ksize"); std::vector strides = ctx.Attr>("strides"); std::vector paddings = ctx.Attr>("paddings"); if (ctx.Attr("global_pooling")) { for (size_t i = 0; i < ksize.size(); ++i) { paddings[i] = 0; ksize[i] = static_cast(in_x->dims()[i + 2]); } } auto& dev_ctx = ctx.template device_context(); const mkldnn::engine& mkldnn_engine = dev_ctx.GetEngine(); const T* out_grad_data = out_grad->data(); T* in_x_grad_data = in_x_grad->mutable_data(ctx.GetPlace()); std::vector diff_src_tz = paddle::framework::vectorize2int(in_x_grad->dims()); std::vector diff_dst_tz = paddle::framework::vectorize2int(out_grad->dims()); // Get an unique name from "argument" name of "Out" variable // This name will be used as key when referring info from device context const std::string key = gethash(diff_src_tz, pooling_type, ksize, strides, paddings, ctx.op().Input("Out")); const std::string key_pool_bwd_p = key + "@pool_bwd_p"; const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p"; const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p"; const std::string key_pool_pd = key + "@pool_pd"; const std::string key_pool_workspace_memory = key + "@pool_workspace_memory"; auto pool_bwd_p = std::static_pointer_cast( dev_ctx.GetBlob(key_pool_bwd_p)); if (pool_bwd_p == nullptr) { auto diff_src_md = platform::MKLDNNMemDesc(diff_src_tz, platform::MKLDNNGetDataType(), mkldnn::memory::format::nchw); auto diff_dst_md = platform::MKLDNNMemDesc(diff_dst_tz, platform::MKLDNNGetDataType(), mkldnn::memory::format::nchw); // Retrieve pool_pd/pool_workspace_memory from device context auto pool_pd = std::static_pointer_cast( dev_ctx.GetBlob(key_pool_pd)); PADDLE_ENFORCE(pool_pd != nullptr, "Fail to find pool_pd in device context"); auto workspace_memory = std::static_pointer_cast( dev_ctx.GetBlob(key_pool_workspace_memory)); PADDLE_ENFORCE(workspace_memory != nullptr, "Fail to find workspace_memory in device context"); auto pool_diff_src_memory_p = std::make_shared(memory( {diff_src_md, mkldnn_engine}, static_cast(in_x_grad_data))); dev_ctx.SetBlob(key_pool_diff_src_mem_p, pool_diff_src_memory_p); auto pool_diff_dst_memory_p = std::make_shared( memory({diff_dst_md, mkldnn_engine}, static_cast(const_cast(out_grad_data)))); dev_ctx.SetBlob(key_pool_diff_dst_mem_p, pool_diff_dst_memory_p); auto pool_bwd_desc = mkldnn::pooling_backward::desc( pooling_type == "max" ? mkldnn::algorithm::pooling_max : mkldnn::algorithm::pooling_avg, diff_src_md, diff_dst_md, strides, ksize, paddings, paddings, mkldnn::padding_kind::zero); auto pool_bwd_pd = mkldnn::pooling_backward::primitive_desc( pool_bwd_desc, mkldnn_engine, *pool_pd); pool_bwd_p = std::make_shared( pool_bwd_pd, *(pool_diff_dst_memory_p.get()), *workspace_memory, *(pool_diff_src_memory_p)); dev_ctx.SetBlob(key_pool_bwd_p, pool_bwd_p); } else { // Primitives already exist auto pool_diff_src_memory_p = std::static_pointer_cast( dev_ctx.GetBlob(key_pool_diff_src_mem_p)); PADDLE_ENFORCE(pool_diff_src_memory_p != nullptr, "Fail to find pooling src mem_p in device context"); auto pool_diff_dst_memory_p = std::static_pointer_cast( dev_ctx.GetBlob(key_pool_diff_dst_mem_p)); PADDLE_ENFORCE(pool_diff_dst_memory_p != nullptr, "Fail to find pooling dst mem_p in device context"); pool_diff_src_memory_p->set_data_handle( reinterpret_cast(in_x_grad_data)); pool_diff_dst_memory_p->set_data_handle(const_cast(out_grad_data)); } // push primitive to stream and wait until it's executed std::vector pipeline{*(pool_bwd_p.get())}; mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); } // Compute() }; } // namespace operators } // namespace paddle REGISTER_OP_KERNEL(pool2d, MKLDNN, ::paddle::platform::CPUPlace, paddle::operators::PoolMKLDNNOpKernel); REGISTER_OP_KERNEL(pool2d_grad, MKLDNN, ::paddle::platform::CPUPlace, paddle::operators::PoolMKLDNNGradOpKernel);