diff --git a/paddle/fluid/operators/checkpoint_load_op.cc b/paddle/fluid/operators/checkpoint_load_op.cc index c18edf63062044970a0ceb0385962546d43356c4..6c88cbdab0758da5160a14b9353f6d4a40074ebf 100644 --- a/paddle/fluid/operators/checkpoint_load_op.cc +++ b/paddle/fluid/operators/checkpoint_load_op.cc @@ -114,7 +114,7 @@ class CheckpointLoadOp : public framework::OperatorBase { std::string dir = Attr("dir"); std::string serial_num_attr = Attr("Serial"); - PADDLE_ENFORCE(!IsNumber(serial_num_attr), + PADDLE_ENFORCE(IsNumber(serial_num_attr), "Checkpoint Serial must be a number"); std::string serial_var_name = std::string(SERIAL_VAR); diff --git a/paddle/fluid/operators/checkpoint_save_op.cc b/paddle/fluid/operators/checkpoint_save_op.cc index 1832c5792a18c7c65e689cd2b0d20df60ddfdd43..f904cdc8269e71d0a61074c8a995d30f50683aff 100644 --- a/paddle/fluid/operators/checkpoint_save_op.cc +++ b/paddle/fluid/operators/checkpoint_save_op.cc @@ -96,8 +96,7 @@ class CheckpointSaveOp : public framework::OperatorBase { int serials = 0; if (!serial_num->empty()) { - std::string::size_type sz; - serials = std::stoi(serial_num->data, &sz); + serials = std::stoi(serial_num->data()); serials += 1; } diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index e1a2fe86a5804546fdd976d39cfeedf9ca71cb80..335dc2342d08c01cafd6e7588d8470ea00a2c830 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -545,6 +545,7 @@ class DistributeTranspiler: startup_prog.global_block().append_op( type="checkpoint_load", inputs={"X": load_vars}, + outputs={"Argv": []}, attrs={"dir": checkpoint_load_dir, "Serial": serial_number}) return startup_prog @@ -616,6 +617,7 @@ class DistributeTranspiler: s_prog.global_block().append_op( type="checkpoint_load", inputs={"X": load_vars}, + outputs={"Argv": []}, attrs={"dir": checkpoint_load_dir, "Serial": serial_number}) @@ -640,7 +642,7 @@ class DistributeTranspiler: """ is _SUCCESS in this dir """ - if not os.path.isdir(cur_dir): + if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)): return -1 try: