diff --git a/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc index 4ac14d5ff95e5fb1a664946ecf892af72f581075..c9b80ba1e7a5607694698d9571dbf56e263c7f42 100644 --- a/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc @@ -17,7 +17,8 @@ limitations under the License. */ #include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/requantize_op.h" -#include "paddle/fluid/platform/mkldnn_helper.h" +#include "paddle/phi/backends/onednn/onednn_helper.h" +#include "paddle/phi/backends/onednn/onednn_reuse.h" namespace paddle { namespace operators { @@ -56,101 +57,56 @@ class ReQuantOpKernel : public framework::OpKernel { platform::errors::InvalidArgument("Scale of output cannot be 0.0")); if (shift_in != 0.0f) { PADDLE_ENFORCE_EQ( - framework::TransToProtoVarType(input->dtype()), - framework::proto::VarType::UINT8, + input->dtype(), + DataType::UINT8, platform::errors::Unimplemented("Requantize does not support nonzero " "shift for signed input.")); } auto& dev_ctx = ctx.template device_context(); - const auto& engine = dev_ctx.GetEngine(); auto src_tz = phi::vectorize(input->dims()); - float reorder_scale = scale_out / scale_in; + auto src_paddle_dt = input->dtype(); + auto dst_paddle_dt = with_shift ? DataType::UINT8 : src_paddle_dt; + + auto xstrides = input->mem_desc().data.format_desc.blocking.strides; + std::vector vstrides(xstrides, + xstrides + input->mem_desc().data.ndims); - std::string key = platform::CreateKey( - dev_ctx, src_tz, scale_in, scale_out, ctx.OutputName("Output")); - key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); - const std::string key_prim = key + "@r"; - const std::string key_src_mem = key + "@s"; - const std::string key_dst_mem = key + "@d"; - - std::shared_ptr src_memory; - std::shared_ptr dst_memory; - std::shared_ptr reorder_p; - reorder_p = std::static_pointer_cast(dev_ctx.GetBlob(key_prim)); - - const T* input_data = input->data(); - - if (reorder_p == nullptr) { - auto src_dt = framework::ToMKLDNNDataType( - framework::TransToProtoVarType(input->dtype())); - auto dst_dt = with_shift ? framework::OneDNNDataType::u8 : src_dt; - - src_memory = std::make_shared( - input->mem_desc(), engine, phi::funcs::to_void_cast(input_data)); - - auto xstrides = input->mem_desc().data.format_desc.blocking.strides; - - std::vector vstrides(xstrides, - xstrides + input->mem_desc().data.ndims); - - auto dst_md = dnnl::memory::desc({src_tz}, dst_dt, vstrides); - - dnnl::primitive_attr attri; - int mask = 0; - attri.set_output_scales(mask, {reorder_scale}); - if (with_shift) { - dnnl::post_ops post_operations; - post_operations.append_sum(); - attri.set_post_ops(post_operations); - uint8_t* output_data = output->mutable_data(ctx.GetPlace()); - uint8_t reorder_shift = - clip_to_uint8(shift_out - reorder_scale * shift_in); - std::memset(output_data, reorder_shift, output->numel()); - dst_memory = std::make_shared( - dst_md, engine, phi::funcs::to_void_cast(output_data)); - } else { - T* output_data = output->mutable_data(ctx.GetPlace()); - dst_memory = std::make_shared( - dst_md, engine, phi::funcs::to_void_cast(output_data)); - } - - auto reorder_pd = - reorder::primitive_desc(*src_memory, *dst_memory, attri); - reorder_p = std::make_shared(reorder_pd); - - dev_ctx.SetBlob(key_prim, reorder_p); - dev_ctx.SetBlob(key_src_mem, src_memory); - dev_ctx.SetBlob(key_dst_mem, dst_memory); - } else { - src_memory = - std::static_pointer_cast(dev_ctx.GetBlob(key_src_mem)); - src_memory->set_data_handle(phi::funcs::to_void_cast(input_data)); - - dst_memory = - std::static_pointer_cast(dev_ctx.GetBlob(key_dst_mem)); - if (with_shift) { - uint8_t* output_data = output->mutable_data(ctx.GetPlace()); - uint8_t reorder_shift = - clip_to_uint8(shift_out - reorder_scale * shift_in); - std::memset(output_data, reorder_shift, output->numel()); - dst_memory->set_data_handle(output_data); - - } else { - T* output_data = output->mutable_data(ctx.GetPlace()); - dst_memory->set_data_handle(output_data); - } + dnnl::primitive_attr attrs; + int mask = 0; + float reorder_scale = scale_out / scale_in; + attrs.set_output_scales(mask, {reorder_scale}); + if (with_shift) { + uint8_t reorder_shift = + clip_to_uint8(shift_out - reorder_scale * shift_in); + attrs.set_zero_points( + DNNL_ARG_DST, mask, {static_cast(reorder_shift)}); } - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + phi::funcs::ReorderOneDNNHandler reorder_handler( + src_tz, + src_paddle_dt, + phi::funcs::ToOneDNNDataType(src_paddle_dt), + dst_paddle_dt, + phi::funcs::ToOneDNNDataType(dst_paddle_dt), + dev_ctx.GetEngine()); - reorder_p->execute(astream, *src_memory, *dst_memory); + auto src_memory_p = reorder_handler.AcquireSrcMemory( + input->mem_desc(), phi::funcs::to_void_cast(input->data())); + auto dst_memory_p = reorder_handler.AcquireDstMemory( + output, src_tz, vstrides, dev_ctx.GetPlace()); + + auto reorder_p = + reorder_handler.AcquireReorder(dst_memory_p, src_memory_p, attrs); + + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + reorder_p->execute(astream, *src_memory_p, *dst_memory_p); astream.wait(); - output->set_mem_desc(dst_memory->get_desc()); + output->set_mem_desc(dst_memory_p->get_desc()); } };