diff --git a/paddle/fluid/operators/checkpoint_load_op.cc b/paddle/fluid/operators/checkpoint_load_op.cc index d270ae31ed79115597756681ce77f2db1994b445..0f0d989ccd2f7f55be0aeb17e89b254973389fc5 100644 --- a/paddle/fluid/operators/checkpoint_load_op.cc +++ b/paddle/fluid/operators/checkpoint_load_op.cc @@ -108,15 +108,22 @@ class CheckpointLoadOp : public framework::OperatorBase { void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { std::string dir = Attr("dir"); - int serial_num = Attr("Serial"); + std::string serial_num = Attr("Serial"); + + std::string serial_var_name = std::string(SERIAL_VAR); + auto *serial_var = scope.FindVar(serial_var_name); + auto *serial_num; + if (serial_var == nullptr) { + *serial_var = scope.Var(serial_var_name); + *serial_num = serial_var->GetMutable(); + serial_num->append("0"); + } - auto *serial_var = scope.FindVar(SERIAL_VAR); - serial_var = serial_num; + *serial_num = serial_var->GetMutable(); VLOG(1) << "CheckpointLoadOp set " << SERIAL_NUMBER << " value: " << serial_num; - std::string success; - = GenePath(dir, std::to_string(serial_num)); + std::string success = GenePath(dir, serial_num); VLOG(3) << "Load checkpoint from dir: " << success; success = GenePath(success, SUCCESS); bool is_present = FileExists(success); @@ -157,9 +164,10 @@ This operator will serialize and write a list of input LoDTensor variables to a file on disk. )DOC"); - AddAttr("Serial", - "(int)" - "The serial number of the checkpoint will to be load."); + AddAttr( + "Serial", + "(std::string)" + "The serial number of the checkpoint will to be load."); AddAttr( "dir", "(string)" diff --git a/paddle/fluid/operators/checkpoint_save_op.cc b/paddle/fluid/operators/checkpoint_save_op.cc index ee494c68822c436871264815a8285a32b6a77d32..3c2cc50ac490a5dcd9b49b076fed9aeac56a1f42 100644 --- a/paddle/fluid/operators/checkpoint_save_op.cc +++ b/paddle/fluid/operators/checkpoint_save_op.cc @@ -82,13 +82,23 @@ class CheckpointSaveOp : public framework::OperatorBase { auto dir = Attr("dir"); auto overwrite = Attr("overwrite"); - auto serial_num = scope.FindVar(SERIAL_VAR); - if (serial_num == nullptr) { - serial_num = scope.Var(SERIAL_VAR); + std::string serial_var_name = std::string(SERIAL_VAR); + auto *serial_var = scope.FindVar(serial_var_name); + auto *serial_num; + if (serial_var == nullptr) { + *serial_var = scope.Var(serial_var_name); + *serial_num = serial_var->GetMutable(); + serial_num->append("0"); } - serial_num = serial_num + 1; - dir = GenePath(dir, std::to_string(serial_num)); + *serial_num = serial_var->GetMutable(); + VLOG(1) << "CheckpointSaveOp get " << SERIAL_NUMBER + << " value: " << serial_num; + + auto *serial_num = serial_var->GetMutable(); + serial_num->append("1"); + + dir = GenePath(dir, serial_num); bool is_present = FileExists(dir); if (is_present && !overwrite) { PADDLE_THROW("%s exists!, checkpoint save cannot to overwrite it", dir,