提交 d3ef4d2e 编写于 作者: C chenguowei01

fix quant online program

上级 1a8bd484
......@@ -242,30 +242,11 @@ class SegModel(object):
if self.status == 'Normal':
fluid.save(self.train_prog, osp.join(save_dir, 'model'))
model_info['status'] = 'Normal'
elif self.status == 'Quant':
float_prog, _ = slim.quant.convert(
self.test_prog, self.exe.place, save_int8=True)
test_input_names = [
var.name for var in list(self.test_inputs.values())
]
test_outputs = list(self.test_outputs.values())
fluid.io.save_inference_model(
dirname=save_dir,
executor=self.exe,
params_filename='__params__',
feeded_var_names=test_input_names,
target_vars=test_outputs,
main_program=float_prog)
fluid.save(self.test_prog, osp.join(save_dir, 'model'))
model_info['status'] = 'QuantOnline'
model_info['_ModelInputsOutputs'] = dict()
model_info['_ModelInputsOutputs']['test_inputs'] = [
[k, v.name] for k, v in self.test_inputs.items()
]
model_info['_ModelInputsOutputs']['test_outputs'] = [
[k, v.name] for k, v in self.test_outputs.items()
]
model_info['status'] = self.status
with open(
osp.join(save_dir, 'model.yml'), encoding='utf-8',
mode='w') as f:
......@@ -307,11 +288,13 @@ class SegModel(object):
logging.info("Model for inference deploy saved in {}.".format(save_dir))
def export_quant_model(self,
dataset,
save_dir,
dataset=None,
save_dir=None,
batch_size=1,
batch_nums=10,
cache_dir="./.temp"):
cache_dir="./.temp",
quant_type="offline"):
if quant_type == "offline":
self.arrange_transform(transforms=dataset.transforms, mode='quant')
dataset.num_samples = batch_size * batch_nums
try:
......@@ -340,7 +323,22 @@ class SegModel(object):
post_training_quantization.quantize()
post_training_quantization.save_quantized_model(save_dir)
if cache_dir is not None:
os.system('rm -r' + cache_dir)
os.system('rm -r ' + cache_dir)
else:
float_prog, _ = slim.quant.convert(
self.test_prog, self.exe.place, save_int8=True)
test_input_names = [
var.name for var in list(self.test_inputs.values())
]
test_outputs = list(self.test_outputs.values())
fluid.io.save_inference_model(
dirname=save_dir,
executor=self.exe,
params_filename='__params__',
feeded_var_names=test_input_names,
target_vars=test_outputs,
main_program=float_prog)
model_info = self.get_model_info()
model_info['status'] = 'Quant'
......@@ -592,6 +590,16 @@ class SegModel(object):
'Current evaluated best model in eval_dataset is epoch_{}, miou={}'
.format(best_model_epoch, best_miou))
if quant:
if osp.exists(osp.join(save_dir, "best_model")):
fluid.load(
program=self.test_prog,
model_path=osp.join(save_dir, "best_model"),
executor=self.exe)
self.export_quant_model(
save_dir=osp.join(save_dir, "best_model_export"),
quant_type="online")
def evaluate(self, eval_dataset, batch_size=1, epoch_id=None):
"""评估。
......
......@@ -33,7 +33,7 @@ def load_model(model_dir):
raise Exception("There's no attribute {} in models".format(
info['Model']))
model = getattr(models, info['Model'])(**info['_init_params'])
if status == "Normal":
if status in ["Normal", "QuantOnline"]:
startup_prog = fluid.Program()
model.test_prog = fluid.Program()
with fluid.program_guard(model.test_prog, startup_prog):
......@@ -41,11 +41,16 @@ def load_model(model_dir):
model.test_inputs, model.test_outputs = model.build_net(
mode='test')
model.test_prog = model.test_prog.clone(for_test=True)
if status == "QuantOnline":
print('test quant online')
import paddleslim as slim
model.test_prog = slim.quant.quant_aware(
model.test_prog, model.exe.place, for_test=True)
model.exe.run(startup_prog)
import pickle
with open(osp.join(model_dir, 'model.pdparams'), 'rb') as f:
load_dict = pickle.load(f)
fluid.io.set_program_state(model.test_prog, load_dict)
fluid.load(model.test_prog, osp.join(model_dir, 'model'))
if status == "QuantOnline":
model.test_prog = slim.quant.convert(model.test_prog,
model.exe.place)
elif status in ['Infer', 'Quant']:
[prog, input_names, outputs] = fluid.io.load_inference_model(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册