diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc index 9d560ddd2e039cfa01c01ae159cb9f7b95cb638a..d3c1fa7117f6fead0d24840a08d462e79251fac9 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -355,6 +355,13 @@ AnalysisPredictor::MkldnnQuantizer::Histogram( return std::make_pair(std::move(hist), std::move(bin_width)); } +void AnalysisPredictor::MkldnnQuantizer::ClearDeviceContext() const { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + platform::MKLDNNDeviceContext* dev_ctx = + (platform::MKLDNNDeviceContext*)pool.Get(predictor_.place_); + dev_ctx->ResetBlobMap(); +} + void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { auto& arg = predictor_.argument_; if (!arg.scope_valid()) arg.SetScope(new framework::Scope); @@ -380,6 +387,7 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { bool AnalysisPredictor::MkldnnQuantizer::Quantize() { if (!RunWarmup()) return false; if (!CalculateScales()) return false; + ClearDeviceContext(); predictor_.PrepareScope(predictor_.scope_); predictor_.CreateExecutor(); if (!RunQuantizePasses()) return false; diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.h b/paddle/fluid/inference/api/mkldnn_quantizer.h index aea4a0ac93d253fe6b81fb726b8b19369dabd169..6c438265f0b8e2a65c0475f0b11064042549269e 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.h +++ b/paddle/fluid/inference/api/mkldnn_quantizer.h @@ -68,6 +68,7 @@ class AnalysisPredictor::MkldnnQuantizer { const framework::LoDTensor& var_tensor, bool is_unsigned); void PrepareArgument() const; + void ClearDeviceContext() const; bool RunQuantizePasses() const; std::vector ExpandQuantizedBins(std::vector quantized_bins, diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 4f048d44685a88c3342de48dc6f364c950605be9..59ba3b63519625fc74fa1a37e5eec2e72e13995a 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -408,6 +408,8 @@ thread_local int cur_thread_id = 0; void set_cur_thread_id(int tid) { cur_thread_id = tid; } int get_cur_thread_id(void) { return cur_thread_id; } +void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); } + void MKLDNNDeviceContext::SetBlob(const std::string& name, std::shared_ptr data) const { BlobMap* pMap = p_blobmap_.get(); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 812181563e6e55455a5c08a0ba1b7ca343ebf851..0da64aea4297d1b7df0b003d0fdae864d19102b0 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -391,6 +391,9 @@ class MKLDNNDeviceContext : public CPUDeviceContext { /* \brief Get the active engine */ const mkldnn::engine& GetEngine() const { return engine_; } + // Remove all entries from the blob map + void ResetBlobMap() const; + // Set data to blob (i.e. name/data pair). Create blob if not existing void SetBlob(const std::string& name, std::shared_ptr data) const;