未验证 提交 e634afe7 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #839 from baiyfbupt/qat_issue

fix quant module
......@@ -135,7 +135,7 @@ def main():
if alg in ['EAST', 'DB']:
program.train_eval_det_run(
config, exe, train_info_dict, eval_info_dict, is_pruning=True)
config, exe, train_info_dict, eval_info_dict, is_slim="prune")
else:
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
......
......@@ -155,14 +155,13 @@ def main():
act_preprocess_func=act_preprocess_func,
optimizer_func=optimizer_func,
executor=executor,
for_test=False,
return_program=True)
for_test=False)
# compile program for multi-devices
train_compile_program = program.create_multi_devices_program(
quant_train_program, train_opt_loss_name, for_quant=True)
init_model(config, quant_train_program, exe)
init_model(config, train_program, exe)
train_info_dict = {'compile_program':train_compile_program,\
'train_program':quant_train_program,\
......@@ -177,9 +176,11 @@ def main():
'fetch_varname_list':eval_fetch_varname_list}
if train_alg_type == 'det':
program.train_eval_det_run(config, exe, train_info_dict, eval_info_dict)
program.train_eval_det_run(
config, exe, train_info_dict, eval_info_dict, is_slim="quant")
else:
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
program.train_eval_rec_run(
config, exe, train_info_dict, eval_info_dict, is_slim="quant")
if __name__ == '__main__':
......
......@@ -241,9 +241,11 @@ def create_multi_devices_program(program, loss_var_name, for_quant=False):
build_strategy.enable_inplace = True
if for_quant:
build_strategy.fuse_all_reduce_ops = False
else:
program = fluid.CompiledProgram(program)
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_iteration_per_drop_scope = 1
compile_program = fluid.CompiledProgram(program).with_data_parallel(
compile_program = program.with_data_parallel(
loss_name=loss_var_name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
......@@ -254,7 +256,7 @@ def train_eval_det_run(config,
exe,
train_info_dict,
eval_info_dict,
is_pruning=False):
is_slim=None):
'''
main program of evaluation for detection
'''
......@@ -313,14 +315,21 @@ def train_eval_det_run(config,
best_batch_id = train_batch_id
best_epoch = epoch
save_path = save_model_dir + "/best_accuracy"
if is_pruning:
import paddleslim as slim
slim.prune.save_model(
exe, train_info_dict['train_program'],
save_path)
else:
if is_slim is None:
save_model(train_info_dict['train_program'],
save_path)
else:
import paddleslim as slim
if is_slim == "prune":
slim.prune.save_model(
exe, train_info_dict['train_program'],
save_path)
elif is_slim == "quant":
save_model(eval_info_dict['program'], save_path)
else:
raise ValueError(
"Only quant and prune are supported currently. But received {}".
format(is_slim))
strs = 'Test iter: {}, metrics:{}, best_hmean:{:.6f}, best_epoch:{}, best_batch_id:{}'.format(
train_batch_id, metrics, best_eval_hmean, best_epoch,
best_batch_id)
......@@ -331,24 +340,42 @@ def train_eval_det_run(config,
train_loader.reset()
if epoch == 0 and save_epoch_step == 1:
save_path = save_model_dir + "/iter_epoch_0"
if is_pruning:
import paddleslim as slim
slim.prune.save_model(exe, train_info_dict['train_program'],
save_path)
else:
if is_slim is None:
save_model(train_info_dict['train_program'], save_path)
else:
import paddleslim as slim
if is_slim == "prune":
slim.prune.save_model(exe, train_info_dict['train_program'],
save_path)
elif is_slim == "quant":
save_model(eval_info_dict['program'], save_path)
else:
raise ValueError(
"Only quant and prune are supported currently. But received {}".
format(is_slim))
if epoch > 0 and epoch % save_epoch_step == 0:
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
if is_pruning:
import paddleslim as slim
slim.prune.save_model(exe, train_info_dict['train_program'],
save_path)
else:
if is_slim is None:
save_model(train_info_dict['train_program'], save_path)
else:
import paddleslim as slim
if is_slim == "prune":
slim.prune.save_model(exe, train_info_dict['train_program'],
save_path)
elif is_slim == "quant":
save_model(eval_info_dict['program'], save_path)
else:
raise ValueError(
"Only quant and prune are supported currently. But received {}".
format(is_slim))
return
def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
def train_eval_rec_run(config,
exe,
train_info_dict,
eval_info_dict,
is_slim=None):
'''
main program of evaluation for recognition
'''
......@@ -428,7 +455,21 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
best_batch_id = train_batch_id
best_epoch = epoch
save_path = save_model_dir + "/best_accuracy"
save_model(train_info_dict['train_program'], save_path)
if is_slim is None:
save_model(train_info_dict['train_program'],
save_path)
else:
import paddleslim as slim
if is_slim == "prune":
slim.prune.save_model(
exe, train_info_dict['train_program'],
save_path)
elif is_slim == "quant":
save_model(eval_info_dict['program'], save_path)
else:
raise ValueError(
"Only quant and prune are supported currently. But received {}".
format(is_slim))
strs = 'Test iter: {}, acc:{:.6f}, best_acc:{:.6f}, best_epoch:{}, best_batch_id:{}, eval_sample_num:{}'.format(
train_batch_id, eval_acc, best_eval_acc, best_epoch,
best_batch_id, eval_sample_num)
......@@ -439,14 +480,42 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
train_loader.reset()
if epoch == 0 and save_epoch_step == 1:
save_path = save_model_dir + "/iter_epoch_0"
save_model(train_info_dict['train_program'], save_path)
if is_slim is None:
save_model(train_info_dict['train_program'], save_path)
else:
import paddleslim as slim
if is_slim == "prune":
slim.prune.save_model(exe, train_info_dict['train_program'],
save_path)
elif is_slim == "quant":
save_model(eval_info_dict['program'], save_path)
else:
raise ValueError(
"Only quant and prune are supported currently. But received {}".
format(is_slim))
if epoch > 0 and epoch % save_epoch_step == 0:
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
save_model(train_info_dict['train_program'], save_path)
if is_slim is None:
save_model(train_info_dict['train_program'], save_path)
else:
import paddleslim as slim
if is_slim == "prune":
slim.prune.save_model(exe, train_info_dict['train_program'],
save_path)
elif is_slim == "quant":
save_model(eval_info_dict['program'], save_path)
else:
raise ValueError(
"Only quant and prune are supported currently. But received {}".
format(is_slim))
return
def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict):
def train_eval_cls_run(config,
exe,
train_info_dict,
eval_info_dict,
is_slim=None):
train_batch_id = 0
log_smooth_window = config['Global']['log_smooth_window']
epoch_num = config['Global']['epoch_num']
......@@ -509,7 +578,21 @@ def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict):
best_batch_id = train_batch_id
best_epoch = epoch
save_path = save_model_dir + "/best_accuracy"
save_model(train_info_dict['train_program'], save_path)
if is_slim is None:
save_model(train_info_dict['train_program'],
save_path)
else:
import paddleslim as slim
if is_slim == "prune":
slim.prune.save_model(
exe, train_info_dict['train_program'],
save_path)
elif is_slim == "quant":
save_model(eval_info_dict['program'], save_path)
else:
raise ValueError(
"Only quant and prune are supported currently. But received {}".
format(is_slim))
strs = 'Test iter: {}, acc:{:.6f}, best_acc:{:.6f}, best_epoch:{}, best_batch_id:{}, eval_sample_num:{}'.format(
train_batch_id, eval_acc, best_eval_acc, best_epoch,
best_batch_id, eval_sample_num)
......@@ -520,10 +603,34 @@ def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict):
train_loader.reset()
if epoch == 0 and save_epoch_step == 1:
save_path = save_model_dir + "/iter_epoch_0"
save_model(train_info_dict['train_program'], save_path)
if is_slim is None:
save_model(train_info_dict['train_program'], save_path)
else:
import paddleslim as slim
if is_slim == "prune":
slim.prune.save_model(exe, train_info_dict['train_program'],
save_path)
elif is_slim == "quant":
save_model(eval_info_dict['program'], save_path)
else:
raise ValueError(
"Only quant and prune are supported currently. But received {}".
format(is_slim))
if epoch > 0 and epoch % save_epoch_step == 0:
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
save_model(train_info_dict['train_program'], save_path)
if is_slim is None:
save_model(train_info_dict['train_program'], save_path)
else:
import paddleslim as slim
if is_slim == "prune":
slim.prune.save_model(exe, train_info_dict['train_program'],
save_path)
elif is_slim == "quant":
save_model(eval_info_dict['program'], save_path)
else:
raise ValueError(
"Only quant and prune are supported currently. But received {}".
format(is_slim))
return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册