diff --git a/paddle/fluid/operators/checkpoint_load_op.cc b/paddle/fluid/operators/checkpoint_load_op.cc index 72cfccaaa22b7d354208a9b744c5ea1e32364653..ad237a889ad0a2ca8a7b6c2fea9061eb8350ffdf 100644 --- a/paddle/fluid/operators/checkpoint_load_op.cc +++ b/paddle/fluid/operators/checkpoint_load_op.cc @@ -17,6 +17,10 @@ limitations under the License. */ #include #include #include +#include + +#include + #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/framework.pb.h" @@ -30,12 +34,70 @@ namespace operators { constexpr char kSEP = '/'; // write empty file named _SUCCESS const char SUCCESS[] = "_SUCCESS"; +const char SERIAL_VAR[] = "SERIAL_NUMBER"; static bool FileExists(const std::string &filepath) { struct stat buffer; return (stat(filepath.c_str(), &buffer) == 0); } +static std::string GenePath(const std::string &dir, const std::string &file) { + boost::filesystem::path dir(dir); + boost::filesystem::path file(file); + boost::filesystem::path full_path = dir / file; + return full_path; +} + +static void LoadInputVars(const framework::Scope &scope, + const platform::Place &place, + const std::vector &inp_var_names, + const std::string &dir) { + // get device context from pool + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); + + // todo (tangwei) made it async + for (size_t i = 0; i < inp_var_names.size(); i++) { + auto *var = scope.FindVar(inp_var_names[i]); + + PADDLE_ENFORCE(var != nullptr, + "Cannot find variable %s for save_combine_op", + inp_var_names[i]); + PADDLE_ENFORCE(var->IsType(), + "SaveCombineOp only supports LoDTensor, %s has wrong type", + inp_var_names[i]); + + std::string var_file = GenePath(dir, inp_var_names[i]); + auto *tensor = var->GetMutable(); + std::ifstream fin(var_file); + PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s for load op", + var_file); + framework::DeserializeFromStream(fin, tensor, dev_ctx); + fin.close(); + VLOG(3) << " load var: " << inp_var_names[i] << " finished"; + } +} + +static void LoadStringArgv(const framework::Scope &scope, + const platform::Place &place, + const std::string &argv, const std::string &dir) { + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(place); + + for (size_t i = 0; i < argv.size(); i++) { + auto *var = scope.FindVar(inp_var_names[i]); + std::string *var_str = var->GetMutable(); + + std::string var_file = GenePath(dir, argv); + std::ifstream fin(var_file); + PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s for load op", + var_file); + std::getline(fin, var_str); + fin.close(); + VLOG(3) << " load String argv: " << argv << " value is: " << var_str; + } +} + class CheckpointLoadOp : public framework::OperatorBase { public: CheckpointLoadOp(const std::string &type, @@ -48,53 +110,33 @@ class CheckpointLoadOp : public framework::OperatorBase { void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { std::string dir = Attr("dir"); + int serial_num = Attr("Serial"); - VLOG(3) << "Load checkpoint from dir: " << dir; + auto *serial_var = scope.FindVar(SERIAL_VAR); + serial_var = serial_num; + VLOG(1) << "CheckpointLoadOp set " << SERIAL_NUMBER + << " value: " << serial_num; std::string success; - success.append(dir); - success.append("/"); - success.append(SUCCESS); - + = GenePath(dir, std::to_string(serial_num)); + VLOG(3) << "Load checkpoint from dir: " << success; + success = GenePath(success, SUCCESS); bool is_present = FileExists(success); if (!is_present) { - VLOG(3) << "can not find _SUCCESS from path: " << success; + VLOG(1) << "CheckpointLoadOp can not find " << SUCCESS + << " from: " << success; return; } + VLOG(3) << "Ready to load vars to scope"; auto inp_var_names = Inputs("X"); PADDLE_ENFORCE_GT(static_cast(inp_var_names.size()), 0, "The number of input variables should be greater than 0"); - // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(place); - - // todo (tangwei) made it async - for (size_t i = 0; i < inp_var_names.size(); i++) { - auto *var = scope.FindVar(inp_var_names[i]); - - PADDLE_ENFORCE(var != nullptr, - "Cannot find variable %s for save_combine_op", - inp_var_names[i]); - PADDLE_ENFORCE(var->IsType(), - "SaveCombineOp only supports LoDTensor, %s has wrong type", - inp_var_names[i]); - - std::string var_file; - var_file.append(dir); - var_file.append("/"); - var_file.append(inp_var_names[i]); - VLOG(3) << "ready to load var: " << inp_var_names[i]; - - auto *tensor = var->GetMutable(); - std::ifstream fin(var_file); - PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s for load op", - var_file); - framework::DeserializeFromStream(fin, tensor, dev_ctx); - fin.close(); - - VLOG(3) << " load var: " << inp_var_names[i] << " finished"; - } + LoadInputVars(scope, place, &inp_var_names); + + VLOG(3) << "Ready to load string argv to scope"; + auto argv = Inputs("Argv"); + LoadStringArgv(scope, place, &argv, &dir); } }; @@ -106,6 +148,10 @@ class CheckpointLoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { "X", "(vector) Input LoDTensors that need to be saved together in a file.") .AsDuplicable(); + AddInput( + "Argv", + "(vector) Input LoDTensors that need to be saved together in a file.") + .AsDuplicable(); AddComment(R"DOC( CheckpointLoad operator @@ -113,6 +159,9 @@ This operator will serialize and write a list of input LoDTensor variables to a file on disk. )DOC"); + AddAttr("Serial", + "(int)" + "The serial number of the checkpoint will to be load."); AddAttr( "dir", "(string)" diff --git a/paddle/fluid/operators/checkpoint_op_test.cc b/paddle/fluid/operators/checkpoint_op_test.cc index bea44b35cada2996a4c8aee0352c349beddf9c2b..75bfc3f840765b42b037912c98e9afc37ca69b61 100644 --- a/paddle/fluid/operators/checkpoint_op_test.cc +++ b/paddle/fluid/operators/checkpoint_op_test.cc @@ -44,8 +44,7 @@ TEST(CheckpointSaveOp, CPU) { attrs.insert({"dir", std::string("ckpt")}); auto save_op = paddle::framework::OpRegistry::CreateOp( - "checkpoint_save", {{"X", {"test_var"}}}, {{"Serial", {"SERIAL_NUMBER"}}}, - attrs); + "checkpoint_save", {{"X", {"test_var"}}}, attrs); save_op->Run(scope, place); } @@ -58,7 +57,8 @@ TEST(CheckpointLoadOp, CPU) { paddle::framework::AttributeMap attrs; attrs.insert({"dir", std::string("ckpt")}); - auto save_op = paddle::framework::OpRegistry::CreateOp( - "checkpoint_load", {{"X", {"test_var"}}}, {}, attrs); - save_op->Run(scope, place); + auto load_op = paddle::framework::OpRegistry::CreateOp( + "checkpoint_load", {{"X", {"test_var"}}}, {{"Serial", {"SERIAL_NUMBER"}}}, + attrs); + load_op->Run(scope, place); } diff --git a/paddle/fluid/operators/checkpoint_save_op.cc b/paddle/fluid/operators/checkpoint_save_op.cc index 1082bb4a345a2ede31642c228922be80dd075c36..54911fc054c213eade970b9604c82c89bba55ebf 100644 --- a/paddle/fluid/operators/checkpoint_save_op.cc +++ b/paddle/fluid/operators/checkpoint_save_op.cc @@ -17,6 +17,10 @@ limitations under the License. */ #include #include #include +#include + +#include + #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/framework.pb.h" @@ -30,6 +34,14 @@ namespace operators { constexpr char kSEP = '/'; // write empty file named _SUCCESS const char SUCCESS[] = "_SUCCESS"; +const char SERIAL_VAR[] = "SERIAL_NUMBER"; + +static std::string GenePath(const std::string &dir, const std::string &file) { + boost::filesystem::path dir(dir); + boost::filesystem::path file(file); + boost::filesystem::path full_path = dir / file; + return full_path; +} static bool FileExists(const std::string &filepath) { struct stat buffer; @@ -72,24 +84,20 @@ class CheckpointSaveOp : public framework::OperatorBase { auto dir = Attr("dir"); auto overwrite = Attr("overwrite"); + auto serial_num = scope.FindVar(SERIAL_VAR); + if (serial_num == nullptr) { + serial_num = scope.Var(SERIAL_VAR); + } + serial_num = serial_num + 1; + + dir = GenePath(dir, std::to_string(serial_num)); bool is_present = FileExists(dir); if (is_present && !overwrite) { - return; - // todo(tangwei) judge the folder is exist - // PADDLE_THROW("%s exists!, cannot save_combine to it when - // overwrite=false", - // dir, overwrite); + PADDLE_THROW("%s exists!, checkpoint save cannot to overwrite it", dir, + overwrite); } MkDirRecursively(dir.c_str()); - auto serial_var_name = Output("Serial"); - auto *serial_var = scope.FindVar(serial_var_name); - std::string *serial_num = serial_var->GetMutable(); - serial_num->append("0"); - dir.append("/"); - dir.append(serial_num->c_str()); - MkDirRecursively(dir.c_str()); - auto inp_var_names = Inputs("X"); PADDLE_ENFORCE_GT(static_cast(inp_var_names.size()), 0, "The number of input variables should be greater than 0"); @@ -101,30 +109,24 @@ class CheckpointSaveOp : public framework::OperatorBase { // todo (tangwei) made it async for (size_t i = 0; i < inp_var_names.size(); i++) { auto *var = scope.FindVar(inp_var_names[i]); - std::string var_file; - var_file.append(dir); - var_file.append("/"); - var_file.append(inp_var_names[i]); PADDLE_ENFORCE(var != nullptr, - "Cannot find variable %s for save_combine_op", - inp_var_names[i]); - PADDLE_ENFORCE(var->IsType(), - "SaveCombineOp only supports LoDTensor, %s has wrong type", + "Cannot find variable %s for checkpoint save op", inp_var_names[i]); + PADDLE_ENFORCE( + var->IsType(), + "CheckpointSaveOp only supports LoDTensor, %s has wrong type", + inp_var_names[i]); auto &tensor = var->Get(); // Serialize tensors one by one - + std::string var_file = GenePath(dir, inp_var_names[i]); std::ofstream fout(var_file); framework::SerializeToStream(fout, tensor, dev_ctx); fout.close(); } - std::string success; - success.append(dir); - success.append("/"); - success.append(SUCCESS); + std::string success = GenePath(dir, SUCCESS); std::ofstream fout(success); fout.close(); } @@ -138,7 +140,6 @@ class CheckpointSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker { "X", "(vector) Input LoDTensors that need to be saved together in a file.") .AsDuplicable(); - AddOutput("Serial", "the serial number"); AddComment(R"DOC( CheckpointSave operator @@ -150,30 +151,29 @@ to a file on disk. "Delete the output dir if it exists.") .SetDefault(false); - AddAttr( - "dir", - "(string)" - "The \"file_path\" where the LoDTensor variables will be saved.") + AddAttr("dir", + "(string)" + "The dir where the LoDTensor variables will be saved.") .AddCustomChecker( [](const std::string &path) { return !path.empty(); }); } }; -class CheckpointSaveOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override { - auto out_var_name = op_desc.Output("Serial").front(); - auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarType::RAW; - out_var.SetType(var_type); - } -}; - -class CheckpointSaveOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override {} -}; +// class CheckpointSaveOpVarTypeInference : public framework::VarTypeInference { +// public: +// void operator()(const framework::OpDesc &op_desc, +// framework::BlockDesc *block) const override { +// auto out_var_name = op_desc.Output("Serial").front(); +// auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); +// auto var_type = framework::proto::VarType::RAW; +// out_var.SetType(var_type); +// } +// }; + +// class CheckpointSaveOpShapeInference : public framework::InferShapeBase { +// public: +// void operator()(framework::InferShapeContext *ctx) const override {} +// }; } // namespace operators } // namespace paddle @@ -181,7 +181,10 @@ class CheckpointSaveOpShapeInference : public framework::InferShapeBase { namespace ops = paddle::operators; REGISTER_OPERATOR(checkpoint_save, ops::CheckpointSaveOp, - paddle::framework::EmptyGradOpMaker, - ops::CheckpointSaveOpProtoMaker, - ops::CheckpointSaveOpVarTypeInference, - ops::CheckpointSaveOpShapeInference); + ops::CheckpointSaveOpProtoMaker); + +// REGISTER_OPERATOR(checkpoint_save, ops::CheckpointSaveOp, +// paddle::framework::EmptyGradOpMaker, +// ops::CheckpointSaveOpProtoMaker, +// ops::CheckpointSaveOpVarTypeInference, +// ops::CheckpointSaveOpShapeInference);