提交 d98480cf 编写于 作者: T tangwei12

fix serial number

上级 22df4c27
...@@ -114,8 +114,8 @@ class CheckpointLoadOp : public framework::OperatorBase { ...@@ -114,8 +114,8 @@ class CheckpointLoadOp : public framework::OperatorBase {
std::string dir = Attr<std::string>("dir"); std::string dir = Attr<std::string>("dir");
std::string serial_num_attr = Attr<std::string>("Serial"); std::string serial_num_attr = Attr<std::string>("Serial");
PADDLE_ENFORCE(IsNumber(serial_num_attr), VLOG(3) << "CheckpointLoadOp get Attr dir: " << dir;
"Checkpoint Serial must be a number"); VLOG(3) << "CheckpointLoadOp get Attr Serial: " << serial_num_attr;
std::string serial_var_name = std::string(SERIAL_VAR); std::string serial_var_name = std::string(SERIAL_VAR);
auto *serial_var = scope.FindVar(serial_var_name); auto *serial_var = scope.FindVar(serial_var_name);
......
...@@ -654,6 +654,9 @@ class DistributeTranspiler: ...@@ -654,6 +654,9 @@ class DistributeTranspiler:
if os.path.isfile(success_path): if os.path.isfile(success_path):
return int(cur_dir) return int(cur_dir)
if os.path.isdir(checkpoint_dir):
return "-1"
current_dir = 0 current_dir = 0
dirs = os.listdir(checkpoint_dir) dirs = os.listdir(checkpoint_dir)
for cur_dir in dirs: for cur_dir in dirs:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册