提交 8a0c47d3 编写于 作者: Y yao_yf

wide&deep only save 0ckpt in data parallel

上级 245415f5
......@@ -109,8 +109,11 @@ def train_and_eval(config):
directory=config.ckpt_path, config=ckptconfig)
out = model.eval(ds_eval)
print("=====" * 5 + "model.eval() initialized: {}".format(out))
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
if get_rank() == 0:
callback_list.append(ckpoint_cb)
model.train(epochs, ds_train,
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb],
callbacks=callback_list,
sink_size=ds_train.get_dataset_size())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册