From f04b23adf96651185bd0b47d90f8b5f1fee77706 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 15 May 2018 16:13:41 +0800 Subject: [PATCH] add checkpoint_load, update checkpoint save --- paddle/fluid/operators/CMakeLists.txt | 3 +- paddle/fluid/operators/checkpoint_load_op.cc | 87 +++++++++++++++++++ ..._save_op_test.cc => checkpoint_op_test.cc} | 0 paddle/fluid/operators/checkpoint_save_op.cc | 21 +++-- 4 files changed, 103 insertions(+), 8 deletions(-) create mode 100644 paddle/fluid/operators/checkpoint_load_op.cc rename paddle/fluid/operators/{checkpoint_save_op_test.cc => checkpoint_op_test.cc} (100%) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 34ec82c294b..df0292d902f 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -243,6 +243,7 @@ op_library(load_op DEPS lod_tensor) op_library(save_combine_op DEPS lod_tensor) op_library(load_combine_op DEPS lod_tensor) op_library(checkpoint_save_op DEPS lod_tensor) +op_library(checkpoint_load_op DEPS lod_tensor) op_library(concat_op DEPS concat) # FIXME(thuan): Move CSP operators to paddle/fluid/framework/operators/concurrency @@ -278,6 +279,6 @@ cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_sea cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory) cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op) -cc_test(checkpoint_save_op_test SRCS checkpoint_save_op_test.cc DEPS checkpoint_save_op) +cc_test(checkpoint_op_test SRCS checkpoint_op_test.cc DEPS checkpoint_save_op checkpoint_load_op) nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context) nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor) diff --git a/paddle/fluid/operators/checkpoint_load_op.cc b/paddle/fluid/operators/checkpoint_load_op.cc new file mode 100644 index 00000000000..b2ca59f2b5b --- /dev/null +++ b/paddle/fluid/operators/checkpoint_load_op.cc @@ -0,0 +1,87 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/data_type_transform.h" +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace operators { + +constexpr char kSEP = '/'; +// write empty file named _SUCCESS +const char SUCCESS[] = "_SUCCESS"; + +static bool FileExists(const std::string &filepath) { + struct stat buffer; + return (stat(filepath.c_str(), &buffer) == 0); +} + +static std::string DirName(const std::string &filepath) { + auto pos = filepath.rfind(kSEP); + if (pos == std::string::npos) { + return ""; + } + return filepath.substr(0, pos); +} + +class CheckpointLoadOp : public framework::OperatorBase { + public: + CheckpointLoadOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto dir = Attr("dir"); + bool is_present = FileExists(dir); + if (!is_present) { + return; + } + + // UPDATE LATER ... + } +}; + +class CheckpointLoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + CheckpointLoadOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr( + "dir", + "(string)" + "The \"file_path\" where the LoDTensor variables will be saved.") + .AddCustomChecker( + [](const std::string &path) { return !path.empty(); }); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(checkpoint_load, ops::CheckpointLoadOp, + ops::CheckpointLoadOpProtoMaker); diff --git a/paddle/fluid/operators/checkpoint_save_op_test.cc b/paddle/fluid/operators/checkpoint_op_test.cc similarity index 100% rename from paddle/fluid/operators/checkpoint_save_op_test.cc rename to paddle/fluid/operators/checkpoint_op_test.cc diff --git a/paddle/fluid/operators/checkpoint_save_op.cc b/paddle/fluid/operators/checkpoint_save_op.cc index 94a1cc05c76..7007ab9e1a1 100644 --- a/paddle/fluid/operators/checkpoint_save_op.cc +++ b/paddle/fluid/operators/checkpoint_save_op.cc @@ -27,8 +27,6 @@ limitations under the License. */ namespace paddle { namespace operators { -// TODO(sidgoyal78): These function are needed by other files (save_op), move -// them to paddle::filesystem namespace. (as noted by yuyang18 in save_op). constexpr char kSEP = '/'; // write empty file named _SUCCESS const char SUCCESS[] = "_SUCCESS"; @@ -82,7 +80,14 @@ class CheckpointSaveOp : public framework::OperatorBase { // overwrite=false", // dir, overwrite); } + MkDirRecursively(dir.c_str()); + auto serial_var_name = Output("Serial"); + auto *serial_var = scope.FindVar(serial_var_name); + std::string *serial_num = serial_var->GetMutable(); + serial_num->append("0"); + dir.append("/"); + dir.append(serial_num); MkDirRecursively(dir.c_str()); auto inp_var_names = Inputs("X"); @@ -93,6 +98,7 @@ class CheckpointSaveOp : public framework::OperatorBase { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); + // todo (tangwei) made it async for (size_t i = 0; i < inp_var_names.size(); i++) { auto *var = scope.FindVar(inp_var_names[i]); std::string var_file; @@ -132,19 +138,20 @@ class CheckpointSaveOpProtoMaker : public framework::OpProtoAndCheckerMaker { "X", "(vector) Input LoDTensors that need to be saved together in a file.") .AsDuplicable(); + AddOutput("Serial", "the serial number"); AddComment(R"DOC( -SaveCombine operator +CheckpointSave operator This operator will serialize and write a list of input LoDTensor variables to a file on disk. )DOC"); AddAttr("overwrite", - "(boolean, default true)" - "Overwrite the output file if it exists.") - .SetDefault(true); + "(boolean, default false)" + "Delete the output dir if it exists.") + .SetDefault(false); AddAttr( - "file_path", + "dir", "(string)" "The \"file_path\" where the LoDTensor variables will be saved.") .AddCustomChecker( -- GitLab