未验证 提交 17bf8713 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #12988 from panyx0718/ir2

program and tensor versioning support
......@@ -56,9 +56,9 @@ else()
cc_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor)
endif()
if (NOT WIN32)
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio)
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio version)
else()
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto)
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto version)
endif (NOT WIN32)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
......@@ -116,7 +116,11 @@ cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope gl
endif(NOT WIN32)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog)
cc_library(version SRCS version.cc)
cc_test(version_test SRCS version_test.cc DEPS version)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
......
......@@ -16,6 +16,13 @@ syntax = "proto2";
option optimize_for = LITE_RUNTIME;
package paddle.framework.proto;
// Any incompatible changes to ProgramDesc and its dependencies should
// raise the version defined version.h.
//
// Serailization and Deserialization codes should be modified in a way
// that supports old versions following the version and compatibility policy.
message Version { optional int64 version = 1 [ default = 0 ]; }
enum AttrType {
INT = 0;
FLOAT = 1;
......@@ -180,4 +187,8 @@ message BlockDesc {
// for more details.
// TODO(panyx0718): A model can have multiple programs. Need a
// way to distinguish them. Maybe ID or name?
message ProgramDesc { repeated BlockDesc blocks = 1; }
message ProgramDesc {
repeated BlockDesc blocks = 1;
optional Version version = 2;
}
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
......@@ -251,8 +252,8 @@ void AppendLoD(LoD *lod, const LoD &lod_length) {
void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
const platform::DeviceContext &dev_ctx) {
{ // the 1st field, uint32_t version for LoDTensor
constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char *>(&version), sizeof(version));
os.write(reinterpret_cast<const char *>(&kCurTensorVersion),
sizeof(kCurTensorVersion));
}
{
// the 2st field, LoD information
......@@ -281,6 +282,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
// the 1st field, unit32_t version for LoDTensor
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE(framework::IsTensorVersionSupported(version),
"tensor version %u is not supported.", version);
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
}
{
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/version.h"
namespace paddle {
namespace framework {
......@@ -38,7 +39,10 @@ proto::ProgramDesc *ProgramDesc::Proto() {
return &desc_;
}
int64_t ProgramDesc::Version() const { return desc_.version().version(); }
ProgramDesc::ProgramDesc() {
desc_.mutable_version()->set_version(kCurProgramVersion);
auto *block = desc_.mutable_blocks()->Add();
block->set_idx(kRootBlockIndex);
block->set_parent_idx(kNoneBlockIndex);
......
......@@ -57,6 +57,8 @@ class ProgramDesc {
proto::ProgramDesc *Proto();
int64_t Version() const;
// The output variable of feed_op is referenced as feed_target.
// This function is used to collect the output variable's name of all
// feed_ops.
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/version.h"
#include <algorithm>
namespace paddle {
namespace framework {
bool IsProgramVersionSupported(int64_t version) {
static int num_supported =
sizeof(kSupportedProgramVersion) / sizeof(kSupportedProgramVersion[0]);
return std::find(kSupportedProgramVersion,
kSupportedProgramVersion + num_supported,
version) != kSupportedProgramVersion + num_supported;
}
bool IsTensorVersionSupported(uint32_t version) {
static int num_supported =
sizeof(kSupportedTensorVersion) / sizeof(kSupportedTensorVersion[0]);
return std::find(kSupportedTensorVersion,
kSupportedTensorVersion + num_supported,
version) != kSupportedTensorVersion + num_supported;
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cstdint>
#pragma once
namespace paddle {
namespace framework {
// Note:
// Program and Tensor that pass the IsXXXVersionSupported should
// be supported by the current codes. Otherwise, it's a compatibility
// bug.
// The program version the current codes generate.
constexpr int64_t kCurProgramVersion = 0;
// The program version that was generated by previous or current codes
// and supported by current codes.
constexpr int64_t kSupportedProgramVersion[] = {0};
// Due to historical reasons, tensor version use uint32_t.
// The tensor version the current codes generate.
constexpr uint32_t kCurTensorVersion = 0;
// The tensor version that was generated by previous or current codes
// and supported by current codes.
constexpr uint32_t kSupportedTensorVersion[] = {0};
bool IsProgramVersionSupported(int64_t version);
bool IsTensorVersionSupported(uint32_t version);
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/version.h"
#include "gtest/gtest.h"
namespace paddle {
namespace framework {
TEST(Version, Basic) {
EXPECT_TRUE(IsProgramVersionSupported(0));
EXPECT_FALSE(IsProgramVersionSupported(1));
EXPECT_FALSE(IsProgramVersionSupported(-1));
EXPECT_TRUE(IsTensorVersionSupported(0));
EXPECT_FALSE(IsTensorVersionSupported(1));
EXPECT_FALSE(IsTensorVersionSupported(-1));
}
} // namespace framework
} // namespace paddle
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/pybind/pybind.h"
......@@ -124,6 +125,9 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
std::unique_ptr<framework::ProgramDesc> main_program(
new framework::ProgramDesc(program_desc_str));
PADDLE_ENFORCE(framework::IsProgramVersionSupported(main_program->Version()),
"model version %ld is not supported.",
main_program->Version());
LoadPersistables(executor, scope, *main_program, dirname, "");
return main_program;
......@@ -138,6 +142,9 @@ std::unique_ptr<framework::ProgramDesc> Load(
std::unique_ptr<framework::ProgramDesc> main_program(
new framework::ProgramDesc(program_desc_str));
PADDLE_ENFORCE(framework::IsProgramVersionSupported(main_program->Version()),
"model version %ld is not supported.",
main_program->Version());
LoadPersistables(executor, scope, *main_program, "", param_filename);
return main_program;
......
......@@ -137,7 +137,10 @@ void BindProgramDesc(pybind11::module *m) {
PADDLE_ENFORCE(desc->ParseFromString(data),
"Fail to parse ProgramDesc from string. This could "
"be a bug of Paddle.");
});
})
.def("_version", [](pd::ProgramDesc &self) -> int64_t {
return self.Proto()->version().version();
});
}
void BindBlockDesc(pybind11::module *m) {
......
......@@ -33,6 +33,7 @@ limitations under the License. */
#include "paddle/fluid/framework/prune.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -530,6 +531,8 @@ All parameter, weight, gradient are variables in Paddle.
m.def("set_feed_variable", framework::SetFeedVariable);
m.def("get_fetch_variable", framework::GetFetchVariable);
m.def("_is_program_version_supported", IsProgramVersionSupported);
BindProgramDesc(&m);
BindBlockDesc(&m);
BindVarDsec(&m);
......
......@@ -1564,6 +1564,9 @@ class Program(object):
"""
return self.desc
def _version(self):
return self.desc._version()
def clone(self, for_test=False):
"""
Create a new, duplicated program.
......
......@@ -750,6 +750,10 @@ def load_inference_model(dirname,
program_desc_str = f.read()
program = Program.parse_from_string(program_desc_str)
if not core._is_program_version_supported(program._version()):
raise ValueError("Unsupported program version: %d\n" %
program._version())
# Binary data also need versioning.
load_persistables(executor, dirname, program, params_filename)
if pserver_endpoints:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册