diff --git a/paddle/fluid/operators/checkpoint_load_op.cc b/paddle/fluid/operators/checkpoint_load_op.cc index d24c7819990f04e633358bacaa62e735f940e71e..a9676de369b4b45a097835ad104e37f745742ec5 100644 --- a/paddle/fluid/operators/checkpoint_load_op.cc +++ b/paddle/fluid/operators/checkpoint_load_op.cc @@ -114,7 +114,7 @@ class CheckpointLoadOp : public framework::OperatorBase { std::string dir = Attr("dir"); std::string serial_num_attr = Attr("Serial"); - PADDLE_ENFORCE(IsNumber(serial_num_attr), + PADDLE_ENFORCE(!IsNumber(serial_num_attr), "Checkpoint Serial must be a number"); std::string serial_var_name = std::string(SERIAL_VAR); @@ -124,7 +124,8 @@ class CheckpointLoadOp : public framework::OperatorBase { serial_var_name); auto *serial_num = serial_var->GetMutable(); - serial_num = serial_num_attr; + serial_num->clear(); + serial_num->append(serial_num_attr); VLOG(1) << "CheckpointLoadOp set " << SERIAL_VAR << " value: " << serial_num; diff --git a/paddle/fluid/operators/checkpoint_op_test.cc b/paddle/fluid/operators/checkpoint_op_test.cc index 2acce227d23de5862784831e67fed3dc1b4c3a41..5312225e5f95230ff1e0946c7925c8b89253b3bd 100644 --- a/paddle/fluid/operators/checkpoint_op_test.cc +++ b/paddle/fluid/operators/checkpoint_op_test.cc @@ -69,6 +69,8 @@ TEST(CheckpointLoadOp, CPU) { } scope.Var("SERIAL_NUMBER"); + auto* serial_num = scope.FindVar("SERIAL_NUMBER")->GetMutable(); + serial_num->append("0"); paddle::framework::AttributeMap attrs; attrs.insert({"dir", std::string("ckpt")}); diff --git a/paddle/fluid/operators/checkpoint_save_op.cc b/paddle/fluid/operators/checkpoint_save_op.cc index bab979e4074a613219bca8971b5b32250616eac6..30eda30c5f52fb3eea6c58015400a65161de3cf9 100644 --- a/paddle/fluid/operators/checkpoint_save_op.cc +++ b/paddle/fluid/operators/checkpoint_save_op.cc @@ -94,8 +94,8 @@ class CheckpointSaveOp : public framework::OperatorBase { VLOG(1) << "CheckpointSaveOp get " << SERIAL_VAR << " value: " << serial_num; - if (!IsNumber(serial_num)) { - serial_num = "0"; + if (serial_num->empty()) { + serial_num->append("0"); } std::string dir = GenePath(ck_dir, serial_num->c_str());