提交 22df4c27 编写于 作者: T tangwei12

fix serial number

上级 dbd02377
...@@ -114,7 +114,7 @@ class CheckpointLoadOp : public framework::OperatorBase { ...@@ -114,7 +114,7 @@ class CheckpointLoadOp : public framework::OperatorBase {
std::string dir = Attr<std::string>("dir"); std::string dir = Attr<std::string>("dir");
std::string serial_num_attr = Attr<std::string>("Serial"); std::string serial_num_attr = Attr<std::string>("Serial");
PADDLE_ENFORCE(!IsNumber(serial_num_attr), PADDLE_ENFORCE(IsNumber(serial_num_attr),
"Checkpoint Serial must be a number"); "Checkpoint Serial must be a number");
std::string serial_var_name = std::string(SERIAL_VAR); std::string serial_var_name = std::string(SERIAL_VAR);
......
...@@ -96,8 +96,7 @@ class CheckpointSaveOp : public framework::OperatorBase { ...@@ -96,8 +96,7 @@ class CheckpointSaveOp : public framework::OperatorBase {
int serials = 0; int serials = 0;
if (!serial_num->empty()) { if (!serial_num->empty()) {
std::string::size_type sz; serials = std::stoi(serial_num->data());
serials = std::stoi(serial_num->data, &sz);
serials += 1; serials += 1;
} }
......
...@@ -545,6 +545,7 @@ class DistributeTranspiler: ...@@ -545,6 +545,7 @@ class DistributeTranspiler:
startup_prog.global_block().append_op( startup_prog.global_block().append_op(
type="checkpoint_load", type="checkpoint_load",
inputs={"X": load_vars}, inputs={"X": load_vars},
outputs={"Argv": []},
attrs={"dir": checkpoint_load_dir, attrs={"dir": checkpoint_load_dir,
"Serial": serial_number}) "Serial": serial_number})
return startup_prog return startup_prog
...@@ -616,6 +617,7 @@ class DistributeTranspiler: ...@@ -616,6 +617,7 @@ class DistributeTranspiler:
s_prog.global_block().append_op( s_prog.global_block().append_op(
type="checkpoint_load", type="checkpoint_load",
inputs={"X": load_vars}, inputs={"X": load_vars},
outputs={"Argv": []},
attrs={"dir": checkpoint_load_dir, attrs={"dir": checkpoint_load_dir,
"Serial": serial_number}) "Serial": serial_number})
...@@ -640,7 +642,7 @@ class DistributeTranspiler: ...@@ -640,7 +642,7 @@ class DistributeTranspiler:
""" """
is _SUCCESS in this dir is _SUCCESS in this dir
""" """
if not os.path.isdir(cur_dir): if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
return -1 return -1
try: try:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册