提交 f688652f 编写于 作者: T tangwei12

bug fix

上级 a4fd3756
......@@ -114,7 +114,7 @@ class CheckpointLoadOp : public framework::OperatorBase {
std::string dir = Attr<std::string>("dir");
std::string serial_num_attr = Attr<std::string>("Serial");
PADDLE_ENFORCE(IsNumber(serial_num_attr),
PADDLE_ENFORCE(!IsNumber(serial_num_attr),
"Checkpoint Serial must be a number");
std::string serial_var_name = std::string(SERIAL_VAR);
......@@ -124,7 +124,8 @@ class CheckpointLoadOp : public framework::OperatorBase {
serial_var_name);
auto *serial_num = serial_var->GetMutable<std::string>();
serial_num = serial_num_attr;
serial_num->clear();
serial_num->append(serial_num_attr);
VLOG(1) << "CheckpointLoadOp set " << SERIAL_VAR
<< " value: " << serial_num;
......
......@@ -69,6 +69,8 @@ TEST(CheckpointLoadOp, CPU) {
}
scope.Var("SERIAL_NUMBER");
auto* serial_num = scope.FindVar("SERIAL_NUMBER")->GetMutable<std::string>();
serial_num->append("0");
paddle::framework::AttributeMap attrs;
attrs.insert({"dir", std::string("ckpt")});
......
......@@ -94,8 +94,8 @@ class CheckpointSaveOp : public framework::OperatorBase {
VLOG(1) << "CheckpointSaveOp get " << SERIAL_VAR
<< " value: " << serial_num;
if (!IsNumber(serial_num)) {
serial_num = "0";
if (serial_num->empty()) {
serial_num->append("0");
}
std::string dir = GenePath(ck_dir, serial_num->c_str());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册