提交 d7081059 编写于 作者: H hsj0429

copy on write


Former-commit-id: f33c408801dec8c0b749d64e5a2d180187651459
上级 98473907
......@@ -60,6 +60,7 @@ function(GENERATE_CFG_AND_PYBIND11_CPP SRCS HDRS PYBIND_SRCS ROOT_DIR)
oneflow/core/common/cfg_reflection_test.proto
oneflow/core/common/data_type.proto
oneflow/core/common/device_type.proto
oneflow/core/common/demo.proto
)
set(of_cfg_proto_python_dir "${PROJECT_BINARY_DIR}/of_cfg_proto_python")
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/demo.cfg.h"
#include "oneflow/core/common/demo.pb.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/cfg.h"
#include <iostream>
using namespace std;
namespace oneflow {
namespace test {
void CfgPassByValue(oneflow::cfg::Foo foo) {
ASSERT_EQ(foo.freeze(), true);
ASSERT_EQ(foo.bar().freeze(), true);
ASSERT_EQ(foo.bars().freeze(), true);
ASSERT_EQ(foo.bar().doo().freeze(), true);
ASSERT_EQ(foo.bar().doo().noo().freeze(), true);
ASSERT_EQ(foo.bar().doo().noos().freeze(), true);
ASSERT_EQ(foo.bar().moo().freeze(), true);
ASSERT_EQ(foo.map_bar().freeze(), true);
ASSERT_EQ(foo.bar().doo().map_noo().freeze(), true);
ASSERT_EQ(foo.bar().doo().of_noo().freeze(), true);
ASSERT_EQ(foo.bar().moo().of_noo2().freeze(), true);
}
TEST(Demo, freeze_root_node) {
cfg::Foo foo;
cfg::Bar bar;
foo.set_name("foo");
foo.mutable_bar()->set_name("bar");
foo.mutable_bar()->mutable_doo()->set_name("doo");
foo.mutable_bar()->mutable_doo()->mutable_of_noo()->set_name("of_noo");
foo.mutable_bar()->mutable_moo()->set_name("moo");
foo.mutable_bar()->mutable_doo()->mutable_noo()->set_name("noo");
foo.mutable_bar()->mutable_moo()->mutable_of_noo2()->set_name("of_noo");
foo.set_freeze();
ASSERT_EQ(foo.freeze(), true);
ASSERT_EQ(foo.bar().freeze(), true);
ASSERT_EQ(foo.bars().freeze(), true);
ASSERT_EQ(foo.bar().doo().freeze(), true);
ASSERT_EQ(foo.bar().doo().noo().freeze(), true);
ASSERT_EQ(foo.bar().doo().noos().freeze(), true);
ASSERT_EQ(foo.bar().moo().freeze(), true);
ASSERT_EQ(foo.map_bar().freeze(), true);
ASSERT_EQ(foo.bar().doo().map_noo().freeze(), true);
ASSERT_EQ(foo.bar().doo().of_noo().freeze(), true);
ASSERT_EQ(foo.bar().moo().of_noo2().freeze(), true);
}
TEST(Demo, pass_by_value) {
cfg::Foo foo;
cfg::Bar bar;
foo.set_name("foo");
foo.mutable_bar()->set_name("bar");
foo.mutable_bar()->mutable_doo()->set_name("doo");
foo.mutable_bar()->mutable_doo()->mutable_of_noo()->set_name("of_noo");
foo.mutable_bar()->mutable_moo()->set_name("moo");
foo.mutable_bar()->mutable_doo()->mutable_noo()->set_name("noo");
foo.mutable_bar()->mutable_moo()->mutable_of_noo2()->set_name("of_noo");
foo.set_freeze();
CfgPassByValue(foo);
}
TEST(Demo, freeze_mid_node) {
cfg::Foo foo;
cfg::Bar bar;
foo.set_name("foo");
foo.mutable_bar()->set_name("bar");
foo.mutable_bar()->mutable_doo()->set_name("doo");
foo.mutable_bar()->mutable_doo()->mutable_of_noo()->set_name("of_noo");
foo.mutable_bar()->mutable_moo()->set_name("moo");
foo.mutable_bar()->mutable_doo()->mutable_noo()->set_name("noo");
foo.mutable_bar()->mutable_moo()->mutable_of_noo2()->set_name("of_noo");
foo.mutable_bar()->mutable_doo()->set_freeze();
ASSERT_EQ(foo.freeze(), false);
ASSERT_EQ(foo.bar().freeze(), false);
ASSERT_EQ(foo.bars().freeze(), false);
ASSERT_EQ(foo.bar().doo().freeze(), true);
ASSERT_EQ(foo.bar().doo().noo().freeze(), true);
ASSERT_EQ(foo.bar().doo().noos().freeze(), true);
ASSERT_EQ(foo.bar().moo().freeze(), false);
ASSERT_EQ(foo.map_bar().freeze(), false);
ASSERT_EQ(foo.bar().doo().map_noo().freeze(), true);
ASSERT_EQ(foo.bar().doo().of_noo().freeze(), true);
ASSERT_EQ(foo.bar().moo().of_noo2().freeze(), false);
}
} // namespace test
} // namespace oneflow
syntax="proto2";
package oneflow;
enum Enum {
kInvalidEnum = 0;
kEnum0 = 2;
}
message Foo {
enum Type {
H2D = 0;
D2H = 1;
}
enum Data {
D = 3;
H = 57;
}
required Type type = 38;
required Data data = 39;
optional string name = 21 [default="unnamed"];
required int64 int_value = 10;
optional int64 opt_int_value = 11;
required string string_value = 50;
optional string opt_string_value = 51;
repeated int64 int_values = 12;
repeated Enum enum_values = 13;
repeated string string_values = 14;
repeated Bar bars = 15;
required Bar bar = 100;
optional Bar optional_bar = 101;
required Enum enum_value = 103;
optional Enum opt_enum_value = 104;
oneof oneof_type {
Bar of_bar = 2;
string of_string_value = 4;
bytes of_bytes_value = 5;
int64 of_int_value = 6;
Enum of_enum_value = 7;
}
oneof oneof_expermental_type {
Bar of_expermental_bar = 8;
string of_expermental_string_value = 18;
bytes of_expermental_bytes_value = 19;
}
map<int64, int64> map_int_int = 106;
map<int64, int64> map_temp = 107;
map<int64, Bar> map_bar = 108;
map<int64, Enum> map_enum = 109;
}
message Bar {
required string name = 1 [default = "undefined-name"];
optional string nickname = 2 [default = "undefined-nickname"];
required Doo doo = 3;
required Moo moo = 6;
oneof of_bar {
int64 of_int_value = 4;
float of_float_value = 5;
}
}
message Doo {
required string name = 1 [default = "undefined-name"];
optional string nickname = 2 [default = "undefined-nickname"];
required Noo noo = 3;
repeated Noo noos = 6;
oneof oneof_doo {
Noo of_noo = 4;
int32 of_int = 5;
}
map<int64, Noo> map_noo = 108;
}
message Moo {
required string name = 1 [default = "undefined-name"];
optional string nickname = 2 [default = "undefined-nickname"];
oneof oneof_noo {
Noo of_noo2 = 4;
int32 of_int = 5;
}
}
message Noo {
required string name = 1 [default = "undefined-name"];
optional string nickname = 2 [default = "undefined-nickname"];
}
......@@ -73,7 +73,11 @@ class _MapField_ {
void Clear() { data_->clear(); }
void CopyFrom(const _MapField_& other) {
*data_ = *other.data_;
if (other.freeze()) {
data_ = other.data_;
} else {
*data_ = *other.data_;
}
}
_MapField_& operator=(const _MapField_& other) {
......@@ -85,8 +89,19 @@ class _MapField_ {
const std::shared_ptr<std::map<Key, T>>& __SharedPtr__() { return data_; }
bool freeze() const {
return freeze_;
}
void set_freeze() {
if (!(std::is_scalar<T>::value || std::is_same<std::string, T>::value)) {
freeze_ = true;
}
}
private:
std::shared_ptr<std::map<Key, T>> data_;
bool freeze_ = false;
};
}
......
......@@ -65,8 +65,12 @@ class _RepeatedField_ {
if (std::is_scalar<T>::value || std::is_same<std::string, T>::value) {
*data_ = *other.data_;
} else {
data_->clear();
for (const auto& elem : other) { *Add() = elem; }
if (other.freeze()) {
data_ = other.data_;
} else {
data_->clear();
for (const auto& elem : other) { *Add() = elem; }
}
}
}
......@@ -90,8 +94,17 @@ class _RepeatedField_ {
const std::shared_ptr<std::vector<T>>& __SharedPtr__() { return data_; }
bool freeze() const {
return freeze_;
}
void set_freeze() {
freeze_ = true;
}
private:
std::shared_ptr<std::vector<T>> data_;
bool freeze_ = false;
};
}
......
......@@ -26,7 +26,7 @@ void Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::InitFromProt
// required_or_optional field: {{ util.field_name(field) }}
if (proto_{{ util.class_name(cls).lower() }}.has_{{ util.field_name(field) }}()) {
{%if util.field_is_message_type(field)%}
*mutable_{{ util.field_name(field) }}() = {{ util.field_message_type_name_with_cfg_namespace(field) }}(proto_{{ util.class_name(cls).lower() }}.{{ util.field_name(field) }}());
*mutable_{{ util.field_name(field) }}() = {{ util.field_message_type_name_with_cfg_namespace(field) }}(proto_{{ util.class_name(cls).lower() }}.{{ util.field_name(field) }}());
{% elif util.field_is_enum_type(field) %}
set_{{ util.field_name(field) }}(static_cast<std::remove_reference<std::remove_const<decltype({{ util.field_name(field) }}())>::type>::type>(proto_{{ util.class_name(cls).lower() }}.{{ util.field_name(field) }}()));
{% else %}
......@@ -86,7 +86,7 @@ void Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::InitFromProt
break;
}
}
{% endfor %}{# oneofs #}
{% endfor %}{# oneofs #}
}
void Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::ToProto({{ util.module_package_namespace(module) }}::{{ util.class_name(cls) }}* proto_{{ util.class_name(cls).lower() }}) const {
......@@ -233,6 +233,7 @@ void Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::clear_{{ uti
if (!{{ util.field_name(field) }}_) {
{{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_type_name_with_cfg_namespace(field) }}>();
}
assert(!{{ util.field_name(field) }}_->freeze());
has_{{ util.field_name(field) }}_ = true;
return {{ util.field_name(field) }}_.get();
}
......@@ -294,31 +295,45 @@ void Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::clear_{{ uti
}
return {{ util.field_name(field) }}_->Clear();
}
{% if util.field_is_message_type(field) %}
{{ util.field_type_name_with_cfg_namespace(field) }}* Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::add_{{ util.field_name(field) }}() {
if (!{{ util.field_name(field) }}_) {
{{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_repeated_container_name(field) }}>();
}
return {{ util.field_name(field) }}_->Add();
}
{{ util.field_repeated_container_name(field) }}* Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::mutable_{{ util.field_name(field) }}() {
if (!{{ util.field_name(field) }}_) {
{{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_repeated_container_name(field) }}>();
}
assert(!{{ util.field_name(field) }}_->freeze());
return {{ util.field_name(field) }}_.get();
}
{{ util.field_type_name_with_cfg_namespace(field) }}* Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::mutable_{{ util.field_name(field) }}(::std::size_t index) {
if (!{{ util.field_name(field) }}_) {
{{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_repeated_container_name(field) }}>();
}
assert(!{{ util.field_name(field) }}_->freeze());
return {{ util.field_name(field) }}_->Mutable(index);
}
{% if util.field_is_message_type(field) %}
{{ util.field_type_name_with_cfg_namespace(field) }}* Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::add_{{ util.field_name(field) }}() {
{% else %}
void Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::add_{{ util.field_name(field) }}(const {{ util.field_type_name_with_cfg_namespace(field) }}& value) {
if (!{{ util.field_name(field) }}_) {
{{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_repeated_container_name(field) }}>();
}
return {{ util.field_name(field) }}_->Add();
return {{ util.field_name(field) }}_->Add(value);
}
{% else %}
void Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::add_{{ util.field_name(field) }}(const {{ util.field_type_name_with_cfg_namespace(field) }}& value) {
{{ util.field_repeated_container_name(field) }}* Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::mutable_{{ util.field_name(field) }}() {
if (!{{ util.field_name(field) }}_) {
{{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_repeated_container_name(field) }}>();
}
return {{ util.field_name(field) }}_->Add(value);
return {{ util.field_name(field) }}_.get();
}
{{ util.field_type_name_with_cfg_namespace(field) }}* Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::mutable_{{ util.field_name(field) }}(::std::size_t index) {
if (!{{ util.field_name(field) }}_) {
{{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_repeated_container_name(field) }}>();
}
return {{ util.field_name(field) }}_->Mutable(index);
}
{% endif %}{# field message type #}
{% elif util.field_has_oneof_label(field) %}
......@@ -376,6 +391,7 @@ const {{ util.field_type_name_with_cfg_namespace(field) }}& Const{{ util.class_n
}
{{ util.field_oneof_name(field) }}_case_ = {{ util.oneof_type_field_enum_value_name(field) }};
::std::shared_ptr<{{ util.field_type_name_with_cfg_namespace(field) }}>* __attribute__((__may_alias__)) ptr = reinterpret_cast<::std::shared_ptr<{{ util.field_type_name_with_cfg_namespace(field) }}>*>(&({{ util.field_oneof_name(field) }}_.{{ util.field_name(field) }}_));
assert(!(*ptr)->freeze());
return (*ptr).get();
}
{% else %}
......@@ -428,13 +444,6 @@ const {{ util.field_map_container_name(field) }}& Const{{ util.class_name(cls) }
return *({{ util.field_name(field) }}_.get());
}
{{ util.field_map_container_name(field) }} * Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::mutable_{{ util.field_name(field) }}() {
if (!{{ util.field_name(field) }}_) {
{{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_map_container_name(field) }}>();
}
return {{ util.field_name(field) }}_.get();
}
const {{ util.field_map_value_type_name_with_cfg_namespace(field) }}& Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::{{ util.field_name(field) }}({{ util.field_map_key_type_name(field) }} key) const {
if (!{{ util.field_name(field) }}_) {
static const ::std::shared_ptr<{{ util.field_map_container_name(field) }}> default_static_value =
......@@ -451,14 +460,27 @@ void Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::clear_{{ uti
return {{ util.field_name(field) }}_->Clear();
}
{% if util.field_is_message_type(field) %}
{% if util.field_map_value_type_is_message(field) %}
{{ util.field_map_container_name(field) }} * Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::mutable_{{ util.field_name(field) }}() {
if (!{{ util.field_name(field) }}_) {
{{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_map_container_name(field) }}>();
}
assert(!{{ util.field_name(field) }}_->freeze());
return {{ util.field_name(field) }}_.get();
}
{% else %}
void Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::add_{{ util.field_name(field) }}(const {{ util.field_type_name_with_cfg_namespace(field) }}& value) {
{{ util.field_map_container_name(field) }} * Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::mutable_{{ util.field_name(field) }}() {
if (!{{ util.field_name(field) }}_) {
{{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_map_container_name(field) }}>();
}
return {{ util.field_name(field) }}_->Add(value);
return {{ util.field_name(field) }}_.get();
}
// void Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::add_{{ util.field_name(field) }}(const {{ util.field_type_name_with_cfg_namespace(field) }}& value) {
// if (!{{ util.field_name(field) }}_) {
// {{ util.field_name(field) }}_ = ::std::make_shared<{{ util.field_map_container_name(field) }}>();
// }
// return {{ util.field_name(field) }}_->Add(value);
// }
{% endif %}{# field message type #}
{% endif %}{# label #}
{% endfor %}{# field #}
......@@ -556,7 +578,7 @@ bool Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::operator==(c
return true
{% for field in util.message_type_fields(cls) %}
{% if util.field_has_required_or_optional_label(field) %}
&& has_{{ util.field_name(field) }}() == other.has_{{ util.field_name(field) }}()
&& has_{{ util.field_name(field) }}() == other.has_{{ util.field_name(field) }}()
&& {{ util.field_name(field) }}() == other.{{ util.field_name(field) }}()
{% elif util.field_has_repeated_label(field) or util.field_has_map_label(field) %}
&& {{ util.field_name(field) }}() == other.{{ util.field_name(field) }}()
......@@ -565,7 +587,7 @@ bool Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::operator==(c
{% for oneof in util.message_type_oneofs(cls) %}
&& {{ util.oneof_name(oneof) }}_case() == other.{{ util.oneof_name(oneof) }}_case()
{% for field in util.oneof_type_fields(oneof) %}
&& {{ util.oneof_name(oneof) }}_case() == {{ util.oneof_type_field_enum_value_name(field) }} ?
&& {{ util.oneof_name(oneof) }}_case() == {{ util.oneof_type_field_enum_value_name(field) }} ?
{{ util.field_name(field) }}() == other.{{ util.field_name(field) }}() : true
{% endfor %}{# oneof_field #}
{% endfor %}{# oneofs #}
......@@ -576,21 +598,21 @@ bool Const{{ util.class_name(cls) }}::_{{ util.class_name(cls) }}_::operator<(co
return false
{% for field in util.message_type_fields(cls) %}
{% if util.field_has_required_or_optional_label(field) %}
|| !(has_{{ util.field_name(field) }}() == other.has_{{ util.field_name(field) }}()) ?
|| !(has_{{ util.field_name(field) }}() == other.has_{{ util.field_name(field) }}()) ?
has_{{ util.field_name(field) }}() < other.has_{{ util.field_name(field) }}() : false
|| !({{ util.field_name(field) }}() == other.{{ util.field_name(field) }}()) ?
|| !({{ util.field_name(field) }}() == other.{{ util.field_name(field) }}()) ?
{{ util.field_name(field) }}() < other.{{ util.field_name(field) }}() : false
{% elif util.field_has_repeated_label(field) or util.field_has_map_label(field) %}
|| !({{ util.field_name(field) }}() == other.{{ util.field_name(field) }}()) ?
|| !({{ util.field_name(field) }}() == other.{{ util.field_name(field) }}()) ?
{{ util.field_name(field) }}() < other.{{ util.field_name(field) }}() : false
{% endif %}{# field_label #}
{% endfor %}{# fields #}
{% for oneof in util.message_type_oneofs(cls) %}
|| !({{ util.oneof_name(oneof) }}_case() == other.{{ util.oneof_name(oneof) }}_case()) ?
|| !({{ util.oneof_name(oneof) }}_case() == other.{{ util.oneof_name(oneof) }}_case()) ?
{{ util.oneof_name(oneof) }}_case() < other.{{ util.oneof_name(oneof) }}_case() : false
{% for field in util.oneof_type_fields(oneof) %}
|| (({{ util.oneof_name(oneof) }}_case() == {{ util.oneof_type_field_enum_value_name(field) }}) &&
!({{ util.field_name(field) }}() == other.{{ util.field_name(field) }}())) ?
|| (({{ util.oneof_name(oneof) }}_case() == {{ util.oneof_type_field_enum_value_name(field) }}) &&
!({{ util.field_name(field) }}() == other.{{ util.field_name(field) }}())) ?
{{ util.field_name(field) }}() < other.{{ util.field_name(field) }}() : false
{% endfor %}{# oneof_field #}
{% endfor %}{# oneofs #}
......@@ -610,7 +632,7 @@ Const{{ util.class_name(cls) }}::~Const{{ util.class_name(cls) }}() = default;
void Const{{ util.class_name(cls) }}::ToProto(PbMessage* proto_{{ util.class_name(cls).lower() }}) const {
__SharedPtrOrDefault__()->ToProto(dynamic_cast<{{ util.module_package_namespace(module) }}::{{ util.class_name(cls) }}*>(proto_{{ util.class_name(cls).lower() }}));
}
::std::string Const{{ util.class_name(cls) }}::DebugString() const {
return __SharedPtrOrDefault__()->DebugString();
}
......@@ -619,6 +641,10 @@ bool Const{{ util.class_name(cls) }}::__Empty__() const {
return !data_;
}
bool Const{{ util.class_name(cls) }}::freeze() const {
return freeze_;
}
int Const{{ util.class_name(cls) }}::FieldNumber4FieldName(const ::std::string& field_name) const {
{% if util.has_message_type_fields(cls) %}
static const ::std::map<::std::string, int> field_name2field_number{
......@@ -796,7 +822,7 @@ void Const{{ util.class_name(cls) }}::BuildFromProto(const PbMessage& proto_{{ u
: Const{{ util.class_name(cls) }}(data) {}
{{ util.class_name(cls) }}::{{ util.class_name(cls) }}(const {{ util.class_name(cls) }}& other) { CopyFrom(other); }
// enable nothrow for ::std::vector<{{ util.class_name(cls) }}> resize
{{ util.class_name(cls) }}::{{ util.class_name(cls) }}({{ util.class_name(cls) }}&&) noexcept = default;
{{ util.class_name(cls) }}::{{ util.class_name(cls) }}({{ util.class_name(cls) }}&&) noexcept = default;
{{ util.class_name(cls) }}::{{ util.class_name(cls) }}(const {{ util.module_package_namespace(module) }}::{{ util.class_name(cls) }}& proto_{{ util.class_name(cls).lower() }}) {
InitFromProto(proto_{{ util.class_name(cls).lower() }});
}
......@@ -807,7 +833,7 @@ void Const{{ util.class_name(cls) }}::BuildFromProto(const PbMessage& proto_{{ u
void {{ util.class_name(cls) }}::InitFromProto(const PbMessage& proto_{{ util.class_name(cls).lower() }}) {
BuildFromProto(proto_{{ util.class_name(cls).lower() }});
}
void* {{ util.class_name(cls) }}::MutableFieldPtr4FieldNumber(int field_number) {
switch (field_number) {
{% for field in util.message_type_fields(cls) %}
......@@ -830,10 +856,49 @@ void {{ util.class_name(cls) }}::Clear() {
void {{ util.class_name(cls) }}::CopyFrom(const {{ util.class_name(cls) }}& other) {
if (other.__Empty__()) {
Clear();
} else if (other.freeze()) {
data_ = other.data_;
freeze_ = true;
} else {
__SharedPtr__()->CopyFrom(*other.data_);
}
}
void {{ util.class_name(cls) }}::set_freeze() {
if(!freeze_) {
{% for field in util.message_type_fields(cls) %}
{% if util.field_has_required_or_optional_label(field) and util.field_is_message_type(field) %}
if (has_{{ util.field_name(field) }}()) {
mutable_{{ util.field_name(field) }}()->set_freeze();
}
{% elif util.field_has_repeated_label(field) and util.field_is_message_type(field) %}
mutable_{{ util.field_name(field) }}()->set_freeze();
{% elif util.field_has_map_label(field) and util.field_map_value_type_is_message(field) %}
mutable_{{ util.field_name(field) }}()->set_freeze();
{% endif %}
{% endfor %}
{% for oneof in util.message_type_oneofs(cls) %}
switch ({{ util.oneof_name(oneof) }}_case()) {
{% for field in util.oneof_type_fields(oneof) %}
{% if util.field_is_message_type(field) %}
case {{ util.oneof_type_field_enum_value_name(field) }}: {
mutable_{{ util.field_name(field) }}()->set_freeze();
break;
}
{% endif %}{# message_type #}
{% endfor %}{# oneof_field #}
case {{ util.oneof_name(oneof).upper() }}_NOT_SET: {
break;
}
default: {
break;
}
}
{% endfor %}{# oneof #}
freeze_ = true;
}
}
{{ util.class_name(cls) }}& {{ util.class_name(cls) }}::operator=(const {{ util.class_name(cls) }}& other) {
CopyFrom(other);
return *this;
......@@ -847,6 +912,7 @@ void {{ util.class_name(cls) }}::clear_{{ util.field_name(field) }}() {
}
{% if util.field_is_message_type(field) %}
{{ util.field_type_name_with_cfg_namespace(field) }}* {{ util.class_name(cls) }}::mutable_{{ util.field_name(field) }}() {
assert(!freeze_);
return __SharedPtr__()->mutable_{{ util.field_name(field) }}();
}
// used by pybind11 only
......@@ -855,9 +921,11 @@ void {{ util.class_name(cls) }}::clear_{{ util.field_name(field) }}() {
}
{% else %}
void {{ util.class_name(cls) }}::set_{{ util.field_name(field) }}(const {{ util.field_type_name_with_cfg_namespace(field) }}& value) {
assert(!freeze_);
return __SharedPtr__()->set_{{ util.field_name(field) }}(value);
}
{{ util.field_type_name_with_cfg_namespace(field) }}* {{ util.class_name(cls) }}::mutable_{{ util.field_name(field) }}() {
assert(!freeze_);
return __SharedPtr__()->mutable_{{ util.field_name(field) }}();
}
{% endif %}
......@@ -867,9 +935,11 @@ void {{ util.class_name(cls) }}::clear_{{ util.field_name(field) }}() {
return __SharedPtr__()->clear_{{ util.field_name(field) }}();
}
{{ util.field_repeated_container_name(field) }}* {{ util.class_name(cls) }}::mutable_{{ util.field_name(field) }}() {
assert(!freeze_);
return __SharedPtr__()->mutable_{{ util.field_name(field) }}();
}
{{ util.field_type_name_with_cfg_namespace(field) }}* {{ util.class_name(cls) }}::mutable_{{ util.field_name(field) }}(::std::size_t index) {
assert(!freeze_);
return __SharedPtr__()->mutable_{{ util.field_name(field) }}(index);
}
{% if util.field_is_message_type(field) %}
......@@ -898,6 +968,7 @@ void {{ util.class_name(cls) }}::clear_{{ util.field_name(field) }}() {
}
{% if util.field_is_message_type(field) %}
{{ util.field_type_name_with_cfg_namespace(field) }}* {{ util.class_name(cls) }}::mutable_{{ util.field_name(field) }}() {
assert(!freeze_);
return __SharedPtr__()->mutable_{{ util.field_name(field) }}();
}
// used by pybind11 only
......@@ -906,9 +977,11 @@ void {{ util.class_name(cls) }}::clear_{{ util.field_name(field) }}() {
}
{% else %}
void {{ util.class_name(cls) }}::set_{{ util.field_name(field) }}(const {{ util.field_type_name_with_cfg_namespace(field) }}& value) {
assert(!freeze_);
return __SharedPtr__()->set_{{ util.field_name(field) }}(value);
}
{{ util.field_type_name_with_cfg_namespace(field) }}* {{ util.class_name(cls) }}::mutable_{{ util.field_name(field) }}() {
assert(!freeze_);
return __SharedPtr__()->mutable_{{ util.field_name(field) }}();
}
{% endif %}{# field message type #}
......@@ -919,11 +992,12 @@ void {{ util.class_name(cls) }}::clear_{{ util.field_name(field) }}() {
return __SharedPtr__()->clear_{{ util.field_name(field) }}();
}
const {{ util.field_map_container_name(field) }} & {{ util.class_name(cls) }}::{{ util.field_name(field) }}() {
return __SharedPtr__()->{{ util.field_name(field) }}();
}
// const {{ util.field_map_container_name(field) }} & {{ util.class_name(cls) }}::{{ util.field_name(field) }}() {
// return __SharedPtr__()->{{ util.field_name(field) }}();
// }
{{ util.field_map_container_name(field) }}* {{ util.class_name(cls) }}::mutable_{{ util.field_name(field) }}() {
assert(!freeze_);
return __SharedPtr__()->mutable_{{ util.field_name(field) }}();
}
......
......@@ -4,6 +4,7 @@
#include <memory>
#include <vector>
#include <map>
#include <assert.h>
#include <google/protobuf/message.h>
#include "oneflow/cfg/repeated_field.h"
#include "oneflow/cfg/map_field.h"
......@@ -93,7 +94,7 @@ class Const{{ util.field_repeated_container_name(field) }};
{% endif %}{# repeated #}
{# map begin #}
{% if util.field_has_map_label(field) and util.add_declared_map_field_type_name(field) %}
class {{ util.field_map_container_name(field) }};
class {{ util.field_map_container_name(field) }};
class Const{{ util.field_map_container_name(field) }};
{% endif %}{# map end #}
{% endfor %}{# field #}
......@@ -157,12 +158,14 @@ class Const{{ util.class_name(cls) }} : public ::oneflow::cfg::Message {
const {{ util.field_repeated_container_name(field) }}& {{ util.field_name(field) }}() const;
const {{ util.field_type_name_with_cfg_namespace(field) }}& {{ util.field_name(field) }}(::std::size_t index) const;
void clear_{{ util.field_name(field) }}();
{{ util.field_repeated_container_name(field) }}* mutable_{{ util.field_name(field) }}();
{{ util.field_type_name_with_cfg_namespace(field) }}* mutable_{{ util.field_name(field) }}(::std::size_t index);
{% if util.field_is_message_type(field) %}
{{ util.field_type_name_with_cfg_namespace(field) }}* add_{{ util.field_name(field) }}();
{{ util.field_repeated_container_name(field) }}* mutable_{{ util.field_name(field) }}();
{{ util.field_type_name_with_cfg_namespace(field) }}* mutable_{{ util.field_name(field) }}(::std::size_t index);
{% else %}
void add_{{ util.field_name(field) }}(const {{ util.field_type_name_with_cfg_namespace(field) }}& value);
{{ util.field_repeated_container_name(field) }}* mutable_{{ util.field_name(field) }}();
{{ util.field_type_name_with_cfg_namespace(field) }}* mutable_{{ util.field_name(field) }}(::std::size_t index);
{% endif %}{# field message type #}
protected:
::std::shared_ptr<{{ util.field_repeated_container_name(field) }}> {{ util.field_name(field) }}_;
......@@ -183,21 +186,21 @@ class Const{{ util.class_name(cls) }} : public ::oneflow::cfg::Message {
::std::size_t {{ util.field_name(field) }}_size() const;
const {{ util.field_map_container_name(field) }}& {{ util.field_name(field) }}() const;
{{ util.field_map_container_name(field) }} * mutable_{{ util.field_name(field) }}();
const {{ util.field_map_value_type_name_with_cfg_namespace(field) }}& {{ util.field_name(field) }}({{ util.field_map_key_type_name(field) }} key) const;
void clear_{{ util.field_name(field) }}();
{% if util.field_is_message_type(field) %}
{% if util.field_map_value_type_is_message(field) %}
{{ util.field_map_container_name(field) }} * mutable_{{ util.field_name(field) }}();
{% else %}
void add_{{ util.field_name(field) }}(const {{ util.field_type_name_with_cfg_namespace(field) }}& value);
{{ util.field_map_container_name(field) }} * mutable_{{ util.field_name(field) }}();
//void add_{{ util.field_name(field) }}(const {{ util.field_type_name_with_cfg_namespace(field) }}& value);
{% endif %}{# field message type #}
protected:
::std::shared_ptr<{{ util.field_map_container_name(field) }}> {{ util.field_name(field) }}_;
{% endif %}{# label #}
{% endfor %}{# field #}
{% for oneof in util.message_type_oneofs(cls) %}
public:
// oneof {{ util.oneof_name(oneof) }}
{{ util.oneof_enum_name(oneof) }} {{ util.oneof_name(oneof) }}_case() const;
......@@ -220,7 +223,7 @@ class Const{{ util.class_name(cls) }} : public ::oneflow::cfg::Message {
} {{ util.oneof_name(oneof) }}_;
{{ util.oneof_enum_name(oneof) }} {{ util.oneof_name(oneof) }}_case_ = {{ util.oneof_name(oneof).upper() }}_NOT_SET;
{% endfor %}{# message_oneof #}
public:
int compare(const _{{ util.class_name(cls) }}_& other);
......@@ -238,9 +241,11 @@ class Const{{ util.class_name(cls) }} : public ::oneflow::cfg::Message {
using PbMessage = ::google::protobuf::Message;
void ToProto(PbMessage* proto_{{ util.class_name(cls).lower() }}) const override;
::std::string DebugString() const;
bool freeze() const;
bool __Empty__() const;
int FieldNumber4FieldName(const ::std::string& field_name) const override;
......@@ -299,7 +304,7 @@ class Const{{ util.class_name(cls) }} : public ::oneflow::cfg::Message {
bool has_{{ util.oneof_name(oneof) }}() const;
{% endfor %}{# oneofs #}
public:
::std::shared_ptr<Const{{ util.class_name(cls) }}> __SharedConst__() const;
int64_t __Id__() const;
public:
......@@ -311,15 +316,16 @@ class Const{{ util.class_name(cls) }} : public ::oneflow::cfg::Message {
const ::std::shared_ptr<_{{ util.class_name(cls) }}_>& __SharedPtr__();
// use a protected member method to avoid someone change member variable(data_) by Const{{ util.class_name(cls) }}
void BuildFromProto(const PbMessage& proto_{{ util.class_name(cls).lower() }});
::std::shared_ptr<_{{ util.class_name(cls) }}_> data_;
bool freeze_ = false;
};
class {{ util.class_name(cls) }} final : public Const{{ util.class_name(cls) }} {
public:
{{ util.class_name(cls) }}(const ::std::shared_ptr<_{{ util.class_name(cls) }}_>& data);
{{ util.class_name(cls) }}(const {{ util.class_name(cls) }}& other);
// enable nothrow for ::std::vector<{{ util.class_name(cls) }}> resize
// enable nothrow for ::std::vector<{{ util.class_name(cls) }}> resize
{{ util.class_name(cls) }}({{ util.class_name(cls) }}&&) noexcept;
{{ util.class_name(cls) }}();
{{ util.class_name(cls) }}(const {{ util.module_package_namespace(module) }}::{{ util.class_name(cls) }}& proto_{{ util.class_name(cls).lower() }});
......@@ -327,13 +333,14 @@ class {{ util.class_name(cls) }} final : public Const{{ util.class_name(cls) }}
~{{ util.class_name(cls) }}() override;
void InitFromProto(const PbMessage& proto_{{ util.class_name(cls).lower() }}) override;
void* MutableFieldPtr4FieldNumber(int field_number) override;
bool operator==(const {{ util.class_name(cls) }}& other) const;
bool operator<(const {{ util.class_name(cls) }}& other) const;
void Clear();
void set_freeze();
void CopyFrom(const {{ util.class_name(cls) }}& other);
{{ util.class_name(cls) }}& operator=(const {{ util.class_name(cls) }}& other);
......@@ -382,7 +389,7 @@ class {{ util.class_name(cls) }} final : public Const{{ util.class_name(cls) }}
public:
void clear_{{ util.field_name(field) }}();
const {{ util.field_map_container_name(field) }} & {{ util.field_name(field) }}();
//const {{ util.field_map_container_name(field) }} & {{ util.field_name(field) }}();
{{ util.field_map_container_name(field) }}* mutable_{{ util.field_name(field) }}();
......@@ -405,7 +412,7 @@ class {{ util.class_name(cls) }} final : public Const{{ util.class_name(cls) }}
{# no duplicated class defined for each repeated field type #}
{% if util.field_has_repeated_label(field) and util.add_defined_repeated_field_type_name(field) %}
// inheritance is helpful for avoiding container iterator boilerplate
// inheritance is helpful for avoiding container iterator boilerplate
class Const{{ util.field_repeated_container_name(field) }} : public ::oneflow::cfg::_RepeatedField_<{{ util.field_type_name_with_cfg_namespace(field) }}> {
public:
Const{{ util.field_repeated_container_name(field) }}(const ::std::shared_ptr<::std::vector<{{ util.field_type_name_with_cfg_namespace(field) }}>>& data);
......@@ -440,7 +447,7 @@ class {{ util.field_repeated_container_name(field) }} final : public Const{{ uti
{# map begin #}
{% if util.field_has_map_label(field) and util.add_defined_map_field_type_name(field) %}
// inheritance is helpful for avoid container iterator boilerplate
// inheritance is helpful for avoid container iterator boilerplate
class Const{{ util.field_map_container_name(field) }} : public ::oneflow::cfg::_MapField_<{{ util.field_map_pair_type_name_with_cfg_namespace(field) }}> {
public:
Const{{ util.field_map_container_name(field) }}(const ::std::shared_ptr<::std::map<{{ util.field_map_pair_type_name_with_cfg_namespace(field) }}>>& data);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册