From a4fd3756bbd95fb8c676af9aab7a22cfe87d9cc5 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 18 May 2018 09:46:14 +0800 Subject: [PATCH] bug fix --- paddle/fluid/operators/checkpoint_load_op.cc | 85 +++++++++++++------- paddle/fluid/operators/checkpoint_op_test.cc | 24 +++++- paddle/fluid/operators/checkpoint_save_op.cc | 36 +++++---- 3 files changed, 95 insertions(+), 50 deletions(-) diff --git a/paddle/fluid/operators/checkpoint_load_op.cc b/paddle/fluid/operators/checkpoint_load_op.cc index 5fd3a7af9cf..d24c7819990 100644 --- a/paddle/fluid/operators/checkpoint_load_op.cc +++ b/paddle/fluid/operators/checkpoint_load_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include #include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type_transform.h" @@ -43,7 +44,13 @@ static std::string GenePath(const std::string &dir, const std::string &file) { file_path.append(file_path); file_path.append("/"); file_path.append(file); - return full_path; + return file_path; +} + +static bool IsNumber(const std::string &s) { + std::string::const_iterator it = s.begin(); + while (it != s.end() && std::isdigit(*it)) ++it; + return !s.empty() && it == s.end(); } static void LoadInputVars(const framework::Scope &scope, @@ -62,7 +69,7 @@ static void LoadInputVars(const framework::Scope &scope, "Cannot find variable %s for save_combine_op", inp_var_names[i]); PADDLE_ENFORCE(var->IsType(), - "SaveCombineOp only supports LoDTensor, %s has wrong type", + "LoadCombineOp only supports LoDTensor, %s has wrong type", inp_var_names[i]); std::string var_file = GenePath(dir, inp_var_names[i]); @@ -78,21 +85,18 @@ static void LoadInputVars(const framework::Scope &scope, static void LoadStringArgv(const framework::Scope &scope, const platform::Place &place, - const std::string &argv, const std::string &dir) { - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(place); - + const std::vector &argv, + const std::string &dir) { for (size_t i = 0; i < argv.size(); i++) { - auto *var = scope.FindVar(inp_var_names[i]); + auto *var = scope.FindVar(argv[i]); std::string *var_str = var->GetMutable(); - - std::string var_file = GenePath(dir, argv); + std::string var_file = GenePath(dir, argv[i]); std::ifstream fin(var_file); PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s for load op", var_file); - std::getline(fin, var_str); + std::getline(fin, *var_str); fin.close(); - VLOG(3) << " load String argv: " << argv << " value is: " << var_str; + VLOG(3) << " load String argv: " << argv[i] << " value is: " << var_str; } } @@ -108,22 +112,24 @@ class CheckpointLoadOp : public framework::OperatorBase { void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { std::string dir = Attr("dir"); - std::string serial_num = Attr("Serial"); + std::string serial_num_attr = Attr("Serial"); + + PADDLE_ENFORCE(IsNumber(serial_num_attr), + "Checkpoint Serial must be a number"); std::string serial_var_name = std::string(SERIAL_VAR); auto *serial_var = scope.FindVar(serial_var_name); - - if (serial_var == nullptr) { - *serial_var = scope.Var(serial_var_name); - auto *serial_tmp = serial_var->GetMutable(); - serial_tmp->append("0"); - } + PADDLE_ENFORCE(serial_var != nullptr, + "Cannot find variable %s for checkpoint_load_op", + serial_var_name); auto *serial_num = serial_var->GetMutable(); - VLOG(1) << "CheckpointLoadOp set " << SERIAL_NUMBER + serial_num = serial_num_attr; + + VLOG(1) << "CheckpointLoadOp set " << SERIAL_VAR << " value: " << serial_num; - std::string success = GenePath(dir, serial_num); + std::string success = GenePath(dir, serial_num->c_str()); VLOG(3) << "Load checkpoint from dir: " << success; success = GenePath(success, SUCCESS); bool is_present = FileExists(success); @@ -137,11 +143,11 @@ class CheckpointLoadOp : public framework::OperatorBase { auto inp_var_names = Inputs("X"); PADDLE_ENFORCE_GT(static_cast(inp_var_names.size()), 0, "The number of input variables should be greater than 0"); - LoadInputVars(scope, place, &inp_var_names); + LoadInputVars(scope, place, inp_var_names, dir); - VLOG(3) << "Ready to load string argv to scope"; - auto argv = Inputs("Argv"); - LoadStringArgv(scope, place, &argv, &dir); + // VLOG(3) << "Ready to load string argv to scope"; + // auto argv = Output("Argv"); + // LoadStringArgv(scope, place, argv, dir); } }; @@ -153,14 +159,13 @@ class CheckpointLoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { "X", "(vector) Input LoDTensors that need to be saved together in a file.") .AsDuplicable(); - AddInput( + AddOutput( "Argv", - "(vector) Input LoDTensors that need to be saved together in a file.") - .AsDuplicable(); + "(vector) Input LoDTensors that need to be saved together in a file."); AddComment(R"DOC( CheckpointLoad operator -This operator will serialize and write a list of input LoDTensor variables +This operator will serialize and write a list of input LoDTensor variables to a file on disk. )DOC"); @@ -177,10 +182,32 @@ to a file on disk. } }; +class CheckpointLoadOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + auto out_var_name = op_desc.Output("Argv").front(); + auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); + auto var_type = framework::proto::VarType::RAW; + out_var.SetType(var_type); + } +}; + +class CheckpointLoadOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override {} +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(checkpoint_load, ops::CheckpointLoadOp, - ops::CheckpointLoadOpProtoMaker); + paddle::framework::EmptyGradOpMaker, + ops::CheckpointLoadOpProtoMaker, + ops::CheckpointLoadOpVarTypeInference, + ops::CheckpointLoadOpShapeInference); + +// REGISTER_OPERATOR(checkpoint_load, ops::CheckpointLoadOp, +// ops::CheckpointLoadOpProtoMaker); diff --git a/paddle/fluid/operators/checkpoint_op_test.cc b/paddle/fluid/operators/checkpoint_op_test.cc index 75bfc3f8407..2acce227d23 100644 --- a/paddle/fluid/operators/checkpoint_op_test.cc +++ b/paddle/fluid/operators/checkpoint_op_test.cc @@ -44,7 +44,7 @@ TEST(CheckpointSaveOp, CPU) { attrs.insert({"dir", std::string("ckpt")}); auto save_op = paddle::framework::OpRegistry::CreateOp( - "checkpoint_save", {{"X", {"test_var"}}}, attrs); + "checkpoint_save", {{"X", {"test_var"}}}, {}, attrs); save_op->Run(scope, place); } @@ -52,13 +52,29 @@ TEST(CheckpointLoadOp, CPU) { paddle::framework::Scope scope; paddle::platform::CPUPlace place; - scope.Var("test_var"); + auto var = scope.Var("test_var"); + auto tensor = var->GetMutable(); + tensor->Resize({3, 10}); + paddle::framework::LoD expect_lod; + expect_lod.resize(1); + expect_lod[0].push_back(0); + expect_lod[0].push_back(1); + expect_lod[0].push_back(2); + expect_lod[0].push_back(3); + + tensor->set_lod(expect_lod); + float* expect = tensor->mutable_data(place); + for (int64_t i = 0; i < tensor->numel(); ++i) { + expect[i] = static_cast(paddle::platform::float16(i)); + } + + scope.Var("SERIAL_NUMBER"); paddle::framework::AttributeMap attrs; attrs.insert({"dir", std::string("ckpt")}); + attrs.insert({"Serial", std::string("SERIAL_NUMBER")}); auto load_op = paddle::framework::OpRegistry::CreateOp( - "checkpoint_load", {{"X", {"test_var"}}}, {{"Serial", {"SERIAL_NUMBER"}}}, - attrs); + "checkpoint_load", {{"X", {"test_var"}}}, {{"Argv", {}}}, attrs); load_op->Run(scope, place); } diff --git a/paddle/fluid/operators/checkpoint_save_op.cc b/paddle/fluid/operators/checkpoint_save_op.cc index 5fccefeed25..bab979e4074 100644 --- a/paddle/fluid/operators/checkpoint_save_op.cc +++ b/paddle/fluid/operators/checkpoint_save_op.cc @@ -33,12 +33,18 @@ constexpr char kSEP = '/'; const char SUCCESS[] = "_SUCCESS"; const char SERIAL_VAR[] = "SERIAL_NUMBER"; +static bool IsNumber(const std::string &s) { + std::string::const_iterator it = s.begin(); + while (it != s.end() && std::isdigit(*it)) ++it; + return !s.empty() && it == s.end(); +} + static std::string GenePath(const std::string &dir, const std::string &file) { std::string file_path; - file_path.append(file_path); + file_path.append(dir); file_path.append("/"); file_path.append(file); - return full_path; + return file_path; } static bool FileExists(const std::string &filepath) { @@ -79,28 +85,24 @@ class CheckpointSaveOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { - auto dir = Attr("dir"); + auto ck_dir = Attr("dir"); auto overwrite = Attr("overwrite"); std::string serial_var_name = std::string(SERIAL_VAR); - auto *serial_var = scope.FindVar(serial_var_name); - - if (serial_var == nullptr) { - *serial_var = scope.Var(serial_var_name); - *serial_tmp = serial_var->GetMutable(); - serial_tmp->append("0"); - } - auto *serial_num = serial_var->GetMutable(); - VLOG(1) << "CheckpointSaveOp get " << SERIAL_NUMBER + auto *serial_num = + scope.FindVar(serial_var_name)->GetMutable(); + VLOG(1) << "CheckpointSaveOp get " << SERIAL_VAR << " value: " << serial_num; - auto *serial_num = serial_var->GetMutable(); - serial_num->append("1"); + if (!IsNumber(serial_num)) { + serial_num = "0"; + } - dir = GenePath(dir, serial_num); + std::string dir = GenePath(ck_dir, serial_num->c_str()); + VLOG(1) << "CheckpointSaveOp current dir: " << dir; bool is_present = FileExists(dir); if (is_present && !overwrite) { - PADDLE_THROW("%s exists!, checkpoint save cannot to overwrite it", dir, + PADDLE_THROW("%s exists!, checkpoint save cannot to overwrite it", dir, overwrite); } MkDirRecursively(dir.c_str()); @@ -150,7 +152,7 @@ class CheckpointSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( CheckpointSave operator -This operator will serialize and write a list of input LoDTensor variables +This operator will serialize and write a list of input LoDTensor variables to a file on disk. )DOC"); AddAttr("overwrite", -- GitLab