提交 4907dcd9 编写于 作者: D Dun 提交者: qingqing01

fix #1585 (#1586)

* Fix Python3.
上级 76448c34
......@@ -34,7 +34,10 @@ def add_arguments():
add_argument('parallel', bool, False, "using ParallelExecutor.")
add_argument('use_gpu', bool, True, "Whether use GPU or CPU.")
add_argument('num_classes', int, 19, "Number of classes.")
parser.add_argument('--enable_ce', action='store_true', help='If set, run the task with continuous evaluation logs.')
parser.add_argument(
'--enable_ce',
action='store_true',
help='If set, run the task with continuous evaluation logs.')
def load_model():
......@@ -52,7 +55,10 @@ def load_model():
else:
if args.num_classes == 19:
fluid.io.load_params(
exe, dirname=args.init_weights_path, main_program=tp)
exe,
dirname="",
filename=args.init_weights_path,
main_program=tp)
else:
fluid.io.load_vars(
exe, dirname="", filename=args.init_weights_path, vars=myvars)
......@@ -93,6 +99,7 @@ def get_cards(args):
else:
return args.num_devices
CityscapeDataset = reader.CityscapeDataset
parser = argparse.ArgumentParser()
......@@ -203,8 +210,7 @@ if args.enable_ce:
gpu_num = get_cards(args)
print("kpis\teach_pass_duration_card%s\t%s" %
(gpu_num, total_time / epoch_idx))
print("kpis\ttrain_loss_card%s\t%s" %
(gpu_num, train_loss))
print("kpis\ttrain_loss_card%s\t%s" % (gpu_num, train_loss))
print("Training done. Model is saved to", args.save_weights_path)
save_model()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册