未验证 提交 524016b3 编写于 作者: D Double_V 提交者: GitHub

Fix pretrained models load problem when load params saved use API save_params (#4471)

* fix ocr pretrain_model problem

* fix ocr pretrain_model problem
上级 13bbc302
......@@ -66,9 +66,11 @@ def evaluate(args):
model_dir = args.model_path
if os.path.isdir(args.model_path):
raise Exception("{} should not be a directory".format(args.model_path))
fluid.load(program=fluid.default_main_program(),
model_path=model_dir,
executor=exe)
fluid.load(
program=fluid.default_main_program(),
model_path=model_dir,
executor=exe,
var_list=fluid.io.get_program_parameter(fluid.default_main_program()))
print("Init model from: %s." % args.model_path)
evaluator.reset(exe)
......
......@@ -87,7 +87,8 @@ def inference(args):
fluid.load(
program=fluid.default_main_program(),
model_path=model_dir,
executor=exe)
executor=exe,
var_list=fluid.io.get_program_parameter(fluid.default_main_program()))
print("Init model from: %s." % args.model_path)
batch_times = []
......
......@@ -106,7 +106,11 @@ def train(args):
# load init model
if args.init_model is not None:
model_dir = args.init_model
fluid.load(fluid.default_main_program(), model_dir)
fluid.load(
fluid.default_main_program(),
model_dir,
var_list=fluid.io.get_program_parameter(fluid.default_main_program(
)))
print("Init model from: %s." % args.init_model)
train_exe = exe
......@@ -135,7 +139,8 @@ def train(args):
exe.run(inference_program, feed=get_feeder_data(data, place))
_, test_seq_error = error_evaluator.eval(exe)
print("\n[%s] - Iter[%d]; Test seq error: %s.\n" %
(time.asctime( time.localtime(time.time())), iter_num, str(test_seq_error[0])))
(time.asctime(time.localtime(time.time())), iter_num,
str(test_seq_error[0])))
#Note: The following logs are special for CE monitoring.
#Other situations do not need to care about these logs.
......@@ -175,15 +180,16 @@ def train(args):
iter_num += 1
# training log
if iter_num % args.log_period == 0:
print("\n[%s] - Iter[%d]; Avg loss: %.3f; Avg seq err: %.3f"
% (time.asctime( time.localtime(time.time())), iter_num,
total_loss / (args.log_period * args.batch_size),
total_seq_error / (args.log_period * args.batch_size)))
print("\n[%s] - Iter[%d]; Avg loss: %.3f; Avg seq err: %.3f" %
(time.asctime(time.localtime(time.time())), iter_num,
total_loss / (args.log_period * args.batch_size),
total_seq_error / (args.log_period * args.batch_size)))
if 'ce_mode' in os.environ:
print("kpis train_cost %f" % (total_loss / (args.log_period *
args.batch_size)))
print("kpis train_acc %f" % (
1 - total_seq_error / (args.log_period * args.batch_size)))
print("kpis train_cost %f" %
(total_loss / (args.log_period * args.batch_size)))
print("kpis train_acc %f" %
(1 - total_seq_error /
(args.log_period * args.batch_size)))
total_loss = 0.0
total_seq_error = 0.0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册