未验证 提交 85b398f1 编写于 作者: H hong 提交者: GitHub

Add op compatible information (#19910)

* add op compatible infomation; test=develop

* add enum type

* add enum type; test=develop
上级 3f021781
...@@ -223,6 +223,9 @@ endif (NOT WIN32) ...@@ -223,6 +223,9 @@ endif (NOT WIN32)
cc_library(dlpack_tensor SRCS dlpack_tensor.cc DEPS tensor dlpack) cc_library(dlpack_tensor SRCS dlpack_tensor.cc DEPS tensor dlpack)
cc_test(dlpack_tensor_test SRCS dlpack_tensor_test.cc DEPS dlpack_tensor glog) cc_test(dlpack_tensor_test SRCS dlpack_tensor_test.cc DEPS dlpack_tensor glog)
cc_library(op_compatible_info SRCS op_compatible_info DEPS string_helper)
cc_test(op_compatible_info_test SRCS op_compatible_info_test.cc DEPS op_compatible_info string_helper glog)
# Get the current working branch # Get the current working branch
execute_process( execute_process(
COMMAND git rev-parse --abbrev-ref HEAD COMMAND git rev-parse --abbrev-ref HEAD
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_compatible_info.h"
#include <iostream>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace framework {
inline std::vector<int> ConvertStr2Int(const std::string& str_text) {
auto vec_text = string::split_string<std::string>(str_text, ".");
PADDLE_ENFORCE((vec_text.size() == 2 || vec_text.size() == 3),
"Input[%s] is not a right version format [1.6 or 1.6.0]",
str_text);
std::vector<int> vec_res;
vec_res.reserve(3);
for (auto& val : vec_text) {
vec_res.emplace_back(atoi(val.c_str()));
}
if (vec_res.size() == 2) {
vec_res.emplace_back(0);
}
return vec_res;
}
/* first version >= second version return true */
inline bool CompareVersion(const std::string& str_first,
const std::string& str_second) {
auto vec_first_version = ConvertStr2Int(str_first);
auto vec_second_version = ConvertStr2Int(str_second);
// first version id
PADDLE_ENFORCE_EQ(
vec_first_version.size(), vec_second_version.size(),
"version information size not equal, first is [%d] second is [%d]",
vec_first_version.size(), vec_second_version.size());
for (size_t i = 0; i < vec_first_version.size() - 1; ++i) {
if (vec_first_version[i] != vec_second_version[i]) {
return vec_first_version[i] > vec_second_version[i];
}
}
return vec_first_version[2] >= vec_second_version[2];
}
void OpCompatibleMap::InitOpCompatibleMap() {
op_compatible_map_["sequence_pad"] = {"1.6.0", OpCompatibleType::DEFIN_NOT};
op_compatible_map_["sequence_unpad"] = {"1.6.0", OpCompatibleType::DEFIN_NOT};
op_compatible_map_["reshape2"] = {"1.6.0", OpCompatibleType::DEFIN_NOT};
op_compatible_map_["slice"] = {"1.6.0", OpCompatibleType::possible};
op_compatible_map_["expand"] = {"1.6.0", OpCompatibleType::possible};
op_compatible_map_["layer_norm"] = {"1.6.0", OpCompatibleType::bug_fix};
}
CompatibleInfo OpCompatibleMap::GetOpCompatibleInfo(std::string op_name) {
auto it = op_compatible_map_.find(op_name);
if (it != op_compatible_map_.end()) {
return it->second;
} else {
return {default_required_version_, OpCompatibleType::DEFIN_NOT};
}
}
OpCompatibleType OpCompatibleMap::IsRequireMiniVersion(
std::string op_name, std::string str_current_version) {
auto it = op_compatible_map_.find(op_name);
if (it != op_compatible_map_.end()) {
if (CompareVersion(str_current_version, it->second.required_version_)) {
return OpCompatibleType::compatible;
} else {
return it->second.compatible_type_;
}
} else {
if (CompareVersion(str_current_version, default_required_version_)) {
return OpCompatibleType::compatible;
} else {
return OpCompatibleType::DEFIN_NOT;
}
}
}
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <map>
#include <string>
#pragma once
namespace paddle {
namespace framework {
enum class OpCompatibleType {
compatible = 0, // support previous version
DEFIN_NOT = 1, // definitely can't support previous version
possible = 2, // possible can support previous version, not sure
bug_fix = 3, // bug fix, can't support previous version
precision_change = 4 // precision change, may cause difference
};
struct CompatibleInfo {
CompatibleInfo(std::string required_version, OpCompatibleType compatible_type)
: required_version_(required_version),
compatible_type_(compatible_type) {}
CompatibleInfo() {}
// op required version, previous version not support
std::string required_version_;
OpCompatibleType compatible_type_;
};
class OpCompatibleMap {
public:
OpCompatibleMap() : default_required_version_("1.5.0") {}
void InitOpCompatibleMap();
CompatibleInfo GetOpCompatibleInfo(std::string op_name);
/* IsRequireMiniVersion
* return type OpCompatibleType */
OpCompatibleType IsRequireMiniVersion(std::string op_name,
std::string current_version);
void SerializeToStr(std::string& str) {} /* NOLINT */
void UnSerialize(const std::string& str) {}
const std::string& GetDefaultRequiredVersion() {
return default_required_version_;
}
private:
std::map<std::string, CompatibleInfo> op_compatible_map_;
std::string default_required_version_;
};
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_compatible_info.h"
#include <iostream>
#include "gtest/gtest.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace framework {
TEST(test_op_compatible_info, test_op_compatible) {
auto comp_map = OpCompatibleMap();
comp_map.InitOpCompatibleMap();
auto default_req_version = comp_map.GetDefaultRequiredVersion();
auto seq_pad = comp_map.GetOpCompatibleInfo("sequence_pad");
auto reshape = comp_map.GetOpCompatibleInfo("reshape");
auto layer_norm = comp_map.GetOpCompatibleInfo("layer_norm");
auto deafult_info = comp_map.GetOpCompatibleInfo("layer_xx");
auto comp_1 = comp_map.IsRequireMiniVersion("sequence_pad", "1.5.0");
ASSERT_EQ(comp_1, OpCompatibleType::DEFIN_NOT);
auto comp_2 = comp_map.IsRequireMiniVersion("sequence_pad", "1.6.0");
ASSERT_EQ(comp_2, OpCompatibleType::compatible);
auto comp_3 = comp_map.IsRequireMiniVersion("sequence_pad", "1.6.1");
ASSERT_EQ(comp_3, OpCompatibleType::compatible);
auto comp_6 = comp_map.IsRequireMiniVersion("sequence_pad", "1.7.0");
ASSERT_EQ(comp_6, OpCompatibleType::compatible);
auto comp_7 = comp_map.IsRequireMiniVersion("sequence_pad", "0.7.0");
ASSERT_EQ(comp_7, OpCompatibleType::DEFIN_NOT);
auto comp_8 = comp_map.IsRequireMiniVersion("sequence_pad", "2.0.0");
ASSERT_EQ(comp_8, OpCompatibleType::compatible);
ASSERT_EQ(comp_map.IsRequireMiniVersion("unkop", "2.0.0"),
OpCompatibleType::compatible);
ASSERT_EQ(comp_map.IsRequireMiniVersion("unkop", "0.7.0"),
OpCompatibleType::DEFIN_NOT);
ASSERT_EQ(comp_map.IsRequireMiniVersion("slice", "0.7.0"),
OpCompatibleType::possible);
ASSERT_EQ(comp_map.IsRequireMiniVersion("slice", "1.6.0"),
OpCompatibleType::compatible);
}
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册