提交 6d53dcee 编写于 作者: T tangwei12

optimized checkpoint serial number and folder

上级 4220b31d
...@@ -17,6 +17,10 @@ limitations under the License. */ ...@@ -17,6 +17,10 @@ limitations under the License. */
#include <fstream> #include <fstream>
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include <string>
#include <boost/filesystem.hpp>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
...@@ -30,41 +34,24 @@ namespace operators { ...@@ -30,41 +34,24 @@ namespace operators {
constexpr char kSEP = '/'; constexpr char kSEP = '/';
// write empty file named _SUCCESS // write empty file named _SUCCESS
const char SUCCESS[] = "_SUCCESS"; const char SUCCESS[] = "_SUCCESS";
const char SERIAL_VAR[] = "SERIAL_NUMBER";
static bool FileExists(const std::string &filepath) { static bool FileExists(const std::string &filepath) {
struct stat buffer; struct stat buffer;
return (stat(filepath.c_str(), &buffer) == 0); return (stat(filepath.c_str(), &buffer) == 0);
} }
class CheckpointLoadOp : public framework::OperatorBase { static std::string GenePath(const std::string &dir, const std::string &file) {
public: boost::filesystem::path dir(dir);
CheckpointLoadOp(const std::string &type, boost::filesystem::path file(file);
const framework::VariableNameMap &inputs, boost::filesystem::path full_path = dir / file;
const framework::VariableNameMap &outputs, return full_path;
const framework::AttributeMap &attrs) }
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
std::string dir = Attr<std::string>("dir");
VLOG(3) << "Load checkpoint from dir: " << dir;
std::string success;
success.append(dir);
success.append("/");
success.append(SUCCESS);
bool is_present = FileExists(success);
if (!is_present) {
VLOG(3) << "can not find _SUCCESS from path: " << success;
return;
}
auto inp_var_names = Inputs("X"); static void LoadInputVars(const framework::Scope &scope,
PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0, const platform::Place &place,
"The number of input variables should be greater than 0"); const std::vector<std::string> &inp_var_names,
const std::string &dir) {
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
...@@ -80,21 +67,76 @@ class CheckpointLoadOp : public framework::OperatorBase { ...@@ -80,21 +67,76 @@ class CheckpointLoadOp : public framework::OperatorBase {
"SaveCombineOp only supports LoDTensor, %s has wrong type", "SaveCombineOp only supports LoDTensor, %s has wrong type",
inp_var_names[i]); inp_var_names[i]);
std::string var_file; std::string var_file = GenePath(dir, inp_var_names[i]);
var_file.append(dir);
var_file.append("/");
var_file.append(inp_var_names[i]);
VLOG(3) << "ready to load var: " << inp_var_names[i];
auto *tensor = var->GetMutable<framework::LoDTensor>(); auto *tensor = var->GetMutable<framework::LoDTensor>();
std::ifstream fin(var_file); std::ifstream fin(var_file);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op", PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
var_file); var_file);
framework::DeserializeFromStream(fin, tensor, dev_ctx); framework::DeserializeFromStream(fin, tensor, dev_ctx);
fin.close(); fin.close();
VLOG(3) << " load var: " << inp_var_names[i] << " finished"; VLOG(3) << " load var: " << inp_var_names[i] << " finished";
} }
}
static void LoadStringArgv(const framework::Scope &scope,
const platform::Place &place,
const std::string &argv, const std::string &dir) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
for (size_t i = 0; i < argv.size(); i++) {
auto *var = scope.FindVar(inp_var_names[i]);
std::string *var_str = var->GetMutable<std::string>();
std::string var_file = GenePath(dir, argv);
std::ifstream fin(var_file);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
var_file);
std::getline(fin, var_str);
fin.close();
VLOG(3) << " load String argv: " << argv << " value is: " << var_str;
}
}
class CheckpointLoadOp : public framework::OperatorBase {
public:
CheckpointLoadOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
std::string dir = Attr<std::string>("dir");
int serial_num = Attr<int>("Serial");
auto *serial_var = scope.FindVar(SERIAL_VAR);
serial_var = serial_num;
VLOG(1) << "CheckpointLoadOp set " << SERIAL_NUMBER
<< " value: " << serial_num;
std::string success;
= GenePath(dir, std::to_string(serial_num));
VLOG(3) << "Load checkpoint from dir: " << success;
success = GenePath(success, SUCCESS);
bool is_present = FileExists(success);
if (!is_present) {
VLOG(1) << "CheckpointLoadOp can not find " << SUCCESS
<< " from: " << success;
return;
}
VLOG(3) << "Ready to load vars to scope";
auto inp_var_names = Inputs("X");
PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0,
"The number of input variables should be greater than 0");
LoadInputVars(scope, place, &inp_var_names);
VLOG(3) << "Ready to load string argv to scope";
auto argv = Inputs("Argv");
LoadStringArgv(scope, place, &argv, &dir);
} }
}; };
...@@ -106,6 +148,10 @@ class CheckpointLoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -106,6 +148,10 @@ class CheckpointLoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"X", "X",
"(vector) Input LoDTensors that need to be saved together in a file.") "(vector) Input LoDTensors that need to be saved together in a file.")
.AsDuplicable(); .AsDuplicable();
AddInput(
"Argv",
"(vector) Input LoDTensors that need to be saved together in a file.")
.AsDuplicable();
AddComment(R"DOC( AddComment(R"DOC(
CheckpointLoad operator CheckpointLoad operator
...@@ -113,6 +159,9 @@ This operator will serialize and write a list of input LoDTensor variables ...@@ -113,6 +159,9 @@ This operator will serialize and write a list of input LoDTensor variables
to a file on disk. to a file on disk.
)DOC"); )DOC");
AddAttr<int>("Serial",
"(int)"
"The serial number of the checkpoint will to be load.");
AddAttr<std::string>( AddAttr<std::string>(
"dir", "dir",
"(string)" "(string)"
......
...@@ -44,8 +44,7 @@ TEST(CheckpointSaveOp, CPU) { ...@@ -44,8 +44,7 @@ TEST(CheckpointSaveOp, CPU) {
attrs.insert({"dir", std::string("ckpt")}); attrs.insert({"dir", std::string("ckpt")});
auto save_op = paddle::framework::OpRegistry::CreateOp( auto save_op = paddle::framework::OpRegistry::CreateOp(
"checkpoint_save", {{"X", {"test_var"}}}, {{"Serial", {"SERIAL_NUMBER"}}}, "checkpoint_save", {{"X", {"test_var"}}}, attrs);
attrs);
save_op->Run(scope, place); save_op->Run(scope, place);
} }
...@@ -58,7 +57,8 @@ TEST(CheckpointLoadOp, CPU) { ...@@ -58,7 +57,8 @@ TEST(CheckpointLoadOp, CPU) {
paddle::framework::AttributeMap attrs; paddle::framework::AttributeMap attrs;
attrs.insert({"dir", std::string("ckpt")}); attrs.insert({"dir", std::string("ckpt")});
auto save_op = paddle::framework::OpRegistry::CreateOp( auto load_op = paddle::framework::OpRegistry::CreateOp(
"checkpoint_load", {{"X", {"test_var"}}}, {}, attrs); "checkpoint_load", {{"X", {"test_var"}}}, {{"Serial", {"SERIAL_NUMBER"}}},
save_op->Run(scope, place); attrs);
load_op->Run(scope, place);
} }
...@@ -17,6 +17,10 @@ limitations under the License. */ ...@@ -17,6 +17,10 @@ limitations under the License. */
#include <fstream> #include <fstream>
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include <string>
#include <boost/filesystem.hpp>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
...@@ -30,6 +34,14 @@ namespace operators { ...@@ -30,6 +34,14 @@ namespace operators {
constexpr char kSEP = '/'; constexpr char kSEP = '/';
// write empty file named _SUCCESS // write empty file named _SUCCESS
const char SUCCESS[] = "_SUCCESS"; const char SUCCESS[] = "_SUCCESS";
const char SERIAL_VAR[] = "SERIAL_NUMBER";
static std::string GenePath(const std::string &dir, const std::string &file) {
boost::filesystem::path dir(dir);
boost::filesystem::path file(file);
boost::filesystem::path full_path = dir / file;
return full_path;
}
static bool FileExists(const std::string &filepath) { static bool FileExists(const std::string &filepath) {
struct stat buffer; struct stat buffer;
...@@ -72,24 +84,20 @@ class CheckpointSaveOp : public framework::OperatorBase { ...@@ -72,24 +84,20 @@ class CheckpointSaveOp : public framework::OperatorBase {
auto dir = Attr<std::string>("dir"); auto dir = Attr<std::string>("dir");
auto overwrite = Attr<bool>("overwrite"); auto overwrite = Attr<bool>("overwrite");
auto serial_num = scope.FindVar(SERIAL_VAR);
if (serial_num == nullptr) {
serial_num = scope.Var(SERIAL_VAR);
}
serial_num = serial_num + 1;
dir = GenePath(dir, std::to_string(serial_num));
bool is_present = FileExists(dir); bool is_present = FileExists(dir);
if (is_present && !overwrite) { if (is_present && !overwrite) {
return; PADDLE_THROW("%s exists!, checkpoint save cannot to overwrite it", dir,
// todo(tangwei) judge the folder is exist overwrite);
// PADDLE_THROW("%s exists!, cannot save_combine to it when
// overwrite=false",
// dir, overwrite);
} }
MkDirRecursively(dir.c_str()); MkDirRecursively(dir.c_str());
auto serial_var_name = Output("Serial");
auto *serial_var = scope.FindVar(serial_var_name);
std::string *serial_num = serial_var->GetMutable<std::string>();
serial_num->append("0");
dir.append("/");
dir.append(serial_num->c_str());
MkDirRecursively(dir.c_str());
auto inp_var_names = Inputs("X"); auto inp_var_names = Inputs("X");
PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0, PADDLE_ENFORCE_GT(static_cast<int>(inp_var_names.size()), 0,
"The number of input variables should be greater than 0"); "The number of input variables should be greater than 0");
...@@ -101,30 +109,24 @@ class CheckpointSaveOp : public framework::OperatorBase { ...@@ -101,30 +109,24 @@ class CheckpointSaveOp : public framework::OperatorBase {
// todo (tangwei) made it async // todo (tangwei) made it async
for (size_t i = 0; i < inp_var_names.size(); i++) { for (size_t i = 0; i < inp_var_names.size(); i++) {
auto *var = scope.FindVar(inp_var_names[i]); auto *var = scope.FindVar(inp_var_names[i]);
std::string var_file;
var_file.append(dir);
var_file.append("/");
var_file.append(inp_var_names[i]);
PADDLE_ENFORCE(var != nullptr, PADDLE_ENFORCE(var != nullptr,
"Cannot find variable %s for save_combine_op", "Cannot find variable %s for checkpoint save op",
inp_var_names[i]); inp_var_names[i]);
PADDLE_ENFORCE(var->IsType<framework::LoDTensor>(), PADDLE_ENFORCE(
"SaveCombineOp only supports LoDTensor, %s has wrong type", var->IsType<framework::LoDTensor>(),
"CheckpointSaveOp only supports LoDTensor, %s has wrong type",
inp_var_names[i]); inp_var_names[i]);
auto &tensor = var->Get<framework::LoDTensor>(); auto &tensor = var->Get<framework::LoDTensor>();
// Serialize tensors one by one // Serialize tensors one by one
std::string var_file = GenePath(dir, inp_var_names[i]);
std::ofstream fout(var_file); std::ofstream fout(var_file);
framework::SerializeToStream(fout, tensor, dev_ctx); framework::SerializeToStream(fout, tensor, dev_ctx);
fout.close(); fout.close();
} }
std::string success; std::string success = GenePath(dir, SUCCESS);
success.append(dir);
success.append("/");
success.append(SUCCESS);
std::ofstream fout(success); std::ofstream fout(success);
fout.close(); fout.close();
} }
...@@ -138,7 +140,6 @@ class CheckpointSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -138,7 +140,6 @@ class CheckpointSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"X", "X",
"(vector) Input LoDTensors that need to be saved together in a file.") "(vector) Input LoDTensors that need to be saved together in a file.")
.AsDuplicable(); .AsDuplicable();
AddOutput("Serial", "the serial number");
AddComment(R"DOC( AddComment(R"DOC(
CheckpointSave operator CheckpointSave operator
...@@ -150,30 +151,29 @@ to a file on disk. ...@@ -150,30 +151,29 @@ to a file on disk.
"Delete the output dir if it exists.") "Delete the output dir if it exists.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>("dir",
"dir",
"(string)" "(string)"
"The \"file_path\" where the LoDTensor variables will be saved.") "The dir where the LoDTensor variables will be saved.")
.AddCustomChecker( .AddCustomChecker(
[](const std::string &path) { return !path.empty(); }); [](const std::string &path) { return !path.empty(); });
} }
}; };
class CheckpointSaveOpVarTypeInference : public framework::VarTypeInference { // class CheckpointSaveOpVarTypeInference : public framework::VarTypeInference {
public: // public:
void operator()(const framework::OpDesc &op_desc, // void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override { // framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("Serial").front(); // auto out_var_name = op_desc.Output("Serial").front();
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); // auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW; // auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type); // out_var.SetType(var_type);
} // }
}; // };
class CheckpointSaveOpShapeInference : public framework::InferShapeBase { // class CheckpointSaveOpShapeInference : public framework::InferShapeBase {
public: // public:
void operator()(framework::InferShapeContext *ctx) const override {} // void operator()(framework::InferShapeContext *ctx) const override {}
}; // };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -181,7 +181,10 @@ class CheckpointSaveOpShapeInference : public framework::InferShapeBase { ...@@ -181,7 +181,10 @@ class CheckpointSaveOpShapeInference : public framework::InferShapeBase {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(checkpoint_save, ops::CheckpointSaveOp, REGISTER_OPERATOR(checkpoint_save, ops::CheckpointSaveOp,
paddle::framework::EmptyGradOpMaker, ops::CheckpointSaveOpProtoMaker);
ops::CheckpointSaveOpProtoMaker,
ops::CheckpointSaveOpVarTypeInference, // REGISTER_OPERATOR(checkpoint_save, ops::CheckpointSaveOp,
ops::CheckpointSaveOpShapeInference); // paddle::framework::EmptyGradOpMaker,
// ops::CheckpointSaveOpProtoMaker,
// ops::CheckpointSaveOpVarTypeInference,
// ops::CheckpointSaveOpShapeInference);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册