未验证 提交 0d275916 编写于 作者: 石晓伟 提交者: GitHub

save operator version infomation to program desc, test=develop (#27668)

上级 00d401ec
......@@ -123,7 +123,9 @@ cc_library(attribute SRCS attribute.cc DEPS framework_proto boost enforce)
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
device_context)
cc_library(op_version_registry SRCS op_version_registry.cc DEPS framework_proto boost)
cc_library(op_version_proto SRCS op_version_proto.cc DEPS framework_proto boost)
cc_library(op_version_registry SRCS op_version_registry.cc DEPS op_version_proto framework_proto boost)
cc_test(op_version_registry_test SRCS op_version_registry_test.cc DEPS op_version_registry)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute glog)
......
......@@ -179,29 +179,15 @@ message BlockDesc {
optional int32 forward_block_idx = 5 [ default = -1 ];
}
// CompatibleInfo is used to determine if a feature is compatible and
// provides the information.
message CompatibleInfo {
enum Type {
COMPATIBLE = 0;
DEFINITELY_NOT = 1;
POSSIBLE = 2;
BUG_FIX = 3;
PRECISION_CHANGE = 4;
}
required string version = 1;
required Type type = 2;
}
// In some cases, Paddle Fluid may perform operator definition iterations,
// and the operator uses OpCompatibleMap for compatibility testing.
message OpCompatibleMap {
message OpCompatiblePair {
// In some cases, Paddle may perform operator definition iterations,
// and the operator uses OpVersionMap for compatibility testing.
message OpVersion { required int32 version = 1; }
message OpVersionMap {
message OpVersionPair {
required string op_name = 1;
required CompatibleInfo compatible_info = 2;
required OpVersion op_version = 2;
}
repeated OpCompatiblePair pair = 1;
optional string default_required_version = 2;
repeated OpVersionPair pair = 1;
}
// Please refer to
......@@ -210,8 +196,8 @@ message OpCompatibleMap {
// TODO(panyx0718): A model can have multiple programs. Need a
// way to distinguish them. Maybe ID or name?
message ProgramDesc {
reserved 2; // For backward compatibility.
reserved 2, 3; // For backward compatibility.
repeated BlockDesc blocks = 1;
optional Version version = 4;
optional OpCompatibleMap op_compatible_map = 3;
optional OpVersionMap op_version_map = 5;
}
......@@ -182,40 +182,5 @@ OpCompatibleType OpCompatibleMap::IsRequireMiniVersion(
}
}
bool OpCompatibleMap::ConvertToProto(proto::OpCompatibleMap* desc) const {
desc->Clear();
desc->set_default_required_version(default_required_version_);
for (auto pair : op_compatible_map_) {
const CompatibleInfo& info = pair.second;
auto* pair_desc = desc->add_pair();
pair_desc->set_op_name(pair.first);
auto* info_desc = pair_desc->mutable_compatible_info();
info_desc->set_version(info.required_version_);
info_desc->set_type(
static_cast<proto::CompatibleInfo_Type>(info.compatible_type_));
}
return true;
}
bool OpCompatibleMap::ReadFromProto(const proto::OpCompatibleMap& desc) {
std::string version = desc.default_required_version();
if (version.empty()) {
LOG(INFO) << "The default operator required version is missing."
" Please update the model version.";
return false;
}
op_compatible_map_.clear();
default_required_version_ = desc.default_required_version();
for (int i = 0; i < desc.pair_size(); ++i) {
const auto& pair_desc = desc.pair(i);
auto info_desc = pair_desc.compatible_info();
CompatibleInfo info(info_desc.version(),
static_cast<OpCompatibleType>(info_desc.type()));
std::pair<std::string, CompatibleInfo> pair(pair_desc.op_name(), info);
op_compatible_map_.insert(pair);
}
return true;
}
} // namespace framework
} // namespace paddle
......@@ -58,14 +58,6 @@ class OpCompatibleMap {
OpCompatibleType IsRequireMiniVersion(std::string op_name,
std::string current_version) const;
// Convert the entire OpCompatibleMap to Proto, which can be serialized
// to the model file as part of the ProgramDesc.
bool ConvertToProto(proto::OpCompatibleMap* desc) const;
// Read and reset the entire object from proto, which can be read from
// the model file as part of the program.
bool ReadFromProto(const proto::OpCompatibleMap& desc);
const std::string& GetDefaultRequiredVersion() const {
return default_required_version_;
}
......
......@@ -28,12 +28,6 @@ TEST(test_op_compatible_info, test_op_compatible) {
auto comp_map = OpCompatibleMap();
comp_map.InitOpCompatibleMap();
// Ensure save-load consistency.
auto program_desc = ProgramDesc();
proto::OpCompatibleMap* proto_map = program_desc.OpCompatibleMap();
comp_map.ConvertToProto(proto_map);
comp_map.ReadFromProto(*proto_map);
ASSERT_NE(comp_map.GetDefaultRequiredVersion(), std::string());
ASSERT_NE(comp_map.GetOpCompatibleInfo("sequence_pad").required_version_,
std::string());
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_version_proto.h"
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/framework.pb.h"
namespace paddle {
namespace framework {
namespace compatible {
namespace pb {
class OpVersion {
public:
explicit OpVersion(proto::OpVersion* desc) : desc_{desc} {}
void SetVersionID(uint32_t version) { desc_->set_version(version); }
private:
proto::OpVersion* desc_;
};
class OpVersionMap {
public:
explicit OpVersionMap(proto::OpVersionMap* desc) : desc_{desc} {}
OpVersion operator[](const std::string& key) {
for (int i = 0; i < desc_->pair_size(); ++i) {
if (desc_->pair(i).op_name() == key) {
return OpVersion(desc_->mutable_pair(i)->mutable_op_version());
}
}
auto* pair = desc_->add_pair();
pair->set_op_name(key);
return OpVersion(pair->mutable_op_version());
}
private:
proto::OpVersionMap* desc_;
};
} // namespace pb
} // namespace compatible
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -22,6 +22,7 @@ limitations under the License. */
#include <boost/any.hpp>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/op_version_proto.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
......@@ -159,12 +160,14 @@ class OpVersionRegistrar {
op_version_map_.insert({op_type, OpVersion()});
return op_version_map_[op_type];
}
const std::unordered_map<std::string, OpVersion>& GetVersionMap() {
return op_version_map_;
}
uint32_t GetVersionID(const std::string& op_type) const {
auto it = op_version_map_.find(op_type);
if (it == op_version_map_.end()) {
return 0;
}
return it->second.GetVersionID();
}
......@@ -175,6 +178,14 @@ class OpVersionRegistrar {
OpVersionRegistrar& operator=(const OpVersionRegistrar&) = delete;
};
inline void SaveOpVersions(
const std::unordered_map<std::string, OpVersion>& src,
pb::OpVersionMap* dst) {
for (const auto& pair : src) {
(*dst)[pair.first].SetVersionID(pair.second.GetVersionID());
}
}
class OpVersionComparator {
public:
virtual bool operator()() = 0;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......
......@@ -39,8 +39,8 @@ proto::ProgramDesc *ProgramDesc::Proto() {
return &desc_;
}
proto::OpCompatibleMap *ProgramDesc::OpCompatibleMap() {
return desc_.mutable_op_compatible_map();
proto::OpVersionMap *ProgramDesc::OpVersionMap() {
return desc_.mutable_op_version_map();
}
int64_t ProgramDesc::Version() const { return desc_.version().version(); }
......
......@@ -58,7 +58,7 @@ class ProgramDesc {
proto::ProgramDesc *Proto();
proto::OpCompatibleMap *OpCompatibleMap();
proto::OpVersionMap *OpVersionMap();
int64_t Version() const;
......
......@@ -192,11 +192,6 @@ bool AnalysisPredictor::PrepareProgram(
// If config_.ir_optim() is False, parameters is loaded in LoadParameters(),
// still need to create other persistable variables.
// So in both case, create persistable variables at first.
if (!CheckOperatorCompatible()) {
LOG(WARNING) << "WARNING: Results may be DIFF! "
"Please use the corresponding version of the model and "
"prediction library, and do not use the develop branch.";
}
executor_->CreateVariables(*inference_program_, 0, true, sub_scope_);
// if enable_ir_optim_ is false,
......@@ -998,40 +993,6 @@ std::string AnalysisPredictor::GetSerializedProgram() const {
return inference_program_->Proto()->SerializeAsString();
}
bool AnalysisPredictor::CheckOperatorCompatible() {
if (!inference_program_) {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Inference program version check failed because the program does not "
"exist."));
return false;
}
bool res = true;
op_compatible_map_.ReadFromProto(*inference_program_->OpCompatibleMap());
const auto &version = framework::DumpVersion(framework::kCurProgramVersion);
LOG(INFO) << "MODEL VERSION: "
<< framework::DumpVersion(inference_program_->Version());
LOG(INFO) << "PREDICTOR VERSION: " << version;
std::set<std::string> op_types;
for (size_t i = 0; i < inference_program_->Size(); ++i) {
const auto &block = inference_program_->Block(i);
for (const auto *op : block.AllOps()) {
op_types.insert(op->Type());
}
}
for (const auto type : op_types) {
auto compatible_type =
op_compatible_map_.IsRequireMiniVersion(type, version);
if (compatible_type != framework::OpCompatibleType::compatible) {
if (!framework::kCurProgramVersion) {
LOG(WARNING) << " - Version incompatible ("
<< static_cast<int>(compatible_type) << ") " << type;
}
res = false;
}
}
return res;
}
// Add SaveOptimModel
void AnalysisPredictor::SaveOptimModel(const std::string &dir) {
// save model
......
......@@ -335,13 +335,6 @@ class AnalysisPredictor : public PaddlePredictor {
/// AnalysisPredictor::ZeroCopyRun() now.
///
void MkldnnPostReset();
///
/// \brief Compute compatibility based on model version information and
/// operator version information
///
/// \return Compatible information
///
bool CheckOperatorCompatible();
#if PADDLE_WITH_TENSORRT
///
......
......@@ -36,9 +36,9 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_compatible_info.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/framework/prune.h"
#include "paddle/fluid/framework/reader.h"
......@@ -432,10 +432,12 @@ PYBIND11_MODULE(core_noavx, m) {
return map_output;
});
m.def("save_op_compatible_info", [](framework::ProgramDesc &desc) {
framework::OpCompatibleMap op_compatible_map;
op_compatible_map.InitOpCompatibleMap();
return op_compatible_map.ConvertToProto(desc.OpCompatibleMap());
m.def("save_op_version_info", [](framework::ProgramDesc &desc) {
framework::compatible::pb::OpVersionMap pb_vmap{desc.OpVersionMap()};
framework::compatible::SaveOpVersions(
framework::compatible::OpVersionRegistrar::GetInstance()
.GetVersionMap(),
&pb_vmap);
});
m.def(
......
......@@ -1346,7 +1346,7 @@ def save_inference_model(dirname,
append_fetch_ops(main_program, fetch_var_names)
main_program.desc._set_version()
paddle.fluid.core.save_op_compatible_info(main_program.desc)
paddle.fluid.core.save_op_version_info(main_program.desc)
with open(model_basename, "wb") as f:
f.write(main_program.desc.serialize_to_string())
else:
......@@ -1720,7 +1720,7 @@ def save(program, model_path):
main_program = program.clone()
program.desc.flush()
main_program.desc._set_version()
paddle.fluid.core.save_op_compatible_info(program.desc)
paddle.fluid.core.save_op_version_info(program.desc)
with open(model_path + ".pdmodel", "wb") as f:
f.write(program.desc.serialize_to_string())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册