From be050565241780003cef777e0b0ad0e49cd7f6b1 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 21 May 2018 19:11:23 +0800 Subject: [PATCH] delete old checkpoint code --- paddle/fluid/operators/CMakeLists.txt | 3 - paddle/fluid/operators/checkpoint_load_op.cc | 213 ------------------- paddle/fluid/operators/checkpoint_op_test.cc | 82 ------- paddle/fluid/operators/checkpoint_save_op.cc | 203 ------------------ python/paddle/fluid/framework.py | 3 +- python/paddle/fluid/io.py | 36 +++- 6 files changed, 32 insertions(+), 508 deletions(-) delete mode 100644 paddle/fluid/operators/checkpoint_load_op.cc delete mode 100644 paddle/fluid/operators/checkpoint_op_test.cc delete mode 100644 paddle/fluid/operators/checkpoint_save_op.cc diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 2288987eaf9..ac1f3f44ae8 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -252,8 +252,6 @@ op_library(save_op DEPS lod_tensor) op_library(load_op DEPS lod_tensor) op_library(save_combine_op DEPS lod_tensor) op_library(load_combine_op DEPS lod_tensor) -op_library(checkpoint_save_op DEPS lod_tensor) -op_library(checkpoint_load_op DEPS lod_tensor) op_library(concat_op DEPS concat) # FIXME(thuan): Move CSP operators to paddle/fluid/framework/operators/concurrency @@ -294,6 +292,5 @@ cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_sea cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory) cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op) -cc_test(checkpoint_op_test SRCS checkpoint_op_test.cc DEPS checkpoint_save_op checkpoint_load_op) nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context) nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor) diff --git a/paddle/fluid/operators/checkpoint_load_op.cc b/paddle/fluid/operators/checkpoint_load_op.cc deleted file mode 100644 index 18871e56c50..00000000000 --- a/paddle/fluid/operators/checkpoint_load_op.cc +++ /dev/null @@ -1,213 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include -#include -#include -#include -#include -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/data_type_transform.h" -#include "paddle/fluid/framework/framework.pb.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/device_context.h" - -namespace paddle { -namespace operators { - -constexpr char kSEP = '/'; -// write empty file named _SUCCESS -const char SUCCESS[] = "_SUCCESS"; -const char SERIAL_VAR[] = "SERIAL_NUMBER"; - -static bool FileExists(const std::string &filepath) { - struct stat buffer; - return (stat(filepath.c_str(), &buffer) == 0); -} - -static std::string GenePath(const std::string &dir, const std::string &file) { - std::string file_path; - file_path.append(file_path); - file_path.append("/"); - file_path.append(file); - return file_path; -} - -static bool IsNumber(const std::string &s) { - std::string::const_iterator it = s.begin(); - while (it != s.end() && std::isdigit(*it)) ++it; - return !s.empty() && it == s.end(); -} - -static void LoadInputVars(const framework::Scope &scope, - const platform::Place &place, - const std::vector &inp_var_names, - const std::string &dir) { - // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(place); - - // todo (tangwei) made it async - for (size_t i = 0; i < inp_var_names.size(); i++) { - auto *var = scope.FindVar(inp_var_names[i]); - - PADDLE_ENFORCE(var != nullptr, - "Cannot find variable %s for save_combine_op", - inp_var_names[i]); - PADDLE_ENFORCE(var->IsType(), - "LoadCombineOp only supports LoDTensor, %s has wrong type", - inp_var_names[i]); - - std::string var_file = GenePath(dir, inp_var_names[i]); - auto *tensor = var->GetMutable(); - std::ifstream fin(var_file); - PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s for load op", - var_file); - framework::DeserializeFromStream(fin, tensor, dev_ctx); - fin.close(); - VLOG(3) << " load var: " << inp_var_names[i] << " finished"; - } -} - -static void LoadStringArgv(const framework::Scope &scope, - const platform::Place &place, - const std::vector &argv, - const std::string &dir) { - for (size_t i = 0; i < argv.size(); i++) { - auto *var = scope.FindVar(argv[i]); - std::string *var_str = var->GetMutable(); - std::string var_file = GenePath(dir, argv[i]); - std::ifstream fin(var_file); - PADDLE_ENFORCE(static_cast(fin), "Cannot open file %s for load op", - var_file); - std::getline(fin, *var_str); - fin.close(); - VLOG(3) << " load String argv: " << argv[i] << " value is: " << var_str; - } -} - -class CheckpointLoadOp : public framework::OperatorBase { - public: - CheckpointLoadOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorBase(type, inputs, outputs, attrs) {} - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - std::string dir = Attr("dir"); - std::string serial_num_attr = Attr("Serial"); - - VLOG(3) << "CheckpointLoadOp get Attr dir: " << dir; - VLOG(3) << "CheckpointLoadOp get Attr Serial: " << serial_num_attr; - - std::string serial_var_name = std::string(SERIAL_VAR); - auto *serial_var = scope.FindVar(serial_var_name); - PADDLE_ENFORCE(serial_var != nullptr, - "Cannot find variable %s for checkpoint_load_op", - serial_var_name); - - auto *serial_num = serial_var->GetMutable(); - serial_num->clear(); - serial_num->append(serial_num_attr); - - VLOG(1) << "CheckpointLoadOp set " << SERIAL_VAR - << " value: " << serial_num; - - std::string success = GenePath(dir, serial_num->c_str()); - VLOG(3) << "Load checkpoint from dir: " << success; - success = GenePath(success, SUCCESS); - bool is_present = FileExists(success); - if (!is_present) { - VLOG(1) << "CheckpointLoadOp can not find " << SUCCESS - << " from: " << success; - return; - } - - VLOG(3) << "Ready to load vars to scope"; - auto inp_var_names = Inputs("X"); - PADDLE_ENFORCE_GT(static_cast(inp_var_names.size()), 0, - "The number of input variables should be greater than 0"); - LoadInputVars(scope, place, inp_var_names, dir); - - // VLOG(3) << "Ready to load string argv to scope"; - // auto argv = Output("Argv"); - // LoadStringArgv(scope, place, argv, dir); - } -}; - -class CheckpointLoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput( - "X", - "(vector) Input LoDTensors that need to be saved together in a file.") - .AsDuplicable(); - AddOutput( - "Argv", - "(vector) Input LoDTensors that need to be saved together in a file."); - AddComment(R"DOC( -CheckpointLoad operator - -This operator will serialize and write a list of input LoDTensor variables -to a file on disk. -)DOC"); - - AddAttr( - "Serial", - "(std::string)" - "The serial number of the checkpoint will to be load."); - AddAttr( - "dir", - "(string)" - "The \"dir\" where the checkpoint files will be loaded.") - .AddCustomChecker( - [](const std::string &path) { return !path.empty(); }); - } -}; - -class CheckpointLoadOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc &op_desc, - framework::BlockDesc *block) const override { - auto out_var_name = op_desc.Output("Argv").front(); - auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarType::RAW; - out_var.SetType(var_type); - } -}; - -class CheckpointLoadOpShapeInference : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override {} -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OPERATOR(checkpoint_load, ops::CheckpointLoadOp, - paddle::framework::EmptyGradOpMaker, - ops::CheckpointLoadOpProtoMaker, - ops::CheckpointLoadOpVarTypeInference, - ops::CheckpointLoadOpShapeInference); - -// REGISTER_OPERATOR(checkpoint_load, ops::CheckpointLoadOp, -// ops::CheckpointLoadOpProtoMaker); diff --git a/paddle/fluid/operators/checkpoint_op_test.cc b/paddle/fluid/operators/checkpoint_op_test.cc deleted file mode 100644 index 5312225e5f9..00000000000 --- a/paddle/fluid/operators/checkpoint_op_test.cc +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "gtest/gtest.h" -#include "paddle/fluid/framework/op_registry.h" - -USE_NO_KERNEL_OP(checkpoint_save) -USE_NO_KERNEL_OP(checkpoint_load) - -TEST(CheckpointSaveOp, CPU) { - paddle::framework::Scope scope; - paddle::platform::CPUPlace place; - - auto var = scope.Var("test_var"); - auto tensor = var->GetMutable(); - tensor->Resize({3, 10}); - paddle::framework::LoD expect_lod; - expect_lod.resize(1); - expect_lod[0].push_back(0); - expect_lod[0].push_back(1); - expect_lod[0].push_back(2); - expect_lod[0].push_back(3); - - tensor->set_lod(expect_lod); - float* expect = tensor->mutable_data(place); - for (int64_t i = 0; i < tensor->numel(); ++i) { - expect[i] = static_cast(paddle::platform::float16(i)); - } - - scope.Var("SERIAL_NUMBER"); - - paddle::framework::AttributeMap attrs; - attrs.insert({"dir", std::string("ckpt")}); - - auto save_op = paddle::framework::OpRegistry::CreateOp( - "checkpoint_save", {{"X", {"test_var"}}}, {}, attrs); - save_op->Run(scope, place); -} - -TEST(CheckpointLoadOp, CPU) { - paddle::framework::Scope scope; - paddle::platform::CPUPlace place; - - auto var = scope.Var("test_var"); - auto tensor = var->GetMutable(); - tensor->Resize({3, 10}); - paddle::framework::LoD expect_lod; - expect_lod.resize(1); - expect_lod[0].push_back(0); - expect_lod[0].push_back(1); - expect_lod[0].push_back(2); - expect_lod[0].push_back(3); - - tensor->set_lod(expect_lod); - float* expect = tensor->mutable_data(place); - for (int64_t i = 0; i < tensor->numel(); ++i) { - expect[i] = static_cast(paddle::platform::float16(i)); - } - - scope.Var("SERIAL_NUMBER"); - auto* serial_num = scope.FindVar("SERIAL_NUMBER")->GetMutable(); - serial_num->append("0"); - - paddle::framework::AttributeMap attrs; - attrs.insert({"dir", std::string("ckpt")}); - attrs.insert({"Serial", std::string("SERIAL_NUMBER")}); - - auto load_op = paddle::framework::OpRegistry::CreateOp( - "checkpoint_load", {{"X", {"test_var"}}}, {{"Argv", {}}}, attrs); - load_op->Run(scope, place); -} diff --git a/paddle/fluid/operators/checkpoint_save_op.cc b/paddle/fluid/operators/checkpoint_save_op.cc deleted file mode 100644 index f904cdc8269..00000000000 --- a/paddle/fluid/operators/checkpoint_save_op.cc +++ /dev/null @@ -1,203 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include -#include -#include -#include -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/data_type_transform.h" -#include "paddle/fluid/framework/framework.pb.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/device_context.h" - -namespace paddle { -namespace operators { - -constexpr char kSEP = '/'; -// write empty file named _SUCCESS -const char SUCCESS[] = "_SUCCESS"; -const char SERIAL_VAR[] = "SERIAL_NUMBER"; - -static bool IsNumber(const std::string &s) { - std::string::const_iterator it = s.begin(); - while (it != s.end() && std::isdigit(*it)) ++it; - return !s.empty() && it == s.end(); -} - -static std::string GenePath(const std::string &dir, const std::string &file) { - std::string file_path; - file_path.append(dir); - file_path.append("/"); - file_path.append(file); - return file_path; -} - -static bool FileExists(const std::string &filepath) { - struct stat buffer; - return (stat(filepath.c_str(), &buffer) == 0); -} - -static std::string DirName(const std::string &filepath) { - auto pos = filepath.rfind(kSEP); - if (pos == std::string::npos) { - return ""; - } - return filepath.substr(0, pos); -} - -static void MkDir(const char *path) { - if (mkdir(path, 0755)) { - PADDLE_ENFORCE_EQ(errno, EEXIST, "%s mkdir failed!", path); - } -} - -static void MkDirRecursively(const char *fullpath) { - if (*fullpath == '\0') return; // empty string - if (FileExists(fullpath)) return; - - MkDirRecursively(DirName(fullpath).c_str()); - MkDir(fullpath); -} - -class CheckpointSaveOp : public framework::OperatorBase { - public: - CheckpointSaveOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorBase(type, inputs, outputs, attrs) {} - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto ck_dir = Attr("dir"); - auto overwrite = Attr("overwrite"); - - std::string serial_var_name = std::string(SERIAL_VAR); - auto *serial_num = - scope.FindVar(serial_var_name)->GetMutable(); - VLOG(1) << "CheckpointSaveOp get " << SERIAL_VAR - << " value: " << serial_num; - - int serials = 0; - if (!serial_num->empty()) { - serials = std::stoi(serial_num->data()); - serials += 1; - } - - serial_num->clear(); - serial_num->append(std::to_string(serials)); - - std::string dir = GenePath(ck_dir, serial_num->c_str()); - VLOG(1) << "CheckpointSaveOp current dir: " << dir; - bool is_present = FileExists(dir); - if (is_present && !overwrite) { - PADDLE_THROW("%s exists!, checkpoint save cannot to overwrite it", dir, - overwrite); - } - MkDirRecursively(dir.c_str()); - - auto inp_var_names = Inputs("X"); - PADDLE_ENFORCE_GT(static_cast(inp_var_names.size()), 0, - "The number of input variables should be greater than 0"); - - // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(place); - - // todo (tangwei) made it async - for (size_t i = 0; i < inp_var_names.size(); i++) { - auto *var = scope.FindVar(inp_var_names[i]); - - PADDLE_ENFORCE(var != nullptr, - "Cannot find variable %s for checkpoint save op", - inp_var_names[i]); - PADDLE_ENFORCE( - var->IsType(), - "CheckpointSaveOp only supports LoDTensor, %s has wrong type", - inp_var_names[i]); - - auto &tensor = var->Get(); - // Serialize tensors one by one - std::string var_file = GenePath(dir, inp_var_names[i]); - std::ofstream fout(var_file); - framework::SerializeToStream(fout, tensor, dev_ctx); - fout.close(); - } - - std::string success = GenePath(dir, SUCCESS); - std::ofstream fout(success); - fout.close(); - } -}; - -class CheckpointSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput( - "X", - "(vector) Input LoDTensors that need to be saved together in a file.") - .AsDuplicable(); - AddComment(R"DOC( -CheckpointSave operator - -This operator will serialize and write a list of input LoDTensor variables -to a file on disk. -)DOC"); - AddAttr("overwrite", - "(boolean, default false)" - "Delete the output dir if it exists.") - .SetDefault(false); - - AddAttr("dir", - "(string)" - "The dir where the LoDTensor variables will be saved.") - .AddCustomChecker( - [](const std::string &path) { return !path.empty(); }); - } -}; - -// class CheckpointSaveOpVarTypeInference : public framework::VarTypeInference { -// public: -// void operator()(const framework::OpDesc &op_desc, -// framework::BlockDesc *block) const override { -// auto out_var_name = op_desc.Output("Serial").front(); -// auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); -// auto var_type = framework::proto::VarType::RAW; -// out_var.SetType(var_type); -// } -// }; - -// class CheckpointSaveOpShapeInference : public framework::InferShapeBase { -// public: -// void operator()(framework::InferShapeContext *ctx) const override {} -// }; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OPERATOR(checkpoint_save, ops::CheckpointSaveOp, - ops::CheckpointSaveOpProtoMaker); - -// REGISTER_OPERATOR(checkpoint_save, ops::CheckpointSaveOp, -// paddle::framework::EmptyGradOpMaker, -// ops::CheckpointSaveOpProtoMaker, -// ops::CheckpointSaveOpVarTypeInference, -// ops::CheckpointSaveOpShapeInference); diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index c5044a07c94..38c765938fe 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -489,8 +489,7 @@ class Operator(object): 'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv', 'listen_and_serv', 'parallel_do', 'save_combine', 'load_combine', 'ncclInit', 'channel_create', 'channel_close', - 'channel_send', 'channel_recv', 'select', 'gen_nccl_id', - 'checkpoint_save', 'checkpoint_load' + 'channel_send', 'channel_recv', 'select', 'gen_nccl_id' } if type not in no_kernel_op_set: self.desc.infer_var_type(self.block.desc) diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 83c32fe9d6e..b1748f0ad0a 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -455,7 +455,7 @@ def get_parameter_value_by_name(name, executor, program=None): SUCCESS = "_SUCCESS" -BEGIN_SECS = time.time() +BEGIN_SECS = None def save_checkpoint(executor, @@ -478,13 +478,21 @@ def save_checkpoint(executor, os.makedirs(dirname) global BEGIN_SECS - if time.time() - BEGIN_SECS < save_secs: - return + if BEGIN_SECS is not None: + if time.time() - BEGIN_SECS < save_secs: + return BEGIN_SECS = time.time() serial = _get_lastest_checkpoint_dir(dirname) + 1 cur_dir = os.path.join(dirname, str(serial)) - save_persistables(executor, cur_dir, main_program) + # save_persistables(executor, cur_dir, main_program) + save_vars( + executor, + dirname=cur_dir, + main_program=main_program, + vars=None, + predicate=is_checkpoint_var, + filename=None) _write_success(cur_dir) _lru_delete(dirname, keep_max) @@ -505,7 +513,25 @@ def restore_checkpoint(dirname, executor, main_program=None): if serial < 0: return cur_dir = os.path.join(dirname, str(serial)) - load_persistables(executor, cur_dir, main_program) + # load_persistables(executor, cur_dir, main_program) + load_vars( + executor, + dirname=cur_dir, + main_program=main_program, + predicate=is_checkpoint_var, + filename=None) + + +def is_checkpoint_var(var): + if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ + var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ + var.desc.type() == core.VarDesc.VarType.RAW: + return False + + if var.name.endswith("@GRAD"): + return False + + return var.persistable def _lru_delete(dirname, keep_max=3): -- GitLab