diff --git a/cmake/cfg.cmake b/cmake/cfg.cmake index 702bc84686373a4d47e23bb2df8faa6c4fa7f4c3..6d7f610d6cc62ae065218f0cd757870c68e59474 100644 --- a/cmake/cfg.cmake +++ b/cmake/cfg.cmake @@ -60,6 +60,7 @@ function(GENERATE_CFG_AND_PYBIND11_CPP SRCS HDRS PYBIND_SRCS ROOT_DIR) oneflow/core/common/cfg_reflection_test.proto oneflow/core/common/data_type.proto oneflow/core/common/device_type.proto + oneflow/core/common/demo.proto ) set(of_cfg_proto_python_dir "${PROJECT_BINARY_DIR}/of_cfg_proto_python") diff --git a/oneflow/core/common/cfg_demo_test.cpp b/oneflow/core/common/cfg_demo_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..76f4220cdf436ffc1e385f92fe29a4746d378bbc --- /dev/null +++ b/oneflow/core/common/cfg_demo_test.cpp @@ -0,0 +1,108 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/common/demo.cfg.h" +#include "oneflow/core/common/demo.pb.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/common/protobuf.h" +#include "oneflow/core/common/cfg.h" +#include +using namespace std; +namespace oneflow { +namespace test { + +void CfgPassByValue(oneflow::cfg::Foo foo) { + ASSERT_EQ(foo.freeze(), true); + ASSERT_EQ(foo.bar().freeze(), true); + ASSERT_EQ(foo.bars().freeze(), true); + ASSERT_EQ(foo.bar().doo().freeze(), true); + ASSERT_EQ(foo.bar().doo().noo().freeze(), true); + ASSERT_EQ(foo.bar().doo().noos().freeze(), true); + ASSERT_EQ(foo.bar().moo().freeze(), true); + ASSERT_EQ(foo.map_bar().freeze(), true); + ASSERT_EQ(foo.bar().doo().map_noo().freeze(), true); + ASSERT_EQ(foo.bar().doo().of_noo().freeze(), true); + ASSERT_EQ(foo.bar().moo().of_noo2().freeze(), true); +} + +TEST(Demo, freeze_root_node) { + cfg::Foo foo; + cfg::Bar bar; + + foo.set_name("foo"); + foo.mutable_bar()->set_name("bar"); + foo.mutable_bar()->mutable_doo()->set_name("doo"); + foo.mutable_bar()->mutable_doo()->mutable_of_noo()->set_name("of_noo"); + foo.mutable_bar()->mutable_moo()->set_name("moo"); + foo.mutable_bar()->mutable_doo()->mutable_noo()->set_name("noo"); + foo.mutable_bar()->mutable_moo()->mutable_of_noo2()->set_name("of_noo"); + foo.set_freeze(); + + ASSERT_EQ(foo.freeze(), true); + ASSERT_EQ(foo.bar().freeze(), true); + ASSERT_EQ(foo.bars().freeze(), true); + ASSERT_EQ(foo.bar().doo().freeze(), true); + ASSERT_EQ(foo.bar().doo().noo().freeze(), true); + ASSERT_EQ(foo.bar().doo().noos().freeze(), true); + ASSERT_EQ(foo.bar().moo().freeze(), true); + ASSERT_EQ(foo.map_bar().freeze(), true); + ASSERT_EQ(foo.bar().doo().map_noo().freeze(), true); + ASSERT_EQ(foo.bar().doo().of_noo().freeze(), true); + ASSERT_EQ(foo.bar().moo().of_noo2().freeze(), true); +} + +TEST(Demo, pass_by_value) { + cfg::Foo foo; + cfg::Bar bar; + + foo.set_name("foo"); + foo.mutable_bar()->set_name("bar"); + foo.mutable_bar()->mutable_doo()->set_name("doo"); + foo.mutable_bar()->mutable_doo()->mutable_of_noo()->set_name("of_noo"); + foo.mutable_bar()->mutable_moo()->set_name("moo"); + foo.mutable_bar()->mutable_doo()->mutable_noo()->set_name("noo"); + foo.mutable_bar()->mutable_moo()->mutable_of_noo2()->set_name("of_noo"); + foo.set_freeze(); + CfgPassByValue(foo); +} + +TEST(Demo, freeze_mid_node) { + cfg::Foo foo; + cfg::Bar bar; + + foo.set_name("foo"); + foo.mutable_bar()->set_name("bar"); + foo.mutable_bar()->mutable_doo()->set_name("doo"); + foo.mutable_bar()->mutable_doo()->mutable_of_noo()->set_name("of_noo"); + foo.mutable_bar()->mutable_moo()->set_name("moo"); + foo.mutable_bar()->mutable_doo()->mutable_noo()->set_name("noo"); + foo.mutable_bar()->mutable_moo()->mutable_of_noo2()->set_name("of_noo"); + foo.mutable_bar()->mutable_doo()->set_freeze(); + + ASSERT_EQ(foo.freeze(), false); + ASSERT_EQ(foo.bar().freeze(), false); + ASSERT_EQ(foo.bars().freeze(), false); + ASSERT_EQ(foo.bar().doo().freeze(), true); + ASSERT_EQ(foo.bar().doo().noo().freeze(), true); + ASSERT_EQ(foo.bar().doo().noos().freeze(), true); + ASSERT_EQ(foo.bar().moo().freeze(), false); + ASSERT_EQ(foo.map_bar().freeze(), false); + ASSERT_EQ(foo.bar().doo().map_noo().freeze(), true); + ASSERT_EQ(foo.bar().doo().of_noo().freeze(), true); + ASSERT_EQ(foo.bar().moo().of_noo2().freeze(), false); +} + +} // namespace test +} // namespace oneflow diff --git a/oneflow/core/common/demo.proto b/oneflow/core/common/demo.proto new file mode 100644 index 0000000000000000000000000000000000000000..c15f2a68d1e648d72069b8767996e0a292d6c610 --- /dev/null +++ b/oneflow/core/common/demo.proto @@ -0,0 +1,89 @@ +syntax="proto2"; + +package oneflow; + +enum Enum { + kInvalidEnum = 0; + kEnum0 = 2; +} + +message Foo { + enum Type { + H2D = 0; + D2H = 1; + } + enum Data { + D = 3; + H = 57; + } + required Type type = 38; + required Data data = 39; + optional string name = 21 [default="unnamed"]; + required int64 int_value = 10; + optional int64 opt_int_value = 11; + required string string_value = 50; + optional string opt_string_value = 51; + repeated int64 int_values = 12; + repeated Enum enum_values = 13; + repeated string string_values = 14; + repeated Bar bars = 15; + required Bar bar = 100; + optional Bar optional_bar = 101; + + required Enum enum_value = 103; + optional Enum opt_enum_value = 104; + + oneof oneof_type { + Bar of_bar = 2; + string of_string_value = 4; + bytes of_bytes_value = 5; + int64 of_int_value = 6; + Enum of_enum_value = 7; + } + oneof oneof_expermental_type { + Bar of_expermental_bar = 8; + string of_expermental_string_value = 18; + bytes of_expermental_bytes_value = 19; + } + map map_int_int = 106; + map map_temp = 107; + map map_bar = 108; + map map_enum = 109; +} + +message Bar { + required string name = 1 [default = "undefined-name"]; + optional string nickname = 2 [default = "undefined-nickname"]; + required Doo doo = 3; + required Moo moo = 6; + oneof of_bar { + int64 of_int_value = 4; + float of_float_value = 5; + } +} + +message Doo { + required string name = 1 [default = "undefined-name"]; + optional string nickname = 2 [default = "undefined-nickname"]; + required Noo noo = 3; + repeated Noo noos = 6; + oneof oneof_doo { + Noo of_noo = 4; + int32 of_int = 5; + } + map map_noo = 108; +} + +message Moo { + required string name = 1 [default = "undefined-name"]; + optional string nickname = 2 [default = "undefined-nickname"]; + oneof oneof_noo { + Noo of_noo2 = 4; + int32 of_int = 5; + } +} + +message Noo { + required string name = 1 [default = "undefined-name"]; + optional string nickname = 2 [default = "undefined-nickname"]; +} diff --git a/tools/cfg/include/oneflow/cfg/map_field.h b/tools/cfg/include/oneflow/cfg/map_field.h index 743252ff712ed7ddc23c7fa40007f3cf2b92d901..31f3bf6aebcb24a0bbf33ed59ef3ce6e6438d74e 100644 --- a/tools/cfg/include/oneflow/cfg/map_field.h +++ b/tools/cfg/include/oneflow/cfg/map_field.h @@ -73,7 +73,11 @@ class _MapField_ { void Clear() { data_->clear(); } void CopyFrom(const _MapField_& other) { - *data_ = *other.data_; + if (other.freeze()) { + data_ = other.data_; + } else { + *data_ = *other.data_; + } } _MapField_& operator=(const _MapField_& other) { @@ -85,8 +89,19 @@ class _MapField_ { const std::shared_ptr>& __SharedPtr__() { return data_; } + bool freeze() const { + return freeze_; + } + + void set_freeze() { + if (!(std::is_scalar::value || std::is_same::value)) { + freeze_ = true; + } + } + private: std::shared_ptr> data_; + bool freeze_ = false; }; } diff --git a/tools/cfg/include/oneflow/cfg/repeated_field.h b/tools/cfg/include/oneflow/cfg/repeated_field.h index 140b2cd464000053fb805f4a4bcbfb8313633e1a..66b8185007d26ba1f9815c0eecd7e1c394721d76 100644 --- a/tools/cfg/include/oneflow/cfg/repeated_field.h +++ b/tools/cfg/include/oneflow/cfg/repeated_field.h @@ -65,8 +65,12 @@ class _RepeatedField_ { if (std::is_scalar::value || std::is_same::value) { *data_ = *other.data_; } else { - data_->clear(); - for (const auto& elem : other) { *Add() = elem; } + if (other.freeze()) { + data_ = other.data_; + } else { + data_->clear(); + for (const auto& elem : other) { *Add() = elem; } + } } } @@ -90,8 +94,17 @@ class _RepeatedField_ { const std::shared_ptr>& __SharedPtr__() { return data_; } + bool freeze() const { + return freeze_; + } + + void set_freeze() { + freeze_ = true; + } + private: std::shared_ptr> data_; + bool freeze_ = false; }; } diff --git a/tools/cfg/template/template.cfg.cpp b/tools/cfg/template/template.cfg.cpp index 918633739400ef61dbb14b09bdaee01417f15cdd..e2919f2060f2f27e0f743245202ff67d7e426eff 100644 --- a/tools/cfg/template/template.cfg.cpp +++ b/tools/cfg/template/template.cfg.cpp @@ -26,7 +26,7 @@ void Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::InitFromProt // required_or_optional field: {{ util.field_name(field) }} if (proto_{{ util.class_name(cls).lower() }}.has_{{ util.field_name(field) }}()) { {%if util.field_is_message_type(field)%} - *mutable_{{ util.field_name(field) }}() = {{ util.field_message_type_name_with_cfg_namespace(field) }}(proto_{{ util.class_name(cls).lower() }}.{{ util.field_name(field) }}()); + *mutable_{{ util.field_name(field) }}() = {{ util.field_message_type_name_with_cfg_namespace(field) }}(proto_{{ util.class_name(cls).lower() }}.{{ util.field_name(field) }}()); {% elif util.field_is_enum_type(field) %} set_{{ util.field_name(field) }}(static_cast::type>::type>(proto_{{ util.class_name(cls).lower() }}.{{ util.field_name(field) }}())); {% else %} @@ -86,7 +86,7 @@ void Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::InitFromProt break; } } - {% endfor %}{# oneofs #} + {% endfor %}{# oneofs #} } void Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::ToProto({{ util.module_package_namespace(module) }}::{{ util.class_name(cls) }}* proto_{{ util.class_name(cls).lower() }}) const { @@ -233,6 +233,7 @@ void Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::clear_{{ uti if (!{{ util.field_name(field) }}_) { {{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_type_name_with_cfg_namespace(field) }}>(); } + assert(!{{ util.field_name(field) }}_->freeze()); has_{{ util.field_name(field) }}_ = true; return {{ util.field_name(field) }}_.get(); } @@ -294,31 +295,45 @@ void Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::clear_{{ uti } return {{ util.field_name(field) }}_->Clear(); } +{% if util.field_is_message_type(field) %} +{{ util.field_type_name_with_cfg_namespace(field) }}* Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::add_{{ util.field_name(field) }}() { + if (!{{ util.field_name(field) }}_) { + {{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_repeated_container_name(field) }}>(); + } + return {{ util.field_name(field) }}_->Add(); +} {{ util.field_repeated_container_name(field) }}* Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::mutable_{{ util.field_name(field) }}() { if (!{{ util.field_name(field) }}_) { {{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_repeated_container_name(field) }}>(); } + assert(!{{ util.field_name(field) }}_->freeze()); return {{ util.field_name(field) }}_.get(); } {{ util.field_type_name_with_cfg_namespace(field) }}* Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::mutable_{{ util.field_name(field) }}(::std::size_t index) { if (!{{ util.field_name(field) }}_) { {{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_repeated_container_name(field) }}>(); } + assert(!{{ util.field_name(field) }}_->freeze()); return {{ util.field_name(field) }}_->Mutable(index); } -{% if util.field_is_message_type(field) %} -{{ util.field_type_name_with_cfg_namespace(field) }}* Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::add_{{ util.field_name(field) }}() { +{% else %} +void Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::add_{{ util.field_name(field) }}(const {{ util.field_type_name_with_cfg_namespace(field) }}& value) { if (!{{ util.field_name(field) }}_) { {{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_repeated_container_name(field) }}>(); } - return {{ util.field_name(field) }}_->Add(); + return {{ util.field_name(field) }}_->Add(value); } -{% else %} -void Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::add_{{ util.field_name(field) }}(const {{ util.field_type_name_with_cfg_namespace(field) }}& value) { +{{ util.field_repeated_container_name(field) }}* Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::mutable_{{ util.field_name(field) }}() { if (!{{ util.field_name(field) }}_) { {{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_repeated_container_name(field) }}>(); } - return {{ util.field_name(field) }}_->Add(value); + return {{ util.field_name(field) }}_.get(); +} +{{ util.field_type_name_with_cfg_namespace(field) }}* Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::mutable_{{ util.field_name(field) }}(::std::size_t index) { + if (!{{ util.field_name(field) }}_) { + {{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_repeated_container_name(field) }}>(); + } + return {{ util.field_name(field) }}_->Mutable(index); } {% endif %}{# field message type #} {% elif util.field_has_oneof_label(field) %} @@ -376,6 +391,7 @@ const {{ util.field_type_name_with_cfg_namespace(field) }}& Const{{ util.class_n } {{ util.field_oneof_name(field) }}_case_ = {{ util.oneof_type_field_enum_value_name(field) }}; ::std::shared_ptr<{{ util.field_type_name_with_cfg_namespace(field) }}>* __attribute__((__may_alias__)) ptr = reinterpret_cast<::std::shared_ptr<{{ util.field_type_name_with_cfg_namespace(field) }}>*>(&({{ util.field_oneof_name(field) }}_.{{ util.field_name(field) }}_)); + assert(!(*ptr)->freeze()); return (*ptr).get(); } {% else %} @@ -428,13 +444,6 @@ const {{ util.field_map_container_name(field) }}& Const{{ util.class_name(cls) } return *({{ util.field_name(field) }}_.get()); } -{{ util.field_map_container_name(field) }} * Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::mutable_{{ util.field_name(field) }}() { - if (!{{ util.field_name(field) }}_) { - {{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_map_container_name(field) }}>(); - } - return {{ util.field_name(field) }}_.get(); -} - const {{ util.field_map_value_type_name_with_cfg_namespace(field) }}& Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::{{ util.field_name(field) }}({{ util.field_map_key_type_name(field) }} key) const { if (!{{ util.field_name(field) }}_) { static const ::std::shared_ptr<{{ util.field_map_container_name(field) }}> default_static_value = @@ -451,14 +460,27 @@ void Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::clear_{{ uti return {{ util.field_name(field) }}_->Clear(); } -{% if util.field_is_message_type(field) %} +{% if util.field_map_value_type_is_message(field) %} +{{ util.field_map_container_name(field) }} * Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::mutable_{{ util.field_name(field) }}() { + if (!{{ util.field_name(field) }}_) { + {{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_map_container_name(field) }}>(); + } + assert(!{{ util.field_name(field) }}_->freeze()); + return {{ util.field_name(field) }}_.get(); +} {% else %} -void Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::add_{{ util.field_name(field) }}(const {{ util.field_type_name_with_cfg_namespace(field) }}& value) { +{{ util.field_map_container_name(field) }} * Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::mutable_{{ util.field_name(field) }}() { if (!{{ util.field_name(field) }}_) { {{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_map_container_name(field) }}>(); } - return {{ util.field_name(field) }}_->Add(value); + return {{ util.field_name(field) }}_.get(); } +// void Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::add_{{ util.field_name(field) }}(const {{ util.field_type_name_with_cfg_namespace(field) }}& value) { +// if (!{{ util.field_name(field) }}_) { +// {{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_map_container_name(field) }}>(); +// } +// return {{ util.field_name(field) }}_->Add(value); +// } {% endif %}{# field message type #} {% endif %}{# label #} {% endfor %}{# field #} @@ -556,7 +578,7 @@ bool Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::operator==(c return true {% for field in util.message_type_fields(cls) %} {% if util.field_has_required_or_optional_label(field) %} - && has_{{ util.field_name(field) }}() == other.has_{{ util.field_name(field) }}() + && has_{{ util.field_name(field) }}() == other.has_{{ util.field_name(field) }}() && {{ util.field_name(field) }}() == other.{{ util.field_name(field) }}() {% elif util.field_has_repeated_label(field) or util.field_has_map_label(field) %} && {{ util.field_name(field) }}() == other.{{ util.field_name(field) }}() @@ -565,7 +587,7 @@ bool Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::operator==(c {% for oneof in util.message_type_oneofs(cls) %} && {{ util.oneof_name(oneof) }}_case() == other.{{ util.oneof_name(oneof) }}_case() {% for field in util.oneof_type_fields(oneof) %} - && {{ util.oneof_name(oneof) }}_case() == {{ util.oneof_type_field_enum_value_name(field) }} ? + && {{ util.oneof_name(oneof) }}_case() == {{ util.oneof_type_field_enum_value_name(field) }} ? {{ util.field_name(field) }}() == other.{{ util.field_name(field) }}() : true {% endfor %}{# oneof_field #} {% endfor %}{# oneofs #} @@ -576,21 +598,21 @@ bool Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::operator<(co return false {% for field in util.message_type_fields(cls) %} {% if util.field_has_required_or_optional_label(field) %} - || !(has_{{ util.field_name(field) }}() == other.has_{{ util.field_name(field) }}()) ? + || !(has_{{ util.field_name(field) }}() == other.has_{{ util.field_name(field) }}()) ? has_{{ util.field_name(field) }}() < other.has_{{ util.field_name(field) }}() : false - || !({{ util.field_name(field) }}() == other.{{ util.field_name(field) }}()) ? + || !({{ util.field_name(field) }}() == other.{{ util.field_name(field) }}()) ? {{ util.field_name(field) }}() < other.{{ util.field_name(field) }}() : false {% elif util.field_has_repeated_label(field) or util.field_has_map_label(field) %} - || !({{ util.field_name(field) }}() == other.{{ util.field_name(field) }}()) ? + || !({{ util.field_name(field) }}() == other.{{ util.field_name(field) }}()) ? {{ util.field_name(field) }}() < other.{{ util.field_name(field) }}() : false {% endif %}{# field_label #} {% endfor %}{# fields #} {% for oneof in util.message_type_oneofs(cls) %} - || !({{ util.oneof_name(oneof) }}_case() == other.{{ util.oneof_name(oneof) }}_case()) ? + || !({{ util.oneof_name(oneof) }}_case() == other.{{ util.oneof_name(oneof) }}_case()) ? {{ util.oneof_name(oneof) }}_case() < other.{{ util.oneof_name(oneof) }}_case() : false {% for field in util.oneof_type_fields(oneof) %} - || (({{ util.oneof_name(oneof) }}_case() == {{ util.oneof_type_field_enum_value_name(field) }}) && - !({{ util.field_name(field) }}() == other.{{ util.field_name(field) }}())) ? + || (({{ util.oneof_name(oneof) }}_case() == {{ util.oneof_type_field_enum_value_name(field) }}) && + !({{ util.field_name(field) }}() == other.{{ util.field_name(field) }}())) ? {{ util.field_name(field) }}() < other.{{ util.field_name(field) }}() : false {% endfor %}{# oneof_field #} {% endfor %}{# oneofs #} @@ -610,7 +632,7 @@ Const{{ util.class_name(cls) }}::~Const{{ util.class_name(cls) }}() = default; void Const{{ util.class_name(cls) }}::ToProto(PbMessage* proto_{{ util.class_name(cls).lower() }}) const { __SharedPtrOrDefault__()->ToProto(dynamic_cast<{{ util.module_package_namespace(module) }}::{{ util.class_name(cls) }}*>(proto_{{ util.class_name(cls).lower() }})); } - + ::std::string Const{{ util.class_name(cls) }}::DebugString() const { return __SharedPtrOrDefault__()->DebugString(); } @@ -619,6 +641,10 @@ bool Const{{ util.class_name(cls) }}::__Empty__() const { return !data_; } +bool Const{{ util.class_name(cls) }}::freeze() const { + return freeze_; +} + int Const{{ util.class_name(cls) }}::FieldNumber4FieldName(const ::std::string& field_name) const { {% if util.has_message_type_fields(cls) %} static const ::std::map<::std::string, int> field_name2field_number{ @@ -796,7 +822,7 @@ void Const{{ util.class_name(cls) }}::BuildFromProto(const PbMessage& proto_{{ u : Const{{ util.class_name(cls) }}(data) {} {{ util.class_name(cls) }}::{{ util.class_name(cls) }}(const {{ util.class_name(cls) }}& other) { CopyFrom(other); } // enable nothrow for ::std::vector<{{ util.class_name(cls) }}> resize -{{ util.class_name(cls) }}::{{ util.class_name(cls) }}({{ util.class_name(cls) }}&&) noexcept = default; +{{ util.class_name(cls) }}::{{ util.class_name(cls) }}({{ util.class_name(cls) }}&&) noexcept = default; {{ util.class_name(cls) }}::{{ util.class_name(cls) }}(const {{ util.module_package_namespace(module) }}::{{ util.class_name(cls) }}& proto_{{ util.class_name(cls).lower() }}) { InitFromProto(proto_{{ util.class_name(cls).lower() }}); } @@ -807,7 +833,7 @@ void Const{{ util.class_name(cls) }}::BuildFromProto(const PbMessage& proto_{{ u void {{ util.class_name(cls) }}::InitFromProto(const PbMessage& proto_{{ util.class_name(cls).lower() }}) { BuildFromProto(proto_{{ util.class_name(cls).lower() }}); } - + void* {{ util.class_name(cls) }}::MutableFieldPtr4FieldNumber(int field_number) { switch (field_number) { {% for field in util.message_type_fields(cls) %} @@ -830,10 +856,49 @@ void {{ util.class_name(cls) }}::Clear() { void {{ util.class_name(cls) }}::CopyFrom(const {{ util.class_name(cls) }}& other) { if (other.__Empty__()) { Clear(); + } else if (other.freeze()) { + data_ = other.data_; + freeze_ = true; } else { __SharedPtr__()->CopyFrom(*other.data_); } } + +void {{ util.class_name(cls) }}::set_freeze() { + if(!freeze_) { +{% for field in util.message_type_fields(cls) %} +{% if util.field_has_required_or_optional_label(field) and util.field_is_message_type(field) %} + if (has_{{ util.field_name(field) }}()) { + mutable_{{ util.field_name(field) }}()->set_freeze(); + } +{% elif util.field_has_repeated_label(field) and util.field_is_message_type(field) %} + mutable_{{ util.field_name(field) }}()->set_freeze(); +{% elif util.field_has_map_label(field) and util.field_map_value_type_is_message(field) %} + mutable_{{ util.field_name(field) }}()->set_freeze(); +{% endif %} +{% endfor %} +{% for oneof in util.message_type_oneofs(cls) %} + switch ({{ util.oneof_name(oneof) }}_case()) { +{% for field in util.oneof_type_fields(oneof) %} +{% if util.field_is_message_type(field) %} + case {{ util.oneof_type_field_enum_value_name(field) }}: { + mutable_{{ util.field_name(field) }}()->set_freeze(); + break; + } +{% endif %}{# message_type #} +{% endfor %}{# oneof_field #} + case {{ util.oneof_name(oneof).upper() }}_NOT_SET: { + break; + } + default: { + break; + } + } +{% endfor %}{# oneof #} + freeze_ = true; + } +} + {{ util.class_name(cls) }}& {{ util.class_name(cls) }}::operator=(const {{ util.class_name(cls) }}& other) { CopyFrom(other); return *this; @@ -847,6 +912,7 @@ void {{ util.class_name(cls) }}::clear_{{ util.field_name(field) }}() { } {% if util.field_is_message_type(field) %} {{ util.field_type_name_with_cfg_namespace(field) }}* {{ util.class_name(cls) }}::mutable_{{ util.field_name(field) }}() { + assert(!freeze_); return __SharedPtr__()->mutable_{{ util.field_name(field) }}(); } // used by pybind11 only @@ -855,9 +921,11 @@ void {{ util.class_name(cls) }}::clear_{{ util.field_name(field) }}() { } {% else %} void {{ util.class_name(cls) }}::set_{{ util.field_name(field) }}(const {{ util.field_type_name_with_cfg_namespace(field) }}& value) { + assert(!freeze_); return __SharedPtr__()->set_{{ util.field_name(field) }}(value); } {{ util.field_type_name_with_cfg_namespace(field) }}* {{ util.class_name(cls) }}::mutable_{{ util.field_name(field) }}() { + assert(!freeze_); return __SharedPtr__()->mutable_{{ util.field_name(field) }}(); } {% endif %} @@ -867,9 +935,11 @@ void {{ util.class_name(cls) }}::clear_{{ util.field_name(field) }}() { return __SharedPtr__()->clear_{{ util.field_name(field) }}(); } {{ util.field_repeated_container_name(field) }}* {{ util.class_name(cls) }}::mutable_{{ util.field_name(field) }}() { + assert(!freeze_); return __SharedPtr__()->mutable_{{ util.field_name(field) }}(); } {{ util.field_type_name_with_cfg_namespace(field) }}* {{ util.class_name(cls) }}::mutable_{{ util.field_name(field) }}(::std::size_t index) { + assert(!freeze_); return __SharedPtr__()->mutable_{{ util.field_name(field) }}(index); } {% if util.field_is_message_type(field) %} @@ -898,6 +968,7 @@ void {{ util.class_name(cls) }}::clear_{{ util.field_name(field) }}() { } {% if util.field_is_message_type(field) %} {{ util.field_type_name_with_cfg_namespace(field) }}* {{ util.class_name(cls) }}::mutable_{{ util.field_name(field) }}() { + assert(!freeze_); return __SharedPtr__()->mutable_{{ util.field_name(field) }}(); } // used by pybind11 only @@ -906,9 +977,11 @@ void {{ util.class_name(cls) }}::clear_{{ util.field_name(field) }}() { } {% else %} void {{ util.class_name(cls) }}::set_{{ util.field_name(field) }}(const {{ util.field_type_name_with_cfg_namespace(field) }}& value) { + assert(!freeze_); return __SharedPtr__()->set_{{ util.field_name(field) }}(value); } {{ util.field_type_name_with_cfg_namespace(field) }}* {{ util.class_name(cls) }}::mutable_{{ util.field_name(field) }}() { + assert(!freeze_); return __SharedPtr__()->mutable_{{ util.field_name(field) }}(); } {% endif %}{# field message type #} @@ -919,11 +992,12 @@ void {{ util.class_name(cls) }}::clear_{{ util.field_name(field) }}() { return __SharedPtr__()->clear_{{ util.field_name(field) }}(); } -const {{ util.field_map_container_name(field) }} & {{ util.class_name(cls) }}::{{ util.field_name(field) }}() { - return __SharedPtr__()->{{ util.field_name(field) }}(); -} +// const {{ util.field_map_container_name(field) }} & {{ util.class_name(cls) }}::{{ util.field_name(field) }}() { +// return __SharedPtr__()->{{ util.field_name(field) }}(); +// } {{ util.field_map_container_name(field) }}* {{ util.class_name(cls) }}::mutable_{{ util.field_name(field) }}() { + assert(!freeze_); return __SharedPtr__()->mutable_{{ util.field_name(field) }}(); } diff --git a/tools/cfg/template/template.cfg.h b/tools/cfg/template/template.cfg.h index f6b786456e7a5451d05b20b2760352b0681ecbe3..be23f3592e4fb6dad3f4d213f5e5f5927f06aa92 100644 --- a/tools/cfg/template/template.cfg.h +++ b/tools/cfg/template/template.cfg.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include "oneflow/cfg/repeated_field.h" #include "oneflow/cfg/map_field.h" @@ -93,7 +94,7 @@ class Const{{ util.field_repeated_container_name(field) }}; {% endif %}{# repeated #} {# map begin #} {% if util.field_has_map_label(field) and util.add_declared_map_field_type_name(field) %} -class {{ util.field_map_container_name(field) }}; +class {{ util.field_map_container_name(field) }}; class Const{{ util.field_map_container_name(field) }}; {% endif %}{# map end #} {% endfor %}{# field #} @@ -157,12 +158,14 @@ class Const{{ util.class_name(cls) }} : public ::oneflow::cfg::Message { const {{ util.field_repeated_container_name(field) }}& {{ util.field_name(field) }}() const; const {{ util.field_type_name_with_cfg_namespace(field) }}& {{ util.field_name(field) }}(::std::size_t index) const; void clear_{{ util.field_name(field) }}(); - {{ util.field_repeated_container_name(field) }}* mutable_{{ util.field_name(field) }}(); - {{ util.field_type_name_with_cfg_namespace(field) }}* mutable_{{ util.field_name(field) }}(::std::size_t index); {% if util.field_is_message_type(field) %} {{ util.field_type_name_with_cfg_namespace(field) }}* add_{{ util.field_name(field) }}(); + {{ util.field_repeated_container_name(field) }}* mutable_{{ util.field_name(field) }}(); + {{ util.field_type_name_with_cfg_namespace(field) }}* mutable_{{ util.field_name(field) }}(::std::size_t index); {% else %} void add_{{ util.field_name(field) }}(const {{ util.field_type_name_with_cfg_namespace(field) }}& value); + {{ util.field_repeated_container_name(field) }}* mutable_{{ util.field_name(field) }}(); + {{ util.field_type_name_with_cfg_namespace(field) }}* mutable_{{ util.field_name(field) }}(::std::size_t index); {% endif %}{# field message type #} protected: ::std::shared_ptr<{{ util.field_repeated_container_name(field) }}> {{ util.field_name(field) }}_; @@ -183,21 +186,21 @@ class Const{{ util.class_name(cls) }} : public ::oneflow::cfg::Message { ::std::size_t {{ util.field_name(field) }}_size() const; const {{ util.field_map_container_name(field) }}& {{ util.field_name(field) }}() const; - {{ util.field_map_container_name(field) }} * mutable_{{ util.field_name(field) }}(); - const {{ util.field_map_value_type_name_with_cfg_namespace(field) }}& {{ util.field_name(field) }}({{ util.field_map_key_type_name(field) }} key) const; void clear_{{ util.field_name(field) }}(); - {% if util.field_is_message_type(field) %} + {% if util.field_map_value_type_is_message(field) %} + {{ util.field_map_container_name(field) }} * mutable_{{ util.field_name(field) }}(); {% else %} - void add_{{ util.field_name(field) }}(const {{ util.field_type_name_with_cfg_namespace(field) }}& value); + {{ util.field_map_container_name(field) }} * mutable_{{ util.field_name(field) }}(); + //void add_{{ util.field_name(field) }}(const {{ util.field_type_name_with_cfg_namespace(field) }}& value); {% endif %}{# field message type #} protected: ::std::shared_ptr<{{ util.field_map_container_name(field) }}> {{ util.field_name(field) }}_; {% endif %}{# label #} {% endfor %}{# field #} {% for oneof in util.message_type_oneofs(cls) %} - + public: // oneof {{ util.oneof_name(oneof) }} {{ util.oneof_enum_name(oneof) }} {{ util.oneof_name(oneof) }}_case() const; @@ -220,7 +223,7 @@ class Const{{ util.class_name(cls) }} : public ::oneflow::cfg::Message { } {{ util.oneof_name(oneof) }}_; {{ util.oneof_enum_name(oneof) }} {{ util.oneof_name(oneof) }}_case_ = {{ util.oneof_name(oneof).upper() }}_NOT_SET; {% endfor %}{# message_oneof #} - + public: int compare(const _{{ util.class_name(cls) }}_& other); @@ -238,9 +241,11 @@ class Const{{ util.class_name(cls) }} : public ::oneflow::cfg::Message { using PbMessage = ::google::protobuf::Message; void ToProto(PbMessage* proto_{{ util.class_name(cls).lower() }}) const override; - + ::std::string DebugString() const; + bool freeze() const; + bool __Empty__() const; int FieldNumber4FieldName(const ::std::string& field_name) const override; @@ -299,7 +304,7 @@ class Const{{ util.class_name(cls) }} : public ::oneflow::cfg::Message { bool has_{{ util.oneof_name(oneof) }}() const; {% endfor %}{# oneofs #} - + public: ::std::shared_ptr __SharedConst__() const; int64_t __Id__() const; public: @@ -311,15 +316,16 @@ class Const{{ util.class_name(cls) }} : public ::oneflow::cfg::Message { const ::std::shared_ptr<_{{ util.class_name(cls) }}_>& __SharedPtr__(); // use a protected member method to avoid someone change member variable(data_) by Const{{ util.class_name(cls) }} void BuildFromProto(const PbMessage& proto_{{ util.class_name(cls).lower() }}); - + ::std::shared_ptr<_{{ util.class_name(cls) }}_> data_; + bool freeze_ = false; }; class {{ util.class_name(cls) }} final : public Const{{ util.class_name(cls) }} { public: {{ util.class_name(cls) }}(const ::std::shared_ptr<_{{ util.class_name(cls) }}_>& data); {{ util.class_name(cls) }}(const {{ util.class_name(cls) }}& other); - // enable nothrow for ::std::vector<{{ util.class_name(cls) }}> resize + // enable nothrow for ::std::vector<{{ util.class_name(cls) }}> resize {{ util.class_name(cls) }}({{ util.class_name(cls) }}&&) noexcept; {{ util.class_name(cls) }}(); {{ util.class_name(cls) }}(const {{ util.module_package_namespace(module) }}::{{ util.class_name(cls) }}& proto_{{ util.class_name(cls).lower() }}); @@ -327,13 +333,14 @@ class {{ util.class_name(cls) }} final : public Const{{ util.class_name(cls) }} ~{{ util.class_name(cls) }}() override; void InitFromProto(const PbMessage& proto_{{ util.class_name(cls).lower() }}) override; - + void* MutableFieldPtr4FieldNumber(int field_number) override; bool operator==(const {{ util.class_name(cls) }}& other) const; bool operator<(const {{ util.class_name(cls) }}& other) const; void Clear(); + void set_freeze(); void CopyFrom(const {{ util.class_name(cls) }}& other); {{ util.class_name(cls) }}& operator=(const {{ util.class_name(cls) }}& other); @@ -382,7 +389,7 @@ class {{ util.class_name(cls) }} final : public Const{{ util.class_name(cls) }} public: void clear_{{ util.field_name(field) }}(); - const {{ util.field_map_container_name(field) }} & {{ util.field_name(field) }}(); + //const {{ util.field_map_container_name(field) }} & {{ util.field_name(field) }}(); {{ util.field_map_container_name(field) }}* mutable_{{ util.field_name(field) }}(); @@ -405,7 +412,7 @@ class {{ util.class_name(cls) }} final : public Const{{ util.class_name(cls) }} {# no duplicated class defined for each repeated field type #} {% if util.field_has_repeated_label(field) and util.add_defined_repeated_field_type_name(field) %} -// inheritance is helpful for avoiding container iterator boilerplate +// inheritance is helpful for avoiding container iterator boilerplate class Const{{ util.field_repeated_container_name(field) }} : public ::oneflow::cfg::_RepeatedField_<{{ util.field_type_name_with_cfg_namespace(field) }}> { public: Const{{ util.field_repeated_container_name(field) }}(const ::std::shared_ptr<::std::vector<{{ util.field_type_name_with_cfg_namespace(field) }}>>& data); @@ -440,7 +447,7 @@ class {{ util.field_repeated_container_name(field) }} final : public Const{{ uti {# map begin #} {% if util.field_has_map_label(field) and util.add_defined_map_field_type_name(field) %} -// inheritance is helpful for avoid container iterator boilerplate +// inheritance is helpful for avoid container iterator boilerplate class Const{{ util.field_map_container_name(field) }} : public ::oneflow::cfg::_MapField_<{{ util.field_map_pair_type_name_with_cfg_namespace(field) }}> { public: Const{{ util.field_map_container_name(field) }}(const ::std::shared_ptr<::std::map<{{ util.field_map_pair_type_name_with_cfg_namespace(field) }}>>& data);