提交 84096932 编写于 作者: M Michał Gallus 提交者: Tao Luo

Reset DeviceContext after quantization warmup (#18182)

test=develop
上级 b7128bac
...@@ -355,6 +355,13 @@ AnalysisPredictor::MkldnnQuantizer::Histogram( ...@@ -355,6 +355,13 @@ AnalysisPredictor::MkldnnQuantizer::Histogram(
return std::make_pair(std::move(hist), std::move(bin_width)); return std::make_pair(std::move(hist), std::move(bin_width));
} }
void AnalysisPredictor::MkldnnQuantizer::ClearDeviceContext() const {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(predictor_.place_);
dev_ctx->ResetBlobMap();
}
void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
auto& arg = predictor_.argument_; auto& arg = predictor_.argument_;
if (!arg.scope_valid()) arg.SetScope(new framework::Scope); if (!arg.scope_valid()) arg.SetScope(new framework::Scope);
...@@ -380,6 +387,7 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { ...@@ -380,6 +387,7 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
bool AnalysisPredictor::MkldnnQuantizer::Quantize() { bool AnalysisPredictor::MkldnnQuantizer::Quantize() {
if (!RunWarmup()) return false; if (!RunWarmup()) return false;
if (!CalculateScales()) return false; if (!CalculateScales()) return false;
ClearDeviceContext();
predictor_.PrepareScope(predictor_.scope_); predictor_.PrepareScope(predictor_.scope_);
predictor_.CreateExecutor(); predictor_.CreateExecutor();
if (!RunQuantizePasses()) return false; if (!RunQuantizePasses()) return false;
......
...@@ -68,6 +68,7 @@ class AnalysisPredictor::MkldnnQuantizer { ...@@ -68,6 +68,7 @@ class AnalysisPredictor::MkldnnQuantizer {
const framework::LoDTensor& var_tensor, const framework::LoDTensor& var_tensor,
bool is_unsigned); bool is_unsigned);
void PrepareArgument() const; void PrepareArgument() const;
void ClearDeviceContext() const;
bool RunQuantizePasses() const; bool RunQuantizePasses() const;
std::vector<int> ExpandQuantizedBins(std::vector<int> quantized_bins, std::vector<int> ExpandQuantizedBins(std::vector<int> quantized_bins,
......
...@@ -408,6 +408,8 @@ thread_local int cur_thread_id = 0; ...@@ -408,6 +408,8 @@ thread_local int cur_thread_id = 0;
void set_cur_thread_id(int tid) { cur_thread_id = tid; } void set_cur_thread_id(int tid) { cur_thread_id = tid; }
int get_cur_thread_id(void) { return cur_thread_id; } int get_cur_thread_id(void) { return cur_thread_id; }
void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); }
void MKLDNNDeviceContext::SetBlob(const std::string& name, void MKLDNNDeviceContext::SetBlob(const std::string& name,
std::shared_ptr<void> data) const { std::shared_ptr<void> data) const {
BlobMap* pMap = p_blobmap_.get(); BlobMap* pMap = p_blobmap_.get();
......
...@@ -391,6 +391,9 @@ class MKLDNNDeviceContext : public CPUDeviceContext { ...@@ -391,6 +391,9 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
/* \brief Get the active engine */ /* \brief Get the active engine */
const mkldnn::engine& GetEngine() const { return engine_; } const mkldnn::engine& GetEngine() const { return engine_; }
// Remove all entries from the blob map
void ResetBlobMap() const;
// Set data to blob (i.e. name/data pair). Create blob if not existing // Set data to blob (i.e. name/data pair). Create blob if not existing
void SetBlob(const std::string& name, std::shared_ptr<void> data) const; void SetBlob(const std::string& name, std::shared_ptr<void> data) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册