From 22df4c278c19ab5eca71431d878eb78f053e6bc5 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 18 May 2018 21:17:37 +0800 Subject: [PATCH] fix serial number --- paddle/fluid/operators/checkpoint_load_op.cc | 2 +- paddle/fluid/operators/checkpoint_save_op.cc | 3 +-- python/paddle/fluid/transpiler/distribute_transpiler.py | 4 +++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/checkpoint_load_op.cc b/paddle/fluid/operators/checkpoint_load_op.cc index c18edf63062..6c88cbdab07 100644 --- a/paddle/fluid/operators/checkpoint_load_op.cc +++ b/paddle/fluid/operators/checkpoint_load_op.cc @@ -114,7 +114,7 @@ class CheckpointLoadOp : public framework::OperatorBase { std::string dir = Attr("dir"); std::string serial_num_attr = Attr("Serial"); - PADDLE_ENFORCE(!IsNumber(serial_num_attr), + PADDLE_ENFORCE(IsNumber(serial_num_attr), "Checkpoint Serial must be a number"); std::string serial_var_name = std::string(SERIAL_VAR); diff --git a/paddle/fluid/operators/checkpoint_save_op.cc b/paddle/fluid/operators/checkpoint_save_op.cc index 1832c5792a1..f904cdc8269 100644 --- a/paddle/fluid/operators/checkpoint_save_op.cc +++ b/paddle/fluid/operators/checkpoint_save_op.cc @@ -96,8 +96,7 @@ class CheckpointSaveOp : public framework::OperatorBase { int serials = 0; if (!serial_num->empty()) { - std::string::size_type sz; - serials = std::stoi(serial_num->data, &sz); + serials = std::stoi(serial_num->data()); serials += 1; } diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index e1a2fe86a58..335dc2342d0 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -545,6 +545,7 @@ class DistributeTranspiler: startup_prog.global_block().append_op( type="checkpoint_load", inputs={"X": load_vars}, + outputs={"Argv": []}, attrs={"dir": checkpoint_load_dir, "Serial": serial_number}) return startup_prog @@ -616,6 +617,7 @@ class DistributeTranspiler: s_prog.global_block().append_op( type="checkpoint_load", inputs={"X": load_vars}, + outputs={"Argv": []}, attrs={"dir": checkpoint_load_dir, "Serial": serial_number}) @@ -640,7 +642,7 @@ class DistributeTranspiler: """ is _SUCCESS in this dir """ - if not os.path.isdir(cur_dir): + if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)): return -1 try: -- GitLab