提交 f688652f 编写于 作者: T tangwei12

bug fix

上级 a4fd3756
...@@ -114,7 +114,7 @@ class CheckpointLoadOp : public framework::OperatorBase { ...@@ -114,7 +114,7 @@ class CheckpointLoadOp : public framework::OperatorBase {
std::string dir = Attr<std::string>("dir"); std::string dir = Attr<std::string>("dir");
std::string serial_num_attr = Attr<std::string>("Serial"); std::string serial_num_attr = Attr<std::string>("Serial");
PADDLE_ENFORCE(IsNumber(serial_num_attr), PADDLE_ENFORCE(!IsNumber(serial_num_attr),
"Checkpoint Serial must be a number"); "Checkpoint Serial must be a number");
std::string serial_var_name = std::string(SERIAL_VAR); std::string serial_var_name = std::string(SERIAL_VAR);
...@@ -124,7 +124,8 @@ class CheckpointLoadOp : public framework::OperatorBase { ...@@ -124,7 +124,8 @@ class CheckpointLoadOp : public framework::OperatorBase {
serial_var_name); serial_var_name);
auto *serial_num = serial_var->GetMutable<std::string>(); auto *serial_num = serial_var->GetMutable<std::string>();
serial_num = serial_num_attr; serial_num->clear();
serial_num->append(serial_num_attr);
VLOG(1) << "CheckpointLoadOp set " << SERIAL_VAR VLOG(1) << "CheckpointLoadOp set " << SERIAL_VAR
<< " value: " << serial_num; << " value: " << serial_num;
......
...@@ -69,6 +69,8 @@ TEST(CheckpointLoadOp, CPU) { ...@@ -69,6 +69,8 @@ TEST(CheckpointLoadOp, CPU) {
} }
scope.Var("SERIAL_NUMBER"); scope.Var("SERIAL_NUMBER");
auto* serial_num = scope.FindVar("SERIAL_NUMBER")->GetMutable<std::string>();
serial_num->append("0");
paddle::framework::AttributeMap attrs; paddle::framework::AttributeMap attrs;
attrs.insert({"dir", std::string("ckpt")}); attrs.insert({"dir", std::string("ckpt")});
......
...@@ -94,8 +94,8 @@ class CheckpointSaveOp : public framework::OperatorBase { ...@@ -94,8 +94,8 @@ class CheckpointSaveOp : public framework::OperatorBase {
VLOG(1) << "CheckpointSaveOp get " << SERIAL_VAR VLOG(1) << "CheckpointSaveOp get " << SERIAL_VAR
<< " value: " << serial_num; << " value: " << serial_num;
if (!IsNumber(serial_num)) { if (serial_num->empty()) {
serial_num = "0"; serial_num->append("0");
} }
std::string dir = GenePath(ck_dir, serial_num->c_str()); std::string dir = GenePath(ck_dir, serial_num->c_str());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册