From fd2eb55071199df6bb564ee0b30e35b3868c7371 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Tue, 24 Oct 2017 14:12:38 -0700 Subject: [PATCH] "Serialize LoDTensor, Save/Restore model" (#4602) * "add model format design doc" * "add restore function" * "add parse protobuf" * "move necessary information to saver.proto" * "format code" * "add gpu option" * "add lod info" * "add saveop python test wrapper" * "checkpoint reuse save operator" * "rewrite model format design doc" * "async support needed" * "fix run once" * "fix doc based on comments" * "refine based on comments" * "fix based comments" * "remove persistable flag from framework.proto" * "add IndicateDataType to restore op" * "add save test" * "modify save restore code" * "modified the restore logic" * rm checkpoint_op.cc * rm test_checkpoint_op.py * "get inputs outputs name from execution context" * Saving each variable to a independent file * Fix bugs * Rewrite save_restore_op_test with new Python framework * Move `SaveOp` and `RestoreOp` from OpWithKernel to OpBase * Refine unit test of SaveOp and RestoreOp * fix compile errorwq --- doc/design/model_format.md | 36 +++++ paddle/framework/CMakeLists.txt | 8 +- paddle/framework/lod_tensor.cc | 144 +++++++++++++++++ paddle/framework/lod_tensor.h | 22 +++ paddle/framework/lod_tensor_test.cc | 24 ++- paddle/framework/lod_tensor_test.cu | 27 ++++ paddle/framework/saver.proto | 39 +++++ paddle/framework/scope.cc | 17 ++ paddle/framework/scope.h | 4 + paddle/framework/scope_test.cc | 15 ++ paddle/framework/tensor.h | 11 +- paddle/operators/CMakeLists.txt | 7 + paddle/operators/save_restore_op.cc | 147 ++++++++++++++++++ python/paddle/v2/framework/framework.py | 3 +- .../framework/tests/test_save_restore_op.py | 71 +++++++++ 15 files changed, 569 insertions(+), 6 deletions(-) create mode 100644 doc/design/model_format.md create mode 100644 paddle/framework/saver.proto create mode 100644 paddle/operators/save_restore_op.cc create mode 100644 python/paddle/v2/framework/tests/test_save_restore_op.py diff --git a/doc/design/model_format.md b/doc/design/model_format.md new file mode 100644 index 0000000000..db8c36e5f5 --- /dev/null +++ b/doc/design/model_format.md @@ -0,0 +1,36 @@ +# Design Doc: Model Format + +## Motivation + +The model is the output of training process. One complete model consists of two parts, namely, the **topology** and the **parameters**. To support industrial deployment, we need to make the model format must be self-completed and do not expose any training source code. + +As a result, In PaddlePaddle, the **topology** represents as a [ProgramDesc](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/doc/design/program.md), which describes the model structure. The **parameters** contain all the trainable weights in the model, we must support large size parameter, and efficient serialization/deserialization. + +## Implementation + +The topology is saved as a plain text, in detail, a self-contain protobuf file. + +The parameters are saved as a binary file. As we all know, the protobuf message has the limits of [64M size](https://developers.google.com/protocol-buffers/docs/reference/cpp/google.protobuf.io.coded_stream#CodedInputStream.SetTotalBytesLimit.details). We do a (benchmark experiment)[https://github.com/PaddlePaddle/Paddle/pull/4610], its result shows protobuf is not fit in this scene. + +As a result, we design a particular format for tensor serialization. By default, arbitrary tensor in Paddle is a [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md), and has a description information proto of (LoDTensorDesc)[https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L99]. We save the DescProto as the byte string header, it contains the necessary information, such as the `dims`, the `name` of the tensor, and the `LoD` information in [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/paddle/framework/lod_tensor.md). Tensor stores value in a continuous memory buffer, for speed we dump the raw memory to disk and save it as the byte string content. So, the binary format of one tensor is, + +|HeaderLength|ContentLength|**LoDTensorDesc**|**TensorValue**| + +In detail, tensor's byte view as the table shows. Note that all the signed value written in little-endian. + +```text +[offset] [type] [description] +0004 4 bytes integer HeaderLength, the length of LoDTensorDesc +0008 4 bytes integer ContentLength, the length of LodTensor Buffer +0009 1 bytes char TensorDesc +00010 1 bytes char TensorDesc +... +00100 1 bytes char TensorValue +00101 1 bytes char TensorValue +00102 1 bytes char TensorValue .. +... +``` + +## Summary + +We introduce the model format, the `ProgramDesc` describe the **topology**, and a bunch of particular format binary tensors describes the **parameters**. diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index dbe76a8eaf..85374a476d 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -1,4 +1,7 @@ # ddim lib +proto_library(framework_proto SRCS framework.proto) +proto_library(saver_proto SRCS framework.proto saver.proto) + cc_library(ddim SRCS ddim.cc DEPS eigen3) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim) @@ -7,8 +10,8 @@ cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context) cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) -cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor) -cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor) +cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor saver_proto framework_proto) +cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) cc_test(variable_test SRCS variable_test.cc) @@ -16,7 +19,6 @@ cc_test(variable_test SRCS variable_test.cc) cc_library(scope SRCS scope.cc) cc_test(scope_test SRCS scope_test.cc DEPS scope) -proto_library(framework_proto SRCS framework.proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto) cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc) diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc index 7c0ea0df78..f53dd1c185 100644 --- a/paddle/framework/lod_tensor.cc +++ b/paddle/framework/lod_tensor.cc @@ -13,6 +13,15 @@ limitations under the License. */ #include "paddle/framework/lod_tensor.h" +#include "paddle/framework/saver.pb.h" + +#include "paddle/memory/memcpy.h" +#include "paddle/memory/memory.h" + +#include +#include +#include +#include #include @@ -112,5 +121,140 @@ void LoDTensor::ShrinkInLevel(size_t level, size_t elem_begin, lod_ = new_lod; } +std::string LoDTensor::SerializeToString() const { + LoDTensorProto desc; + + // set data_type + if (this->type() == typeid(int8_t)) desc.set_data_type(DataType::BOOL); + if (this->type() == typeid(int16_t)) desc.set_data_type(DataType::INT16); + if (this->type() == typeid(int32_t)) desc.set_data_type(DataType::INT32); + if (this->type() == typeid(int64_t)) desc.set_data_type(DataType::INT64); + // FIXME(dzh): there is no fp16 in standard c++ + + if (this->type() == typeid(float)) // NOLINT + desc.set_data_type(DataType::FP32); + if (this->type() == typeid(double)) // NOLINT + desc.set_data_type(DataType::FP64); + + for (int i = 0; i < dims().size(); ++i) { + desc.add_dims(dims()[i]); + } + + // set lod information + desc.set_lod_level(this->NumLevels()); + for (size_t i = 0; i < this->NumLevels(); ++i) { + LoDInfo* lod = desc.add_levels(); + for (size_t j = 0; j < lod_[i].size(); ++j) { + lod->add_level(lod_[i][j]); + } + } + + desc.set_version(0); + + std::string desc_bytes = desc.SerializeAsString(); + + // FIXME(dzh) : implement fix chunk size buffer. + size_t DESC_SIZE = desc_bytes.size(); + size_t DATA_SIZE = holder_->size() - offset_; + + const size_t BUFFER_SIZE = DESC_SIZE + DATA_SIZE + 2 * sizeof(size_t); + char* buffer = + static_cast(memory::Alloc(platform::CPUPlace(), BUFFER_SIZE)); + + // format: desc_size data_size, desc_bytes, data_bytes. + platform::CPUPlace src_place; + platform::CPUPlace dst_place; + + memory::Copy(dst_place, buffer, src_place, &BUFFER_SIZE, sizeof(size_t)); + memory::Copy(dst_place, buffer + sizeof(size_t), src_place, &DESC_SIZE, + sizeof(size_t)); + memory::Copy(dst_place, buffer + sizeof(size_t) * 2, src_place, + desc_bytes.c_str(), desc_bytes.size()); + + PADDLE_ENFORCE(this->numel() != 0, "Serialize a empty Tensor!"); + + platform::Place place = holder_->place(); + int element_width = holder_->size() / this->numel(); + + if (platform::is_cpu_place(place)) { + memory::Copy(dst_place, buffer + sizeof(size_t) * 2 + desc_bytes.size(), + boost::get(place), + static_cast(holder_->ptr()) + offset_ / element_width, + DATA_SIZE); + } +#ifdef PADDLE_WITH_GPU + if (platform::is_gpu_place(place)) { + memory::Copy(dst_place, buffer + sizeof(size_t) * 2 + desc_bytes.size(), + boost::get(place), + static_cast(holder_->ptr()) + offset_ / element_width, + DATA_SIZE); + } +#endif + + std::string ret(buffer, BUFFER_SIZE); + memory::Free(platform::CPUPlace(), buffer); + return ret; +} + +void LoDTensor::DeserializeFromString(const std::string& s, + const platform::Place& dst_place) { + size_t DESC_SIZE, BUFFER_SIZE; + platform::CPUPlace src_place; + + memory::Copy(src_place, &BUFFER_SIZE, src_place, s.c_str(), sizeof(size_t)); + memory::Copy(src_place, &DESC_SIZE, src_place, s.c_str() + sizeof(size_t), + sizeof(size_t)); + + const size_t DATA_SIZE = BUFFER_SIZE - DESC_SIZE - sizeof(size_t) * 2; + + // parse LoDTensorDesc + LoDTensorProto desc; + desc.ParseFromArray(s.c_str() + sizeof(size_t) * 2, DESC_SIZE); + + std::vector dims; + std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims)); + this->Resize(make_ddim(dims)); + + // parse data type + void* ptr = nullptr; + if (desc.data_type() == DataType::BOOL) + ptr = this->mutable_data(dst_place); + if (desc.data_type() == DataType::INT16) + ptr = this->mutable_data(dst_place); + if (desc.data_type() == DataType::INT32) + ptr = this->mutable_data(dst_place); + if (desc.data_type() == DataType::INT64) + ptr = this->mutable_data(dst_place); + // FIXME(dzh): there is no fp16 in standard c++ + + if (desc.data_type() == DataType::FP32) + ptr = this->mutable_data(dst_place); + if (desc.data_type() == DataType::FP64) + ptr = this->mutable_data(dst_place); + + LoD lod; + std::vector levels; + for (int i = 0; i < desc.levels().size(); ++i) { + auto current_level = desc.levels()[i].level(); + std::copy(current_level.begin(), current_level.end(), + std::back_inserter(levels)); + lod.emplace_back(levels); + levels.clear(); + } + + this->set_lod(lod); + + if (platform::is_cpu_place(dst_place)) { + memory::Copy(boost::get(dst_place), ptr, src_place, + s.c_str() + sizeof(size_t) * 2 + DESC_SIZE, DATA_SIZE); + } +#ifdef PADDLE_WITH_GPU + if (platform::is_gpu_place(dst_place)) { + memory::Copy(boost::get(dst_place), ptr, src_place, + s.c_str() + sizeof(size_t) * 2 + DESC_SIZE, DATA_SIZE); + } +#endif +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index dec59a5750..f78a751c53 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -25,6 +25,7 @@ #include "paddle/framework/ddim.h" #include "paddle/framework/tensor.h" #include "paddle/platform/enforce.h" +#include "paddle/platform/place.h" namespace paddle { namespace framework { @@ -132,6 +133,27 @@ class LoDTensor : public Tensor { */ void ShrinkInLevel(size_t level, size_t elem_begin, size_t elem_end); + /** + * @brief Serialize tensor to char bytes. + * Please check model_format.md for the format detail. + * NOTE: GPUTensor will copy data to cpu implicitly. + * @return return string + */ + + // FIXME(dzh) : Currently, this interface should only be used in + // save/restore model and checkpoint. ParameterServer do not use shape + // information to do the optimization, as a result, when we serialize + // parameter/gradient to string, we should serialize the tensor + // to string in the ps trainer instead of LoDTensor. + std::string SerializeToString() const; + + /** + * @brief Deserialize char bytes to tensor. + * @return return string + */ + void DeserializeFromString(const std::string& s, + const platform::Place& dst_place); + private: LoD lod_; }; diff --git a/paddle/framework/lod_tensor_test.cc b/paddle/framework/lod_tensor_test.cc index e1e15abecf..b984d62071 100644 --- a/paddle/framework/lod_tensor_test.cc +++ b/paddle/framework/lod_tensor_test.cc @@ -17,10 +17,13 @@ #include #include #include +#include namespace paddle { namespace framework { +const int kLodTensorSize = 20 * 128; + class LoDTensorTester : public ::testing::Test { public: virtual void SetUp() override { @@ -38,7 +41,10 @@ class LoDTensorTester : public ::testing::Test { lod_tensor_.Resize({20 /*batch size*/, 128 /*dim*/}); // malloc memory - lod_tensor_.mutable_data(place); + float* dst_ptr = lod_tensor_.mutable_data(place); + for (int i = 0; i < kLodTensorSize; ++i) { + dst_ptr[i] = i; + } lod_tensor_.set_lod(lod); } @@ -101,5 +107,21 @@ TEST_F(LoDTensorTester, ShrinkInLevel) { ASSERT_EQ(new_lod_tensor.data(), lod_tensor_.data()); } +TEST_F(LoDTensorTester, SerializeDeserialize) { + LoDTensor new_lod_tensor = lod_tensor_; + float* src_ptr = lod_tensor_.data(); + std::string s = lod_tensor_.SerializeToString(); + LoDTensor dst; + dst.DeserializeFromString(s, platform::CPUPlace()); + float* dst_ptr = dst.data(); + for (int i = 0; i < kLodTensorSize; ++i) { + EXPECT_EQ(dst_ptr[i], src_ptr[i]); + } + + ASSERT_EQ(dst.NumElements(0), 2UL); + ASSERT_EQ(dst.NumElements(1), 3UL); + ASSERT_EQ(dst.NumElements(2), 8UL); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/lod_tensor_test.cu b/paddle/framework/lod_tensor_test.cu index 25041024cb..11659be02a 100644 --- a/paddle/framework/lod_tensor_test.cu +++ b/paddle/framework/lod_tensor_test.cu @@ -48,3 +48,30 @@ TEST(LoDTensor, LoDInGPU) { CHECK_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2); } } + +TEST(LoDTensor, SerializeDeserialize) { + paddle::framework::LoDTensor lod_tensor; + paddle::platform::GPUPlace place(0); + + paddle::framework::LoD src_lod; + src_lod.push_back(std::vector{0, 2, 4, 6, 8, 10, 12, 14}); + + lod_tensor.Resize({14, 16}); + lod_tensor.mutable_data(place); + + lod_tensor.set_lod(src_lod); + CHECK_EQ(lod_tensor.lod_element(0, 2).first, 4UL); + CHECK_EQ(lod_tensor.lod_element(0, 4).first, 8UL); + + test<<<1, 8>>>(src_lod[0].data(), src_lod[0].size()); + cudaDeviceSynchronize(); + + std::string s = lod_tensor.SerializeToString(); + paddle::framework::LoDTensor dst; + dst.DeserializeFromString(s, place); + paddle::framework::LoD dst_lod = dst.lod(); + + for (size_t i = 0; i < dst_lod[0].size(); ++i) { + CHECK_EQ(src_lod[0].data()[i], dst_lod[0].data()[i] * 2); + } +} diff --git a/paddle/framework/saver.proto b/paddle/framework/saver.proto new file mode 100644 index 0000000000..90a191a6a7 --- /dev/null +++ b/paddle/framework/saver.proto @@ -0,0 +1,39 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +syntax = "proto2"; +option optimize_for = LITE_RUNTIME; +package paddle.framework; + +import "framework.proto"; + +/** + * This file contains necessary information for model, checkpoint. + * etc. + */ + +message LoDInfo { repeated int64 level = 1; } + +/** + * Save the LoDTensorDesc information through LoDTensorProto, its data memory + * is copyed to c buffer immediately. See model_format.md for details. + */ + +message LoDTensorProto { + optional DataType data_type = 1; + repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] + repeated LoDInfo levels = 3; + optional int32 lod_level = 4 [ default = 0 ]; + optional int32 version = 5; +} diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc index ac3ac649f9..19e25fba05 100644 --- a/paddle/framework/scope.cc +++ b/paddle/framework/scope.cc @@ -65,6 +65,23 @@ void Scope::DropKids() { kids_.clear(); } +std::vector Scope::GetAllNames(bool recursive) const { + std::vector known_vars(vars_.size()); + + if (recursive) { + for (auto& kid : kids_) { + auto kid_vars = kid->GetAllNames(); + for (auto& p : kid_vars) { + known_vars.emplace_back(p); + } + } + } + for (auto& p : vars_) { + known_vars.emplace_back(p.first); + } + return known_vars; +} + void Scope::DeleteScope(Scope* scope) { auto it = std::find(this->kids_.begin(), this->kids_.end(), scope); PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope); diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index 7206b53068..ac334da5ef 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include #include "paddle/framework/variable.h" #include "paddle/platform/macros.h" @@ -64,6 +65,9 @@ class Scope { /// Drop all kids scopes belonged to this scope. void DropKids(); + // enumerate all the variables current contains. + std::vector GetAllNames(bool recursive = false) const; + private: // Call Scope::NewScope for a sub-scope. explicit Scope(Scope const* parent) : parent_(parent) {} diff --git a/paddle/framework/scope_test.cc b/paddle/framework/scope_test.cc index 7cc5e3510d..f738d5ba9e 100644 --- a/paddle/framework/scope_test.cc +++ b/paddle/framework/scope_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/scope.h" +#include "glog/logging.h" #include "gtest/gtest.h" using paddle::framework::Scope; @@ -54,3 +55,17 @@ TEST(Scope, FindScope) { EXPECT_EQ(&s, s.FindScope(v)); EXPECT_EQ(&s, ss.FindScope(v)); } + +TEST(Scope, GetAllNames) { + Scope s; + Variable* v = s.Var("a"); + EXPECT_EQ(&s, s.FindScope(v)); + + std::vector ans = s.GetAllNames(); + std::string str; + for (auto& var : ans) { + str += var; + } + + EXPECT_STREQ("a", str.c_str()); +} diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 3a2bdaf086..e31472327d 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -31,6 +31,8 @@ namespace paddle { namespace framework { +class LoDTensor; + class Tensor { public: template @@ -134,6 +136,8 @@ class Tensor { inline void check_memory_size() const; private: + friend class LoDTensor; + /** * @note Placeholder hides type T, so it doesn't appear as a template * parameter of Variable. @@ -181,7 +185,12 @@ class Tensor { /*! holds the memory block if allocated. */ std::shared_ptr holder_; - /*! points to dimensions of memory block. */ + /** + * @brief points to elements dimensions. + * + * @note dims_ do not indicate the memory block size. + */ + DDim dims_; /** diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index f97bc837dc..d2d70d8be7 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -69,6 +69,13 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n") endif() + # save_restore_op contains several operators + if ("${TARGET}" STREQUAL "save_restore_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(save);\n") + endif() + # activation_op contains several operators if ("${TARGET}" STREQUAL "activation_op") set(pybind_flag 1) diff --git a/paddle/operators/save_restore_op.cc b/paddle/operators/save_restore_op.cc new file mode 100644 index 0000000000..314e4e9279 --- /dev/null +++ b/paddle/operators/save_restore_op.cc @@ -0,0 +1,147 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +#include + +namespace paddle { +namespace operators { + +using framework::Tensor; +using framework::LoDTensor; + +inline static std::string VarToFileName(const std::string& folder_path, + const std::string& var_name) { + return folder_path + "/__" + var_name + "__"; +} + +class SaveOp : public framework::OperatorBase { + public: + SaveOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void Run(const framework::Scope& scope, + const platform::DeviceContext& dev_ctx) const override { + const auto& var_names = this->Inputs("X"); + for (const auto& name : var_names) { + PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name), + "Can not find variable '%s' in the scope.", name); + } + std::string folder_path = this->Attr("folderPath"); + PADDLE_ENFORCE(!folder_path.empty(), + "'folderPath' of SaveOp shouldn't be empty."); + + VLOG(1) << "Save variables to folder: " << folder_path; + for (const auto& name : var_names) { + std::string file_name = VarToFileName(folder_path, name); + std::ofstream fout(file_name, std::ofstream::out); + PADDLE_ENFORCE(fout.is_open(), "Fail to create file %s.", file_name); + const LoDTensor& tensor = scope.FindVar(name)->Get(); + std::string bytes = tensor.SerializeToString(); + fout << bytes; + fout.close(); + } + VLOG(1) << "Compelete saving variables. Items count: " << var_names.size(); + } +}; + +class SaveOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SaveOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(tensor), the tensor count can be 1~INT_MAX, tensors names which " + "values will be saved.") + .AsDuplicable(); + AddAttr("folderPath", "the folderPath for save model."); + AddComment(R"DOC( +Save the input tensors to a binary file based on input tensor names and absolute path. + +All the inputs can carry the LoD (Level of Details) information, +or not. +)DOC"); + } +}; + +class RestoreOp : public framework::OperatorBase { + public: + RestoreOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void Run(const framework::Scope& scope, + const platform::DeviceContext& dev_ctx) const override { + const auto& var_names = this->Outputs("Out"); + for (const auto& name : var_names) { + PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name), + "Can not find variable '%s' in the scope.", name); + } + std::string folder_path = this->Attr("folderPath"); + PADDLE_ENFORCE(!folder_path.empty(), + "'folderPath' of RestoreOp shouldn't be empty."); + + VLOG(1) << "Try loading variables from folder: " << folder_path; + + for (const auto& name : var_names) { + std::string file_name = VarToFileName(folder_path, name); + std::ifstream fin(file_name, std::ifstream::in); + PADDLE_ENFORCE(fin.is_open(), "Fail to open file %s.", file_name); + const size_t kBufferSize = 4096; // equal to linux page size + char buffer[kBufferSize]; + std::string cache; + while (!fin.eof()) { + fin.read(buffer, kBufferSize); + cache.append(buffer, fin.gcount()); + } + LoDTensor* tensor = scope.FindVar(name)->GetMutable(); + tensor->DeserializeFromString(cache, dev_ctx.GetPlace()); + fin.close(); + } + VLOG(1) << "Complete loading variables."; + } +}; + +class RestoreOpMaker : public framework::OpProtoAndCheckerMaker { + public: + RestoreOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddOutput("Out", + "(tensor), the tensor count can be 1~INT_MAX, tensors which " + "values will be restores.") + .AsDuplicable(); + AddAttr("folderPath", "the folderPath for model file."); + AddAttr("data_type", "output tensor data type") + .SetDefault(framework::DataType::FP32); + AddComment(R"DOC( +Restore the tensors from model file based on absolute path. + +All the tensors outputs may carry the LoD (Level of Details) information, +or not. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR(save, paddle::operators::SaveOp, + paddle::framework::EmptyGradOpMaker, + paddle::operators::SaveOpMaker); + +REGISTER_OPERATOR(restore, paddle::operators::RestoreOp, + paddle::framework::EmptyGradOpMaker, + paddle::operators::RestoreOpMaker); diff --git a/python/paddle/v2/framework/framework.py b/python/paddle/v2/framework/framework.py index 40b9008d67..b3f8be8be9 100644 --- a/python/paddle/v2/framework/framework.py +++ b/python/paddle/v2/framework/framework.py @@ -261,7 +261,8 @@ class Operator(object): self.desc.set_attr(attr_name, attrs[attr_name]) self.desc.check_attrs() - if type not in {'feed', 'fetch'}: + no_kernel_op_set = {'feed', 'fetch', 'save', 'restore'} + if type not in no_kernel_op_set: self.desc.infer_var_type(self.block.desc) self.desc.infer_shape(self.block.desc) diff --git a/python/paddle/v2/framework/tests/test_save_restore_op.py b/python/paddle/v2/framework/tests/test_save_restore_op.py new file mode 100644 index 0000000000..3a36d03f62 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_save_restore_op.py @@ -0,0 +1,71 @@ +import paddle.v2.framework.core as core +import paddle.v2.framework.framework as framework +import paddle.v2.framework.executor as executor + +import numpy as np +import unittest +import os +import sys +import shutil + +FOLDER_PATH = "./tmp_test_dir" + + +class TestSaveRestoreOp(unittest.TestCase): + def test_save_restore_op(self): + tensor_1_val = np.random.rand(3, 9).astype("float32") + tensor_2_val = np.random.randint(0, 20, size=(4, 2)).astype("int32") + place = core.CPUPlace() + + program = framework.Program() + block = program.global_block() + v_a = block.create_var( + dtype="float32", shape=[3, 9], lod_level=0, name="tensor_1") + v_b = block.create_var( + dtype="int32", shape=[4, 2], lod_level=0, name="tensor_2") + + t_1 = core.LoDTensor() + t_1.set(tensor_1_val, place) + t_2 = core.LoDTensor() + t_2.set(tensor_2_val, place) + block.append_op( + type="save", + inputs={"X": [v_a, v_b]}, + attrs={"folderPath": FOLDER_PATH}) + block.append_op( + type="fill_constant", + outputs={"Out": [v_a]}, + attrs={"shape": [2, 2], + "value": 0.0}) + block.append_op( + type="fill_constant", + outputs={"Out": [v_b]}, + attrs={"shape": [2, 2], + "value": 0.0}) + block.append_op( + type="restore", + outputs={"Out": [v_a, v_b]}, + attrs={"folderPath": FOLDER_PATH}) + + if os.path.exists(FOLDER_PATH): + shutil.rmtree(FOLDER_PATH) + os.makedirs(FOLDER_PATH) + + exe = executor.Executor(place) + out = exe.run(program, + feed={"tensor_1": t_1, + "tensor_2": t_2}, + fetch_list=[v_a, v_b]) + + self.assertTrue(os.path.isdir(FOLDER_PATH)) + self.assertTrue(os.path.isfile(FOLDER_PATH + "/__tensor_1__")) + self.assertTrue(os.path.isfile(FOLDER_PATH + "/__tensor_2__")) + + self.assertTrue(np.array_equal(np.array(out[0]), tensor_1_val)) + self.assertTrue(np.array_equal(np.array(out[1]), tensor_2_val)) + + shutil.rmtree(FOLDER_PATH) + + +if __name__ == "__main__": + unittest.main() -- GitLab