提交 afec85f2 编写于 作者: O ouyangyu

Add model save parameters

上级 f51c43e2
......@@ -92,6 +92,24 @@ def get_parser(parser=None):
parser.add_argument(
"--model_load_dir", type=str, default=None, help="model load directory if need"
)
parser.add_argument(
"--save_epoch_interval",
type=int,
default=10,
help="Number of iterations between checkpoint saves.",
)
parser.add_argument(
"--save_last",
action="store_true",
default=False,
help="save model snapshot for last iteration",
)
parser.add_argument(
"--save_init",
action="store_true",
default=False,
help="save model snapshot for inited",
)
parser.add_argument("--batch_size_per_device", type=int, default=64)
parser.add_argument("--val_batch_size_per_device", type=int, default=8)
......
......@@ -67,6 +67,7 @@ if args.nccl_fusion_max_ops:
if args.num_nodes > 1 and args.use_rdma:
flow.config.use_rdma(True)
def label_smoothing(labels, classes, eta, dtype):
assert classes > 0
assert eta >= 0.0 and eta < 1.0
......@@ -132,11 +133,11 @@ def main():
InitNodes(args)
flow.env.log_dir(args.log_dir)
snapshot = Snapshot(args.model_save_dir, args.model_load_dir)
snapshot = Snapshot(args.model_save_dir, args.model_load_dir, args.save_init)
print(" {} iter per epoch...".format(epoch_size))
for epoch in range(args.num_epochs):
for epoch in range(1, args.num_epochs):
metric = Metric(
desc="train",
calculate_batches=args.loss_print_every_n_iter,
......@@ -154,7 +155,11 @@ def main():
)
for i in range(num_val_steps):
InferenceNet().async_get(metric.metric_cb(epoch, i))
snapshot.save("epoch_{}".format(epoch))
if epoch % args.save_epoch_interval == 0:
snapshot.save("epoch_{}".format(epoch))
if args.save_last:
snapshot.save("epoch_{}".format("last"))
if __name__ == "__main__":
......
......@@ -36,14 +36,14 @@ def InitNodes(args):
class Snapshot(object):
def __init__(self, model_save_dir, model_load_dir):
def __init__(self, model_save_dir, model_load_dir, save_init=False):
self._model_save_dir = model_save_dir
if model_load_dir:
assert os.path.isdir(model_load_dir)
print("Restoring model from {}.".format(model_load_dir))
flow.load_variables(flow.checkpoint.get(model_load_dir))
else:
# flow.checkpoint.save("initial_model")
elif save_init:
flow.checkpoint.save("initial_model")
print("Init model on demand.")
def save(self, name):
......
......@@ -121,6 +121,13 @@ def get_parser(parser=None):
required=False,
help="model save directory",
)
parser.add_argument(
"--model_save_init",
action="store_true",
default=False,
help="save model snapshot for inited",
)
parser.add_argument(
"--save_last_snapshot",
type=str2bool,
......
......@@ -124,7 +124,7 @@ def main():
InitNodes(args)
snapshot = Snapshot(args.model_save_dir, args.model_load_dir)
snapshot = Snapshot(args.model_save_dir, args.model_load_dir, args.model_save_init)
print("num_accumulation_steps:", args.num_accumulation_steps)
metric = Metric(
......
......@@ -37,13 +37,13 @@ def InitNodes(args):
class Snapshot(object):
def __init__(self, model_save_dir, model_load_dir):
def __init__(self, model_save_dir, model_load_dir, model_save_init=False):
self._model_save_dir = model_save_dir
if model_load_dir:
assert os.path.isdir(model_load_dir)
print("Restoring model from {}.".format(model_load_dir))
flow.load_variables(flow.checkpoint.get(model_load_dir))
else:
elif model_save_init:
flow.checkpoint.save("initial_model")
print("Init model on demand.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册