diff --git a/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc index 04cd60be964a3967a45e73122324c4b3fdf0b3d0..11c2b83d6814ba5e926be68081cf64d0f726395a 100644 --- a/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc @@ -85,11 +85,11 @@ class QuantOpKernel : public framework::OpKernel { std::shared_ptr dst_pd; if (is_negative) { - platform::ConvMKLDNNHandler::SetDstMemory( - ctx, output, dst_tz, engine, dst_pd, dst_memory); + platform::SetDstMemoryQuantized(ctx, output, dst_tz, engine, + dst_pd, dst_memory); } else { - platform::ConvMKLDNNHandler::SetDstMemory( - ctx, output, dst_tz, engine, dst_pd, dst_memory); + platform::SetDstMemoryQuantized(ctx, output, dst_tz, engine, + dst_pd, dst_memory); } auto reorder_pd = std::shared_ptr( new reorder::primitive_desc(src_pd, *dst_pd, attri)); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index de53abf9e031021287bcbd0b23fc2e380623d0c9..e7a7fa2ca36071033a2338aa51d2744e0f6de707 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -27,6 +27,7 @@ namespace paddle { namespace platform { using user_function = std::function(const float*)>; +using memory = mkldnn::memory; class MKLDNNHandler { public: @@ -196,21 +197,6 @@ class MKLDNNHandler { return dims2str(operand_dims) + suffix; } - template - static void SetDstMemory( - const framework::ExecutionContext& ctx, framework::Tensor* output, - std::vector dst_tz, const mkldnn::engine& engine, - std::shared_ptr& dst_pd, // NOLINT - std::shared_ptr& dst_memory) { // NOLINT - T* output_data = output->mutable_data(ctx.GetPlace()); - auto dst_md = platform::MKLDNNMemDesc( - {dst_tz}, paddle::framework::ToMKLDNNDataType( - framework::DataTypeTrait::DataType), - mkldnn::memory::format::nhwc); - dst_pd.reset(new mkldnn::memory::primitive_desc(dst_md, engine)); - dst_memory.reset(new mkldnn::memory(*dst_pd, to_void_cast(output_data))); - } - static void AppendKey( std::string* key, const mkldnn::memory::dims& input_dims, const mkldnn::memory::dims& weights_dims, const std::vector& strides, @@ -915,5 +901,26 @@ static void SetDstMemoryHandler( (*dst_memory_p)->set_data_handle(to_void_cast(output_data)); } +template +static void SetDstMemoryQuantized( + const framework::ExecutionContext& ctx, framework::Tensor* output, + std::vector dst_tz, const mkldnn::engine& engine, + std::shared_ptr& dst_pd, // NOLINT + std::shared_ptr& dst_memory) { // NOLINT + T* output_data = output->mutable_data(ctx.GetPlace()); + const size_t dst_dims = dst_tz.size(); + memory::format dst_fmt; + PADDLE_ENFORCE(dst_dims <= 5, + "Dst memory for quantization can not have dims > 5"); + dst_fmt = platform::MKLDNNFormatForSize(dst_dims, memory::format::nhwc); + + auto dst_md = platform::MKLDNNMemDesc( + {dst_tz}, paddle::framework::ToMKLDNNDataType( + framework::DataTypeTrait::DataType), + dst_fmt); + dst_pd.reset(new mkldnn::memory::primitive_desc(dst_md, engine)); + dst_memory.reset(new mkldnn::memory(*dst_pd, to_void_cast(output_data))); +} + } // namespace platform } // namespace paddle