From 4313d870a2a4e99c3a039949224fff41750b1e52 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 27 Aug 2018 20:56:38 +0800 Subject: [PATCH] refine --- paddle/fluid/framework/CMakeLists.txt | 4 ++-- paddle/fluid/framework/lod_tensor.cc | 7 +++++-- paddle/fluid/framework/version.cc | 10 +++++++++- paddle/fluid/framework/version.h | 15 ++++++++++++--- paddle/fluid/framework/version_test.cc | 6 +++++- paddle/fluid/inference/io.cc | 6 ++++-- 6 files changed, 37 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 1c9130305..d998109df 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -56,9 +56,9 @@ else() cc_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor) endif() if (NOT WIN32) -cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio) +cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio version) else() -cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto) +cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto version) endif (NOT WIN32) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) diff --git a/paddle/fluid/framework/lod_tensor.cc b/paddle/fluid/framework/lod_tensor.cc index adeb26e4e..1e7da9a69 100644 --- a/paddle/fluid/framework/lod_tensor.cc +++ b/paddle/fluid/framework/lod_tensor.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/framework/version.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memory.h" @@ -251,8 +252,8 @@ void AppendLoD(LoD *lod, const LoD &lod_length) { void SerializeToStream(std::ostream &os, const LoDTensor &tensor, const platform::DeviceContext &dev_ctx) { { // the 1st field, uint32_t version for LoDTensor - constexpr uint32_t version = 0; - os.write(reinterpret_cast(&version), sizeof(version)); + os.write(reinterpret_cast(&kCurTensorVersion), + sizeof(kCurTensorVersion)); } { // the 2st field, LoD information @@ -281,6 +282,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, // the 1st field, unit32_t version for LoDTensor uint32_t version; is.read(reinterpret_cast(&version), sizeof(version)); + PADDLE_ENFORCE(framework::IsTensorVersionSupported(version), + "tensor version %u is not supported.", version); PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); } { diff --git a/paddle/fluid/framework/version.cc b/paddle/fluid/framework/version.cc index b0d5c26a3..3d559e26e 100644 --- a/paddle/fluid/framework/version.cc +++ b/paddle/fluid/framework/version.cc @@ -17,12 +17,20 @@ limitations under the License. */ namespace paddle { namespace framework { -bool IsProgramVersionSupported(int version) { +bool IsProgramVersionSupported(int64_t version) { static int num_supported = sizeof(kSupportedProgramVersion) / sizeof(kSupportedProgramVersion[0]); return std::find(kSupportedProgramVersion, kSupportedProgramVersion + num_supported, version) != kSupportedProgramVersion + num_supported; } + +bool IsTensorVersionSupported(uint32_t version) { + static int num_supported = + sizeof(kSupportedTensorVersion) / sizeof(kSupportedTensorVersion[0]); + return std::find(kSupportedTensorVersion, + kSupportedTensorVersion + num_supported, + version) != kSupportedTensorVersion + num_supported; +} } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/version.h b/paddle/fluid/framework/version.h index 2960ac978..bf07fc288 100644 --- a/paddle/fluid/framework/version.h +++ b/paddle/fluid/framework/version.h @@ -12,19 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include + #pragma once namespace paddle { namespace framework { // The program version the current codes generate. -constexpr int kCurProgramVersion = 0; +constexpr int64_t kCurProgramVersion = 0; // The program version that was generated by previous or current codes // and supported by current codes. -constexpr int kSupportedProgramVersion[] = {0}; +constexpr int64_t kSupportedProgramVersion[] = {0}; + +// Due to historical reasons, tensor version use uint32_t. +constexpr uint32_t kCurTensorVersion = 0; + +constexpr uint32_t kSupportedTensorVersion[] = {0}; + +bool IsProgramVersionSupported(int64_t version); -bool IsProgramVersionSupported(int version); +bool IsTensorVersionSupported(uint32_t version); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/version_test.cc b/paddle/fluid/framework/version_test.cc index cc57f713d..e8c5f2560 100644 --- a/paddle/fluid/framework/version_test.cc +++ b/paddle/fluid/framework/version_test.cc @@ -17,10 +17,14 @@ namespace paddle { namespace framework { -TEST(Variable, GetMutable) { +TEST(Version, Basic) { EXPECT_TRUE(IsProgramVersionSupported(0)); EXPECT_FALSE(IsProgramVersionSupported(1)); EXPECT_FALSE(IsProgramVersionSupported(-1)); + + EXPECT_TRUE(IsTensorVersionSupported(0)); + EXPECT_FALSE(IsTensorVersionSupported(1)); + EXPECT_FALSE(IsTensorVersionSupported(-1)); } } // namespace framework } // namespace paddle diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index 1d20643ce..e246a06fd 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -126,7 +126,8 @@ std::unique_ptr Load(framework::Executor* executor, std::unique_ptr main_program( new framework::ProgramDesc(program_desc_str)); PADDLE_ENFORCE(framework::IsProgramVersionSupported(main_program->Version()), - "model version %d is not supported.", main_program->Version()); + "model version %ld is not supported.", + main_program->Version()); LoadPersistables(executor, scope, *main_program, dirname, ""); return main_program; @@ -142,7 +143,8 @@ std::unique_ptr Load( std::unique_ptr main_program( new framework::ProgramDesc(program_desc_str)); PADDLE_ENFORCE(framework::IsProgramVersionSupported(main_program->Version()), - "model version %d is not supported.", main_program->Version()); + "model version %ld is not supported.", + main_program->Version()); LoadPersistables(executor, scope, *main_program, "", param_filename); return main_program; -- GitLab