未验证 提交 abb0b2d6 编写于 作者: G Guanghua Yu 提交者: GitHub

[cherry-pick]Add progress bar and speed up Quantization Pass (#43454)

* Add progress bar and speed up Quantization Pass

* fix typo
上级 7e940b84
......@@ -17,6 +17,10 @@ import re
import logging
import numpy as np
import shutil
try:
from tqdm import tqdm
except:
from .utils import tqdm
from inspect import isgeneratorfunction
from .... import io
from .... import core
......@@ -357,38 +361,40 @@ class PostTrainingQuantization(object):
self._set_activation_persistable()
if self._algo in ["KL", "hist"]:
_logger.info("Preparation stage ...")
batch_id = 0
with tqdm(
total=self._batch_nums,
bar_format='Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for data in self._data_loader():
self._executor.run(program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False,
scope=self._scope)
self._collect_activation_abs_min_max()
batch_id += 1
t.update()
if self._batch_nums and batch_id >= self._batch_nums:
break
self._init_sampling_act_histogram()
batch_id = 0
with tqdm(
total=self._batch_nums,
bar_format='Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for data in self._data_loader():
self._executor.run(program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False,
scope=self._scope)
self._collect_activation_abs_min_max()
if batch_id % 5 == 0:
_logger.info("Run batch: " + str(batch_id))
self._sampling()
batch_id += 1
t.update()
if self._batch_nums and batch_id >= self._batch_nums:
break
_logger.info("Finish preparation stage, all batch:" + str(batch_id))
self._init_sampling_act_histogram()
_logger.info("Sampling stage ...")
batch_id = 0
for data in self._data_loader():
self._executor.run(program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False,
scope=self._scope)
self._sampling()
if batch_id % 5 == 0:
_logger.info("Run batch: " + str(batch_id))
batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums:
break
_logger.info("Finish sampling stage, all batch: " + str(batch_id))
if self._algo == 'avg':
for var_name in self._quantized_act_var_name:
......@@ -823,8 +829,9 @@ class PostTrainingQuantization(object):
min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor))
if var_name not in self._sampling_act_abs_min_max:
self._sampling_act_abs_min_max[
var_name] = [min_value, max_value]
self._sampling_act_abs_min_max[var_name] = [
min_value, max_value
]
else:
if min_value < self._sampling_act_abs_min_max[var_name][0]:
self._sampling_act_abs_min_max[var_name][0] = min_value
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
from ....framework import IrNode
from ....framework import Operator
......@@ -52,7 +53,6 @@ _act_supported_quantizable_op_type = [
"leaky_relu",
"tanh",
"swish",
"scale",
"transpose",
"transpose2",
"sigmoid",
......@@ -162,7 +162,6 @@ _op_real_in_out_name = {
"sigmoid": [["X"], ["Out"]],
"elementwise_mul": [["X", "Y"], ["Out"]],
"elementwise_pow": [["X", "Y"], ["Out"]],
"scale": [["X"], ["Out"]],
"hard_swish": [["X"], ["Out"]],
"hard_sigmoid": [["X"], ["Out"]],
"gru": [["Input", "Weight"], ["Hidden"]],
......@@ -414,3 +413,26 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor):
cos_sim = np.inner(orig_tensor.flatten(), qdq_tensor.flatten()) \
/ (np.linalg.norm(orig_tensor.flatten()) * np.linalg.norm(qdq_tensor.flatten()))
return cos_sim
class tqdm(object):
def __init__(self, total, bar_format='Loading|{bar}', ncols=80):
self.total = total
self.bar_format = bar_format
self.ncols = ncols
self.n = 0
def update(self, n=1):
self.n += n
a = "=" * round((self.n / self.total) * self.ncols)
b = " " * (self.ncols - len(a))
prefix = self.bar_format.split('|')[0]
sys.stderr.write("\r{}|{}=>{}| {}/{}".format(prefix, a, b, self.n,
self.total))
sys.stderr.flush()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stderr.write('\n')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册