未验证 提交 89069af5 编写于 作者: G Guanghua Yu 提交者: GitHub

Support quantization of condition block (#37498)

* Support sub graph quant-post
上级 76c73226
...@@ -410,6 +410,23 @@ class PostTrainingQuantization(object): ...@@ -410,6 +410,23 @@ class PostTrainingQuantization(object):
for op_type in self._dynamic_quantize_op_type): for op_type in self._dynamic_quantize_op_type):
self._collect_dynamic_quantize_op_threshold( self._collect_dynamic_quantize_op_threshold(
self._dynamic_quantize_op_type) self._dynamic_quantize_op_type)
# Move sub blocks persistable var to global block
global_block = self._program.global_block()
for _op in global_block.ops:
if _op.type == "while":
_block_id = _op.attr("sub_block").id
_block = self._program.block(_block_id)
persistables = []
for _name, _var in _block.vars.items():
if _var.persistable:
global_block._clone_variable(_var)
persistables.append(_name)
for _name in persistables:
_block._remove_var(_name)
persistables.extend(_op.input('X'))
_op.desc.set_input("X", persistables)
return self._program return self._program
def save_quantized_model(self, def save_quantized_model(self,
...@@ -451,10 +468,6 @@ class PostTrainingQuantization(object): ...@@ -451,10 +468,6 @@ class PostTrainingQuantization(object):
model_filename=self._model_filename, model_filename=self._model_filename,
params_filename=self._params_filename) params_filename=self._params_filename)
if self._program.num_blocks > 1:
_logger.error("The post training quantization requires that the "
"program only has one block.")
if self._optimize_model: if self._optimize_model:
self._optimize_fp32_model() self._optimize_fp32_model()
...@@ -505,15 +518,18 @@ class PostTrainingQuantization(object): ...@@ -505,15 +518,18 @@ class PostTrainingQuantization(object):
self._quantized_act_var_name.add(var_name) self._quantized_act_var_name.add(var_name)
persistable_var_names = _all_persistable_var_names(self._program) persistable_var_names = _all_persistable_var_names(self._program)
for op in self._program.global_block().ops: for block_id in range(len(self._program.blocks)):
for op in self._program.blocks[block_id].ops:
op_type = op.type op_type = op.type
if self._is_full_quantize and \ if self._is_full_quantize and \
op_type not in self._quantizable_op_type: op_type not in self._quantizable_op_type:
_logger.warning(op_type + " is not supported for quantization.") _logger.warning(op_type +
" is not supported for quantization.")
# For quantized ops, sample inputs and outputs # For quantized ops, sample inputs and outputs
if op_type in self._quantizable_op_type: if op_type in self._quantizable_op_type:
collect_var_name( collect_var_name(
_get_op_input_var_names(op), persistable_var_names, op_type) _get_op_input_var_names(op), persistable_var_names,
op_type)
collect_var_name( collect_var_name(
_get_op_output_var_names(op), persistable_var_names, _get_op_output_var_names(op), persistable_var_names,
op_type) op_type)
...@@ -696,7 +712,8 @@ class PostTrainingQuantization(object): ...@@ -696,7 +712,8 @@ class PostTrainingQuantization(object):
''' '''
assert self._algo == "min_max", \ assert self._algo == "min_max", \
"The algo should be min_max to save input threshold." "The algo should be min_max to save input threshold."
for op in self._program.global_block().ops: for block_id in range(len(self._program.blocks)):
for op in self._program.blocks[block_id].ops:
if op.type in self._quantizable_op_type: if op.type in self._quantizable_op_type:
for var_name in _get_op_input_var_names(op): for var_name in _get_op_input_var_names(op):
assert var_name in self._quantized_var_min assert var_name in self._quantized_var_min
...@@ -795,7 +812,12 @@ class PostTrainingQuantization(object): ...@@ -795,7 +812,12 @@ class PostTrainingQuantization(object):
activation_quantize_type=self._activation_quantize_type, activation_quantize_type=self._activation_quantize_type,
weight_quantize_type=self._weight_quantize_type, weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types) quantizable_op_type=major_quantizable_op_types)
transform_pass.apply(graph)
for sub_graph in graph.all_sub_graphs():
# Insert fake_quant/fake_dequantize op must in test graph, so
# set per graph's _for_test is True.
sub_graph._for_test = True
transform_pass.apply(sub_graph)
# use AddQuantDequantPass to insert fake_quant_dequant op # use AddQuantDequantPass to insert fake_quant_dequant op
minor_quantizable_op_types = [] minor_quantizable_op_types = []
...@@ -806,7 +828,10 @@ class PostTrainingQuantization(object): ...@@ -806,7 +828,10 @@ class PostTrainingQuantization(object):
scope=self._scope, scope=self._scope,
place=self._place, place=self._place,
quantizable_op_type=minor_quantizable_op_types) quantizable_op_type=minor_quantizable_op_types)
add_quant_dequant_pass.apply(graph)
for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
add_quant_dequant_pass.apply(sub_graph)
# save threshold to scale var node # save threshold to scale var node
if self._algo in ["KL", "hist"]: if self._algo in ["KL", "hist"]:
...@@ -836,7 +861,11 @@ class PostTrainingQuantization(object): ...@@ -836,7 +861,11 @@ class PostTrainingQuantization(object):
activation_bits=self._activation_bits, activation_bits=self._activation_bits,
weight_quantize_type=self._weight_quantize_type, weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types) quantizable_op_type=major_quantizable_op_types)
freeze_pass.apply(graph)
for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
freeze_pass.apply(sub_graph)
self._program = graph.to_program() self._program = graph.to_program()
def _save_output_threshold(self): def _save_output_threshold(self):
...@@ -888,8 +917,10 @@ class PostTrainingQuantization(object): ...@@ -888,8 +917,10 @@ class PostTrainingQuantization(object):
save_info(op_node, out_var_name, self._quantized_var_max, save_info(op_node, out_var_name, self._quantized_var_max,
"out_max", "post_min_max") "out_max", "post_min_max")
for op in self._program.global_block().ops: for block_id in range(len(self._program.blocks)):
if op.type in (self._quantizable_op_type + self._out_scale_op_list): for op in self._program.blocks[block_id].ops:
if op.type in (
self._quantizable_op_type + self._out_scale_op_list):
out_var_names = _get_op_output_var_names(op) out_var_names = _get_op_output_var_names(op)
assert len(out_var_names) == 1, "Post training " + \ assert len(out_var_names) == 1, "Post training " + \
"quantization only support one output for " + op.type "quantization only support one output for " + op.type
......
...@@ -139,6 +139,7 @@ endfunction() ...@@ -139,6 +139,7 @@ endfunction()
if(WIN32) if(WIN32)
list(REMOVE_ITEM TEST_OPS test_light_nas) list(REMOVE_ITEM TEST_OPS test_light_nas)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mnist) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mnist)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_while)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_lstm_model) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_lstm_model)
...@@ -336,6 +337,7 @@ if(NOT WIN32) ...@@ -336,6 +337,7 @@ if(NOT WIN32)
set_tests_properties(test_post_training_quantization_mobilenetv1 PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY") set_tests_properties(test_post_training_quantization_mobilenetv1 PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY")
set_tests_properties(test_post_training_quantization_resnet50 PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY") set_tests_properties(test_post_training_quantization_resnet50 PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY")
set_tests_properties(test_post_training_quantization_mnist PROPERTIES TIMEOUT 120) set_tests_properties(test_post_training_quantization_mnist PROPERTIES TIMEOUT 120)
set_tests_properties(test_post_training_quantization_while PROPERTIES TIMEOUT 120)
set_tests_properties(test_imperative_ptq PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_ptq PROPERTIES TIMEOUT 120)
set_tests_properties(test_weight_quantization_mobilenetv1 PROPERTIES TIMEOUT 120) set_tests_properties(test_weight_quantization_mobilenetv1 PROPERTIES TIMEOUT 120)
endif() endif()
......
# copyright (c) 2021 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import os
import time
import sys
import random
import math
import functools
import contextlib
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.dataset.common import download
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
paddle.enable_static()
random.seed(0)
np.random.seed(0)
class TestPostTrainingQuantization(unittest.TestCase):
def setUp(self):
self.download_path = 'int8/download'
self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' +
self.download_path)
self.timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
self.int8_model_path = os.path.join(os.getcwd(),
"post_training_" + self.timestamp)
try:
os.system("mkdir -p " + self.int8_model_path)
except Exception as e:
print("Failed to create {} due to {}".format(self.int8_model_path,
str(e)))
sys.exit(-1)
def tearDown(self):
try:
os.system("rm -rf {}".format(self.int8_model_path))
except Exception as e:
print("Failed to delete {} due to {}".format(self.int8_model_path,
str(e)))
def cache_unzipping(self, target_folder, zip_path):
cmd = 'tar xf {0} -C {1}'.format(zip_path, target_folder)
os.system(cmd)
def download_model(self, data_url, data_md5, folder_name):
download(data_url, self.download_path, data_md5)
file_name = data_url.split('/')[-1]
zip_path = os.path.join(self.cache_folder, file_name)
print('Data is downloaded at {0}'.format(zip_path))
data_cache_folder = os.path.join(self.cache_folder, folder_name)
self.cache_unzipping(self.cache_folder, zip_path)
return data_cache_folder
def run_program(self, model_path, batch_size, infer_iterations):
print("test model path:" + model_path)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
[infer_program, feed_dict, fetch_targets] = \
fluid.io.load_inference_model(model_path,
model_filename='model.pdmodel',
params_filename='model.pdiparams', executor=exe)
val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size)
img_shape = [1, 28, 28]
test_info = []
cnt = 0
periods = []
for batch_id, data in enumerate(val_reader()):
image = np.array(
[x[0].reshape(img_shape) for x in data]).astype("float32")
input_label = np.array([x[1] for x in data]).astype("int64")
t1 = time.time()
out = exe.run(infer_program,
feed={feed_dict[0]: image},
fetch_list=fetch_targets)
t2 = time.time()
period = t2 - t1
periods.append(period)
out_label = np.argmax(np.array(out[0]), axis=1)
top1_num = sum(input_label == out_label)
test_info.append(top1_num)
cnt += len(data)
if (batch_id + 1) == infer_iterations:
break
throughput = cnt / np.sum(periods)
latency = np.average(periods)
acc1 = np.sum(test_info) / cnt
return (throughput, latency, acc1)
def generate_quantized_model(self,
model_path,
algo="KL",
quantizable_op_type=["conv2d"],
is_full_quantize=False,
is_use_cache_file=False,
is_optimize_model=False,
batch_size=10,
batch_nums=10):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.global_scope()
val_reader = paddle.dataset.mnist.train()
ptq = PostTrainingQuantization(
executor=exe,
model_dir=model_path,
model_filename='model.pdmodel',
params_filename='model.pdiparams',
sample_generator=val_reader,
batch_size=batch_size,
batch_nums=batch_nums,
algo=algo,
quantizable_op_type=quantizable_op_type,
is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model,
is_use_cache_file=is_use_cache_file)
ptq.quantize()
ptq.save_quantized_model(
self.int8_model_path,
model_filename='model.pdmodel',
params_filename='model.pdiparams')
def run_test(self,
model_name,
data_url,
data_md5,
algo,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
batch_size=10,
infer_iterations=10,
quant_iterations=5):
origin_model_path = self.download_model(data_url, data_md5, model_name)
#origin_model_path = os.path.join(origin_model_path, model_name)
print("Start FP32 inference for {0} on {1} images ...".format(
model_name, infer_iterations * batch_size))
(fp32_throughput, fp32_latency, fp32_acc1) = self.run_program(
origin_model_path, batch_size, infer_iterations)
print("Start INT8 post training quantization for {0} on {1} images ...".
format(model_name, quant_iterations * batch_size))
self.generate_quantized_model(
origin_model_path, algo, quantizable_op_type, is_full_quantize,
is_use_cache_file, is_optimize_model, batch_size, quant_iterations)
print("Start INT8 inference for {0} on {1} images ...".format(
model_name, infer_iterations * batch_size))
(int8_throughput, int8_latency, int8_acc1) = self.run_program(
self.int8_model_path, batch_size, infer_iterations)
print("---Post training quantization of {} method---".format(algo))
print(
"FP32 {0}: batch_size {1}, throughput {2} img/s, latency {3} s, acc1 {4}.".
format(model_name, batch_size, fp32_throughput, fp32_latency,
fp32_acc1))
print(
"INT8 {0}: batch_size {1}, throughput {2} img/s, latency {3} s, acc1 {4}.\n".
format(model_name, batch_size, int8_throughput, int8_latency,
int8_acc1))
sys.stdout.flush()
delta_value = fp32_acc1 - int8_acc1
self.assertLess(delta_value, diff_threshold)
class TestPostTrainingKLForWhile(TestPostTrainingQuantization):
def test_post_training_kl(self):
model_name = "mnist_while"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_while.tar.gz"
data_md5 = "2387390beeb37b51dec041c27b8a681f"
algo = "KL"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.01
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, batch_size, infer_iterations,
quant_iterations)
class TestPostTraininghistForWhile(TestPostTrainingQuantization):
def test_post_training_hist(self):
model_name = "mnist_while"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_while.tar.gz"
data_md5 = "2387390beeb37b51dec041c27b8a681f"
algo = "hist"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.01
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, batch_size, infer_iterations,
quant_iterations)
class TestPostTrainingmseForWhile(TestPostTrainingQuantization):
def test_post_training_mse(self):
model_name = "mnist_while"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_while.tar.gz"
data_md5 = "2387390beeb37b51dec041c27b8a681f"
algo = "mse"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.01
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, batch_size, infer_iterations,
quant_iterations)
class TestPostTrainingavgForWhile(TestPostTrainingQuantization):
def test_post_training_avg(self):
model_name = "mnist_while"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_while.tar.gz"
data_md5 = "2387390beeb37b51dec041c27b8a681f"
algo = "avg"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.01
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, batch_size, infer_iterations,
quant_iterations)
class TestPostTrainingMinMaxForWhile(TestPostTrainingQuantization):
def test_post_training_min_max(self):
model_name = "mnist_while"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_while.tar.gz"
data_md5 = "2387390beeb37b51dec041c27b8a681f"
algo = "min_max"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.01
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, batch_size, infer_iterations,
quant_iterations)
class TestPostTrainingAbsMaxForWhile(TestPostTrainingQuantization):
def test_post_training_abs_max(self):
model_name = "mnist_while"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_while.tar.gz"
data_md5 = "2387390beeb37b51dec041c27b8a681f"
algo = "abs_max"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.01
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, batch_size, infer_iterations,
quant_iterations)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册