提交 4313d870 编写于 作者: X Xin Pan

refine

上级 c69cf6dd
...@@ -56,9 +56,9 @@ else() ...@@ -56,9 +56,9 @@ else()
cc_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor) cc_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor)
endif() endif()
if (NOT WIN32) if (NOT WIN32)
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio) cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio version)
else() else()
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto) cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto version)
endif (NOT WIN32) endif (NOT WIN32)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
...@@ -251,8 +252,8 @@ void AppendLoD(LoD *lod, const LoD &lod_length) { ...@@ -251,8 +252,8 @@ void AppendLoD(LoD *lod, const LoD &lod_length) {
void SerializeToStream(std::ostream &os, const LoDTensor &tensor, void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
const platform::DeviceContext &dev_ctx) { const platform::DeviceContext &dev_ctx) {
{ // the 1st field, uint32_t version for LoDTensor { // the 1st field, uint32_t version for LoDTensor
constexpr uint32_t version = 0; os.write(reinterpret_cast<const char *>(&kCurTensorVersion),
os.write(reinterpret_cast<const char *>(&version), sizeof(version)); sizeof(kCurTensorVersion));
} }
{ {
// the 2st field, LoD information // the 2st field, LoD information
...@@ -281,6 +282,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, ...@@ -281,6 +282,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
// the 1st field, unit32_t version for LoDTensor // the 1st field, unit32_t version for LoDTensor
uint32_t version; uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version)); is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE(framework::IsTensorVersionSupported(version),
"tensor version %u is not supported.", version);
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported"); PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
} }
{ {
......
...@@ -17,12 +17,20 @@ limitations under the License. */ ...@@ -17,12 +17,20 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
bool IsProgramVersionSupported(int version) { bool IsProgramVersionSupported(int64_t version) {
static int num_supported = static int num_supported =
sizeof(kSupportedProgramVersion) / sizeof(kSupportedProgramVersion[0]); sizeof(kSupportedProgramVersion) / sizeof(kSupportedProgramVersion[0]);
return std::find(kSupportedProgramVersion, return std::find(kSupportedProgramVersion,
kSupportedProgramVersion + num_supported, kSupportedProgramVersion + num_supported,
version) != kSupportedProgramVersion + num_supported; version) != kSupportedProgramVersion + num_supported;
} }
bool IsTensorVersionSupported(uint32_t version) {
static int num_supported =
sizeof(kSupportedTensorVersion) / sizeof(kSupportedTensorVersion[0]);
return std::find(kSupportedTensorVersion,
kSupportedTensorVersion + num_supported,
version) != kSupportedTensorVersion + num_supported;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -12,19 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,19 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cstdint>
#pragma once #pragma once
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// The program version the current codes generate. // The program version the current codes generate.
constexpr int kCurProgramVersion = 0; constexpr int64_t kCurProgramVersion = 0;
// The program version that was generated by previous or current codes // The program version that was generated by previous or current codes
// and supported by current codes. // and supported by current codes.
constexpr int kSupportedProgramVersion[] = {0}; constexpr int64_t kSupportedProgramVersion[] = {0};
// Due to historical reasons, tensor version use uint32_t.
constexpr uint32_t kCurTensorVersion = 0;
constexpr uint32_t kSupportedTensorVersion[] = {0};
bool IsProgramVersionSupported(int64_t version);
bool IsProgramVersionSupported(int version); bool IsTensorVersionSupported(uint32_t version);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -17,10 +17,14 @@ ...@@ -17,10 +17,14 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
TEST(Variable, GetMutable) { TEST(Version, Basic) {
EXPECT_TRUE(IsProgramVersionSupported(0)); EXPECT_TRUE(IsProgramVersionSupported(0));
EXPECT_FALSE(IsProgramVersionSupported(1)); EXPECT_FALSE(IsProgramVersionSupported(1));
EXPECT_FALSE(IsProgramVersionSupported(-1)); EXPECT_FALSE(IsProgramVersionSupported(-1));
EXPECT_TRUE(IsTensorVersionSupported(0));
EXPECT_FALSE(IsTensorVersionSupported(1));
EXPECT_FALSE(IsTensorVersionSupported(-1));
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -126,7 +126,8 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor, ...@@ -126,7 +126,8 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
std::unique_ptr<framework::ProgramDesc> main_program( std::unique_ptr<framework::ProgramDesc> main_program(
new framework::ProgramDesc(program_desc_str)); new framework::ProgramDesc(program_desc_str));
PADDLE_ENFORCE(framework::IsProgramVersionSupported(main_program->Version()), PADDLE_ENFORCE(framework::IsProgramVersionSupported(main_program->Version()),
"model version %d is not supported.", main_program->Version()); "model version %ld is not supported.",
main_program->Version());
LoadPersistables(executor, scope, *main_program, dirname, ""); LoadPersistables(executor, scope, *main_program, dirname, "");
return main_program; return main_program;
...@@ -142,7 +143,8 @@ std::unique_ptr<framework::ProgramDesc> Load( ...@@ -142,7 +143,8 @@ std::unique_ptr<framework::ProgramDesc> Load(
std::unique_ptr<framework::ProgramDesc> main_program( std::unique_ptr<framework::ProgramDesc> main_program(
new framework::ProgramDesc(program_desc_str)); new framework::ProgramDesc(program_desc_str));
PADDLE_ENFORCE(framework::IsProgramVersionSupported(main_program->Version()), PADDLE_ENFORCE(framework::IsProgramVersionSupported(main_program->Version()),
"model version %d is not supported.", main_program->Version()); "model version %ld is not supported.",
main_program->Version());
LoadPersistables(executor, scope, *main_program, "", param_filename); LoadPersistables(executor, scope, *main_program, "", param_filename);
return main_program; return main_program;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册