提交 6d53dcee 编写于 作者: T tangwei12

optimized checkpoint serial number and folder

上级 4220b31d
......@@ -17,6 +17,10 @@ limitations under the License. */
#include <fstream>
#include <numeric>
#include <sstream>
#include <string>
#include <boost/filesystem.hpp>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/framework.pb.h"
......@@ -30,41 +34,24 @@ namespace operators {
constexpr char kSEP = '/';
// write empty file named _SUCCESS
const char SUCCESS[] = "_SUCCESS";
const char SERIAL_VAR[] = "SERIAL_NUMBER";
static bool FileExists(const std::string &filepath) {
struct stat buffer;
return (stat(filepath.c_str(), &buffer) == 0);
}
class CheckpointLoadOp : public framework::OperatorBase {
public:
CheckpointLoadOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
std::string dir = Attr<std::string>("dir");
VLOG(3) << "Load checkpoint from dir: " << dir;
std::string success;
success.append(dir);
success.append("/");
success.append(SUCCESS);
bool is_present = FileExists(success);
if (!is_present) {
VLOG(3) << "can not find _SUCCESS from path: " << success;
return;
}
static std::string GenePath(const std::string &dir, const std::string &file) {
boost::filesystem::path dir(dir);
boost::filesystem::path file(file);
boost::filesystem::path full_path = dir / file;
return full_path;
}
auto inp_var_names = Inputs("X");
PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0,
"The number of input variables should be greater than 0");
static void LoadInputVars(const framework::Scope &scope,
const platform::Place &place,
const std::vector<std::string> &inp_var_names,
const std::string &dir) {
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
......@@ -80,21 +67,76 @@ class CheckpointLoadOp : public framework::OperatorBase {
"SaveCombineOp only supports LoDTensor, %s has wrong type",
inp_var_names[i]);
std::string var_file;
var_file.append(dir);
var_file.append("/");
var_file.append(inp_var_names[i]);
VLOG(3) << "ready to load var: " << inp_var_names[i];
std::string var_file = GenePath(dir, inp_var_names[i]);
auto *tensor = var->GetMutable<framework::LoDTensor>();
std::ifstream fin(var_file);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
var_file);
framework::DeserializeFromStream(fin, tensor, dev_ctx);
fin.close();
VLOG(3) << " load var: " << inp_var_names[i] << " finished";
}
}
static void LoadStringArgv(const framework::Scope &scope,
const platform::Place &place,
const std::string &argv, const std::string &dir) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
for (size_t i = 0; i < argv.size(); i++) {
auto *var = scope.FindVar(inp_var_names[i]);
std::string *var_str = var->GetMutable<std::string>();
std::string var_file = GenePath(dir, argv);
std::ifstream fin(var_file);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
var_file);
std::getline(fin, var_str);
fin.close();
VLOG(3) << " load String argv: " << argv << " value is: " << var_str;
}
}
class CheckpointLoadOp : public framework::OperatorBase {
public:
CheckpointLoadOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
std::string dir = Attr<std::string>("dir");
int serial_num = Attr<int>("Serial");
auto *serial_var = scope.FindVar(SERIAL_VAR);
serial_var = serial_num;
VLOG(1) << "CheckpointLoadOp set " << SERIAL_NUMBER
<< " value: " << serial_num;
std::string success;
= GenePath(dir, std::to_string(serial_num));
VLOG(3) << "Load checkpoint from dir: " << success;
success = GenePath(success, SUCCESS);
bool is_present = FileExists(success);
if (!is_present) {
VLOG(1) << "CheckpointLoadOp can not find " << SUCCESS
<< " from: " << success;
return;
}
VLOG(3) << "Ready to load vars to scope";
auto inp_var_names = Inputs("X");
PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0,
"The number of input variables should be greater than 0");
LoadInputVars(scope, place, &inp_var_names);
VLOG(3) << "Ready to load string argv to scope";
auto argv = Inputs("Argv");
LoadStringArgv(scope, place, &argv, &dir);
}
};
......@@ -106,6 +148,10 @@ class CheckpointLoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"X",
"(vector) Input LoDTensors that need to be saved together in a file.")
.AsDuplicable();
AddInput(
"Argv",
"(vector) Input LoDTensors that need to be saved together in a file.")
.AsDuplicable();
AddComment(R"DOC(
CheckpointLoad operator
......@@ -113,6 +159,9 @@ This operator will serialize and write a list of input LoDTensor variables
to a file on disk.
)DOC");
AddAttr<int>("Serial",
"(int)"
"The serial number of the checkpoint will to be load.");
AddAttr<std::string>(
"dir",
"(string)"
......
......@@ -44,8 +44,7 @@ TEST(CheckpointSaveOp, CPU) {
attrs.insert({"dir", std::string("ckpt")});
auto save_op = paddle::framework::OpRegistry::CreateOp(
"checkpoint_save", {{"X", {"test_var"}}}, {{"Serial", {"SERIAL_NUMBER"}}},
attrs);
"checkpoint_save", {{"X", {"test_var"}}}, attrs);
save_op->Run(scope, place);
}
......@@ -58,7 +57,8 @@ TEST(CheckpointLoadOp, CPU) {
paddle::framework::AttributeMap attrs;
attrs.insert({"dir", std::string("ckpt")});
auto save_op = paddle::framework::OpRegistry::CreateOp(
"checkpoint_load", {{"X", {"test_var"}}}, {}, attrs);
save_op->Run(scope, place);
auto load_op = paddle::framework::OpRegistry::CreateOp(
"checkpoint_load", {{"X", {"test_var"}}}, {{"Serial", {"SERIAL_NUMBER"}}},
attrs);
load_op->Run(scope, place);
}
......@@ -17,6 +17,10 @@ limitations under the License. */
#include <fstream>
#include <numeric>
#include <sstream>
#include <string>
#include <boost/filesystem.hpp>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/framework.pb.h"
......@@ -30,6 +34,14 @@ namespace operators {
constexpr char kSEP = '/';
// write empty file named _SUCCESS
const char SUCCESS[] = "_SUCCESS";
const char SERIAL_VAR[] = "SERIAL_NUMBER";
static std::string GenePath(const std::string &dir, const std::string &file) {
boost::filesystem::path dir(dir);
boost::filesystem::path file(file);
boost::filesystem::path full_path = dir / file;
return full_path;
}
static bool FileExists(const std::string &filepath) {
struct stat buffer;
......@@ -72,24 +84,20 @@ class CheckpointSaveOp : public framework::OperatorBase {
auto dir = Attr<std::string>("dir");
auto overwrite = Attr<bool>("overwrite");
auto serial_num = scope.FindVar(SERIAL_VAR);
if (serial_num == nullptr) {
serial_num = scope.Var(SERIAL_VAR);
}
serial_num = serial_num + 1;
dir = GenePath(dir, std::to_string(serial_num));
bool is_present = FileExists(dir);
if (is_present && !overwrite) {
return;
// todo(tangwei) judge the folder is exist
// PADDLE_THROW("%s exists!, cannot save_combine to it when
// overwrite=false",
// dir, overwrite);
PADDLE_THROW("%s exists!, checkpoint save cannot to overwrite it", dir,
overwrite);
}
MkDirRecursively(dir.c_str());
auto serial_var_name = Output("Serial");
auto *serial_var = scope.FindVar(serial_var_name);
std::string *serial_num = serial_var->GetMutable<std::string>();
serial_num->append("0");
dir.append("/");
dir.append(serial_num->c_str());
MkDirRecursively(dir.c_str());
auto inp_var_names = Inputs("X");
PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0,
"The number of input variables should be greater than 0");
......@@ -101,30 +109,24 @@ class CheckpointSaveOp : public framework::OperatorBase {
// todo (tangwei) made it async
for (size_t i = 0; i < inp_var_names.size(); i++) {
auto *var = scope.FindVar(inp_var_names[i]);
std::string var_file;
var_file.append(dir);
var_file.append("/");
var_file.append(inp_var_names[i]);
PADDLE_ENFORCE(var != nullptr,
"Cannot find variable %s for save_combine_op",
"Cannot find variable %s for checkpoint save op",
inp_var_names[i]);
PADDLE_ENFORCE(var->IsType<framework::LoDTensor>(),
"SaveCombineOp only supports LoDTensor, %s has wrong type",
PADDLE_ENFORCE(
var->IsType<framework::LoDTensor>(),
"CheckpointSaveOp only supports LoDTensor, %s has wrong type",
inp_var_names[i]);
auto &tensor = var->Get<framework::LoDTensor>();
// Serialize tensors one by one
std::string var_file = GenePath(dir, inp_var_names[i]);
std::ofstream fout(var_file);
framework::SerializeToStream(fout, tensor, dev_ctx);
fout.close();
}
std::string success;
success.append(dir);
success.append("/");
success.append(SUCCESS);
std::string success = GenePath(dir, SUCCESS);
std::ofstream fout(success);
fout.close();
}
......@@ -138,7 +140,6 @@ class CheckpointSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"X",
"(vector) Input LoDTensors that need to be saved together in a file.")
.AsDuplicable();
AddOutput("Serial", "the serial number");
AddComment(R"DOC(
CheckpointSave operator
......@@ -150,30 +151,29 @@ to a file on disk.
"Delete the output dir if it exists.")
.SetDefault(false);
AddAttr<std::string>(
"dir",
AddAttr<std::string>("dir",
"(string)"
"The \"file_path\" where the LoDTensor variables will be saved.")
"The dir where the LoDTensor variables will be saved.")
.AddCustomChecker(
[](const std::string &path) { return !path.empty(); });
}
};
class CheckpointSaveOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("Serial").front();
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class CheckpointSaveOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
};
// class CheckpointSaveOpVarTypeInference : public framework::VarTypeInference {
// public:
// void operator()(const framework::OpDesc &op_desc,
// framework::BlockDesc *block) const override {
// auto out_var_name = op_desc.Output("Serial").front();
// auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
// auto var_type = framework::proto::VarType::RAW;
// out_var.SetType(var_type);
// }
// };
// class CheckpointSaveOpShapeInference : public framework::InferShapeBase {
// public:
// void operator()(framework::InferShapeContext *ctx) const override {}
// };
} // namespace operators
} // namespace paddle
......@@ -181,7 +181,10 @@ class CheckpointSaveOpShapeInference : public framework::InferShapeBase {
namespace ops = paddle::operators;
REGISTER_OPERATOR(checkpoint_save, ops::CheckpointSaveOp,
paddle::framework::EmptyGradOpMaker,
ops::CheckpointSaveOpProtoMaker,
ops::CheckpointSaveOpVarTypeInference,
ops::CheckpointSaveOpShapeInference);
ops::CheckpointSaveOpProtoMaker);
// REGISTER_OPERATOR(checkpoint_save, ops::CheckpointSaveOp,
// paddle::framework::EmptyGradOpMaker,
// ops::CheckpointSaveOpProtoMaker,
// ops::CheckpointSaveOpVarTypeInference,
// ops::CheckpointSaveOpShapeInference);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册