未验证 提交 e634afe7 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #839 from baiyfbupt/qat_issue

fix quant module
...@@ -135,7 +135,7 @@ def main(): ...@@ -135,7 +135,7 @@ def main():
if alg in ['EAST', 'DB']: if alg in ['EAST', 'DB']:
program.train_eval_det_run( program.train_eval_det_run(
config, exe, train_info_dict, eval_info_dict, is_pruning=True) config, exe, train_info_dict, eval_info_dict, is_slim="prune")
else: else:
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict) program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
......
...@@ -155,14 +155,13 @@ def main(): ...@@ -155,14 +155,13 @@ def main():
act_preprocess_func=act_preprocess_func, act_preprocess_func=act_preprocess_func,
optimizer_func=optimizer_func, optimizer_func=optimizer_func,
executor=executor, executor=executor,
for_test=False, for_test=False)
return_program=True)
# compile program for multi-devices # compile program for multi-devices
train_compile_program = program.create_multi_devices_program( train_compile_program = program.create_multi_devices_program(
quant_train_program, train_opt_loss_name, for_quant=True) quant_train_program, train_opt_loss_name, for_quant=True)
init_model(config, quant_train_program, exe) init_model(config, train_program, exe)
train_info_dict = {'compile_program':train_compile_program,\ train_info_dict = {'compile_program':train_compile_program,\
'train_program':quant_train_program,\ 'train_program':quant_train_program,\
...@@ -177,9 +176,11 @@ def main(): ...@@ -177,9 +176,11 @@ def main():
'fetch_varname_list':eval_fetch_varname_list} 'fetch_varname_list':eval_fetch_varname_list}
if train_alg_type == 'det': if train_alg_type == 'det':
program.train_eval_det_run(config, exe, train_info_dict, eval_info_dict) program.train_eval_det_run(
config, exe, train_info_dict, eval_info_dict, is_slim="quant")
else: else:
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict) program.train_eval_rec_run(
config, exe, train_info_dict, eval_info_dict, is_slim="quant")
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -241,9 +241,11 @@ def create_multi_devices_program(program, loss_var_name, for_quant=False): ...@@ -241,9 +241,11 @@ def create_multi_devices_program(program, loss_var_name, for_quant=False):
build_strategy.enable_inplace = True build_strategy.enable_inplace = True
if for_quant: if for_quant:
build_strategy.fuse_all_reduce_ops = False build_strategy.fuse_all_reduce_ops = False
else:
program = fluid.CompiledProgram(program)
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_iteration_per_drop_scope = 1 exec_strategy.num_iteration_per_drop_scope = 1
compile_program = fluid.CompiledProgram(program).with_data_parallel( compile_program = program.with_data_parallel(
loss_name=loss_var_name, loss_name=loss_var_name,
build_strategy=build_strategy, build_strategy=build_strategy,
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
...@@ -254,7 +256,7 @@ def train_eval_det_run(config, ...@@ -254,7 +256,7 @@ def train_eval_det_run(config,
exe, exe,
train_info_dict, train_info_dict,
eval_info_dict, eval_info_dict,
is_pruning=False): is_slim=None):
''' '''
main program of evaluation for detection main program of evaluation for detection
''' '''
...@@ -313,14 +315,21 @@ def train_eval_det_run(config, ...@@ -313,14 +315,21 @@ def train_eval_det_run(config,
best_batch_id = train_batch_id best_batch_id = train_batch_id
best_epoch = epoch best_epoch = epoch
save_path = save_model_dir + "/best_accuracy" save_path = save_model_dir + "/best_accuracy"
if is_pruning: if is_slim is None:
import paddleslim as slim
slim.prune.save_model(
exe, train_info_dict['train_program'],
save_path)
else:
save_model(train_info_dict['train_program'], save_model(train_info_dict['train_program'],
save_path) save_path)
else:
import paddleslim as slim
if is_slim == "prune":
slim.prune.save_model(
exe, train_info_dict['train_program'],
save_path)
elif is_slim == "quant":
save_model(eval_info_dict['program'], save_path)
else:
raise ValueError(
"Only quant and prune are supported currently. But received {}".
format(is_slim))
strs = 'Test iter: {}, metrics:{}, best_hmean:{:.6f}, best_epoch:{}, best_batch_id:{}'.format( strs = 'Test iter: {}, metrics:{}, best_hmean:{:.6f}, best_epoch:{}, best_batch_id:{}'.format(
train_batch_id, metrics, best_eval_hmean, best_epoch, train_batch_id, metrics, best_eval_hmean, best_epoch,
best_batch_id) best_batch_id)
...@@ -331,24 +340,42 @@ def train_eval_det_run(config, ...@@ -331,24 +340,42 @@ def train_eval_det_run(config,
train_loader.reset() train_loader.reset()
if epoch == 0 and save_epoch_step == 1: if epoch == 0 and save_epoch_step == 1:
save_path = save_model_dir + "/iter_epoch_0" save_path = save_model_dir + "/iter_epoch_0"
if is_pruning: if is_slim is None:
import paddleslim as slim
slim.prune.save_model(exe, train_info_dict['train_program'],
save_path)
else:
save_model(train_info_dict['train_program'], save_path) save_model(train_info_dict['train_program'], save_path)
else:
import paddleslim as slim
if is_slim == "prune":
slim.prune.save_model(exe, train_info_dict['train_program'],
save_path)
elif is_slim == "quant":
save_model(eval_info_dict['program'], save_path)
else:
raise ValueError(
"Only quant and prune are supported currently. But received {}".
format(is_slim))
if epoch > 0 and epoch % save_epoch_step == 0: if epoch > 0 and epoch % save_epoch_step == 0:
save_path = save_model_dir + "/iter_epoch_%d" % (epoch) save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
if is_pruning: if is_slim is None:
import paddleslim as slim
slim.prune.save_model(exe, train_info_dict['train_program'],
save_path)
else:
save_model(train_info_dict['train_program'], save_path) save_model(train_info_dict['train_program'], save_path)
else:
import paddleslim as slim
if is_slim == "prune":
slim.prune.save_model(exe, train_info_dict['train_program'],
save_path)
elif is_slim == "quant":
save_model(eval_info_dict['program'], save_path)
else:
raise ValueError(
"Only quant and prune are supported currently. But received {}".
format(is_slim))
return return
def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict): def train_eval_rec_run(config,
exe,
train_info_dict,
eval_info_dict,
is_slim=None):
''' '''
main program of evaluation for recognition main program of evaluation for recognition
''' '''
...@@ -428,7 +455,21 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict): ...@@ -428,7 +455,21 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
best_batch_id = train_batch_id best_batch_id = train_batch_id
best_epoch = epoch best_epoch = epoch
save_path = save_model_dir + "/best_accuracy" save_path = save_model_dir + "/best_accuracy"
save_model(train_info_dict['train_program'], save_path) if is_slim is None:
save_model(train_info_dict['train_program'],
save_path)
else:
import paddleslim as slim
if is_slim == "prune":
slim.prune.save_model(
exe, train_info_dict['train_program'],
save_path)
elif is_slim == "quant":
save_model(eval_info_dict['program'], save_path)
else:
raise ValueError(
"Only quant and prune are supported currently. But received {}".
format(is_slim))
strs = 'Test iter: {}, acc:{:.6f}, best_acc:{:.6f}, best_epoch:{}, best_batch_id:{}, eval_sample_num:{}'.format( strs = 'Test iter: {}, acc:{:.6f}, best_acc:{:.6f}, best_epoch:{}, best_batch_id:{}, eval_sample_num:{}'.format(
train_batch_id, eval_acc, best_eval_acc, best_epoch, train_batch_id, eval_acc, best_eval_acc, best_epoch,
best_batch_id, eval_sample_num) best_batch_id, eval_sample_num)
...@@ -439,14 +480,42 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict): ...@@ -439,14 +480,42 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
train_loader.reset() train_loader.reset()
if epoch == 0 and save_epoch_step == 1: if epoch == 0 and save_epoch_step == 1:
save_path = save_model_dir + "/iter_epoch_0" save_path = save_model_dir + "/iter_epoch_0"
save_model(train_info_dict['train_program'], save_path) if is_slim is None:
save_model(train_info_dict['train_program'], save_path)
else:
import paddleslim as slim
if is_slim == "prune":
slim.prune.save_model(exe, train_info_dict['train_program'],
save_path)
elif is_slim == "quant":
save_model(eval_info_dict['program'], save_path)
else:
raise ValueError(
"Only quant and prune are supported currently. But received {}".
format(is_slim))
if epoch > 0 and epoch % save_epoch_step == 0: if epoch > 0 and epoch % save_epoch_step == 0:
save_path = save_model_dir + "/iter_epoch_%d" % (epoch) save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
save_model(train_info_dict['train_program'], save_path) if is_slim is None:
save_model(train_info_dict['train_program'], save_path)
else:
import paddleslim as slim
if is_slim == "prune":
slim.prune.save_model(exe, train_info_dict['train_program'],
save_path)
elif is_slim == "quant":
save_model(eval_info_dict['program'], save_path)
else:
raise ValueError(
"Only quant and prune are supported currently. But received {}".
format(is_slim))
return return
def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict): def train_eval_cls_run(config,
exe,
train_info_dict,
eval_info_dict,
is_slim=None):
train_batch_id = 0 train_batch_id = 0
log_smooth_window = config['Global']['log_smooth_window'] log_smooth_window = config['Global']['log_smooth_window']
epoch_num = config['Global']['epoch_num'] epoch_num = config['Global']['epoch_num']
...@@ -509,7 +578,21 @@ def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict): ...@@ -509,7 +578,21 @@ def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict):
best_batch_id = train_batch_id best_batch_id = train_batch_id
best_epoch = epoch best_epoch = epoch
save_path = save_model_dir + "/best_accuracy" save_path = save_model_dir + "/best_accuracy"
save_model(train_info_dict['train_program'], save_path) if is_slim is None:
save_model(train_info_dict['train_program'],
save_path)
else:
import paddleslim as slim
if is_slim == "prune":
slim.prune.save_model(
exe, train_info_dict['train_program'],
save_path)
elif is_slim == "quant":
save_model(eval_info_dict['program'], save_path)
else:
raise ValueError(
"Only quant and prune are supported currently. But received {}".
format(is_slim))
strs = 'Test iter: {}, acc:{:.6f}, best_acc:{:.6f}, best_epoch:{}, best_batch_id:{}, eval_sample_num:{}'.format( strs = 'Test iter: {}, acc:{:.6f}, best_acc:{:.6f}, best_epoch:{}, best_batch_id:{}, eval_sample_num:{}'.format(
train_batch_id, eval_acc, best_eval_acc, best_epoch, train_batch_id, eval_acc, best_eval_acc, best_epoch,
best_batch_id, eval_sample_num) best_batch_id, eval_sample_num)
...@@ -520,10 +603,34 @@ def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict): ...@@ -520,10 +603,34 @@ def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict):
train_loader.reset() train_loader.reset()
if epoch == 0 and save_epoch_step == 1: if epoch == 0 and save_epoch_step == 1:
save_path = save_model_dir + "/iter_epoch_0" save_path = save_model_dir + "/iter_epoch_0"
save_model(train_info_dict['train_program'], save_path) if is_slim is None:
save_model(train_info_dict['train_program'], save_path)
else:
import paddleslim as slim
if is_slim == "prune":
slim.prune.save_model(exe, train_info_dict['train_program'],
save_path)
elif is_slim == "quant":
save_model(eval_info_dict['program'], save_path)
else:
raise ValueError(
"Only quant and prune are supported currently. But received {}".
format(is_slim))
if epoch > 0 and epoch % save_epoch_step == 0: if epoch > 0 and epoch % save_epoch_step == 0:
save_path = save_model_dir + "/iter_epoch_%d" % (epoch) save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
save_model(train_info_dict['train_program'], save_path) if is_slim is None:
save_model(train_info_dict['train_program'], save_path)
else:
import paddleslim as slim
if is_slim == "prune":
slim.prune.save_model(exe, train_info_dict['train_program'],
save_path)
elif is_slim == "quant":
save_model(eval_info_dict['program'], save_path)
else:
raise ValueError(
"Only quant and prune are supported currently. But received {}".
format(is_slim))
return return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册