未验证 提交 abb0b2d6 编写于 作者: G Guanghua Yu 提交者: GitHub

[cherry-pick]Add progress bar and speed up Quantization Pass (#43454)

* Add progress bar and speed up Quantization Pass

* fix typo
上级 7e940b84
...@@ -17,6 +17,10 @@ import re ...@@ -17,6 +17,10 @@ import re
import logging import logging
import numpy as np import numpy as np
import shutil import shutil
try:
from tqdm import tqdm
except:
from .utils import tqdm
from inspect import isgeneratorfunction from inspect import isgeneratorfunction
from .... import io from .... import io
from .... import core from .... import core
...@@ -357,38 +361,40 @@ class PostTrainingQuantization(object): ...@@ -357,38 +361,40 @@ class PostTrainingQuantization(object):
self._set_activation_persistable() self._set_activation_persistable()
if self._algo in ["KL", "hist"]: if self._algo in ["KL", "hist"]:
_logger.info("Preparation stage ...")
batch_id = 0 batch_id = 0
with tqdm(
total=self._batch_nums,
bar_format='Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for data in self._data_loader():
self._executor.run(program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False,
scope=self._scope)
self._collect_activation_abs_min_max()
batch_id += 1
t.update()
if self._batch_nums and batch_id >= self._batch_nums:
break
self._init_sampling_act_histogram()
batch_id = 0
with tqdm(
total=self._batch_nums,
bar_format='Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for data in self._data_loader(): for data in self._data_loader():
self._executor.run(program=self._program, self._executor.run(program=self._program,
feed=data, feed=data,
fetch_list=self._fetch_list, fetch_list=self._fetch_list,
return_numpy=False, return_numpy=False,
scope=self._scope) scope=self._scope)
self._collect_activation_abs_min_max() self._sampling()
if batch_id % 5 == 0:
_logger.info("Run batch: " + str(batch_id))
batch_id += 1 batch_id += 1
t.update()
if self._batch_nums and batch_id >= self._batch_nums: if self._batch_nums and batch_id >= self._batch_nums:
break break
_logger.info("Finish preparation stage, all batch:" + str(batch_id))
self._init_sampling_act_histogram()
_logger.info("Sampling stage ...")
batch_id = 0
for data in self._data_loader():
self._executor.run(program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False,
scope=self._scope)
self._sampling()
if batch_id % 5 == 0:
_logger.info("Run batch: " + str(batch_id))
batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums:
break
_logger.info("Finish sampling stage, all batch: " + str(batch_id))
if self._algo == 'avg': if self._algo == 'avg':
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
...@@ -823,8 +829,9 @@ class PostTrainingQuantization(object): ...@@ -823,8 +829,9 @@ class PostTrainingQuantization(object):
min_value = float(np.min(var_tensor)) min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor)) max_value = float(np.max(var_tensor))
if var_name not in self._sampling_act_abs_min_max: if var_name not in self._sampling_act_abs_min_max:
self._sampling_act_abs_min_max[ self._sampling_act_abs_min_max[var_name] = [
var_name] = [min_value, max_value] min_value, max_value
]
else: else:
if min_value < self._sampling_act_abs_min_max[var_name][0]: if min_value < self._sampling_act_abs_min_max[var_name][0]:
self._sampling_act_abs_min_max[var_name][0] = min_value self._sampling_act_abs_min_max[var_name][0] = min_value
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
import numpy as np import numpy as np
from ....framework import IrNode from ....framework import IrNode
from ....framework import Operator from ....framework import Operator
...@@ -52,7 +53,6 @@ _act_supported_quantizable_op_type = [ ...@@ -52,7 +53,6 @@ _act_supported_quantizable_op_type = [
"leaky_relu", "leaky_relu",
"tanh", "tanh",
"swish", "swish",
"scale",
"transpose", "transpose",
"transpose2", "transpose2",
"sigmoid", "sigmoid",
...@@ -162,7 +162,6 @@ _op_real_in_out_name = { ...@@ -162,7 +162,6 @@ _op_real_in_out_name = {
"sigmoid": [["X"], ["Out"]], "sigmoid": [["X"], ["Out"]],
"elementwise_mul": [["X", "Y"], ["Out"]], "elementwise_mul": [["X", "Y"], ["Out"]],
"elementwise_pow": [["X", "Y"], ["Out"]], "elementwise_pow": [["X", "Y"], ["Out"]],
"scale": [["X"], ["Out"]],
"hard_swish": [["X"], ["Out"]], "hard_swish": [["X"], ["Out"]],
"hard_sigmoid": [["X"], ["Out"]], "hard_sigmoid": [["X"], ["Out"]],
"gru": [["Input", "Weight"], ["Hidden"]], "gru": [["Input", "Weight"], ["Hidden"]],
...@@ -414,3 +413,26 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor): ...@@ -414,3 +413,26 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor):
cos_sim = np.inner(orig_tensor.flatten(), qdq_tensor.flatten()) \ cos_sim = np.inner(orig_tensor.flatten(), qdq_tensor.flatten()) \
/ (np.linalg.norm(orig_tensor.flatten()) * np.linalg.norm(qdq_tensor.flatten())) / (np.linalg.norm(orig_tensor.flatten()) * np.linalg.norm(qdq_tensor.flatten()))
return cos_sim return cos_sim
class tqdm(object):
def __init__(self, total, bar_format='Loading|{bar}', ncols=80):
self.total = total
self.bar_format = bar_format
self.ncols = ncols
self.n = 0
def update(self, n=1):
self.n += n
a = "=" * round((self.n / self.total) * self.ncols)
b = " " * (self.ncols - len(a))
prefix = self.bar_format.split('|')[0]
sys.stderr.write("\r{}|{}=>{}| {}/{}".format(prefix, a, b, self.n,
self.total))
sys.stderr.flush()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stderr.write('\n')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册