From 7b6c0abfc9b1e5ab44404ed0c253d4250d9a440a Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 17 May 2018 22:41:02 +0800 Subject: [PATCH] modify variable point --- paddle/fluid/operators/checkpoint_load_op.cc | 24 +++++++++++++------- paddle/fluid/operators/checkpoint_save_op.cc | 20 ++++++++++++---- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/checkpoint_load_op.cc b/paddle/fluid/operators/checkpoint_load_op.cc index d270ae31ed7..0f0d989ccd2 100644 --- a/paddle/fluid/operators/checkpoint_load_op.cc +++ b/paddle/fluid/operators/checkpoint_load_op.cc @@ -108,15 +108,22 @@ class CheckpointLoadOp : public framework::OperatorBase { void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { std::string dir = Attr("dir"); - int serial_num = Attr("Serial"); + std::string serial_num = Attr("Serial"); + + std::string serial_var_name = std::string(SERIAL_VAR); + auto *serial_var = scope.FindVar(serial_var_name); + auto *serial_num; + if (serial_var == nullptr) { + *serial_var = scope.Var(serial_var_name); + *serial_num = serial_var->GetMutable(); + serial_num->append("0"); + } - auto *serial_var = scope.FindVar(SERIAL_VAR); - serial_var = serial_num; + *serial_num = serial_var->GetMutable(); VLOG(1) << "CheckpointLoadOp set " << SERIAL_NUMBER << " value: " << serial_num; - std::string success; - = GenePath(dir, std::to_string(serial_num)); + std::string success = GenePath(dir, serial_num); VLOG(3) << "Load checkpoint from dir: " << success; success = GenePath(success, SUCCESS); bool is_present = FileExists(success); @@ -157,9 +164,10 @@ This operator will serialize and write a list of input LoDTensor variables to a file on disk. )DOC"); - AddAttr("Serial", - "(int)" - "The serial number of the checkpoint will to be load."); + AddAttr( + "Serial", + "(std::string)" + "The serial number of the checkpoint will to be load."); AddAttr( "dir", "(string)" diff --git a/paddle/fluid/operators/checkpoint_save_op.cc b/paddle/fluid/operators/checkpoint_save_op.cc index ee494c68822..3c2cc50ac49 100644 --- a/paddle/fluid/operators/checkpoint_save_op.cc +++ b/paddle/fluid/operators/checkpoint_save_op.cc @@ -82,13 +82,23 @@ class CheckpointSaveOp : public framework::OperatorBase { auto dir = Attr("dir"); auto overwrite = Attr("overwrite"); - auto serial_num = scope.FindVar(SERIAL_VAR); - if (serial_num == nullptr) { - serial_num = scope.Var(SERIAL_VAR); + std::string serial_var_name = std::string(SERIAL_VAR); + auto *serial_var = scope.FindVar(serial_var_name); + auto *serial_num; + if (serial_var == nullptr) { + *serial_var = scope.Var(serial_var_name); + *serial_num = serial_var->GetMutable(); + serial_num->append("0"); } - serial_num = serial_num + 1; - dir = GenePath(dir, std::to_string(serial_num)); + *serial_num = serial_var->GetMutable(); + VLOG(1) << "CheckpointSaveOp get " << SERIAL_NUMBER + << " value: " << serial_num; + + auto *serial_num = serial_var->GetMutable(); + serial_num->append("1"); + + dir = GenePath(dir, serial_num); bool is_present = FileExists(dir); if (is_present && !overwrite) { PADDLE_THROW("%s exists!, checkpoint save cannot to overwrite it", dir, -- GitLab