diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 8c49a5f6de7706eadd6afc1e6e55701dbfe9304d..10d2c2c6c9172ef2025d72e1723d74c8423aed1d 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -122,6 +122,10 @@ cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor cc_library(attribute SRCS attribute.cc DEPS framework_proto boost) cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc device_context) + +cc_library(op_version_registry SRCS op_version_registry.cc DEPS framework_proto boost) +cc_test(op_version_registry_test SRCS op_version_registry_test.cc DEPS op_version_registry) + cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute glog) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_library(no_need_buffer_vars_inference SRCS no_need_buffer_vars_inference.cc DEPS attribute device_context) diff --git a/paddle/fluid/framework/op_version_registry.cc b/paddle/fluid/framework/op_version_registry.cc new file mode 100644 index 0000000000000000000000000000000000000000..11b7224e683402264573019e1541c5645a3a7514 --- /dev/null +++ b/paddle/fluid/framework/op_version_registry.cc @@ -0,0 +1,15 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_version_registry.h" diff --git a/paddle/fluid/framework/op_version_registry.h b/paddle/fluid/framework/op_version_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..214b569fb58f444d668d599ceff0383334027ec9 --- /dev/null +++ b/paddle/fluid/framework/op_version_registry.h @@ -0,0 +1,123 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace compatible { + +struct OpUpdateRecord { + enum class Type { kInvalid = 0, kModifyAttr, kNewAttr }; + Type type_; + std::string remark_; +}; + +struct ModifyAttr : OpUpdateRecord { + ModifyAttr(const std::string& name, const std::string& remark, + boost::any default_value) + : OpUpdateRecord({Type::kModifyAttr, remark}), + name_(name), + default_value_(default_value) { + // TODO(Shixiaowei02): Check the data type with proto::OpDesc. + } + + private: + std::string name_; + boost::any default_value_; +}; +struct NewAttr : OpUpdateRecord { + NewAttr(const std::string& name, const std::string& remark) + : OpUpdateRecord({Type::kNewAttr, remark}), name_(name) {} + + private: + std::string name_; +}; + +class OpVersionDesc { + public: + OpVersionDesc& ModifyAttr(const std::string& name, const std::string& remark, + boost::any default_value) { + infos_.push_back(std::shared_ptr( + new compatible::ModifyAttr(name, remark, default_value))); + return *this; + } + + OpVersionDesc& NewAttr(const std::string& name, const std::string& remark) { + infos_.push_back( + std::shared_ptr(new compatible::NewAttr(name, remark))); + return *this; + } + + private: + std::vector> infos_; +}; + +class OpVersion { + public: + OpVersion& AddCheckpoint(const std::string& note, + const OpVersionDesc& op_version_desc) { + checkpoints_.push_back(Checkpoint({note, op_version_desc})); + return *this; + } + + private: + struct Checkpoint { + std::string note_; + OpVersionDesc op_version_desc_; + }; + std::vector checkpoints_; +}; + +class OpVersionRegistrar { + public: + static OpVersionRegistrar& GetInstance() { + static OpVersionRegistrar instance; + return instance; + } + OpVersion& Register(const std::string& op_type) { + if (op_version_map_.find(op_type) != op_version_map_.end()) { + PADDLE_THROW("'%s' is registered in operator version more than once.", + op_type); + } + op_version_map_.insert({op_type, OpVersion()}); + return op_version_map_[op_type]; + } + + private: + std::unordered_map op_version_map_; + + OpVersionRegistrar() = default; + OpVersionRegistrar& operator=(const OpVersionRegistrar&) = delete; +}; + +} // namespace compatible +} // namespace framework +} // namespace paddle + +#define REGISTER_OP_VERSION(op_type) \ + static paddle::framework::compatible::OpVersion \ + RegisterOpVersion__##op_type = \ + paddle::framework::compatible::OpVersionRegistrar::GetInstance() \ + .Register(#op_type) diff --git a/paddle/fluid/framework/op_version_registry_test.cc b/paddle/fluid/framework/op_version_registry_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cb7aba4ab79482342a29b54dddd5bcba48c18371 --- /dev/null +++ b/paddle/fluid/framework/op_version_registry_test.cc @@ -0,0 +1,47 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include + +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace compatible { + +TEST(test_operator_version, test_operator_version) { + REGISTER_OP_VERSION(test__) + .AddCheckpoint( + R"ROC( + Upgrade reshape, modified one attribute [axis] and add a new attribute [size]. + )ROC", + framework::compatible::OpVersionDesc() + .ModifyAttr("axis", + "Increased from the original one method to two.", -1) + .NewAttr("size", + "In order to represent a two-dimensional rectangle, the " + "parameter size is added.")) + .AddCheckpoint( + R"ROC( + Add a new attribute [height] + )ROC", + framework::compatible::OpVersionDesc().NewAttr( + "height", + "In order to represent a two-dimensional rectangle, the " + "parameter height is added.")); +} +} // namespace compatible +} // namespace framework +} // namespace paddle