未验证 提交 21a63f6f 编写于 作者: 石晓伟 提交者: GitHub

enhance the op_version_registry, test=develop (#28347)

* enhance the op_version_registry, test=develop

* add unittests, test=develop

* enhance the op_version_registry, test=develop

* fix bugs, test=develop

* revert pybind_boost_headers.h, test=develop

* fix a attribute bug, test=develop
上级 c1c3e217
...@@ -23,9 +23,9 @@ function(pass_library TARGET DEST) ...@@ -23,9 +23,9 @@ function(pass_library TARGET DEST)
cmake_parse_arguments(pass_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(pass_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if(pass_library_DIR) if(pass_library_DIR)
cc_library(${TARGET} SRCS ${pass_library_DIR}/${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${pass_library_DEPS}) cc_library(${TARGET} SRCS ${pass_library_DIR}/${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base op_version_registry ${pass_library_DEPS})
else() else()
cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${pass_library_DEPS}) cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base op_version_registry ${pass_library_DEPS})
endif() endif()
# add more DEST here, such as train, dist and collect USE_PASS into a file automatically. # add more DEST here, such as train, dist and collect USE_PASS into a file automatically.
......
...@@ -13,3 +13,75 @@ See the License for the specific language governing permissions and ...@@ -13,3 +13,75 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace compatible {
namespace {
template <OpUpdateType type__, typename InfoType>
OpUpdate<InfoType, type__>* new_update(InfoType&& info) {
return new OpUpdate<InfoType, type__>(info);
}
}
OpVersionDesc&& OpVersionDesc::ModifyAttr(const std::string& name,
const std::string& remark,
const OpAttrVariantT& default_value) {
infos_.emplace_back(new_update<OpUpdateType::kModifyAttr>(
OpAttrInfo(name, remark, default_value)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::NewAttr(const std::string& name,
const std::string& remark,
const OpAttrVariantT& default_value) {
infos_.emplace_back(new_update<OpUpdateType::kNewAttr>(
OpAttrInfo(name, remark, default_value)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::NewInput(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
new_update<OpUpdateType::kNewInput>(OpInputOutputInfo(name, remark)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::NewOutput(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
new_update<OpUpdateType::kNewOutput>(OpInputOutputInfo(name, remark)));
return std::move(*this);
}
OpVersionDesc&& OpVersionDesc::BugfixWithBehaviorChanged(
const std::string& remark) {
infos_.emplace_back(new_update<OpUpdateType::kBugfixWithBehaviorChanged>(
OpBugfixInfo(remark)));
return std::move(*this);
}
OpVersion& OpVersionRegistrar::Register(const std::string& op_type) {
PADDLE_ENFORCE_EQ(
op_version_map_.find(op_type), op_version_map_.end(),
platform::errors::AlreadyExists(
"'%s' is registered in operator version more than once.", op_type));
op_version_map_.insert(
std::pair<std::string, OpVersion>{op_type, OpVersion()});
return op_version_map_[op_type];
}
uint32_t OpVersionRegistrar::version_id(const std::string& op_type) const {
PADDLE_ENFORCE_NE(
op_version_map_.count(op_type), 0,
platform::errors::InvalidArgument(
"The version of operator type %s has not been registered.", op_type));
return op_version_map_.find(op_type)->second.version_id();
}
// Provide a fake registration item for pybind testing.
#include "paddle/fluid/framework/op_version_registry.inl"
} // namespace compatible
} // namespace framework
} // namespace paddle
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <boost/any.hpp> #include <boost/variant.hpp>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/op_version_proto.h" #include "paddle/fluid/framework/op_version_proto.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -29,160 +29,173 @@ namespace paddle { ...@@ -29,160 +29,173 @@ namespace paddle {
namespace framework { namespace framework {
namespace compatible { namespace compatible {
struct OpUpdateRecord { using OpAttrVariantT =
enum class Type { boost::variant<bool, /* AttrType::BOOL */
kInvalid = 0, float, /* AttrType::FLOAT */
kModifyAttr, int32_t, /* AttrType::INT */
kNewAttr, int64_t, /* AttrType::LONG*/
kNewInput, std::string, /* AttrType::STRING */
kNewOutput, std::vector<bool>, /* AttrType::BOOLS */
kBugfixWithBehaviorChanged, std::vector<float>, /* AttrType::FLOATS */
}; std::vector<int32_t>, /* AttrType::INTS */
Type type_; std::vector<int64_t>, /* AttrType::LONGS */
std::string remark_; std::vector<std::string> /* AttrType::STRINGS */
>;
struct OpUpdateInfo {
virtual ~OpUpdateInfo() = default;
}; };
struct ModifyAttr : OpUpdateRecord { struct OpAttrInfo : OpUpdateInfo {
ModifyAttr(const std::string& name, const std::string& remark, OpAttrInfo(const std::string& name, const std::string& remark,
const boost::any& default_value) const OpAttrVariantT& default_value)
: OpUpdateRecord({Type::kModifyAttr, remark}), : name_{name}, default_value_{default_value}, remark_{remark} {}
name_(name),
default_value_(default_value) { const std::string& name() const { return name_; }
// TODO(Shixiaowei02): Check the data type with proto::OpDesc. const OpAttrVariantT& default_value() const { return default_value_; }
} const std::string& remark() const { return remark_; }
private: private:
std::string name_; std::string name_;
boost::any default_value_; OpAttrVariantT default_value_;
std::string remark_;
}; };
struct NewAttr : OpUpdateRecord { struct OpInputOutputInfo : OpUpdateInfo {
NewAttr(const std::string& name, const std::string& remark, OpInputOutputInfo(const std::string& name, const std::string& remark)
const boost::any& default_value) : name_{name}, remark_{remark} {}
: OpUpdateRecord({Type::kNewAttr, remark}),
name_(name), const std::string& name() const { return name_; }
default_value_(default_value) {} const std::string& remark() const { return remark_; }
private: private:
std::string name_; std::string name_;
boost::any default_value_; std::string remark_;
}; };
struct NewInput : OpUpdateRecord { struct OpBugfixInfo : OpUpdateInfo {
NewInput(const std::string& name, const std::string& remark) explicit OpBugfixInfo(const std::string& remark) : remark_{remark} {}
: OpUpdateRecord({Type::kNewInput, remark}), name_(name) {} const std::string& remark() const { return remark_; }
private: private:
std::string name_; std::string remark_;
}; };
struct NewOutput : OpUpdateRecord { enum class OpUpdateType {
NewOutput(const std::string& name, const std::string& remark) kInvalid = 0,
: OpUpdateRecord({Type::kNewOutput, remark}), name_(name) {} kModifyAttr,
kNewAttr,
kNewInput,
kNewOutput,
kBugfixWithBehaviorChanged,
};
private: class OpUpdateBase {
std::string name_; public:
virtual const OpUpdateInfo* info() const = 0;
virtual OpUpdateType type() const = 0;
virtual ~OpUpdateBase() = default;
}; };
struct BugfixWithBehaviorChanged : OpUpdateRecord { template <typename InfoType, OpUpdateType type__>
explicit BugfixWithBehaviorChanged(const std::string& remark) class OpUpdate : public OpUpdateBase {
: OpUpdateRecord({Type::kBugfixWithBehaviorChanged, remark}) {} public:
explicit OpUpdate(const InfoType& info) : info_{info}, type_{type__} {}
const OpUpdateInfo* info() const override { return &info_; }
OpUpdateType type() const override { return type_; }
private:
InfoType info_;
OpUpdateType type_;
}; };
class OpVersionDesc { class OpVersionDesc {
public: public:
OpVersionDesc& ModifyAttr(const std::string& name, const std::string& remark, OpVersionDesc&& ModifyAttr(const std::string& name, const std::string& remark,
boost::any default_value) { const OpAttrVariantT& default_value);
infos_.push_back(std::shared_ptr<OpUpdateRecord>( OpVersionDesc&& NewAttr(const std::string& name, const std::string& remark,
new compatible::ModifyAttr(name, remark, default_value))); const OpAttrVariantT& default_value);
return *this; OpVersionDesc&& NewInput(const std::string& name, const std::string& remark);
OpVersionDesc&& NewOutput(const std::string& name, const std::string& remark);
OpVersionDesc&& BugfixWithBehaviorChanged(const std::string& remark);
const std::vector<std::unique_ptr<OpUpdateBase>>& infos() const {
return infos_;
} }
OpVersionDesc& NewAttr(const std::string& name, const std::string& remark, OpVersionDesc() = default;
boost::any default_value) { OpVersionDesc(OpVersionDesc&&) = default;
infos_.push_back(std::shared_ptr<OpUpdateRecord>( OpVersionDesc& operator=(OpVersionDesc&&) = default;
new compatible::NewAttr(name, remark, default_value)));
return *this;
}
OpVersionDesc& NewInput(const std::string& name, const std::string& remark) { private:
infos_.push_back(std::shared_ptr<OpUpdateRecord>( std::vector<std::unique_ptr<OpUpdateBase>> infos_;
new compatible::NewInput(name, remark))); };
return *this;
}
OpVersionDesc& NewOutput(const std::string& name, const std::string& remark) { class OpCheckpoint {
infos_.push_back(std::shared_ptr<OpUpdateRecord>( public:
new compatible::NewOutput(name, remark))); OpCheckpoint(const std::string& note, OpVersionDesc&& op_version_desc)
return *this; : note_{note},
} op_version_desc_{std::forward<OpVersionDesc>(op_version_desc)} {}
const std::string& note() const { return note_; }
const OpVersionDesc& version_desc() { return op_version_desc_; }
OpVersionDesc& BugfixWithBehaviorChanged(const std::string& remark) { OpCheckpoint() = default;
infos_.push_back(std::shared_ptr<OpUpdateRecord>( OpCheckpoint(OpCheckpoint&&) = default;
new compatible::BugfixWithBehaviorChanged(remark))); OpCheckpoint& operator=(OpCheckpoint&&) = default;
return *this;
}
private: private:
std::vector<std::shared_ptr<OpUpdateRecord>> infos_; std::string note_;
OpVersionDesc op_version_desc_;
}; };
class OpVersion { class OpVersion {
public: public:
OpVersion& AddCheckpoint(const std::string& note, OpVersion& AddCheckpoint(const std::string& note,
const OpVersionDesc& op_version_desc) { OpVersionDesc&& op_version_desc) {
checkpoints_.push_back(Checkpoint({note, op_version_desc})); checkpoints_.emplace_back(OpCheckpoint{note, std::move(op_version_desc)});
return *this; return *this;
} }
uint32_t GetVersionID() const { uint32_t version_id() const {
return static_cast<uint32_t>(checkpoints_.size()); return static_cast<uint32_t>(checkpoints_.size());
} }
const std::vector<OpCheckpoint>& checkpoints() const { return checkpoints_; }
OpVersion() = default;
OpVersion(OpVersion&&) = default;
OpVersion& operator=(OpVersion&&) = default;
private: private:
struct Checkpoint { std::vector<OpCheckpoint> checkpoints_;
std::string note_;
OpVersionDesc op_version_desc_;
};
std::vector<Checkpoint> checkpoints_;
}; };
class OpVersionRegistrar { class OpVersionRegistrar {
public: public:
OpVersionRegistrar() = default;
static OpVersionRegistrar& GetInstance() { static OpVersionRegistrar& GetInstance() {
static OpVersionRegistrar instance; static OpVersionRegistrar instance;
return instance; return instance;
} }
OpVersion& Register(const std::string& op_type) { OpVersion& Register(const std::string& op_type);
PADDLE_ENFORCE_EQ(
op_version_map_.find(op_type), op_version_map_.end(),
platform::errors::AlreadyExists(
"'%s' is registered in operator version more than once.", op_type));
op_version_map_.insert({op_type, OpVersion()});
return op_version_map_[op_type];
}
const std::unordered_map<std::string, OpVersion>& GetVersionMap() { const std::unordered_map<std::string, OpVersion>& GetVersionMap() {
return op_version_map_; return op_version_map_;
} }
uint32_t GetVersionID(const std::string& op_type) const { bool Has(const std::string& op_type) const {
auto it = op_version_map_.find(op_type); return op_version_map_.count(op_type);
if (it == op_version_map_.end()) {
return 0;
}
return it->second.GetVersionID();
} }
uint32_t version_id(const std::string& op_type) const;
private: private:
std::unordered_map<std::string, OpVersion> op_version_map_; std::unordered_map<std::string, OpVersion> op_version_map_;
OpVersionRegistrar() = default;
OpVersionRegistrar& operator=(const OpVersionRegistrar&) = delete;
}; };
inline const std::unordered_map<std::string, OpVersion>& get_op_version_map() {
return OpVersionRegistrar::GetInstance().GetVersionMap();
}
inline void SaveOpVersions( inline void SaveOpVersions(
const std::unordered_map<std::string, OpVersion>& src, const std::unordered_map<std::string, OpVersion>& src,
pb::OpVersionMap* dst) { pb::OpVersionMap* dst) {
for (const auto& pair : src) { for (const auto& pair : src) {
(*dst)[pair.first].SetVersionID(pair.second.GetVersionID()); (*dst)[pair.first].SetVersionID(pair.second.version_id());
} }
} }
...@@ -192,21 +205,24 @@ class OpVersionComparator { ...@@ -192,21 +205,24 @@ class OpVersionComparator {
virtual ~OpVersionComparator() = default; virtual ~OpVersionComparator() = default;
}; };
#define ADD_OP_VERSION_COMPARATOR(cmp_name, cmp_math) \ #define ADD_OP_VERSION_COMPARATOR(cmp_name, cmp_math) \
class OpVersion##cmp_name##Comparator : public OpVersionComparator { \ class OpVersion##cmp_name##Comparator : public OpVersionComparator { \
public: \ public: \
explicit OpVersion##cmp_name##Comparator(const std::string op_name, \ explicit OpVersion##cmp_name##Comparator(const std::string op_name, \
uint32_t target_version) \ uint32_t target_version) \
: op_name_(op_name), target_version_(target_version) {} \ : op_name_(op_name), target_version_(target_version) {} \
virtual bool operator()() { \ virtual bool operator()() { \
return OpVersionRegistrar::GetInstance().GetVersionID(op_name_) \ uint32_t version_id = 0; \
cmp_math target_version_; \ if (OpVersionRegistrar::GetInstance().Has(op_name_)) { \
} \ version_id = OpVersionRegistrar::GetInstance().version_id(op_name_); \
virtual ~OpVersion##cmp_name##Comparator() {} \ } \
\ return version_id cmp_math target_version_; \
private: \ } \
std::string op_name_; \ virtual ~OpVersion##cmp_name##Comparator() {} \
uint32_t target_version_; \ \
private: \
std::string op_name_; \
uint32_t target_version_; \
}; };
ADD_OP_VERSION_COMPARATOR(LE, <=); ADD_OP_VERSION_COMPARATOR(LE, <=);
...@@ -310,7 +326,7 @@ class PassVersionCheckerRegistrar { ...@@ -310,7 +326,7 @@ class PassVersionCheckerRegistrar {
} // namespace paddle } // namespace paddle
#define REGISTER_OP_VERSION(op_type) \ #define REGISTER_OP_VERSION(op_type) \
static paddle::framework::compatible::OpVersion \ UNUSED static paddle::framework::compatible::OpVersion& \
RegisterOpVersion__##op_type = \ RegisterOpVersion__##op_type = \
paddle::framework::compatible::OpVersionRegistrar::GetInstance() \ paddle::framework::compatible::OpVersionRegistrar::GetInstance() \
.Register(#op_type) .Register(#op_type)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
REGISTER_OP_VERSION(for_pybind_test__)
.AddCheckpoint("Note 0", framework::compatible::OpVersionDesc()
.BugfixWithBehaviorChanged(
"BugfixWithBehaviorChanged Remark"))
.AddCheckpoint("Note 1", framework::compatible::OpVersionDesc()
.ModifyAttr("BOOL", "bool", true)
.ModifyAttr("FLOAT", "float", 1.23f)
.ModifyAttr("INT", "int32", -1)
.ModifyAttr("STRING", "std::string",
std::string{"hello"}))
.AddCheckpoint("Note 2",
framework::compatible::OpVersionDesc()
.ModifyAttr("BOOLS", "std::vector<bool>",
std::vector<bool>{true, false})
.ModifyAttr("FLOATS", "std::vector<float>",
std::vector<float>{2.56f, 1.28f})
.ModifyAttr("INTS", "std::vector<int32>",
std::vector<int32_t>{10, 100})
.NewAttr("LONGS", "std::vector<int64>",
std::vector<int64_t>{10000001, -10000001}))
.AddCheckpoint("Note 3", framework::compatible::OpVersionDesc()
.NewAttr("STRINGS", "std::vector<std::string>",
std::vector<std::string>{"str1", "str2"})
.ModifyAttr("LONG", "int64", static_cast<int64_t>(10000001))
.NewInput("NewInput", "NewInput_")
.NewOutput("NewOutput", "NewOutput_")
.BugfixWithBehaviorChanged(
"BugfixWithBehaviorChanged_"));
...@@ -21,7 +21,7 @@ namespace framework { ...@@ -21,7 +21,7 @@ namespace framework {
namespace compatible { namespace compatible {
TEST(test_operator_version, test_operator_version) { TEST(test_operator_version, test_operator_version) {
REGISTER_OP_VERSION(test__) REGISTER_OP_VERSION(op_name__)
.AddCheckpoint( .AddCheckpoint(
R"ROC(Fix the bug of reshape op, support the case of axis < 0)ROC", R"ROC(Fix the bug of reshape op, support the case of axis < 0)ROC",
framework::compatible::OpVersionDesc().BugfixWithBehaviorChanged( framework::compatible::OpVersionDesc().BugfixWithBehaviorChanged(
...@@ -56,6 +56,7 @@ TEST(test_operator_version, test_operator_version) { ...@@ -56,6 +56,7 @@ TEST(test_operator_version, test_operator_version) {
} }
TEST(test_pass_op_version_checker, test_pass_op_version_checker) { TEST(test_pass_op_version_checker, test_pass_op_version_checker) {
const std::string fake_op_name{"op_name__"};
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible( ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"no_bind_pass")); "no_bind_pass"));
...@@ -90,7 +91,7 @@ TEST(test_pass_op_version_checker, test_pass_op_version_checker) { ...@@ -90,7 +91,7 @@ TEST(test_pass_op_version_checker, test_pass_op_version_checker) {
REGISTER_PASS_CAPABILITY(test_pass4) REGISTER_PASS_CAPABILITY(test_pass4)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.GE("test__", 5) .GE(fake_op_name, 5)
.EQ("fc", 0)); .EQ("fc", 0));
ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible( ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass4")); "test_pass4"));
...@@ -98,7 +99,7 @@ TEST(test_pass_op_version_checker, test_pass_op_version_checker) { ...@@ -98,7 +99,7 @@ TEST(test_pass_op_version_checker, test_pass_op_version_checker) {
REGISTER_PASS_CAPABILITY(test_pass5) REGISTER_PASS_CAPABILITY(test_pass5)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.GE("test__", 4) .GE(fake_op_name, 4)
.EQ("fc", 0)); .EQ("fc", 0));
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible( ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass5")); "test_pass5"));
...@@ -106,7 +107,7 @@ TEST(test_pass_op_version_checker, test_pass_op_version_checker) { ...@@ -106,7 +107,7 @@ TEST(test_pass_op_version_checker, test_pass_op_version_checker) {
REGISTER_PASS_CAPABILITY(test_pass6) REGISTER_PASS_CAPABILITY(test_pass6)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("test__", 4) .EQ(fake_op_name, 4)
.EQ("fc", 0)); .EQ("fc", 0));
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible( ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass6")); "test_pass6"));
...@@ -114,7 +115,7 @@ TEST(test_pass_op_version_checker, test_pass_op_version_checker) { ...@@ -114,7 +115,7 @@ TEST(test_pass_op_version_checker, test_pass_op_version_checker) {
REGISTER_PASS_CAPABILITY(test_pass7) REGISTER_PASS_CAPABILITY(test_pass7)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.NE("test__", 4) .NE(fake_op_name, 4)
.EQ("fc", 0)); .EQ("fc", 0));
ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible( ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass7")); "test_pass7"));
......
...@@ -104,6 +104,7 @@ endif() ...@@ -104,6 +104,7 @@ endif()
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} device_memory_aligment) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} device_memory_aligment)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} layer) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} layer)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} tensor_formatter) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} tensor_formatter)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} op_version_registry)
# FIXME(typhoonzero): operator deps may not needed. # FIXME(typhoonzero): operator deps may not needed.
# op_library(lod_tensor_to_array_op DEPS lod_rank_table_op) # op_library(lod_tensor_to_array_op DEPS lod_rank_table_op)
......
set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper prune set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper prune
feed_fetch_method pass_builder parallel_executor profiler layer tracer engine scope_pool feed_fetch_method pass_builder parallel_executor profiler layer tracer engine scope_pool
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
gloo_wrapper infer_io_utils heter_wrapper generator) gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry)
if (WITH_NCCL) if (WITH_NCCL)
set(PYBIND_DEPS ${PYBIND_DEPS} nccl_wrapper) set(PYBIND_DEPS ${PYBIND_DEPS} nccl_wrapper)
......
...@@ -13,26 +13,136 @@ ...@@ -13,26 +13,136 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/pybind/compatible.h" #include "paddle/fluid/pybind/compatible.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/pybind/pybind_boost_headers.h"
namespace py = pybind11; namespace py = pybind11;
using paddle::framework::compatible::PassVersionCheckerRegistrar; using paddle::framework::compatible::OpAttrVariantT;
using paddle::framework::compatible::OpUpdateInfo;
using paddle::framework::compatible::OpAttrInfo;
using paddle::framework::compatible::OpInputOutputInfo;
using paddle::framework::compatible::OpBugfixInfo;
using paddle::framework::compatible::OpUpdateType;
using paddle::framework::compatible::OpUpdateBase;
using paddle::framework::compatible::OpVersionDesc;
using paddle::framework::compatible::OpCheckpoint;
using paddle::framework::compatible::OpVersion;
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
void BindCompatible(py::module* m) { namespace {
using paddle::framework::compatible::PassVersionCheckerRegistrar;
void BindPassVersionChecker(py::module *m) {
py::class_<PassVersionCheckerRegistrar>(*m, "PassVersionChecker") py::class_<PassVersionCheckerRegistrar>(*m, "PassVersionChecker")
.def_static("IsCompatible", [](const std::string& name) -> bool { .def_static("IsCompatible", [](const std::string &name) -> bool {
auto instance = PassVersionCheckerRegistrar::GetInstance(); auto instance = PassVersionCheckerRegistrar::GetInstance();
return instance.IsPassCompatible(name); return instance.IsPassCompatible(name);
}); });
} }
void BindPassCompatible(py::module *m) { BindPassVersionChecker(m); }
void BindOpUpdateInfo(py::module *m) {
py::class_<OpUpdateInfo>(*m, "OpUpdateInfo").def(py::init<>());
}
void BindOpAttrInfo(py::module *m) {
py::class_<OpAttrInfo, OpUpdateInfo>(*m, "OpAttrInfo")
.def(py::init<const std::string &, const std::string &,
const OpAttrVariantT &>())
.def(py::init<const OpAttrInfo &>())
.def("name", &OpAttrInfo::name)
.def("default_value", &OpAttrInfo::default_value)
.def("remark", &OpAttrInfo::remark);
}
void BindOpInputOutputInfo(py::module *m) {
py::class_<OpInputOutputInfo, OpUpdateInfo>(*m, "OpInputOutputInfo")
.def(py::init<const std::string &, const std::string &>())
.def(py::init<const OpInputOutputInfo &>())
.def("name", &OpInputOutputInfo::name)
.def("remark", &OpInputOutputInfo::remark);
}
void BindOpBugfixInfo(py::module *m) {
py::class_<OpBugfixInfo, OpUpdateInfo>(*m, "OpBugfixInfo")
.def(py::init<const std::string &>())
.def(py::init<const OpBugfixInfo &>())
.def("remark", &OpBugfixInfo::remark);
}
void BindOpCompatible(py::module *m) {
BindOpUpdateInfo(m);
BindOpAttrInfo(m);
BindOpInputOutputInfo(m);
BindOpBugfixInfo(m);
}
void BindOpUpdateType(py::module *m) {
py::enum_<OpUpdateType>(*m, "OpUpdateType")
.value("kInvalid", OpUpdateType::kInvalid)
.value("kModifyAttr", OpUpdateType::kModifyAttr)
.value("kNewAttr", OpUpdateType::kNewAttr)
.value("kNewInput", OpUpdateType::kNewInput)
.value("kNewOutput", OpUpdateType::kNewOutput)
.value("kBugfixWithBehaviorChanged",
OpUpdateType::kBugfixWithBehaviorChanged);
}
void BindOpUpdateBase(py::module *m) {
py::class_<OpUpdateBase>(*m, "OpUpdateBase")
.def("info", [](const OpUpdateBase &obj) { return obj.info(); },
py::return_value_policy::reference)
.def("type", &OpUpdateBase::type);
}
void BindOpVersionDesc(py::module *m) {
py::class_<OpVersionDesc>(*m, "OpVersionDesc")
// Pybind11 does not yet support the transfer of `const
// std::vector<std::unique_ptr<T>>&` type objects.
.def("infos", [](const OpVersionDesc &obj) {
auto pylist = py::list();
for (const auto &ptr : obj.infos()) {
auto pyobj = py::cast(*ptr, py::return_value_policy::reference);
pylist.append(pyobj);
}
return pylist;
});
}
void BindOpCheckpoint(py::module *m) {
py::class_<OpCheckpoint>(*m, "OpCheckpoint")
.def("note", &OpCheckpoint::note, py::return_value_policy::reference)
.def("version_desc", &OpCheckpoint::version_desc,
py::return_value_policy::reference);
}
void BindOpVersion(py::module *m) {
py::class_<OpVersion>(*m, "OpVersion")
.def("version_id", &OpVersion::version_id,
py::return_value_policy::reference)
.def("checkpoints", &OpVersion::checkpoints,
py::return_value_policy::reference);
// At least pybind v2.3.0 is required because of bug #1603 of pybind11.
m->def("get_op_version_map", &framework::compatible::get_op_version_map,
py::return_value_policy::reference);
}
} // namespace
void BindCompatible(py::module *m) {
BindPassCompatible(m);
BindOpCompatible(m);
BindOpUpdateType(m);
BindOpUpdateBase(m);
BindOpVersionDesc(m);
BindOpCheckpoint(m);
BindOpVersion(m);
}
} // namespace pybind } // namespace pybind
} // namespace paddle } // namespace paddle
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.utils as utils
import paddle.fluid as fluid
class OpLastCheckpointCheckerTest(unittest.TestCase):
def __init__(self, methodName='runTest'):
super(OpLastCheckpointCheckerTest, self).__init__(methodName)
self.checker = utils.OpLastCheckpointChecker()
self.fake_op = 'for_pybind_test__'
def test_op_attr_info(self):
update_type = fluid.core.OpUpdateType.kNewAttr
info_list = self.checker.filter_updates(self.fake_op, update_type,
'STRINGS')
self.assertTrue(info_list)
self.assertEqual(info_list[0].name(), 'STRINGS')
self.assertEqual(info_list[0].default_value(), ['str1', 'str2'])
self.assertEqual(info_list[0].remark(), 'std::vector<std::string>')
def test_op_input_output_info(self):
update_type = fluid.core.OpUpdateType.kNewInput
info_list = self.checker.filter_updates(self.fake_op, update_type,
'NewInput')
self.assertTrue(info_list)
self.assertEqual(info_list[0].name(), 'NewInput')
self.assertEqual(info_list[0].remark(), 'NewInput_')
def test_op_bug_fix_info(self):
update_type = fluid.core.OpUpdateType.kBugfixWithBehaviorChanged
info_list = self.checker.filter_updates(self.fake_op, update_type)
self.assertTrue(info_list)
self.assertEqual(info_list[0].remark(), 'BugfixWithBehaviorChanged_')
class OpVersionTest(unittest.TestCase):
def __init__(self, methodName='runTest'):
super(OpVersionTest, self).__init__(methodName)
self.vmap = fluid.core.get_op_version_map()
self.fake_op = 'for_pybind_test__'
def test_checkpoints(self):
version_id = self.vmap[self.fake_op].version_id()
checkpoints = self.vmap[self.fake_op].checkpoints()
self.assertEqual(version_id, 4)
self.assertEqual(len(checkpoints), 4)
self.assertEqual(checkpoints[2].note(), 'Note 2')
desc_1 = checkpoints[1].version_desc().infos()
self.assertEqual(desc_1[0].info().default_value(), True)
self.assertAlmostEqual(desc_1[1].info().default_value(), 1.23, 2)
self.assertEqual(desc_1[2].info().default_value(), -1)
self.assertEqual(desc_1[3].info().default_value(), 'hello')
desc_2 = checkpoints[2].version_desc().infos()
self.assertEqual(desc_2[0].info().default_value(), [True, False])
true_l = [2.56, 1.28]
self.assertEqual(len(true_l), len(desc_2[1].info().default_value()))
for i in range(len(true_l)):
self.assertAlmostEqual(desc_2[1].info().default_value()[i],
true_l[i], 2)
self.assertEqual(desc_2[2].info().default_value(), [10, 100])
self.assertEqual(desc_2[3].info().default_value(),
[10000001, -10000001])
if __name__ == '__main__':
unittest.main()
...@@ -17,6 +17,7 @@ from .profiler import Profiler ...@@ -17,6 +17,7 @@ from .profiler import Profiler
from .profiler import get_profiler from .profiler import get_profiler
from .deprecated import deprecated from .deprecated import deprecated
from .lazy_import import try_import from .lazy_import import try_import
from .op_version import OpLastCheckpointChecker
from .install_check import run_check from .install_check import run_check
from ..fluid.framework import unique_name from ..fluid.framework import unique_name
from ..fluid.framework import load_op_library from ..fluid.framework import load_op_library
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..fluid import core
__all__ = ['OpLastCheckpointChecker']
def Singleton(cls):
_instance = {}
def _singleton(*args, **kargs):
if cls not in _instance:
_instance[cls] = cls(*args, **kargs)
return _instance[cls]
return _singleton
class OpUpdateInfoHelper(object):
def __init__(self, info):
self._info = info
def verify_key_value(self, name=''):
result = False
key_funcs = {
core.OpAttrInfo: 'name',
core.OpInputOutputInfo: 'name',
}
if name == '':
result = True
elif type(self._info) in key_funcs:
if getattr(self._info, key_funcs[type(self._info)])() == name:
result = True
return result
@Singleton
class OpLastCheckpointChecker(object):
def __init__(self):
self.raw_version_map = core.get_op_version_map()
self.checkpoints_map = {}
self._construct_map()
def _construct_map(self):
for op_name in self.raw_version_map:
last_checkpoint = self.raw_version_map[op_name].checkpoints()[-1]
infos = last_checkpoint.version_desc().infos()
self.checkpoints_map[op_name] = infos
def filter_updates(self, op_name, type=core.OpUpdateType.kInvalid, key=''):
updates = []
if op_name in self.checkpoints_map:
for update in self.checkpoints_map[op_name]:
if (update.type() == type) or (
type == core.OpUpdateType.kInvalid):
if OpUpdateInfoHelper(update.info()).verify_key_value(key):
updates.append(update.info())
return updates
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册