diff --git a/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc index 262b7408a7f5f65c4d97120914c16f38ce5fdbe7..accc9a9d71ffccf2812d57a7516eaf7e0f83275c 100644 --- a/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/dequantize_op.h" #include "paddle/fluid/platform/mkldnn_helper.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" namespace paddle { namespace operators { @@ -30,6 +31,18 @@ using framework::DataLayout; using mkldnn::stream; using platform::GetMKLDNNFormat; +std::string CreateKey(const paddle::framework::ExecutionContext& ctx, + const mkldnn::memory::data_type& src_dt, + const std::vector& src_tz, const float scale_data) { + std::string key; + key.reserve(platform::MKLDNNHandler::MaxKeyLength); + platform::MKLDNNHandler::AppendKey(&key, std::to_string(src_dt)); + platform::MKLDNNHandler::AppendKeyDims(&key, src_tz); + platform::MKLDNNHandler::AppendKey(&key, std::to_string(scale_data)); + platform::MKLDNNHandler::AppendKey(&key, ctx.op().Output("Output")); + return key; +} + template class DeQuantOpKernel : public framework::OpKernel { public: @@ -51,31 +64,55 @@ class DeQuantOpKernel : public framework::OpKernel { mkldnn::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType(input->type()); mkldnn::memory::format src_fmt = input->format(); + std::string key = CreateKey(ctx, src_dt, src_tz, reorder_scale[0]); + const std::string key_prim = key + "@reorder_p"; + const std::string key_src_mem = key + "@src_mem"; + const std::string key_dst_mem = key + "@dst_mem"; + + std::shared_ptr src_memory; + std::shared_ptr dst_memory; + std::shared_ptr reorder_p; + reorder_p = std::static_pointer_cast(dev_ctx.GetBlob(key_prim)); + + if (reorder_p == nullptr) { + mkldnn::primitive_attr attri; + int mask = 0; + attri.set_output_scales(mask, reorder_scale); + + auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt); + auto src_pd = mkldnn::memory::primitive_desc(src_md, engine); + src_memory = + std::make_shared(src_pd, to_void_cast(input_data)); + std::shared_ptr src_memory_p = + std::shared_ptr(new primitive::at(*src_memory)); + + auto dst_md = platform::MKLDNNMemDesc({dst_tz}, memory::data_type::f32, + memory::format::nchw); + auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine); + dst_memory = std::make_shared( + dst_pd, to_void_cast(output_data)); + + auto reorder_pd = std::shared_ptr( + new reorder::primitive_desc(src_pd, dst_pd, attri)); + reorder_p = std::shared_ptr( + new reorder(*reorder_pd, *src_memory_p, *dst_memory)); + dev_ctx.SetBlob(key_prim, reorder_p); + dev_ctx.SetBlob(key_src_mem, src_memory); + dev_ctx.SetBlob(key_dst_mem, dst_memory); + } else { + src_memory = std::static_pointer_cast( + dev_ctx.GetBlob(key_src_mem)); + src_memory->set_data_handle(to_void_cast(input_data)); + + dst_memory = std::static_pointer_cast( + dev_ctx.GetBlob(key_dst_mem)); + dst_memory->set_data_handle(output->mutable_data(ctx.GetPlace())); + } - mkldnn::primitive_attr attri; - int mask = 0; - attri.set_output_scales(mask, reorder_scale); - - auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt); - auto src_pd = mkldnn::memory::primitive_desc(src_md, engine); - auto src_memory = - std::make_shared(src_pd, to_void_cast(input_data)); - std::shared_ptr src_memory_p = - std::shared_ptr(new primitive::at(*src_memory)); - - auto dst_md = platform::MKLDNNMemDesc({dst_tz}, memory::data_type::f32, - memory::format::nchw); - auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine); - auto dst_memory = mkldnn::memory(dst_pd, to_void_cast(output_data)); - - auto reorder_pd = std::shared_ptr( - new reorder::primitive_desc(src_pd, dst_pd, attri)); - auto reorder_p = std::shared_ptr( - new reorder(*reorder_pd, *src_memory_p, dst_memory)); pipeline.push_back(*reorder_p); stream(stream::kind::eager).submit(pipeline).wait(); - output->set_format(GetMKLDNNFormat(dst_memory)); + output->set_format(GetMKLDNNFormat(*dst_memory)); } };