未验证 提交 0d275916 编写于 作者: 石晓伟 提交者: GitHub

save operator version infomation to program desc, test=develop (#27668)

上级 00d401ec
...@@ -123,7 +123,9 @@ cc_library(attribute SRCS attribute.cc DEPS framework_proto boost enforce) ...@@ -123,7 +123,9 @@ cc_library(attribute SRCS attribute.cc DEPS framework_proto boost enforce)
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
device_context) device_context)
cc_library(op_version_registry SRCS op_version_registry.cc DEPS framework_proto boost) cc_library(op_version_proto SRCS op_version_proto.cc DEPS framework_proto boost)
cc_library(op_version_registry SRCS op_version_registry.cc DEPS op_version_proto framework_proto boost)
cc_test(op_version_registry_test SRCS op_version_registry_test.cc DEPS op_version_registry) cc_test(op_version_registry_test SRCS op_version_registry_test.cc DEPS op_version_registry)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute glog) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute glog)
......
...@@ -179,29 +179,15 @@ message BlockDesc { ...@@ -179,29 +179,15 @@ message BlockDesc {
optional int32 forward_block_idx = 5 [ default = -1 ]; optional int32 forward_block_idx = 5 [ default = -1 ];
} }
// CompatibleInfo is used to determine if a feature is compatible and // In some cases, Paddle may perform operator definition iterations,
// provides the information. // and the operator uses OpVersionMap for compatibility testing.
message CompatibleInfo { message OpVersion { required int32 version = 1; }
enum Type { message OpVersionMap {
COMPATIBLE = 0; message OpVersionPair {
DEFINITELY_NOT = 1;
POSSIBLE = 2;
BUG_FIX = 3;
PRECISION_CHANGE = 4;
}
required string version = 1;
required Type type = 2;
}
// In some cases, Paddle Fluid may perform operator definition iterations,
// and the operator uses OpCompatibleMap for compatibility testing.
message OpCompatibleMap {
message OpCompatiblePair {
required string op_name = 1; required string op_name = 1;
required CompatibleInfo compatible_info = 2; required OpVersion op_version = 2;
} }
repeated OpCompatiblePair pair = 1; repeated OpVersionPair pair = 1;
optional string default_required_version = 2;
} }
// Please refer to // Please refer to
...@@ -210,8 +196,8 @@ message OpCompatibleMap { ...@@ -210,8 +196,8 @@ message OpCompatibleMap {
// TODO(panyx0718): A model can have multiple programs. Need a // TODO(panyx0718): A model can have multiple programs. Need a
// way to distinguish them. Maybe ID or name? // way to distinguish them. Maybe ID or name?
message ProgramDesc { message ProgramDesc {
reserved 2; // For backward compatibility. reserved 2, 3; // For backward compatibility.
repeated BlockDesc blocks = 1; repeated BlockDesc blocks = 1;
optional Version version = 4; optional Version version = 4;
optional OpCompatibleMap op_compatible_map = 3; optional OpVersionMap op_version_map = 5;
} }
...@@ -182,40 +182,5 @@ OpCompatibleType OpCompatibleMap::IsRequireMiniVersion( ...@@ -182,40 +182,5 @@ OpCompatibleType OpCompatibleMap::IsRequireMiniVersion(
} }
} }
bool OpCompatibleMap::ConvertToProto(proto::OpCompatibleMap* desc) const {
desc->Clear();
desc->set_default_required_version(default_required_version_);
for (auto pair : op_compatible_map_) {
const CompatibleInfo& info = pair.second;
auto* pair_desc = desc->add_pair();
pair_desc->set_op_name(pair.first);
auto* info_desc = pair_desc->mutable_compatible_info();
info_desc->set_version(info.required_version_);
info_desc->set_type(
static_cast<proto::CompatibleInfo_Type>(info.compatible_type_));
}
return true;
}
bool OpCompatibleMap::ReadFromProto(const proto::OpCompatibleMap& desc) {
std::string version = desc.default_required_version();
if (version.empty()) {
LOG(INFO) << "The default operator required version is missing."
" Please update the model version.";
return false;
}
op_compatible_map_.clear();
default_required_version_ = desc.default_required_version();
for (int i = 0; i < desc.pair_size(); ++i) {
const auto& pair_desc = desc.pair(i);
auto info_desc = pair_desc.compatible_info();
CompatibleInfo info(info_desc.version(),
static_cast<OpCompatibleType>(info_desc.type()));
std::pair<std::string, CompatibleInfo> pair(pair_desc.op_name(), info);
op_compatible_map_.insert(pair);
}
return true;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -58,14 +58,6 @@ class OpCompatibleMap { ...@@ -58,14 +58,6 @@ class OpCompatibleMap {
OpCompatibleType IsRequireMiniVersion(std::string op_name, OpCompatibleType IsRequireMiniVersion(std::string op_name,
std::string current_version) const; std::string current_version) const;
// Convert the entire OpCompatibleMap to Proto, which can be serialized
// to the model file as part of the ProgramDesc.
bool ConvertToProto(proto::OpCompatibleMap* desc) const;
// Read and reset the entire object from proto, which can be read from
// the model file as part of the program.
bool ReadFromProto(const proto::OpCompatibleMap& desc);
const std::string& GetDefaultRequiredVersion() const { const std::string& GetDefaultRequiredVersion() const {
return default_required_version_; return default_required_version_;
} }
......
...@@ -28,12 +28,6 @@ TEST(test_op_compatible_info, test_op_compatible) { ...@@ -28,12 +28,6 @@ TEST(test_op_compatible_info, test_op_compatible) {
auto comp_map = OpCompatibleMap(); auto comp_map = OpCompatibleMap();
comp_map.InitOpCompatibleMap(); comp_map.InitOpCompatibleMap();
// Ensure save-load consistency.
auto program_desc = ProgramDesc();
proto::OpCompatibleMap* proto_map = program_desc.OpCompatibleMap();
comp_map.ConvertToProto(proto_map);
comp_map.ReadFromProto(*proto_map);
ASSERT_NE(comp_map.GetDefaultRequiredVersion(), std::string()); ASSERT_NE(comp_map.GetDefaultRequiredVersion(), std::string());
ASSERT_NE(comp_map.GetOpCompatibleInfo("sequence_pad").required_version_, ASSERT_NE(comp_map.GetOpCompatibleInfo("sequence_pad").required_version_,
std::string()); std::string());
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_version_proto.h"
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/framework.pb.h"
namespace paddle {
namespace framework {
namespace compatible {
namespace pb {
class OpVersion {
public:
explicit OpVersion(proto::OpVersion* desc) : desc_{desc} {}
void SetVersionID(uint32_t version) { desc_->set_version(version); }
private:
proto::OpVersion* desc_;
};
class OpVersionMap {
public:
explicit OpVersionMap(proto::OpVersionMap* desc) : desc_{desc} {}
OpVersion operator[](const std::string& key) {
for (int i = 0; i < desc_->pair_size(); ++i) {
if (desc_->pair(i).op_name() == key) {
return OpVersion(desc_->mutable_pair(i)->mutable_op_version());
}
}
auto* pair = desc_->add_pair();
pair->set_op_name(key);
return OpVersion(pair->mutable_op_version());
}
private:
proto::OpVersionMap* desc_;
};
} // namespace pb
} // namespace compatible
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include <boost/any.hpp> #include <boost/any.hpp>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/op_version_proto.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -159,12 +160,14 @@ class OpVersionRegistrar { ...@@ -159,12 +160,14 @@ class OpVersionRegistrar {
op_version_map_.insert({op_type, OpVersion()}); op_version_map_.insert({op_type, OpVersion()});
return op_version_map_[op_type]; return op_version_map_[op_type];
} }
const std::unordered_map<std::string, OpVersion>& GetVersionMap() {
return op_version_map_;
}
uint32_t GetVersionID(const std::string& op_type) const { uint32_t GetVersionID(const std::string& op_type) const {
auto it = op_version_map_.find(op_type); auto it = op_version_map_.find(op_type);
if (it == op_version_map_.end()) { if (it == op_version_map_.end()) {
return 0; return 0;
} }
return it->second.GetVersionID(); return it->second.GetVersionID();
} }
...@@ -175,6 +178,14 @@ class OpVersionRegistrar { ...@@ -175,6 +178,14 @@ class OpVersionRegistrar {
OpVersionRegistrar& operator=(const OpVersionRegistrar&) = delete; OpVersionRegistrar& operator=(const OpVersionRegistrar&) = delete;
}; };
inline void SaveOpVersions(
const std::unordered_map<std::string, OpVersion>& src,
pb::OpVersionMap* dst) {
for (const auto& pair : src) {
(*dst)[pair.first].SetVersionID(pair.second.GetVersionID());
}
}
class OpVersionComparator { class OpVersionComparator {
public: public:
virtual bool operator()() = 0; virtual bool operator()() = 0;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
......
...@@ -39,8 +39,8 @@ proto::ProgramDesc *ProgramDesc::Proto() { ...@@ -39,8 +39,8 @@ proto::ProgramDesc *ProgramDesc::Proto() {
return &desc_; return &desc_;
} }
proto::OpCompatibleMap *ProgramDesc::OpCompatibleMap() { proto::OpVersionMap *ProgramDesc::OpVersionMap() {
return desc_.mutable_op_compatible_map(); return desc_.mutable_op_version_map();
} }
int64_t ProgramDesc::Version() const { return desc_.version().version(); } int64_t ProgramDesc::Version() const { return desc_.version().version(); }
......
...@@ -58,7 +58,7 @@ class ProgramDesc { ...@@ -58,7 +58,7 @@ class ProgramDesc {
proto::ProgramDesc *Proto(); proto::ProgramDesc *Proto();
proto::OpCompatibleMap *OpCompatibleMap(); proto::OpVersionMap *OpVersionMap();
int64_t Version() const; int64_t Version() const;
......
...@@ -192,11 +192,6 @@ bool AnalysisPredictor::PrepareProgram( ...@@ -192,11 +192,6 @@ bool AnalysisPredictor::PrepareProgram(
// If config_.ir_optim() is False, parameters is loaded in LoadParameters(), // If config_.ir_optim() is False, parameters is loaded in LoadParameters(),
// still need to create other persistable variables. // still need to create other persistable variables.
// So in both case, create persistable variables at first. // So in both case, create persistable variables at first.
if (!CheckOperatorCompatible()) {
LOG(WARNING) << "WARNING: Results may be DIFF! "
"Please use the corresponding version of the model and "
"prediction library, and do not use the develop branch.";
}
executor_->CreateVariables(*inference_program_, 0, true, sub_scope_); executor_->CreateVariables(*inference_program_, 0, true, sub_scope_);
// if enable_ir_optim_ is false, // if enable_ir_optim_ is false,
...@@ -998,40 +993,6 @@ std::string AnalysisPredictor::GetSerializedProgram() const { ...@@ -998,40 +993,6 @@ std::string AnalysisPredictor::GetSerializedProgram() const {
return inference_program_->Proto()->SerializeAsString(); return inference_program_->Proto()->SerializeAsString();
} }
bool AnalysisPredictor::CheckOperatorCompatible() {
if (!inference_program_) {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Inference program version check failed because the program does not "
"exist."));
return false;
}
bool res = true;
op_compatible_map_.ReadFromProto(*inference_program_->OpCompatibleMap());
const auto &version = framework::DumpVersion(framework::kCurProgramVersion);
LOG(INFO) << "MODEL VERSION: "
<< framework::DumpVersion(inference_program_->Version());
LOG(INFO) << "PREDICTOR VERSION: " << version;
std::set<std::string> op_types;
for (size_t i = 0; i < inference_program_->Size(); ++i) {
const auto &block = inference_program_->Block(i);
for (const auto *op : block.AllOps()) {
op_types.insert(op->Type());
}
}
for (const auto type : op_types) {
auto compatible_type =
op_compatible_map_.IsRequireMiniVersion(type, version);
if (compatible_type != framework::OpCompatibleType::compatible) {
if (!framework::kCurProgramVersion) {
LOG(WARNING) << " - Version incompatible ("
<< static_cast<int>(compatible_type) << ") " << type;
}
res = false;
}
}
return res;
}
// Add SaveOptimModel // Add SaveOptimModel
void AnalysisPredictor::SaveOptimModel(const std::string &dir) { void AnalysisPredictor::SaveOptimModel(const std::string &dir) {
// save model // save model
......
...@@ -335,13 +335,6 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -335,13 +335,6 @@ class AnalysisPredictor : public PaddlePredictor {
/// AnalysisPredictor::ZeroCopyRun() now. /// AnalysisPredictor::ZeroCopyRun() now.
/// ///
void MkldnnPostReset(); void MkldnnPostReset();
///
/// \brief Compute compatibility based on model version information and
/// operator version information
///
/// \return Compatible information
///
bool CheckOperatorCompatible();
#if PADDLE_WITH_TENSORRT #if PADDLE_WITH_TENSORRT
/// ///
......
...@@ -36,9 +36,9 @@ limitations under the License. */ ...@@ -36,9 +36,9 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_compatible_info.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/framework/prune.h" #include "paddle/fluid/framework/prune.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
...@@ -432,10 +432,12 @@ PYBIND11_MODULE(core_noavx, m) { ...@@ -432,10 +432,12 @@ PYBIND11_MODULE(core_noavx, m) {
return map_output; return map_output;
}); });
m.def("save_op_compatible_info", [](framework::ProgramDesc &desc) { m.def("save_op_version_info", [](framework::ProgramDesc &desc) {
framework::OpCompatibleMap op_compatible_map; framework::compatible::pb::OpVersionMap pb_vmap{desc.OpVersionMap()};
op_compatible_map.InitOpCompatibleMap(); framework::compatible::SaveOpVersions(
return op_compatible_map.ConvertToProto(desc.OpCompatibleMap()); framework::compatible::OpVersionRegistrar::GetInstance()
.GetVersionMap(),
&pb_vmap);
}); });
m.def( m.def(
......
...@@ -1346,7 +1346,7 @@ def save_inference_model(dirname, ...@@ -1346,7 +1346,7 @@ def save_inference_model(dirname,
append_fetch_ops(main_program, fetch_var_names) append_fetch_ops(main_program, fetch_var_names)
main_program.desc._set_version() main_program.desc._set_version()
paddle.fluid.core.save_op_compatible_info(main_program.desc) paddle.fluid.core.save_op_version_info(main_program.desc)
with open(model_basename, "wb") as f: with open(model_basename, "wb") as f:
f.write(main_program.desc.serialize_to_string()) f.write(main_program.desc.serialize_to_string())
else: else:
...@@ -1720,7 +1720,7 @@ def save(program, model_path): ...@@ -1720,7 +1720,7 @@ def save(program, model_path):
main_program = program.clone() main_program = program.clone()
program.desc.flush() program.desc.flush()
main_program.desc._set_version() main_program.desc._set_version()
paddle.fluid.core.save_op_compatible_info(program.desc) paddle.fluid.core.save_op_version_info(program.desc)
with open(model_path + ".pdmodel", "wb") as f: with open(model_path + ".pdmodel", "wb") as f:
f.write(program.desc.serialize_to_string()) f.write(program.desc.serialize_to_string())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册