From f688652f1e3ee2eaf949ef79cbd56c05fc4980cd Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 18 May 2018 10:26:41 +0800 Subject: [PATCH] bug fix --- paddle/fluid/operators/checkpoint_load_op.cc | 5 +++-- paddle/fluid/operators/checkpoint_op_test.cc | 2 ++ paddle/fluid/operators/checkpoint_save_op.cc | 4 ++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/checkpoint_load_op.cc b/paddle/fluid/operators/checkpoint_load_op.cc index d24c781999..a9676de369 100644 --- a/paddle/fluid/operators/checkpoint_load_op.cc +++ b/paddle/fluid/operators/checkpoint_load_op.cc @@ -114,7 +114,7 @@ class CheckpointLoadOp : public framework::OperatorBase { std::string dir = Attr("dir"); std::string serial_num_attr = Attr("Serial"); - PADDLE_ENFORCE(IsNumber(serial_num_attr), + PADDLE_ENFORCE(!IsNumber(serial_num_attr), "Checkpoint Serial must be a number"); std::string serial_var_name = std::string(SERIAL_VAR); @@ -124,7 +124,8 @@ class CheckpointLoadOp : public framework::OperatorBase { serial_var_name); auto *serial_num = serial_var->GetMutable(); - serial_num = serial_num_attr; + serial_num->clear(); + serial_num->append(serial_num_attr); VLOG(1) << "CheckpointLoadOp set " << SERIAL_VAR << " value: " << serial_num; diff --git a/paddle/fluid/operators/checkpoint_op_test.cc b/paddle/fluid/operators/checkpoint_op_test.cc index 2acce227d2..5312225e5f 100644 --- a/paddle/fluid/operators/checkpoint_op_test.cc +++ b/paddle/fluid/operators/checkpoint_op_test.cc @@ -69,6 +69,8 @@ TEST(CheckpointLoadOp, CPU) { } scope.Var("SERIAL_NUMBER"); + auto* serial_num = scope.FindVar("SERIAL_NUMBER")->GetMutable(); + serial_num->append("0"); paddle::framework::AttributeMap attrs; attrs.insert({"dir", std::string("ckpt")}); diff --git a/paddle/fluid/operators/checkpoint_save_op.cc b/paddle/fluid/operators/checkpoint_save_op.cc index bab979e407..30eda30c5f 100644 --- a/paddle/fluid/operators/checkpoint_save_op.cc +++ b/paddle/fluid/operators/checkpoint_save_op.cc @@ -94,8 +94,8 @@ class CheckpointSaveOp : public framework::OperatorBase { VLOG(1) << "CheckpointSaveOp get " << SERIAL_VAR << " value: " << serial_num; - if (!IsNumber(serial_num)) { - serial_num = "0"; + if (serial_num->empty()) { + serial_num->append("0"); } std::string dir = GenePath(ck_dir, serial_num->c_str()); -- GitLab