提交 7b6c0abf 编写于 作者: T tangwei12

modify variable point

上级 8430c8d7
...@@ -108,15 +108,22 @@ class CheckpointLoadOp : public framework::OperatorBase { ...@@ -108,15 +108,22 @@ class CheckpointLoadOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
std::string dir = Attr<std::string>("dir"); std::string dir = Attr<std::string>("dir");
int serial_num = Attr<int>("Serial"); std::string serial_num = Attr<std::string>("Serial");
std::string serial_var_name = std::string(SERIAL_VAR);
auto *serial_var = scope.FindVar(serial_var_name);
auto *serial_num;
if (serial_var == nullptr) {
*serial_var = scope.Var(serial_var_name);
*serial_num = serial_var->GetMutable<std::string>();
serial_num->append("0");
}
auto *serial_var = scope.FindVar(SERIAL_VAR); *serial_num = serial_var->GetMutable<std::string>();
serial_var = serial_num;
VLOG(1) << "CheckpointLoadOp set " << SERIAL_NUMBER VLOG(1) << "CheckpointLoadOp set " << SERIAL_NUMBER
<< " value: " << serial_num; << " value: " << serial_num;
std::string success; std::string success = GenePath(dir, serial_num);
= GenePath(dir, std::to_string(serial_num));
VLOG(3) << "Load checkpoint from dir: " << success; VLOG(3) << "Load checkpoint from dir: " << success;
success = GenePath(success, SUCCESS); success = GenePath(success, SUCCESS);
bool is_present = FileExists(success); bool is_present = FileExists(success);
...@@ -157,8 +164,9 @@ This operator will serialize and write a list of input LoDTensor variables ...@@ -157,8 +164,9 @@ This operator will serialize and write a list of input LoDTensor variables
to a file on disk. to a file on disk.
)DOC"); )DOC");
AddAttr<int>("Serial", AddAttr<std::string>(
"(int)" "Serial",
"(std::string)"
"The serial number of the checkpoint will to be load."); "The serial number of the checkpoint will to be load.");
AddAttr<std::string>( AddAttr<std::string>(
"dir", "dir",
......
...@@ -82,13 +82,23 @@ class CheckpointSaveOp : public framework::OperatorBase { ...@@ -82,13 +82,23 @@ class CheckpointSaveOp : public framework::OperatorBase {
auto dir = Attr<std::string>("dir"); auto dir = Attr<std::string>("dir");
auto overwrite = Attr<bool>("overwrite"); auto overwrite = Attr<bool>("overwrite");
auto serial_num = scope.FindVar(SERIAL_VAR); std::string serial_var_name = std::string(SERIAL_VAR);
if (serial_num == nullptr) { auto *serial_var = scope.FindVar(serial_var_name);
serial_num = scope.Var(SERIAL_VAR); auto *serial_num;
if (serial_var == nullptr) {
*serial_var = scope.Var(serial_var_name);
*serial_num = serial_var->GetMutable<std::string>();
serial_num->append("0");
} }
serial_num = serial_num + 1;
dir = GenePath(dir, std::to_string(serial_num)); *serial_num = serial_var->GetMutable<std::string>();
VLOG(1) << "CheckpointSaveOp get " << SERIAL_NUMBER
<< " value: " << serial_num;
auto *serial_num = serial_var->GetMutable<std::string>();
serial_num->append("1");
dir = GenePath(dir, serial_num);
bool is_present = FileExists(dir); bool is_present = FileExists(dir);
if (is_present && !overwrite) { if (is_present && !overwrite) {
PADDLE_THROW("%s exists!, checkpoint save cannot to overwrite it", dir, PADDLE_THROW("%s exists!, checkpoint save cannot to overwrite it", dir,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册