From 840969327271159ef46917508f6c50a004ba9e62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Gallus?= Date: Thu, 27 Jun 2019 04:32:53 +0200 Subject: [PATCH] Reset DeviceContext after quantization warmup (#18182) test=develop --- paddle/fluid/inference/api/mkldnn_quantizer.cc | 8 ++++++++ paddle/fluid/inference/api/mkldnn_quantizer.h | 1 + paddle/fluid/platform/device_context.cc | 2 ++ paddle/fluid/platform/device_context.h | 3 +++ 4 files changed, 14 insertions(+) diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc index 9d560ddd2e..d3c1fa7117 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -355,6 +355,13 @@ AnalysisPredictor::MkldnnQuantizer::Histogram( return std::make_pair(std::move(hist), std::move(bin_width)); } +void AnalysisPredictor::MkldnnQuantizer::ClearDeviceContext() const { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + platform::MKLDNNDeviceContext* dev_ctx = + (platform::MKLDNNDeviceContext*)pool.Get(predictor_.place_); + dev_ctx->ResetBlobMap(); +} + void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { auto& arg = predictor_.argument_; if (!arg.scope_valid()) arg.SetScope(new framework::Scope); @@ -380,6 +387,7 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { bool AnalysisPredictor::MkldnnQuantizer::Quantize() { if (!RunWarmup()) return false; if (!CalculateScales()) return false; + ClearDeviceContext(); predictor_.PrepareScope(predictor_.scope_); predictor_.CreateExecutor(); if (!RunQuantizePasses()) return false; diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.h b/paddle/fluid/inference/api/mkldnn_quantizer.h index aea4a0ac93..6c438265f0 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.h +++ b/paddle/fluid/inference/api/mkldnn_quantizer.h @@ -68,6 +68,7 @@ class AnalysisPredictor::MkldnnQuantizer { const framework::LoDTensor& var_tensor, bool is_unsigned); void PrepareArgument() const; + void ClearDeviceContext() const; bool RunQuantizePasses() const; std::vector ExpandQuantizedBins(std::vector quantized_bins, diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 4f048d4468..59ba3b6351 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -408,6 +408,8 @@ thread_local int cur_thread_id = 0; void set_cur_thread_id(int tid) { cur_thread_id = tid; } int get_cur_thread_id(void) { return cur_thread_id; } +void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); } + void MKLDNNDeviceContext::SetBlob(const std::string& name, std::shared_ptr data) const { BlobMap* pMap = p_blobmap_.get(); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 812181563e..0da64aea42 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -391,6 +391,9 @@ class MKLDNNDeviceContext : public CPUDeviceContext { /* \brief Get the active engine */ const mkldnn::engine& GetEngine() const { return engine_; } + // Remove all entries from the blob map + void ResetBlobMap() const; + // Set data to blob (i.e. name/data pair). Create blob if not existing void SetBlob(const std::string& name, std::shared_ptr data) const; -- GitLab