From d83d0f33fdea50d28f1f4bd22eb5f4988228d43b Mon Sep 17 00:00:00 2001 From: "xiaoli.liu@intel.com" Date: Mon, 24 Dec 2018 16:09:09 +0800 Subject: [PATCH] extract templated function test=develop --- paddle/fluid/operators/quantize_mkldnn_op.cc | 25 ++++++-------------- paddle/fluid/platform/mkldnn_reuse.h | 16 +++++++++++++ 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/operators/quantize_mkldnn_op.cc b/paddle/fluid/operators/quantize_mkldnn_op.cc index 400ba383e7..0638e42873 100644 --- a/paddle/fluid/operators/quantize_mkldnn_op.cc +++ b/paddle/fluid/operators/quantize_mkldnn_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/quantize_op.h" #include "paddle/fluid/platform/mkldnn_helper.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" namespace paddle { namespace operators { @@ -59,37 +60,25 @@ class QuantOpKernel : public framework::OpKernel { std::shared_ptr(new primitive::at(*src_memory)); bool is_negative = ctx.Attr("is_negative_input"); - mkldnn::memory::primitive_desc dst_pd; + std::shared_ptr dst_pd; std::shared_ptr dst_memory; if (is_negative) { - int8_t* output_data = output->mutable_data(ctx.GetPlace()); - auto dst_md = platform::MKLDNNMemDesc({dst_tz}, memory::data_type::s8, - memory::format::nhwc); - dst_pd = mkldnn::memory::primitive_desc(dst_md, engine); - dst_memory.reset( - new mkldnn::memory(dst_pd, to_void_cast(output_data))); + platform::ConvMKLDNNHandler::SetDstMemory( + ctx, output, dst_tz, engine, dst_pd, dst_memory); } else { - uint8_t* output_data = output->mutable_data(ctx.GetPlace()); - auto dst_md = platform::MKLDNNMemDesc({dst_tz}, memory::data_type::u8, - memory::format::nhwc); - dst_pd = mkldnn::memory::primitive_desc(dst_md, engine); - dst_memory.reset( - new mkldnn::memory(dst_pd, to_void_cast(output_data))); + platform::ConvMKLDNNHandler::SetDstMemory( + ctx, output, dst_tz, engine, dst_pd, dst_memory); } - auto reorder_pd = std::shared_ptr( - new reorder::primitive_desc(src_pd, dst_pd, attri)); + new reorder::primitive_desc(src_pd, *dst_pd, attri)); auto reorder_p = std::shared_ptr( new reorder(*reorder_pd, *src_memory_p, *dst_memory)); - pipeline.push_back(*reorder_p); stream(stream::kind::eager).submit(pipeline).wait(); - output->set_layout(DataLayout::kMKLDNN); output->set_format(GetMKLDNNFormat(*dst_memory)); } }; - } // namespace operators } // namespace paddle namespace ops = paddle::operators; diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 1c6421f3fa..febd776a27 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/place.h" @@ -181,6 +182,21 @@ class MKLDNNHandler { return dims2str(operand_dims) + suffix; } + template + static void SetDstMemory( + const framework::ExecutionContext& ctx, framework::Tensor* output, + std::vector dst_tz, const mkldnn::engine& engine, + std::shared_ptr& dst_pd, // NOLINT + std::shared_ptr& dst_memory) { // NOLINT + M* output_data = output->mutable_data(ctx.GetPlace()); + auto dst_md = platform::MKLDNNMemDesc( + {dst_tz}, paddle::framework::ToMKLDNNDataType( + framework::DataTypeTrait::DataType), + mkldnn::memory::format::nhwc); + dst_pd.reset(new mkldnn::memory::primitive_desc(dst_md, engine)); + dst_memory.reset(new mkldnn::memory(*dst_pd, to_void_cast(output_data))); + } + protected: static std::string dims2str(const mkldnn::memory::dims& operand_dims) { std::string dstr = ""; -- GitLab