diff --git a/paddle/fluid/operators/checkpoint_load_op.cc b/paddle/fluid/operators/checkpoint_load_op.cc index 0f0d989ccd2f7f55be0aeb17e89b254973389fc5..5fd3a7af9cf1f7151894579d3d1537e1a8399a6a 100644 --- a/paddle/fluid/operators/checkpoint_load_op.cc +++ b/paddle/fluid/operators/checkpoint_load_op.cc @@ -112,14 +112,14 @@ class CheckpointLoadOp : public framework::OperatorBase { std::string serial_var_name = std::string(SERIAL_VAR); auto *serial_var = scope.FindVar(serial_var_name); - auto *serial_num; + if (serial_var == nullptr) { *serial_var = scope.Var(serial_var_name); - *serial_num = serial_var->GetMutable(); - serial_num->append("0"); + auto *serial_tmp = serial_var->GetMutable(); + serial_tmp->append("0"); } - *serial_num = serial_var->GetMutable(); + auto *serial_num = serial_var->GetMutable(); VLOG(1) << "CheckpointLoadOp set " << SERIAL_NUMBER << " value: " << serial_num; diff --git a/paddle/fluid/operators/checkpoint_save_op.cc b/paddle/fluid/operators/checkpoint_save_op.cc index 3c2cc50ac490a5dcd9b49b076fed9aeac56a1f42..5fccefeed251a2fde4c54462b3fbbbb36db6bc1b 100644 --- a/paddle/fluid/operators/checkpoint_save_op.cc +++ b/paddle/fluid/operators/checkpoint_save_op.cc @@ -84,14 +84,13 @@ class CheckpointSaveOp : public framework::OperatorBase { std::string serial_var_name = std::string(SERIAL_VAR); auto *serial_var = scope.FindVar(serial_var_name); - auto *serial_num; + if (serial_var == nullptr) { *serial_var = scope.Var(serial_var_name); - *serial_num = serial_var->GetMutable(); - serial_num->append("0"); + *serial_tmp = serial_var->GetMutable(); + serial_tmp->append("0"); } - - *serial_num = serial_var->GetMutable(); + auto *serial_num = serial_var->GetMutable(); VLOG(1) << "CheckpointSaveOp get " << SERIAL_NUMBER << " value: " << serial_num;