From 568a329c83312df89defe22f24dc9ef497ac0aca Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 9 May 2018 20:59:46 +0800 Subject: [PATCH] add checkpoint util class and implement --- paddle/fluid/operators/detail/checkpoint.cc | 54 +++++++++++++++++++++ paddle/fluid/operators/detail/checkpoint.h | 33 +++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 paddle/fluid/operators/detail/checkpoint.cc create mode 100644 paddle/fluid/operators/detail/checkpoint.h diff --git a/paddle/fluid/operators/detail/checkpoint.cc b/paddle/fluid/operators/detail/checkpoint.cc new file mode 100644 index 00000000000..78506a0a72e --- /dev/null +++ b/paddle/fluid/operators/detail/checkpoint.cc @@ -0,0 +1,54 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/operators/detail/checkpoint.h" + +#include + +namespace paddle { +namespace framework { +namespace details { +Checkpoint::Save(const framework::Scope& scope, const platform::Place& place, + const std::string& save_dir, const std::string& var_name, + const bool overwrite) { + auto* var = scope.FindVar(var_name); + PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op", + var_name); + PADDLE_ENFORCE(var->IsType(), + "Checkpoint only supports LoDTensor, %s has wrong type", + var_name); + + bool is_present = FileExists(save_dir); + if (is_present && !overwrite) { + PADDLE_THROW("%s exists!, checkpoint cannot write it when overwrite=false", + save_dir, overwrite); + } + + MkDirRecursively(DirName(save_dir).c_str()); + std::ofstream fout(save_dir); + PADDLE_ENFORCE(static_cast(fout), "Cannot open %s to write", save_dir); + + // get device context from pool + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& dev_ctx = *pool.Get(place); + + auto& tensor = var->Get(); + // Serialize tensor + framework::SerializeToStream(fout, tensor, dev_ctx); + fout.close(); +} +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/operators/detail/checkpoint.h b/paddle/fluid/operators/detail/checkpoint.h new file mode 100644 index 00000000000..0f0f450ce17 --- /dev/null +++ b/paddle/fluid/operators/detail/checkpoint.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include + +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace framework { +namespace details { +class Checkpoint { + public: + static void Save(const framework::Scope& scope, const platform::Place& place, + const std::string& save_dir, const std::string& var_name, + const bool overwrite); + + static void Load(); +} +} // namespace details +} // namespace framework +} // namespace paddle -- GitLab