提交 4313d870 编写于 作者: X Xin Pan

refine

上级 c69cf6dd
......@@ -56,9 +56,9 @@ else()
cc_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor)
endif()
if (NOT WIN32)
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio)
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio version)
else()
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto)
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto version)
endif (NOT WIN32)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
......@@ -251,8 +252,8 @@ void AppendLoD(LoD *lod, const LoD &lod_length) {
void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
const platform::DeviceContext &dev_ctx) {
{ // the 1st field, uint32_t version for LoDTensor
constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char *>(&version), sizeof(version));
os.write(reinterpret_cast<const char *>(&kCurTensorVersion),
sizeof(kCurTensorVersion));
}
{
// the 2st field, LoD information
......@@ -281,6 +282,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
// the 1st field, unit32_t version for LoDTensor
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE(framework::IsTensorVersionSupported(version),
"tensor version %u is not supported.", version);
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
}
{
......
......@@ -17,12 +17,20 @@ limitations under the License. */
namespace paddle {
namespace framework {
bool IsProgramVersionSupported(int version) {
bool IsProgramVersionSupported(int64_t version) {
static int num_supported =
sizeof(kSupportedProgramVersion) / sizeof(kSupportedProgramVersion[0]);
return std::find(kSupportedProgramVersion,
kSupportedProgramVersion + num_supported,
version) != kSupportedProgramVersion + num_supported;
}
bool IsTensorVersionSupported(uint32_t version) {
static int num_supported =
sizeof(kSupportedTensorVersion) / sizeof(kSupportedTensorVersion[0]);
return std::find(kSupportedTensorVersion,
kSupportedTensorVersion + num_supported,
version) != kSupportedTensorVersion + num_supported;
}
} // namespace framework
} // namespace paddle
......@@ -12,19 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cstdint>
#pragma once
namespace paddle {
namespace framework {
// The program version the current codes generate.
constexpr int kCurProgramVersion = 0;
constexpr int64_t kCurProgramVersion = 0;
// The program version that was generated by previous or current codes
// and supported by current codes.
constexpr int kSupportedProgramVersion[] = {0};
constexpr int64_t kSupportedProgramVersion[] = {0};
// Due to historical reasons, tensor version use uint32_t.
constexpr uint32_t kCurTensorVersion = 0;
constexpr uint32_t kSupportedTensorVersion[] = {0};
bool IsProgramVersionSupported(int64_t version);
bool IsProgramVersionSupported(int version);
bool IsTensorVersionSupported(uint32_t version);
} // namespace framework
} // namespace paddle
......@@ -17,10 +17,14 @@
namespace paddle {
namespace framework {
TEST(Variable, GetMutable) {
TEST(Version, Basic) {
EXPECT_TRUE(IsProgramVersionSupported(0));
EXPECT_FALSE(IsProgramVersionSupported(1));
EXPECT_FALSE(IsProgramVersionSupported(-1));
EXPECT_TRUE(IsTensorVersionSupported(0));
EXPECT_FALSE(IsTensorVersionSupported(1));
EXPECT_FALSE(IsTensorVersionSupported(-1));
}
} // namespace framework
} // namespace paddle
......@@ -126,7 +126,8 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
std::unique_ptr<framework::ProgramDesc> main_program(
new framework::ProgramDesc(program_desc_str));
PADDLE_ENFORCE(framework::IsProgramVersionSupported(main_program->Version()),
"model version %d is not supported.", main_program->Version());
"model version %ld is not supported.",
main_program->Version());
LoadPersistables(executor, scope, *main_program, dirname, "");
return main_program;
......@@ -142,7 +143,8 @@ std::unique_ptr<framework::ProgramDesc> Load(
std::unique_ptr<framework::ProgramDesc> main_program(
new framework::ProgramDesc(program_desc_str));
PADDLE_ENFORCE(framework::IsProgramVersionSupported(main_program->Version()),
"model version %d is not supported.", main_program->Version());
"model version %ld is not supported.",
main_program->Version());
LoadPersistables(executor, scope, *main_program, "", param_filename);
return main_program;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册