From 35bab4f2b3af6ff3de999759786c75c8ea83100a Mon Sep 17 00:00:00 2001 From: Adam <38704900+grygielski@users.noreply.github.com> Date: Thu, 16 Jan 2020 10:16:24 +0100 Subject: [PATCH] Add caching mechanizm to requantize_mkldnn_op (#22267) --- .../operators/mkldnn/requantize_mkldnn_op.cc | 92 +++++++++++-------- 1 file changed, 56 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc index 21a49a153d..92e7744e3c 100644 --- a/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc @@ -21,14 +21,10 @@ limitations under the License. */ namespace paddle { namespace operators { -using mkldnn::memory; -using mkldnn::primitive; -using mkldnn::reorder; +using dnnl::memory; +using dnnl::reorder; using platform::to_void_cast; using Tensor = framework::Tensor; -using framework::DataLayout; -using mkldnn::stream; -using platform::GetMKLDNNFormat; template class ReQuantOpKernel : public framework::OpKernel { @@ -42,42 +38,66 @@ class ReQuantOpKernel : public framework::OpKernel { ctx.template device_context(); const auto& engine = dev_ctx.GetEngine(); - std::vector pipeline; - auto src_tz = paddle::framework::vectorize(input->dims()); - auto dst_tz = paddle::framework::vectorize(output->dims()); - mkldnn::memory::data_type src_dt = - paddle::framework::ToMKLDNNDataType(input->type()); - mkldnn::memory::data_type dst_dt = src_dt; - MKLDNNMemoryFormat src_fmt = MKLDNNMemoryFormat::nhwc; - MKLDNNMemoryFormat dst_fmt = MKLDNNMemoryFormat::nhwc; + auto src_tz = paddle::framework::vectorize(input->dims()); - const T* input_data = input->data(); - T* output_data = output->mutable_data(ctx.GetPlace()); - float scale_shift = scale_out / scale_in; - - mkldnn::primitive_attr attri; - int mask = 0; - attri.set_output_scales(mask, {scale_shift}); - - auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt); - auto src_memory = std::make_shared( - src_md, engine, to_void_cast(input_data)); + std::string key = platform::CreateKey(src_tz, scale_in, scale_out, + ctx.OutputName("Output")); + const std::string key_prim = key + "@reorder_p"; + const std::string key_src_mem = key + "@src_mem"; + const std::string key_dst_mem = key + "@dst_mem"; - auto dst_md = platform::MKLDNNMemDesc({dst_tz}, dst_dt, dst_fmt); - auto dst_memory = - mkldnn::memory(dst_md, engine, to_void_cast(output_data)); + std::shared_ptr src_memory; + std::shared_ptr dst_memory; + std::shared_ptr reorder_p; + reorder_p = std::static_pointer_cast(dev_ctx.GetBlob(key_prim)); - auto reorder_pd = std::shared_ptr( - new reorder::primitive_desc(*src_memory, dst_memory, attri)); - - auto reorder_p = std::shared_ptr(new reorder(*reorder_pd)); + const T* input_data = input->data(); + T* output_data = output->mutable_data(ctx.GetPlace()); - mkldnn::stream astream(engine); - reorder_p->execute(astream, *src_memory, dst_memory); + if (reorder_p == nullptr) { + dnnl::primitive_attr attri; + int mask = 0; + float scale_shift = scale_out / scale_in; + attri.set_output_scales(mask, {scale_shift}); + + auto dst_tz = paddle::framework::vectorize(output->dims()); + dnnl::memory::data_type src_dt = + paddle::framework::ToMKLDNNDataType(input->type()); + dnnl::memory::data_type dst_dt = src_dt; + + auto src_md = + platform::MKLDNNMemDesc({src_tz}, src_dt, MKLDNNMemoryFormat::nhwc); + src_memory = std::make_shared(src_md, engine, + to_void_cast(input_data)); + + auto dst_md = + platform::MKLDNNMemDesc({dst_tz}, dst_dt, MKLDNNMemoryFormat::nhwc); + dst_memory = std::make_shared(dst_md, engine, + to_void_cast(output_data)); + + auto reorder_pd = + reorder::primitive_desc(*src_memory, *dst_memory, attri); + reorder_p = std::make_shared(reorder_pd); + + dev_ctx.SetBlob(key_prim, reorder_p); + dev_ctx.SetBlob(key_src_mem, src_memory); + dev_ctx.SetBlob(key_dst_mem, dst_memory); + } else { + src_memory = + std::static_pointer_cast(dev_ctx.GetBlob(key_src_mem)); + src_memory->set_data_handle(to_void_cast(input_data)); + + dst_memory = + std::static_pointer_cast(dev_ctx.GetBlob(key_dst_mem)); + dst_memory->set_data_handle(output_data); + } + + dnnl::stream astream(engine); + reorder_p->execute(astream, *src_memory, *dst_memory); astream.wait(); - output->set_layout(DataLayout::kMKLDNN); - output->set_format(GetMKLDNNFormat(dst_memory)); + output->set_layout(framework::DataLayout::kMKLDNN); + output->set_format(platform::GetMKLDNNFormat(*dst_memory)); } }; -- GitLab