提交 4907dcd9 编写于 作者: D Dun 提交者: qingqing01

fix #1585 (#1586)

* Fix Python3.
上级 76448c34
...@@ -34,7 +34,10 @@ def add_arguments(): ...@@ -34,7 +34,10 @@ def add_arguments():
add_argument('parallel', bool, False, "using ParallelExecutor.") add_argument('parallel', bool, False, "using ParallelExecutor.")
add_argument('use_gpu', bool, True, "Whether use GPU or CPU.") add_argument('use_gpu', bool, True, "Whether use GPU or CPU.")
add_argument('num_classes', int, 19, "Number of classes.") add_argument('num_classes', int, 19, "Number of classes.")
parser.add_argument('--enable_ce', action='store_true', help='If set, run the task with continuous evaluation logs.') parser.add_argument(
'--enable_ce',
action='store_true',
help='If set, run the task with continuous evaluation logs.')
def load_model(): def load_model():
...@@ -52,7 +55,10 @@ def load_model(): ...@@ -52,7 +55,10 @@ def load_model():
else: else:
if args.num_classes == 19: if args.num_classes == 19:
fluid.io.load_params( fluid.io.load_params(
exe, dirname=args.init_weights_path, main_program=tp) exe,
dirname="",
filename=args.init_weights_path,
main_program=tp)
else: else:
fluid.io.load_vars( fluid.io.load_vars(
exe, dirname="", filename=args.init_weights_path, vars=myvars) exe, dirname="", filename=args.init_weights_path, vars=myvars)
...@@ -93,6 +99,7 @@ def get_cards(args): ...@@ -93,6 +99,7 @@ def get_cards(args):
else: else:
return args.num_devices return args.num_devices
CityscapeDataset = reader.CityscapeDataset CityscapeDataset = reader.CityscapeDataset
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -202,9 +209,8 @@ for i, imgs, labels, names in batches: ...@@ -202,9 +209,8 @@ for i, imgs, labels, names in batches:
if args.enable_ce: if args.enable_ce:
gpu_num = get_cards(args) gpu_num = get_cards(args)
print("kpis\teach_pass_duration_card%s\t%s" % print("kpis\teach_pass_duration_card%s\t%s" %
(gpu_num, total_time / epoch_idx)) (gpu_num, total_time / epoch_idx))
print("kpis\ttrain_loss_card%s\t%s" % print("kpis\ttrain_loss_card%s\t%s" % (gpu_num, train_loss))
(gpu_num, train_loss))
print("Training done. Model is saved to", args.save_weights_path) print("Training done. Model is saved to", args.save_weights_path)
save_model() save_model()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册