From 5c2b9258a615ea09d6a86f030515856690074c8e Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Sun, 23 Aug 2020 15:08:47 +0200 Subject: [PATCH] Fix (de/re)quantize cache keys (#26549) --- paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc | 10 +++++----- paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc | 11 ++++++----- paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc | 11 ++++++----- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc index 86c1c323264..540642c7140 100644 --- a/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc @@ -51,11 +51,11 @@ class DeQuantOpKernel : public framework::OpKernel { mkldnn::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType(input->type()); MKLDNNMemoryFormat src_fmt = input->format(); - std::string key = - platform::CreateKey(src_dt, src_tz, ctx.OutputName("Output")); - const std::string key_prim = key + "@reorder_p"; - const std::string key_src_mem = key + "@src_mem"; - const std::string key_dst_mem = key + "@dst_mem"; + std::string key = platform::CreateKey(platform::ThreadIDasStr(), src_dt, + src_tz, ctx.OutputName("Output")); + const std::string key_prim = key + "@r"; + const std::string key_src_mem = key + "@s"; + const std::string key_dst_mem = key + "@d"; std::shared_ptr src_memory; std::shared_ptr dst_memory; diff --git a/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc index 55bd683f8f4..29a86a35d7b 100644 --- a/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc @@ -48,11 +48,12 @@ class QuantOpKernel : public framework::OpKernel { const T* input_data = input->data(); bool is_negative = ctx.Attr("is_negative_input"); - std::string key = platform::CreateKey(src_tz, scale_data, is_negative, - ctx.OutputName("Output")); - const std::string key_prim = key + "@reorder_p"; - const std::string key_src_mem = key + "@src_mem"; - const std::string key_dst_mem = key + "@dst_mem"; + std::string key = + platform::CreateKey(platform::ThreadIDasStr(), src_tz, scale_data, + is_negative, ctx.OutputName("Output")); + const std::string key_prim = key + "@r"; + const std::string key_src_mem = key + "@s"; + const std::string key_dst_mem = key + "@d"; std::shared_ptr src_memory; std::shared_ptr dst_memory; diff --git a/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc index 92e7744e3c0..5ad5ad94505 100644 --- a/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc @@ -40,11 +40,12 @@ class ReQuantOpKernel : public framework::OpKernel { auto src_tz = paddle::framework::vectorize(input->dims()); - std::string key = platform::CreateKey(src_tz, scale_in, scale_out, - ctx.OutputName("Output")); - const std::string key_prim = key + "@reorder_p"; - const std::string key_src_mem = key + "@src_mem"; - const std::string key_dst_mem = key + "@dst_mem"; + std::string key = + platform::CreateKey(platform::ThreadIDasStr(), src_tz, scale_in, + scale_out, ctx.OutputName("Output")); + const std::string key_prim = key + "@r"; + const std::string key_src_mem = key + "@s"; + const std::string key_dst_mem = key + "@d"; std::shared_ptr src_memory; std::shared_ptr dst_memory; -- GitLab