提交 7b6c0abf 编写于 作者: T tangwei12

modify variable point

上级 8430c8d7
......@@ -108,15 +108,22 @@ class CheckpointLoadOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
std::string dir = Attr<std::string>("dir");
int serial_num = Attr<int>("Serial");
std::string serial_num = Attr<std::string>("Serial");
std::string serial_var_name = std::string(SERIAL_VAR);
auto *serial_var = scope.FindVar(serial_var_name);
auto *serial_num;
if (serial_var == nullptr) {
*serial_var = scope.Var(serial_var_name);
*serial_num = serial_var->GetMutable<std::string>();
serial_num->append("0");
}
auto *serial_var = scope.FindVar(SERIAL_VAR);
serial_var = serial_num;
*serial_num = serial_var->GetMutable<std::string>();
VLOG(1) << "CheckpointLoadOp set " << SERIAL_NUMBER
<< " value: " << serial_num;
std::string success;
= GenePath(dir, std::to_string(serial_num));
std::string success = GenePath(dir, serial_num);
VLOG(3) << "Load checkpoint from dir: " << success;
success = GenePath(success, SUCCESS);
bool is_present = FileExists(success);
......@@ -157,8 +164,9 @@ This operator will serialize and write a list of input LoDTensor variables
to a file on disk.
)DOC");
AddAttr<int>("Serial",
"(int)"
AddAttr<std::string>(
"Serial",
"(std::string)"
"The serial number of the checkpoint will to be load.");
AddAttr<std::string>(
"dir",
......
......@@ -82,13 +82,23 @@ class CheckpointSaveOp : public framework::OperatorBase {
auto dir = Attr<std::string>("dir");
auto overwrite = Attr<bool>("overwrite");
auto serial_num = scope.FindVar(SERIAL_VAR);
if (serial_num == nullptr) {
serial_num = scope.Var(SERIAL_VAR);
std::string serial_var_name = std::string(SERIAL_VAR);
auto *serial_var = scope.FindVar(serial_var_name);
auto *serial_num;
if (serial_var == nullptr) {
*serial_var = scope.Var(serial_var_name);
*serial_num = serial_var->GetMutable<std::string>();
serial_num->append("0");
}
serial_num = serial_num + 1;
dir = GenePath(dir, std::to_string(serial_num));
*serial_num = serial_var->GetMutable<std::string>();
VLOG(1) << "CheckpointSaveOp get " << SERIAL_NUMBER
<< " value: " << serial_num;
auto *serial_num = serial_var->GetMutable<std::string>();
serial_num->append("1");
dir = GenePath(dir, serial_num);
bool is_present = FileExists(dir);
if (is_present && !overwrite) {
PADDLE_THROW("%s exists!, checkpoint save cannot to overwrite it", dir,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册