未验证 提交 bf123166 编写于 作者: G Guanghua Yu 提交者: GitHub

add post_quant CE (#961)

上级 06175aea
...@@ -20,6 +20,7 @@ import contextlib ...@@ -20,6 +20,7 @@ import contextlib
import os import os
import time import time
import math import math
import random
import numpy as np import numpy as np
from PIL import Image from PIL import Image
...@@ -28,7 +29,6 @@ from paddle.io import Dataset ...@@ -28,7 +29,6 @@ from paddle.io import Dataset
from paddle.vision.transforms import transforms from paddle.vision.transforms import transforms
import paddle.vision.models as models import paddle.vision.models as models
import paddle.nn as nn import paddle.nn as nn
from paddleslim import PTQ from paddleslim import PTQ
import sys import sys
...@@ -41,7 +41,6 @@ from models.dygraph.mobilenet_v3 import MobileNetV3_large_x1_0 ...@@ -41,7 +41,6 @@ from models.dygraph.mobilenet_v3 import MobileNetV3_large_x1_0
class ImageNetValDataset(Dataset): class ImageNetValDataset(Dataset):
def __init__(self, data_dir, image_size=224, resize_short_size=256): def __init__(self, data_dir, image_size=224, resize_short_size=256):
super(ImageNetValDataset, self).__init__() super(ImageNetValDataset, self).__init__()
train_file_list = os.path.join(data_dir, 'train_list.txt')
val_file_list = os.path.join(data_dir, 'val_list.txt') val_file_list = os.path.join(data_dir, 'val_list.txt')
test_file_list = os.path.join(data_dir, 'test_list.txt') test_file_list = os.path.join(data_dir, 'test_list.txt')
self.data_dir = data_dir self.data_dir = data_dir
...@@ -68,9 +67,9 @@ class ImageNetValDataset(Dataset): ...@@ -68,9 +67,9 @@ class ImageNetValDataset(Dataset):
return len(self.data) return len(self.data)
def calibrate(model, dataset, batch_num, batch_size): def calibrate(model, dataset, batch_num, batch_size, num_workers=5):
data_loader = paddle.io.DataLoader( data_loader = paddle.io.DataLoader(
dataset, batch_size=batch_size, num_workers=5) dataset, batch_size=batch_size, num_workers=num_workers)
for idx, data in enumerate(data_loader()): for idx, data in enumerate(data_loader()):
img = data[0] img = data[0]
...@@ -85,6 +84,15 @@ def calibrate(model, dataset, batch_num, batch_size): ...@@ -85,6 +84,15 @@ def calibrate(model, dataset, batch_num, batch_size):
def main(): def main():
num_workers = 5
if FLAGS.ce_test:
# set seed
seed = 111
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
num_workers = 0
# 1 load model # 1 load model
model_list = [x for x in models.__dict__["__all__"]] model_list = [x for x in models.__dict__["__all__"]]
model_list.append('mobilenet_v3') model_list.append('mobilenet_v3')
...@@ -126,8 +134,12 @@ def main(): ...@@ -126,8 +134,12 @@ def main():
quant_model = ptq.quantize(fp32_model, fuse=FLAGS.fuse, fuse_list=fuse_list) quant_model = ptq.quantize(fp32_model, fuse=FLAGS.fuse, fuse_list=fuse_list)
print("Start calibrate...") print("Start calibrate...")
calibrate(quant_model, val_dataset, FLAGS.quant_batch_num, calibrate(
FLAGS.quant_batch_size) quant_model,
val_dataset,
FLAGS.quant_batch_num,
FLAGS.quant_batch_size,
num_workers=num_workers)
# 3 save # 3 save
quant_output_dir = os.path.join(FLAGS.output_dir, FLAGS.model, "int8_infer", quant_output_dir = os.path.join(FLAGS.output_dir, FLAGS.model, "int8_infer",
...@@ -172,6 +184,8 @@ if __name__ == '__main__': ...@@ -172,6 +184,8 @@ if __name__ == '__main__':
"--quant_batch_num", default=10, type=int, help="batch num for quant") "--quant_batch_num", default=10, type=int, help="batch num for quant")
parser.add_argument( parser.add_argument(
"--quant_batch_size", default=10, type=int, help="batch size for quant") "--quant_batch_size", default=10, type=int, help="batch size for quant")
parser.add_argument(
'--ce_test', default=False, type=bool, help="Whether to CE test.")
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
assert FLAGS.data, "error: must provide data path" assert FLAGS.data, "error: must provide data path"
......
...@@ -10,6 +10,7 @@ import paddle ...@@ -10,6 +10,7 @@ import paddle
import six import six
import reader import reader
from net import skip_gram_word2vec from net import skip_gram_word2vec
import paddle
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("paddle") logger = logging.getLogger("paddle")
...@@ -77,6 +78,12 @@ def parse_args(): ...@@ -77,6 +78,12 @@ def parse_args():
required=False, required=False,
default=False, default=False,
help='print speed or not , (default: False)') help='print speed or not , (default: False)')
parser.add_argument(
'--ce_test',
required=False,
default=False,
help='Whether to CE test, (default: False)')
return parser.parse_args() return parser.parse_args()
...@@ -185,6 +192,12 @@ def GetFileList(data_path): ...@@ -185,6 +192,12 @@ def GetFileList(data_path):
def train(args): def train(args):
if args.ce_test:
# set seed
seed = 111
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
if not os.path.isdir(args.model_output_dir): if not os.path.isdir(args.model_output_dir):
os.mkdir(args.model_output_dir) os.mkdir(args.model_output_dir)
......
...@@ -6,7 +6,9 @@ import argparse ...@@ -6,7 +6,9 @@ import argparse
import functools import functools
import math import math
import time import time
import random
import numpy as np import numpy as np
import paddle
sys.path[0] = os.path.join( sys.path[0] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir) os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
...@@ -29,12 +31,13 @@ add_arg('params_filename', str, None, "params file name") ...@@ -29,12 +31,13 @@ add_arg('params_filename', str, None, "params file name")
add_arg('algo', str, 'hist', "calibration algorithm") add_arg('algo', str, 'hist', "calibration algorithm")
add_arg('hist_percent', float, 0.9999, "The percentile of algo:hist") add_arg('hist_percent', float, 0.9999, "The percentile of algo:hist")
add_arg('bias_correction', bool, False, "Whether to use bias correction") add_arg('bias_correction', bool, False, "Whether to use bias correction")
add_arg('ce_test', bool, False, "Whether to CE test.")
# yapf: enable # yapf: enable
def quantize(args): def quantize(args):
val_reader = reader.train() val_reader = reader.val()
place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace() place = paddle.CUDAPlace(0) if args.use_gpu else paddle.CPUPlace()
...@@ -59,6 +62,12 @@ def quantize(args): ...@@ -59,6 +62,12 @@ def quantize(args):
def main(): def main():
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args)
if args.ce_test:
# set seed
seed = 111
np.random.seed(seed)
paddle.seed(seed)
random.seed(seed)
quantize(args) quantize(args)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册