未验证 提交 bf123166 编写于 作者: G Guanghua Yu 提交者: GitHub

add post_quant CE (#961)

上级 06175aea
......@@ -20,6 +20,7 @@ import contextlib
import os
import time
import math
import random
import numpy as np
from PIL import Image
......@@ -28,7 +29,6 @@ from paddle.io import Dataset
from paddle.vision.transforms import transforms
import paddle.vision.models as models
import paddle.nn as nn
from paddleslim import PTQ
import sys
......@@ -41,7 +41,6 @@ from models.dygraph.mobilenet_v3 import MobileNetV3_large_x1_0
class ImageNetValDataset(Dataset):
def __init__(self, data_dir, image_size=224, resize_short_size=256):
super(ImageNetValDataset, self).__init__()
train_file_list = os.path.join(data_dir, 'train_list.txt')
val_file_list = os.path.join(data_dir, 'val_list.txt')
test_file_list = os.path.join(data_dir, 'test_list.txt')
self.data_dir = data_dir
......@@ -68,9 +67,9 @@ class ImageNetValDataset(Dataset):
return len(self.data)
def calibrate(model, dataset, batch_num, batch_size):
def calibrate(model, dataset, batch_num, batch_size, num_workers=5):
data_loader = paddle.io.DataLoader(
dataset, batch_size=batch_size, num_workers=5)
dataset, batch_size=batch_size, num_workers=num_workers)
for idx, data in enumerate(data_loader()):
img = data[0]
......@@ -85,6 +84,15 @@ def calibrate(model, dataset, batch_num, batch_size):
def main():
num_workers = 5
if FLAGS.ce_test:
# set seed
seed = 111
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
num_workers = 0
# 1 load model
model_list = [x for x in models.__dict__["__all__"]]
model_list.append('mobilenet_v3')
......@@ -126,8 +134,12 @@ def main():
quant_model = ptq.quantize(fp32_model, fuse=FLAGS.fuse, fuse_list=fuse_list)
print("Start calibrate...")
calibrate(quant_model, val_dataset, FLAGS.quant_batch_num,
FLAGS.quant_batch_size)
calibrate(
quant_model,
val_dataset,
FLAGS.quant_batch_num,
FLAGS.quant_batch_size,
num_workers=num_workers)
# 3 save
quant_output_dir = os.path.join(FLAGS.output_dir, FLAGS.model, "int8_infer",
......@@ -172,6 +184,8 @@ if __name__ == '__main__':
"--quant_batch_num", default=10, type=int, help="batch num for quant")
parser.add_argument(
"--quant_batch_size", default=10, type=int, help="batch size for quant")
parser.add_argument(
'--ce_test', default=False, type=bool, help="Whether to CE test.")
FLAGS = parser.parse_args()
assert FLAGS.data, "error: must provide data path"
......
......@@ -10,6 +10,7 @@ import paddle
import six
import reader
from net import skip_gram_word2vec
import paddle
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("paddle")
......@@ -77,6 +78,12 @@ def parse_args():
required=False,
default=False,
help='print speed or not , (default: False)')
parser.add_argument(
'--ce_test',
required=False,
default=False,
help='Whether to CE test, (default: False)')
return parser.parse_args()
......@@ -185,6 +192,12 @@ def GetFileList(data_path):
def train(args):
if args.ce_test:
# set seed
seed = 111
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
if not os.path.isdir(args.model_output_dir):
os.mkdir(args.model_output_dir)
......
......@@ -6,7 +6,9 @@ import argparse
import functools
import math
import time
import random
import numpy as np
import paddle
sys.path[0] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
......@@ -29,12 +31,13 @@ add_arg('params_filename', str, None, "params file name")
add_arg('algo', str, 'hist', "calibration algorithm")
add_arg('hist_percent', float, 0.9999, "The percentile of algo:hist")
add_arg('bias_correction', bool, False, "Whether to use bias correction")
add_arg('ce_test', bool, False, "Whether to CE test.")
# yapf: enable
def quantize(args):
val_reader = reader.train()
val_reader = reader.val()
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
......@@ -59,6 +62,12 @@ def quantize(args):
def main():
args = parser.parse_args()
print_arguments(args)
if args.ce_test:
# set seed
seed = 111
np.random.seed(seed)
paddle.seed(seed)
random.seed(seed)
quantize(args)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册