diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 1c9130305cc188bf1794ba5d76e1ed41650dbe8e..d998109df21f585bc4905e00e59fe07247fd3f5e 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -56,9 +56,9 @@ else() cc_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor) endif() if (NOT WIN32) -cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio) +cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio version) else() -cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto) +cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto version) endif (NOT WIN32) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) diff --git a/paddle/fluid/framework/lod_tensor.cc b/paddle/fluid/framework/lod_tensor.cc index adeb26e4e78693eb9760ec1e12e4b71ba3115d5b..1e7da9a69c7cbf8c13306656599a759515802b76 100644 --- a/paddle/fluid/framework/lod_tensor.cc +++ b/paddle/fluid/framework/lod_tensor.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/framework/version.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memory.h" @@ -251,8 +252,8 @@ void AppendLoD(LoD *lod, const LoD &lod_length) { void SerializeToStream(std::ostream &os, const LoDTensor &tensor, const platform::DeviceContext &dev_ctx) { { // the 1st field, uint32_t version for LoDTensor - constexpr uint32_t version = 0; - os.write(reinterpret_cast(&version), sizeof(version)); + os.write(reinterpret_cast(&kCurTensorVersion), + sizeof(kCurTensorVersion)); } { // the 2st field, LoD information @@ -281,6 +282,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, // the 1st field, unit32_t version for LoDTensor uint32_t version; is.read(reinterpret_cast(&version), sizeof(version)); + PADDLE_ENFORCE(framework::IsTensorVersionSupported(version), + "tensor version %u is not supported.", version); PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); } { diff --git a/paddle/fluid/framework/version.cc b/paddle/fluid/framework/version.cc index b0d5c26a31426222124776f0131df25a00694700..3d559e26e0b629848a002a9a7ac3c9adf5047d12 100644 --- a/paddle/fluid/framework/version.cc +++ b/paddle/fluid/framework/version.cc @@ -17,12 +17,20 @@ limitations under the License. */ namespace paddle { namespace framework { -bool IsProgramVersionSupported(int version) { +bool IsProgramVersionSupported(int64_t version) { static int num_supported = sizeof(kSupportedProgramVersion) / sizeof(kSupportedProgramVersion[0]); return std::find(kSupportedProgramVersion, kSupportedProgramVersion + num_supported, version) != kSupportedProgramVersion + num_supported; } + +bool IsTensorVersionSupported(uint32_t version) { + static int num_supported = + sizeof(kSupportedTensorVersion) / sizeof(kSupportedTensorVersion[0]); + return std::find(kSupportedTensorVersion, + kSupportedTensorVersion + num_supported, + version) != kSupportedTensorVersion + num_supported; +} } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/version.h b/paddle/fluid/framework/version.h index 2960ac9782c074f2568023829de6ced6b301f43e..bf07fc288d827ee560c54ef3edc7cd39c9594d95 100644 --- a/paddle/fluid/framework/version.h +++ b/paddle/fluid/framework/version.h @@ -12,19 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include + #pragma once namespace paddle { namespace framework { // The program version the current codes generate. -constexpr int kCurProgramVersion = 0; +constexpr int64_t kCurProgramVersion = 0; // The program version that was generated by previous or current codes // and supported by current codes. -constexpr int kSupportedProgramVersion[] = {0}; +constexpr int64_t kSupportedProgramVersion[] = {0}; + +// Due to historical reasons, tensor version use uint32_t. +constexpr uint32_t kCurTensorVersion = 0; + +constexpr uint32_t kSupportedTensorVersion[] = {0}; + +bool IsProgramVersionSupported(int64_t version); -bool IsProgramVersionSupported(int version); +bool IsTensorVersionSupported(uint32_t version); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/version_test.cc b/paddle/fluid/framework/version_test.cc index cc57f713d8bcde239d1043f6fac79e2dd2df8951..e8c5f256000522af976bbf487741a586f1abc439 100644 --- a/paddle/fluid/framework/version_test.cc +++ b/paddle/fluid/framework/version_test.cc @@ -17,10 +17,14 @@ namespace paddle { namespace framework { -TEST(Variable, GetMutable) { +TEST(Version, Basic) { EXPECT_TRUE(IsProgramVersionSupported(0)); EXPECT_FALSE(IsProgramVersionSupported(1)); EXPECT_FALSE(IsProgramVersionSupported(-1)); + + EXPECT_TRUE(IsTensorVersionSupported(0)); + EXPECT_FALSE(IsTensorVersionSupported(1)); + EXPECT_FALSE(IsTensorVersionSupported(-1)); } } // namespace framework } // namespace paddle diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index 1d20643ce01389b65085f6a215ec648b293c1001..e246a06fd079d837ac321197914c9f70b528f2c8 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -126,7 +126,8 @@ std::unique_ptr Load(framework::Executor* executor, std::unique_ptr main_program( new framework::ProgramDesc(program_desc_str)); PADDLE_ENFORCE(framework::IsProgramVersionSupported(main_program->Version()), - "model version %d is not supported.", main_program->Version()); + "model version %ld is not supported.", + main_program->Version()); LoadPersistables(executor, scope, *main_program, dirname, ""); return main_program; @@ -142,7 +143,8 @@ std::unique_ptr Load( std::unique_ptr main_program( new framework::ProgramDesc(program_desc_str)); PADDLE_ENFORCE(framework::IsProgramVersionSupported(main_program->Version()), - "model version %d is not supported.", main_program->Version()); + "model version %ld is not supported.", + main_program->Version()); LoadPersistables(executor, scope, *main_program, "", param_filename); return main_program;