未验证 提交 01b9d079 编写于 作者: 石晓伟 提交者: GitHub

update operator compatible info, test=develop (#19978)

* update operator compatible info, test=develop

* revert cmake/version.cmake, test=develop

* add unit_tests and fix bugs, test=develop

* update ../paddle/fluid/framework/framework.proto, test=develop

* fix bug of paddle/fluid/inference/api/analysis_predictor.cc, test=develop

* update paddle/fluid/framework/version_test.cc, test=develop

* add comments and rename interfaces, test=develop
上级 20f0878f
...@@ -59,5 +59,13 @@ while ("${PADDLE_VERSION}" STREQUAL "") ...@@ -59,5 +59,13 @@ while ("${PADDLE_VERSION}" STREQUAL "")
endif() endif()
endwhile() endwhile()
string(REPLACE "." ";" PADDLE_VER_LIST ${PADDLE_VERSION})
list(GET PADDLE_VER_LIST 0 PADDLE_MAJOR_VER)
list(GET PADDLE_VER_LIST 1 PADDLE_MINOR_VER)
list(GET PADDLE_VER_LIST 2 PADDLE_PATCH_VER)
math(EXPR PADDLE_VERSION_INTEGER "${PADDLE_MAJOR_VER} * 1000000
+ ${PADDLE_MINOR_VER} * 1000 + ${PADDLE_PATCH_VER}")
add_definitions(-DPADDLE_VERSION=${PADDLE_VERSION}) add_definitions(-DPADDLE_VERSION=${PADDLE_VERSION})
add_definitions(-DPADDLE_VERSION_INTEGER=${PADDLE_VERSION_INTEGER})
message(STATUS "Paddle version is ${PADDLE_VERSION}") message(STATUS "Paddle version is ${PADDLE_VERSION}")
...@@ -224,7 +224,7 @@ cc_library(dlpack_tensor SRCS dlpack_tensor.cc DEPS tensor dlpack) ...@@ -224,7 +224,7 @@ cc_library(dlpack_tensor SRCS dlpack_tensor.cc DEPS tensor dlpack)
cc_test(dlpack_tensor_test SRCS dlpack_tensor_test.cc DEPS dlpack_tensor glog) cc_test(dlpack_tensor_test SRCS dlpack_tensor_test.cc DEPS dlpack_tensor glog)
cc_library(op_compatible_info SRCS op_compatible_info DEPS string_helper) cc_library(op_compatible_info SRCS op_compatible_info DEPS string_helper)
cc_test(op_compatible_info_test SRCS op_compatible_info_test.cc DEPS op_compatible_info string_helper glog) cc_test(op_compatible_info_test SRCS op_compatible_info_test.cc DEPS op_compatible_info proto_desc string_helper glog)
# Get the current working branch # Get the current working branch
execute_process( execute_process(
......
...@@ -179,13 +179,39 @@ message BlockDesc { ...@@ -179,13 +179,39 @@ message BlockDesc {
optional int32 forward_block_idx = 5 [ default = -1 ]; optional int32 forward_block_idx = 5 [ default = -1 ];
} }
// CompatibleInfo is used to determine if a feature is compatible and
// provides the information.
message CompatibleInfo {
enum Type {
COMPATIBLE = 0;
DEFINITELY_NOT = 1;
POSSIBLE = 2;
BUG_FIX = 3;
PRECISION_CHANGE = 4;
}
required string version = 1;
required Type type = 2;
}
// In some cases, Paddle Fluid may perform operator definition iterations,
// and the operator uses OpCompatibleMap for compatibility testing.
message OpCompatibleMap {
message OpCompatiblePair {
required string op_name = 1;
required CompatibleInfo compatible_info = 2;
}
repeated OpCompatiblePair pair = 1;
optional string default_required_version = 2;
}
// Please refer to // Please refer to
// https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/program.md // https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/program.md
// for more details. // for more details.
// TODO(panyx0718): A model can have multiple programs. Need a // TODO(panyx0718): A model can have multiple programs. Need a
// way to distinguish them. Maybe ID or name? // way to distinguish them. Maybe ID or name?
message ProgramDesc { message ProgramDesc {
reserved 2; // For backward compatibility.
repeated BlockDesc blocks = 1; repeated BlockDesc blocks = 1;
optional Version version = 4;
optional Version version = 2; optional OpCompatibleMap op_compatible_map = 3;
} }
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/op_compatible_info.h" #include "paddle/fluid/framework/op_compatible_info.h"
#include <iostream> #include <iostream>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
...@@ -72,7 +73,7 @@ void OpCompatibleMap::InitOpCompatibleMap() { ...@@ -72,7 +73,7 @@ void OpCompatibleMap::InitOpCompatibleMap() {
op_compatible_map_["layer_norm"] = {"1.6.0", OpCompatibleType::bug_fix}; op_compatible_map_["layer_norm"] = {"1.6.0", OpCompatibleType::bug_fix};
} }
CompatibleInfo OpCompatibleMap::GetOpCompatibleInfo(std::string op_name) { CompatibleInfo OpCompatibleMap::GetOpCompatibleInfo(std::string op_name) const {
auto it = op_compatible_map_.find(op_name); auto it = op_compatible_map_.find(op_name);
if (it != op_compatible_map_.end()) { if (it != op_compatible_map_.end()) {
return it->second; return it->second;
...@@ -82,7 +83,7 @@ CompatibleInfo OpCompatibleMap::GetOpCompatibleInfo(std::string op_name) { ...@@ -82,7 +83,7 @@ CompatibleInfo OpCompatibleMap::GetOpCompatibleInfo(std::string op_name) {
} }
OpCompatibleType OpCompatibleMap::IsRequireMiniVersion( OpCompatibleType OpCompatibleMap::IsRequireMiniVersion(
std::string op_name, std::string str_current_version) { std::string op_name, std::string str_current_version) const {
auto it = op_compatible_map_.find(op_name); auto it = op_compatible_map_.find(op_name);
if (it != op_compatible_map_.end()) { if (it != op_compatible_map_.end()) {
if (CompareVersion(str_current_version, it->second.required_version_)) { if (CompareVersion(str_current_version, it->second.required_version_)) {
...@@ -100,5 +101,40 @@ OpCompatibleType OpCompatibleMap::IsRequireMiniVersion( ...@@ -100,5 +101,40 @@ OpCompatibleType OpCompatibleMap::IsRequireMiniVersion(
} }
} }
bool OpCompatibleMap::ConvertToProto(proto::OpCompatibleMap* desc) const {
desc->Clear();
desc->set_default_required_version(default_required_version_);
for (auto pair : op_compatible_map_) {
const CompatibleInfo& info = pair.second;
auto* pair_desc = desc->add_pair();
pair_desc->set_op_name(pair.first);
auto* info_desc = pair_desc->mutable_compatible_info();
info_desc->set_version(info.required_version_);
info_desc->set_type(
static_cast<proto::CompatibleInfo_Type>(info.compatible_type_));
}
return true;
}
bool OpCompatibleMap::ReadFromProto(const proto::OpCompatibleMap& desc) {
std::string version = desc.default_required_version();
if (version.empty()) {
LOG(INFO) << "The default operator required version is missing."
" Please update the model version.";
return false;
}
op_compatible_map_.clear();
default_required_version_ = desc.default_required_version();
for (int i = 0; i < desc.pair_size(); ++i) {
const auto& pair_desc = desc.pair(i);
auto info_desc = pair_desc.compatible_info();
CompatibleInfo info(info_desc.version(),
static_cast<OpCompatibleType>(info_desc.type()));
std::pair<std::string, CompatibleInfo> pair(pair_desc.op_name(), info);
op_compatible_map_.insert(pair);
}
return true;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <map> #include <map>
#include <string> #include <string>
#include "paddle/fluid/framework/program_desc.h"
#pragma once #pragma once
...@@ -44,24 +45,28 @@ class OpCompatibleMap { ...@@ -44,24 +45,28 @@ class OpCompatibleMap {
OpCompatibleMap() : default_required_version_("1.5.0") {} OpCompatibleMap() : default_required_version_("1.5.0") {}
void InitOpCompatibleMap(); void InitOpCompatibleMap();
CompatibleInfo GetOpCompatibleInfo(std::string op_name); CompatibleInfo GetOpCompatibleInfo(std::string op_name) const;
/* IsRequireMiniVersion /* IsRequireMiniVersion
* return type OpCompatibleType */ * return type OpCompatibleType */
OpCompatibleType IsRequireMiniVersion(std::string op_name, OpCompatibleType IsRequireMiniVersion(std::string op_name,
std::string current_version); std::string current_version) const;
void SerializeToStr(std::string& str) {} /* NOLINT */ // Convert the entire OpCompatibleMap to Proto, which can be serialized
void UnSerialize(const std::string& str) {} // to the model file as part of the ProgramDesc.
bool ConvertToProto(proto::OpCompatibleMap* desc) const;
const std::string& GetDefaultRequiredVersion() { // Read and reset the entire object from proto, which can be read from
// the model file as part of the program.
bool ReadFromProto(const proto::OpCompatibleMap& desc);
const std::string& GetDefaultRequiredVersion() const {
return default_required_version_; return default_required_version_;
} }
private: private:
std::map<std::string, CompatibleInfo> op_compatible_map_; std::map<std::string, CompatibleInfo> op_compatible_map_;
std::string default_required_version_; std::string default_required_version_;
}; };
......
...@@ -15,21 +15,31 @@ ...@@ -15,21 +15,31 @@
#include "paddle/fluid/framework/op_compatible_info.h" #include "paddle/fluid/framework/op_compatible_info.h"
#include <iostream> #include <iostream>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
TEST(test_op_compatible_info, test_op_compatible) { TEST(test_op_compatible_info, test_op_compatible) {
auto comp_map = OpCompatibleMap(); auto comp_map = OpCompatibleMap();
comp_map.InitOpCompatibleMap(); comp_map.InitOpCompatibleMap();
auto default_req_version = comp_map.GetDefaultRequiredVersion(); // Ensure save-load consistency.
auto program_desc = ProgramDesc();
auto seq_pad = comp_map.GetOpCompatibleInfo("sequence_pad"); proto::OpCompatibleMap* proto_map = program_desc.OpCompatibleMap();
auto reshape = comp_map.GetOpCompatibleInfo("reshape"); comp_map.ConvertToProto(proto_map);
auto layer_norm = comp_map.GetOpCompatibleInfo("layer_norm"); comp_map.ReadFromProto(*proto_map);
auto deafult_info = comp_map.GetOpCompatibleInfo("layer_xx"); ASSERT_NE(comp_map.GetDefaultRequiredVersion(), std::string());
ASSERT_NE(comp_map.GetOpCompatibleInfo("sequence_pad").required_version_,
std::string());
ASSERT_NE(comp_map.GetOpCompatibleInfo("reshape").required_version_,
std::string());
ASSERT_NE(comp_map.GetOpCompatibleInfo("layer_norm").required_version_,
std::string());
ASSERT_NE(comp_map.GetOpCompatibleInfo("layer_xx").required_version_,
std::string());
auto comp_1 = comp_map.IsRequireMiniVersion("sequence_pad", "1.5.0"); auto comp_1 = comp_map.IsRequireMiniVersion("sequence_pad", "1.5.0");
ASSERT_EQ(comp_1, OpCompatibleType::DEFIN_NOT); ASSERT_EQ(comp_1, OpCompatibleType::DEFIN_NOT);
...@@ -54,5 +64,6 @@ TEST(test_op_compatible_info, test_op_compatible) { ...@@ -54,5 +64,6 @@ TEST(test_op_compatible_info, test_op_compatible) {
ASSERT_EQ(comp_map.IsRequireMiniVersion("slice", "1.6.0"), ASSERT_EQ(comp_map.IsRequireMiniVersion("slice", "1.6.0"),
OpCompatibleType::compatible); OpCompatibleType::compatible);
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -39,10 +39,18 @@ proto::ProgramDesc *ProgramDesc::Proto() { ...@@ -39,10 +39,18 @@ proto::ProgramDesc *ProgramDesc::Proto() {
return &desc_; return &desc_;
} }
proto::OpCompatibleMap *ProgramDesc::OpCompatibleMap() {
return desc_.mutable_op_compatible_map();
}
int64_t ProgramDesc::Version() const { return desc_.version().version(); } int64_t ProgramDesc::Version() const { return desc_.version().version(); }
void ProgramDesc::SetVersion(const int64_t version) {
desc_.mutable_version()->set_version(version);
}
ProgramDesc::ProgramDesc() { ProgramDesc::ProgramDesc() {
desc_.mutable_version()->set_version(kCurProgramVersion); SetVersion(kCurProgramVersion);
auto *block = desc_.mutable_blocks()->Add(); auto *block = desc_.mutable_blocks()->Add();
block->set_idx(kRootBlockIndex); block->set_idx(kRootBlockIndex);
block->set_parent_idx(kNoneBlockIndex); block->set_parent_idx(kNoneBlockIndex);
......
...@@ -57,8 +57,12 @@ class ProgramDesc { ...@@ -57,8 +57,12 @@ class ProgramDesc {
proto::ProgramDesc *Proto(); proto::ProgramDesc *Proto();
proto::OpCompatibleMap *OpCompatibleMap();
int64_t Version() const; int64_t Version() const;
void SetVersion(const int64_t version);
// The output variable of feed_op is referenced as feed_target. // The output variable of feed_op is referenced as feed_target.
// This function is used to collect the output variable's name of all // This function is used to collect the output variable's name of all
// feed_ops. // feed_ops.
......
...@@ -14,23 +14,38 @@ limitations under the License. */ ...@@ -14,23 +14,38 @@ limitations under the License. */
#include "paddle/fluid/framework/version.h" #include "paddle/fluid/framework/version.h"
#include <algorithm> #include <algorithm>
#include <sstream>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
bool IsProgramVersionSupported(int64_t version) { bool IsProgramVersionSupported(int64_t version) {
static int num_supported = /* So far, all old versions of Tensor are supported in the
sizeof(kSupportedProgramVersion) / sizeof(kSupportedProgramVersion[0]); * new version. The compatibility judgment cannot be made only
return std::find(kSupportedProgramVersion, * by the version number. Please do not use this interface,
kSupportedProgramVersion + num_supported, * it may be discarded because backward compatibility.
version) != kSupportedProgramVersion + num_supported; */
return true;
} }
bool IsTensorVersionSupported(uint32_t version) { bool IsTensorVersionSupported(uint32_t version) {
static int num_supported = /* So far, all old versions of Tensor are supported in the
sizeof(kSupportedTensorVersion) / sizeof(kSupportedTensorVersion[0]); * new version. The compatibility judgment cannot be made only
return std::find(kSupportedTensorVersion, * by the version number. Please do not use this interface,
kSupportedTensorVersion + num_supported, * it may be discarded because backward compatibility.
version) != kSupportedTensorVersion + num_supported; */
return true;
}
std::string DumpVersion(const int64_t version) {
std::stringstream buffer;
const int major = version / MAJOR_COEFF;
const int minor = (version - major * MAJOR_COEFF) / MINOR_COEFF;
const int patch =
(version - major * MAJOR_COEFF - minor * MINOR_COEFF) / PATCH_COEFF;
buffer << major << "." << minor << "." << patch;
return buffer.str();
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cstdint>
#pragma once #pragma once
#include <cstdint>
#include <string>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -24,8 +25,16 @@ namespace framework { ...@@ -24,8 +25,16 @@ namespace framework {
// be supported by the current codes. Otherwise, it's a compatibility // be supported by the current codes. Otherwise, it's a compatibility
// bug. // bug.
constexpr int MAJOR_COEFF = 1000000;
constexpr int MINOR_COEFF = 1000;
constexpr int PATCH_COEFF = 1;
// The program version the current codes generate. // The program version the current codes generate.
#ifdef PADDLE_VERSION_INTEGER
constexpr int64_t kCurProgramVersion = PADDLE_VERSION_INTEGER;
#else
constexpr int64_t kCurProgramVersion = 0; constexpr int64_t kCurProgramVersion = 0;
#endif
// The program version that was generated by previous or current codes // The program version that was generated by previous or current codes
// and supported by current codes. // and supported by current codes.
...@@ -39,9 +48,12 @@ constexpr uint32_t kCurTensorVersion = 0; ...@@ -39,9 +48,12 @@ constexpr uint32_t kCurTensorVersion = 0;
// and supported by current codes. // and supported by current codes.
constexpr uint32_t kSupportedTensorVersion[] = {0}; constexpr uint32_t kSupportedTensorVersion[] = {0};
// WARNING: DO NOT use this interface, it may be discarded.
bool IsProgramVersionSupported(int64_t version); bool IsProgramVersionSupported(int64_t version);
// WARNING: DO NOT use this interface, it may be discarded.
bool IsTensorVersionSupported(uint32_t version); bool IsTensorVersionSupported(uint32_t version);
std::string DumpVersion(const int64_t version);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -19,12 +19,12 @@ namespace paddle { ...@@ -19,12 +19,12 @@ namespace paddle {
namespace framework { namespace framework {
TEST(Version, Basic) { TEST(Version, Basic) {
EXPECT_TRUE(IsProgramVersionSupported(0)); EXPECT_TRUE(IsProgramVersionSupported(0));
EXPECT_FALSE(IsProgramVersionSupported(1)); EXPECT_TRUE(IsProgramVersionSupported(1));
EXPECT_FALSE(IsProgramVersionSupported(-1)); EXPECT_TRUE(IsProgramVersionSupported(-1));
EXPECT_TRUE(IsTensorVersionSupported(0)); EXPECT_TRUE(IsTensorVersionSupported(0));
EXPECT_FALSE(IsTensorVersionSupported(1)); EXPECT_TRUE(IsTensorVersionSupported(1));
EXPECT_FALSE(IsTensorVersionSupported(-1)); EXPECT_TRUE(IsTensorVersionSupported(-1));
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -50,7 +50,7 @@ else(WITH_NGRAPH) ...@@ -50,7 +50,7 @@ else(WITH_NGRAPH)
cc_library(paddle_pass_builder SRCS paddle_pass_builder.cc) cc_library(paddle_pass_builder SRCS paddle_pass_builder.cc)
endif(WITH_NGRAPH) endif(WITH_NGRAPH)
cc_library(analysis_predictor SRCS analysis_predictor.cc ${mkldnn_quantizer_src} DEPS paddle_inference_api zero_copy_tensor cc_library(analysis_predictor SRCS analysis_predictor.cc ${mkldnn_quantizer_src} DEPS paddle_inference_api zero_copy_tensor
reset_tensor_array analysis_config paddle_pass_builder ir_pass_manager ${inference_deps}) reset_tensor_array analysis_config paddle_pass_builder ir_pass_manager op_compatible_info ${inference_deps})
cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS
lod_tensor scope paddle_pass_builder reset_tensor_array analysis_config lod_tensor scope paddle_pass_builder reset_tensor_array analysis_config
paddle_pass_builder zero_copy_tensor paddle_pass_builder zero_copy_tensor
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <algorithm> #include <algorithm>
#include <fstream> #include <fstream>
#include <memory> #include <memory>
#include <set>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -27,6 +28,7 @@ ...@@ -27,6 +28,7 @@
#include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_type_traits.h" #include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h" #include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
#include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/helper.h"
...@@ -142,6 +144,10 @@ bool AnalysisPredictor::PrepareProgram( ...@@ -142,6 +144,10 @@ bool AnalysisPredictor::PrepareProgram(
// If config_.ir_optim() is False, parameters is loaded in LoadParameters(), // If config_.ir_optim() is False, parameters is loaded in LoadParameters(),
// still need to create other persistable variables. // still need to create other persistable variables.
// So in both case, create persistable variables at first. // So in both case, create persistable variables at first.
if (!CheckOperatorCompatible()) {
LOG(WARNING) << "WARNING: Results may be DIFF! "
"Using same versions between model and lib.";
}
executor_->CreateVariables(*inference_program_, 0, true, sub_scope_); executor_->CreateVariables(*inference_program_, 0, true, sub_scope_);
// if enable_ir_optim_ is false, // if enable_ir_optim_ is false,
...@@ -823,6 +829,37 @@ std::string AnalysisPredictor::GetSerializedProgram() const { ...@@ -823,6 +829,37 @@ std::string AnalysisPredictor::GetSerializedProgram() const {
return inference_program_->Proto()->SerializeAsString(); return inference_program_->Proto()->SerializeAsString();
} }
bool AnalysisPredictor::CheckOperatorCompatible() {
if (!inference_program_) {
LOG(FATAL) << "Inference program version check failed because the program "
"does not exist.";
return false;
}
bool res = true;
op_compatible_map_.ReadFromProto(*inference_program_->OpCompatibleMap());
const auto &version = framework::DumpVersion(framework::kCurProgramVersion);
LOG(INFO) << "MODEL VERSION: "
<< framework::DumpVersion(inference_program_->Version());
LOG(INFO) << "PREDICTOR VERSION: " << version;
std::set<std::string> op_types;
for (size_t i = 0; i < inference_program_->Size(); ++i) {
const auto &block = inference_program_->Block(i);
for (const auto *op : block.AllOps()) {
op_types.insert(op->Type());
}
}
for (const auto type : op_types) {
auto compatible_type =
op_compatible_map_.IsRequireMiniVersion(type, version);
if (compatible_type != framework::OpCompatibleType::compatible) {
LOG(WARNING) << " - Version incompatible ("
<< static_cast<int>(compatible_type) << ") " << type;
res = false;
}
}
return res;
}
// Add SaveOptimModel // Add SaveOptimModel
void AnalysisPredictor::SaveOptimModel(const std::string &dir) { void AnalysisPredictor::SaveOptimModel(const std::string &dir) {
// save model // save model
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/op_compatible_info.h"
#include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h" #include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/details/reset_tensor_array.h" #include "paddle/fluid/inference/api/details/reset_tensor_array.h"
...@@ -111,6 +112,7 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -111,6 +112,7 @@ class AnalysisPredictor : public PaddlePredictor {
// AnalysisPredictor::ZeroRun() now. // AnalysisPredictor::ZeroRun() now.
void MkldnnPreSet(const std::vector<PaddleTensor> &inputs); void MkldnnPreSet(const std::vector<PaddleTensor> &inputs);
void MkldnnPostReset(); void MkldnnPostReset();
bool CheckOperatorCompatible();
#if PADDLE_WITH_TENSORRT #if PADDLE_WITH_TENSORRT
// When we use Paddle-TRT INT8 engine, we need to generate calibration table // When we use Paddle-TRT INT8 engine, we need to generate calibration table
...@@ -143,6 +145,7 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -143,6 +145,7 @@ class AnalysisPredictor : public PaddlePredictor {
std::shared_ptr<framework::Scope> scope_; std::shared_ptr<framework::Scope> scope_;
framework::Scope *sub_scope_{nullptr}; framework::Scope *sub_scope_{nullptr};
std::shared_ptr<framework::ProgramDesc> inference_program_; std::shared_ptr<framework::ProgramDesc> inference_program_;
framework::OpCompatibleMap op_compatible_map_;
std::vector<framework::OpDesc *> feeds_; std::vector<framework::OpDesc *> feeds_;
std::map<std::string, size_t> feed_names_; std::map<std::string, size_t> feed_names_;
// Sorted according to the idx. // Sorted according to the idx.
......
...@@ -71,6 +71,9 @@ TEST(AnalysisPredictor, analysis_on) { ...@@ -71,6 +71,9 @@ TEST(AnalysisPredictor, analysis_on) {
auto _predictor = CreatePaddlePredictor<AnalysisConfig>(config); auto _predictor = CreatePaddlePredictor<AnalysisConfig>(config);
auto* predictor = static_cast<AnalysisPredictor*>(_predictor.get()); auto* predictor = static_cast<AnalysisPredictor*>(_predictor.get());
if (predictor->inference_program_->Version() == 0) {
ASSERT_FALSE(predictor->CheckOperatorCompatible());
}
ASSERT_TRUE(predictor->scope_); ASSERT_TRUE(predictor->scope_);
ASSERT_TRUE(predictor->sub_scope_); ASSERT_TRUE(predictor->sub_scope_);
ASSERT_EQ(predictor->scope_->parent(), nullptr); ASSERT_EQ(predictor->scope_->parent(), nullptr);
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/pybind/pybind_boost_headers.h" #include "paddle/fluid/pybind/pybind_boost_headers.h"
...@@ -69,9 +70,13 @@ void BindProgramDesc(pybind11::module *m) { ...@@ -69,9 +70,13 @@ void BindProgramDesc(pybind11::module *m) {
"Fail to parse ProgramDesc from string. This could " "Fail to parse ProgramDesc from string. This could "
"be a bug of Paddle."); "be a bug of Paddle.");
}) })
.def("_version", [](pd::ProgramDesc &self) -> int64_t { .def("_set_version",
return self.Proto()->version().version(); [](pd::ProgramDesc &self, int64_t version) {
}); return self.SetVersion(version);
},
pybind11::arg("version") = pd::kCurProgramVersion)
.def("_version",
[](pd::ProgramDesc &self) -> int64_t { return self.Version(); });
} }
void BindBlockDesc(pybind11::module *m) { void BindBlockDesc(pybind11::module *m) {
......
...@@ -32,6 +32,7 @@ limitations under the License. */ ...@@ -32,6 +32,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_compatible_info.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/parallel_executor.h"
...@@ -171,6 +172,12 @@ PYBIND11_MODULE(core_noavx, m) { ...@@ -171,6 +172,12 @@ PYBIND11_MODULE(core_noavx, m) {
m.def("set_num_threads", &platform::SetNumThreads); m.def("set_num_threads", &platform::SetNumThreads);
m.def("save_op_compatible_info", [](framework::ProgramDesc &desc) {
framework::OpCompatibleMap op_compatible_map;
op_compatible_map.InitOpCompatibleMap();
return op_compatible_map.ConvertToProto(desc.OpCompatibleMap());
});
m.def( m.def(
"_append_python_callable_object_and_return_id", "_append_python_callable_object_and_return_id",
[](py::object py_obj) -> size_t { [](py::object py_obj) -> size_t {
......
...@@ -1103,6 +1103,8 @@ def save_inference_model(dirname, ...@@ -1103,6 +1103,8 @@ def save_inference_model(dirname,
prepend_feed_ops(main_program, feeded_var_names) prepend_feed_ops(main_program, feeded_var_names)
append_fetch_ops(main_program, fetch_var_names) append_fetch_ops(main_program, fetch_var_names)
main_program.desc._set_version()
paddle.fluid.core.save_op_compatible_info(main_program.desc)
with open(model_basename, "wb") as f: with open(model_basename, "wb") as f:
f.write(main_program.desc.serialize_to_string()) f.write(main_program.desc.serialize_to_string())
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册