提交 cf253718 编写于 作者: H hsj0429

format files and fix bug


Former-commit-id: 18154dcd186f64d3d51f80203815a0a6d773d402
上级 efc5e233
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/cfg_reflection_test.cfg.h"
#include "oneflow/core/common/cfg_reflection_test.pb.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/cfg.h"
namespace oneflow {
namespace test {
TEST(CfgReflection, FieldDefined_repeated_field_not_set) {
{
ReflectionTestFoo foo;
cfg::ReflectionTestFoo cfg_foo(foo);
const oneflow::cfg::_RepeatedField_<int>& repeated_int32 = oneflow::cfg::_RepeatedField_<int>();
static_assert(std::is_same<decltype(repeated_int32), decltype(cfg_foo.repeated_int32())>::value,
"Not a oneflow::cfg::_RepeatedField_<int> type!");
}
{
ReflectionTestBar bar;
cfg::ReflectionTestBar cfg_bar(bar);
const oneflow::cfg::_RepeatedField_<cfg::ReflectionTestFoo>& repeated_foo =
oneflow::cfg::_RepeatedField_<cfg::ReflectionTestFoo>();
static_assert(std::is_same<decltype(repeated_foo), decltype(cfg_bar.repeated_foo())>::value,
"Not a oneflow::cfg::_RepeatedField_<cfg::ReflectionTestFoo> type!");
}
}
TEST(CfgReflection, FieldDefined_repeated_field_set) {
{
ReflectionTestFoo foo;
foo.add_repeated_int32(0);
cfg::ReflectionTestFoo cfg_foo(foo);
const oneflow::cfg::_RepeatedField_<int>& repeated_int32 = oneflow::cfg::_RepeatedField_<int>();
static_assert(std::is_same<decltype(repeated_int32), decltype(cfg_foo.repeated_int32())>::value,
"Not a oneflow::cfg::_RepeatedField_<int> type!");
}
{
ReflectionTestBar bar;
bar.add_repeated_foo();
cfg::ReflectionTestBar cfg_bar(bar);
const oneflow::cfg::_RepeatedField_<cfg::ReflectionTestFoo>& repeated_foo =
oneflow::cfg::_RepeatedField_<cfg::ReflectionTestFoo>();
static_assert(std::is_same<decltype(repeated_foo), decltype(cfg_bar.repeated_foo())>::value,
"Not a oneflow::cfg::_RepeatedField_<cfg::ReflectionTestFoo> type!");
}
}
} // namespace test
} // namespace oneflow
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
import oneflow_api.oneflow.core.common.cfg_reflection_test as cfg
import unittest
@flow.unittest.skip_unless_1n1d()
class TestRepeatedField(flow.unittest.TestCase):
def test_repeated_field(test_case):
foo = cfg.ReflectionTestFoo()
bar = cfg.ReflectionTestBar()
foo.add_repeated_int32(11)
foo.add_repeated_int32(22)
foo.add_repeated_string("oneflow")
foo.add_repeated_string("pytorch")
bar.mutable_repeated_foo().Add().CopyFrom(foo)
test_case.assertEqual(
str(type(foo.repeated_int32())),
"<class 'oneflow_api.oneflow.core.common.cfg_reflection_test._ConstRepeatedField_<int32_t>'>",
)
test_case.assertEqual(
str(type(foo.repeated_string())),
"<class 'oneflow_api.oneflow.core.common.cfg_reflection_test._ConstRepeatedField_<::std::string>'>",
)
test_case.assertEqual(
str(type(bar.repeated_foo())),
"<class 'oneflow_api.oneflow.core.common.cfg_reflection_test._ConstRepeatedField_<::oneflow::cfg::ReflectionTestFoo>'>",
)
if __name__ == "__main__":
unittest.main()
......@@ -26,14 +26,14 @@ class _MapField_ {
using reverse_iterator = typename std::map<Key, T>::reverse_iterator;
using const_reverse_iterator = typename std::map<Key, T>::const_reverse_iterator;
_MapField_(): data_(std::make_shared<std::map<Key, T>>()) {}
_MapField_(const std::shared_ptr<std::map<Key, T>>& data): data_(data) {}
_MapField_(const _MapField_& other): data_(std::make_shared<std::map<Key, T>>()) {
_MapField_() : data_(std::make_shared<std::map<Key, T>>()) {}
_MapField_(const std::shared_ptr<std::map<Key, T>>& data) : data_(data) {}
_MapField_(const _MapField_& other) : data_(std::make_shared<std::map<Key, T>>()) {
CopyFrom(other);
}
_MapField_(_MapField_&&) = default;
template<typename InputIt>
_MapField_(InputIt begin, InputIt end): data_(std::make_shared<std::map<Key, T>>(begin, end)) {}
_MapField_(InputIt begin, InputIt end) : data_(std::make_shared<std::map<Key, T>>(begin, end)) {}
~_MapField_() = default;
iterator begin() noexcept { return data_->begin(); }
......@@ -69,12 +69,12 @@ class _MapField_ {
std::pair<iterator, bool> insert(const value_type& value) { return data_->insert(value); }
template<class InputIt>
void insert(InputIt first, InputIt last) { return data_->insert(first, last); }
void insert(InputIt first, InputIt last) {
return data_->insert(first, last);
}
void Clear() { data_->clear(); }
void CopyFrom(const _MapField_& other) {
*data_ = *other.data_;
}
void CopyFrom(const _MapField_& other) { *data_ = *other.data_; }
_MapField_& operator=(const _MapField_& other) {
CopyFrom(other);
......@@ -89,7 +89,7 @@ class _MapField_ {
std::shared_ptr<std::map<Key, T>> data_;
};
}
}
} // namespace cfg
} // namespace oneflow
#endif // ONEFLOW_CFG_MAP_FIELD_H_
......@@ -9,14 +9,12 @@ namespace google {
namespace protobuf {
class Message;
}
}
} // namespace google
namespace oneflow {
namespace cfg {
class Message {
public:
Message() = default;
......@@ -68,18 +66,17 @@ class Message {
}
virtual int FieldNumber4FieldName(const std::string& field_name) const = 0;
virtual bool FieldDefined4FieldNumber(int field_number) const = 0;
virtual bool FieldDefined4FieldNumber(int field_number) const = 0;
virtual const std::set<std::type_index>& ValidTypeIndices4FieldNumber(int field_number) const = 0;
virtual const void* FieldPtr4FieldNumber(int field_number) const = 0;
virtual void* MutableFieldPtr4FieldNumber(int field_number) { return nullptr; }
using PbMessage = ::google::protobuf::Message;
virtual void ToProto(PbMessage*) const = 0;
virtual void InitFromProto(const PbMessage&) {};
virtual void InitFromProto(const PbMessage&){};
};
}
}
} // namespace cfg
} // namespace oneflow
#endif // ONEFLOW_CFG_CFG_MESSAGE_H_
......@@ -40,26 +40,26 @@ class Pybind11ModuleRegistry {
const std::function<void(pybind11::module&)>& BuildModule);
};
} // namespace cfg
} // namespace cfg
} // namespace oneflow
} // namespace oneflow
#define ONEFLOW_CFG_PYBIND11_MODULE(module_path, m) \
static void OneflowCfgPythonModule##__LINE__(pybind11::module& m, \
::oneflow::cfg::Pybind11Context* ctx); \
namespace { \
void OneflowCfgPythonModule(pybind11::module& m) { \
::oneflow::cfg::Pybind11Context ctx; \
OneflowCfgPythonModule##__LINE__(m, &ctx); \
} \
struct CfgRegistryInit { \
CfgRegistryInit() { \
::oneflow::cfg::Pybind11ModuleRegistry().Register(module_path, &OneflowCfgPythonModule); \
} \
}; \
CfgRegistryInit cfg_registry_init; \
} \
static void OneflowCfgPythonModule##__LINE__(pybind11::module& m, \
::oneflow::cfg::Pybind11Context* ctx)
#define ONEFLOW_CFG_PYBIND11_MODULE(module_path, m) \
static void OneflowCfgPythonModule##__LINE__(pybind11::module& m, ::oneflow::cfg::Pybind11Context* ctx); \
namespace { \
void OneflowCfgPythonModule(pybind11::module& m) { \
::oneflow::cfg::Pybind11Context ctx; \
OneflowCfgPythonModule##__LINE__(m, &ctx); \
} \
struct CfgRegistryInit { \
CfgRegistryInit() { \
::oneflow::cfg::Pybind11ModuleRegistry() \
.Register(module_path, &OneflowCfgPythonModule); \
} \
}; \
CfgRegistryInit cfg_registry_init; \
} \
static void OneflowCfgPythonModule##__LINE__(pybind11::module& m, ::oneflow::cfg::Pybind11Context* ctx)
#endif // CFG_PYBIND_REGISTRY_H_
#endif // CFG_PYBIND_REGISTRY_H_
......@@ -11,8 +11,7 @@ namespace cfg {
template<typename T>
class _ConstRepeatedField_ {
public:
static_assert(std::is_nothrow_move_constructible<T>::value, "");
public:
using value_type = typename std::vector<T>::value_type;
using size_type = typename std::vector<T>::size_type;
using difference_type = typename std::vector<T>::difference_type;
......@@ -21,10 +20,11 @@ class _ConstRepeatedField_ {
using const_iterator = typename std::vector<T>::const_iterator;
using const_reverse_iterator = typename std::vector<T>::const_reverse_iterator;
_ConstRepeatedField_(): data_(std::make_shared<std::vector<T>>()) {}
_ConstRepeatedField_(const std::shared_ptr<std::vector<T>>& data): data_(data) {}
_ConstRepeatedField_() : data_(std::make_shared<std::vector<T>>()) {}
_ConstRepeatedField_(const std::shared_ptr<std::vector<T>>& data) : data_(data) {}
template<typename InputIt>
_ConstRepeatedField_(InputIt begin, InputIt end): data_(std::make_shared<std::vector<T>>(begin, end)) {}
_ConstRepeatedField_(InputIt begin, InputIt end)
: data_(std::make_shared<std::vector<T>>(begin, end)) {}
virtual ~_ConstRepeatedField_() = default;
const_iterator begin() const noexcept { return data_->begin(); }
......@@ -51,18 +51,20 @@ class _ConstRepeatedField_ {
return std::make_shared<_ConstRepeatedField_>(__SharedPtr__());
}
bool operator==(const _ConstRepeatedField_& other) const {return *__SharedPtr__() == *other.__SharedPtr__();}
bool operator<(const _ConstRepeatedField_& other) const {return *__SharedPtr__() < *other.__SharedPtr__();}
bool operator==(const _ConstRepeatedField_& other) const {
return *__SharedPtr__() == *other.__SharedPtr__();
}
bool operator<(const _ConstRepeatedField_& other) const {
return *__SharedPtr__() < *other.__SharedPtr__();
}
protected:
protected:
std::shared_ptr<std::vector<T>> data_;
};
template<typename T>
class _RepeatedField_: public _ConstRepeatedField_<T>{
class _RepeatedField_ : public _ConstRepeatedField_<T> {
public:
static_assert(std::is_nothrow_move_constructible<T>::value, "");
using reference = typename std::vector<T>::reference;
using pointer = typename std::vector<T>::pointer;
using iterator = typename std::vector<T>::iterator;
......@@ -81,18 +83,14 @@ class _RepeatedField_: public _ConstRepeatedField_<T>{
using _ConstRepeatedField_<T>::data_;
using _ConstRepeatedField_<T>::__SharedPtr__;
_RepeatedField_(): _ConstRepeatedField_<T>::_ConstRepeatedField_() {}
_RepeatedField_(const std::shared_ptr<std::vector<T>>& data): _ConstRepeatedField_<T>(data) {}
_RepeatedField_(const _RepeatedField_& other) {
CopyFrom(other);
}
_RepeatedField_(const _ConstRepeatedField_<T>& other) {
CopyFrom(other);
}
_RepeatedField_() : _ConstRepeatedField_<T>::_ConstRepeatedField_() {}
_RepeatedField_(const std::shared_ptr<std::vector<T>>& data) : _ConstRepeatedField_<T>(data) {}
_RepeatedField_(const _RepeatedField_& other) { CopyFrom(other); }
_RepeatedField_(const _ConstRepeatedField_<T>& other) { CopyFrom(other); }
_RepeatedField_(_RepeatedField_&&) = default;
template<typename InputIt>
_RepeatedField_(InputIt begin, InputIt end): _ConstRepeatedField_<T>(begin, end) {}
_RepeatedField_(InputIt begin, InputIt end) : _ConstRepeatedField_<T>(begin, end) {}
~_RepeatedField_() = default;
iterator begin() noexcept { return data_->begin(); }
......@@ -114,9 +112,7 @@ class _RepeatedField_: public _ConstRepeatedField_<T>{
return Mutable(index)->__SharedMutable__();
}
std::shared_ptr<T> __SharedAdd__() {
return Add()->__SharedMutable__();
}
std::shared_ptr<T> __SharedAdd__() { return Add()->__SharedMutable__(); }
pointer Mutable(size_type pos) { return &data_->at(pos); }
......@@ -137,22 +133,16 @@ class _RepeatedField_: public _ConstRepeatedField_<T>{
}
}
void CopyFrom(const _ConstRepeatedField_<T>& other) {
CopyFrom(other);
}
void CopyFrom(const _ConstRepeatedField_<T>& other) { CopyFrom(other); }
_RepeatedField_& operator=(const _RepeatedField_& other) {
CopyFrom(other);
return *this;
}
void Set(size_type pos, const T& elem) {
data_->at(pos) = elem;
}
void Set(size_type pos, const T& elem) { data_->at(pos) = elem; }
void Add(const T& elem) {
data_->push_back(std::move(elem));
}
void Add(const T& elem) { data_->push_back(std::move(elem)); }
pointer Add() {
data_->push_back(T());
......@@ -160,7 +150,7 @@ class _RepeatedField_: public _ConstRepeatedField_<T>{
}
};
}
}
} // namespace cfg
} // namespace oneflow
#endif // ONEFLOW_CFG_REPEATED_FIELD_H_
......@@ -19,14 +19,11 @@ class _SharedPairIterator_ {
using pointer = std::unique_ptr<value_type>;
using reference = value_type;
_SharedPairIterator_(DataIter data_iter)
: data_iter_(data_iter) {}
_SharedPairIterator_(DataIter data_iter) : data_iter_(data_iter) {}
// const methods
bool operator==(const _SharedPairIterator_& rhs) const {
return data_iter_ == rhs.data_iter_;
}
bool operator==(const _SharedPairIterator_& rhs) const { return data_iter_ == rhs.data_iter_; }
bool operator!=(const _SharedPairIterator_& rhs) const { return !(*this == rhs); }
......
......@@ -15,7 +15,7 @@ SubModuleMap* GetSubModuleMap() {
return &sub_module_map;
}
}
} // namespace
std::set<std::type_index>* Pybind11Context::GetRegisteredTypeIndices() {
static std::set<std::type_index> registered_type_indices;
......@@ -23,7 +23,7 @@ std::set<std::type_index>* Pybind11Context::GetRegisteredTypeIndices() {
}
void Pybind11ModuleRegistry::Register(std::string module_path,
std::function<void(pybind11::module&)> BuildModule) {
std::function<void(pybind11::module&)> BuildModule) {
(*GetSubModuleMap())[module_path].emplace_back(BuildModule);
}
......@@ -51,6 +51,6 @@ void Pybind11ModuleRegistry::BuildSubModule(
}
}
} // namespace cfg
} // namespace cfg
} // namespace oneflow
} // namespace oneflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册