Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
cf253718
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
cf253718
编写于
12月 02, 2020
作者:
H
hsj0429
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
format files and fix bug
Former-commit-id: 18154dcd186f64d3d51f80203815a0a6d773d402
上级
efc5e233
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
179 addition
and
82 deletion
+179
-82
oneflow/core/common/cfg_repeated_field_test.cpp
oneflow/core/common/cfg_repeated_field_test.cpp
+63
-0
oneflow/python/test/test_repeated_field.py
oneflow/python/test/test_repeated_field.py
+50
-0
tools/cfg/include/oneflow/cfg/map_field.h
tools/cfg/include/oneflow/cfg/map_field.h
+10
-10
tools/cfg/include/oneflow/cfg/message.h
tools/cfg/include/oneflow/cfg/message.h
+6
-9
tools/cfg/include/oneflow/cfg/pybind_module_registry.h
tools/cfg/include/oneflow/cfg/pybind_module_registry.h
+20
-20
tools/cfg/include/oneflow/cfg/repeated_field.h
tools/cfg/include/oneflow/cfg/repeated_field.h
+24
-34
tools/cfg/include/oneflow/cfg/shared_pair_iterator.h
tools/cfg/include/oneflow/cfg/shared_pair_iterator.h
+2
-5
tools/cfg/pybind_module_registry.cpp
tools/cfg/pybind_module_registry.cpp
+4
-4
未找到文件。
oneflow/core/common/cfg_repeated_field_test.cpp
0 → 100644
浏览文件 @
cf253718
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/cfg_reflection_test.cfg.h"
#include "oneflow/core/common/cfg_reflection_test.pb.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/cfg.h"
namespace
oneflow
{
namespace
test
{
TEST
(
CfgReflection
,
FieldDefined_repeated_field_not_set
)
{
{
ReflectionTestFoo
foo
;
cfg
::
ReflectionTestFoo
cfg_foo
(
foo
);
const
oneflow
::
cfg
::
_RepeatedField_
<
int
>&
repeated_int32
=
oneflow
::
cfg
::
_RepeatedField_
<
int
>
();
static_assert
(
std
::
is_same
<
decltype
(
repeated_int32
),
decltype
(
cfg_foo
.
repeated_int32
())
>::
value
,
"Not a oneflow::cfg::_RepeatedField_<int> type!"
);
}
{
ReflectionTestBar
bar
;
cfg
::
ReflectionTestBar
cfg_bar
(
bar
);
const
oneflow
::
cfg
::
_RepeatedField_
<
cfg
::
ReflectionTestFoo
>&
repeated_foo
=
oneflow
::
cfg
::
_RepeatedField_
<
cfg
::
ReflectionTestFoo
>
();
static_assert
(
std
::
is_same
<
decltype
(
repeated_foo
),
decltype
(
cfg_bar
.
repeated_foo
())
>::
value
,
"Not a oneflow::cfg::_RepeatedField_<cfg::ReflectionTestFoo> type!"
);
}
}
TEST
(
CfgReflection
,
FieldDefined_repeated_field_set
)
{
{
ReflectionTestFoo
foo
;
foo
.
add_repeated_int32
(
0
);
cfg
::
ReflectionTestFoo
cfg_foo
(
foo
);
const
oneflow
::
cfg
::
_RepeatedField_
<
int
>&
repeated_int32
=
oneflow
::
cfg
::
_RepeatedField_
<
int
>
();
static_assert
(
std
::
is_same
<
decltype
(
repeated_int32
),
decltype
(
cfg_foo
.
repeated_int32
())
>::
value
,
"Not a oneflow::cfg::_RepeatedField_<int> type!"
);
}
{
ReflectionTestBar
bar
;
bar
.
add_repeated_foo
();
cfg
::
ReflectionTestBar
cfg_bar
(
bar
);
const
oneflow
::
cfg
::
_RepeatedField_
<
cfg
::
ReflectionTestFoo
>&
repeated_foo
=
oneflow
::
cfg
::
_RepeatedField_
<
cfg
::
ReflectionTestFoo
>
();
static_assert
(
std
::
is_same
<
decltype
(
repeated_foo
),
decltype
(
cfg_bar
.
repeated_foo
())
>::
value
,
"Not a oneflow::cfg::_RepeatedField_<cfg::ReflectionTestFoo> type!"
);
}
}
}
// namespace test
}
// namespace oneflow
oneflow/python/test/test_repeated_field.py
0 → 100644
浏览文件 @
cf253718
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import
oneflow
as
flow
import
oneflow_api.oneflow.core.common.cfg_reflection_test
as
cfg
import
unittest
@
flow
.
unittest
.
skip_unless_1n1d
()
class
TestRepeatedField
(
flow
.
unittest
.
TestCase
):
def
test_repeated_field
(
test_case
):
foo
=
cfg
.
ReflectionTestFoo
()
bar
=
cfg
.
ReflectionTestBar
()
foo
.
add_repeated_int32
(
11
)
foo
.
add_repeated_int32
(
22
)
foo
.
add_repeated_string
(
"oneflow"
)
foo
.
add_repeated_string
(
"pytorch"
)
bar
.
mutable_repeated_foo
().
Add
().
CopyFrom
(
foo
)
test_case
.
assertEqual
(
str
(
type
(
foo
.
repeated_int32
())),
"<class 'oneflow_api.oneflow.core.common.cfg_reflection_test._ConstRepeatedField_<int32_t>'>"
,
)
test_case
.
assertEqual
(
str
(
type
(
foo
.
repeated_string
())),
"<class 'oneflow_api.oneflow.core.common.cfg_reflection_test._ConstRepeatedField_<::std::string>'>"
,
)
test_case
.
assertEqual
(
str
(
type
(
bar
.
repeated_foo
())),
"<class 'oneflow_api.oneflow.core.common.cfg_reflection_test._ConstRepeatedField_<::oneflow::cfg::ReflectionTestFoo>'>"
,
)
if
__name__
==
"__main__"
:
unittest
.
main
()
tools/cfg/include/oneflow/cfg/map_field.h
浏览文件 @
cf253718
...
...
@@ -26,14 +26,14 @@ class _MapField_ {
using
reverse_iterator
=
typename
std
::
map
<
Key
,
T
>::
reverse_iterator
;
using
const_reverse_iterator
=
typename
std
::
map
<
Key
,
T
>::
const_reverse_iterator
;
_MapField_
()
:
data_
(
std
::
make_shared
<
std
::
map
<
Key
,
T
>>
())
{}
_MapField_
(
const
std
::
shared_ptr
<
std
::
map
<
Key
,
T
>>&
data
)
:
data_
(
data
)
{}
_MapField_
(
const
_MapField_
&
other
)
:
data_
(
std
::
make_shared
<
std
::
map
<
Key
,
T
>>
())
{
_MapField_
()
:
data_
(
std
::
make_shared
<
std
::
map
<
Key
,
T
>>
())
{}
_MapField_
(
const
std
::
shared_ptr
<
std
::
map
<
Key
,
T
>>&
data
)
:
data_
(
data
)
{}
_MapField_
(
const
_MapField_
&
other
)
:
data_
(
std
::
make_shared
<
std
::
map
<
Key
,
T
>>
())
{
CopyFrom
(
other
);
}
_MapField_
(
_MapField_
&&
)
=
default
;
template
<
typename
InputIt
>
_MapField_
(
InputIt
begin
,
InputIt
end
)
:
data_
(
std
::
make_shared
<
std
::
map
<
Key
,
T
>>
(
begin
,
end
))
{}
_MapField_
(
InputIt
begin
,
InputIt
end
)
:
data_
(
std
::
make_shared
<
std
::
map
<
Key
,
T
>>
(
begin
,
end
))
{}
~
_MapField_
()
=
default
;
iterator
begin
()
noexcept
{
return
data_
->
begin
();
}
...
...
@@ -69,12 +69,12 @@ class _MapField_ {
std
::
pair
<
iterator
,
bool
>
insert
(
const
value_type
&
value
)
{
return
data_
->
insert
(
value
);
}
template
<
class
InputIt
>
void
insert
(
InputIt
first
,
InputIt
last
)
{
return
data_
->
insert
(
first
,
last
);
}
void
insert
(
InputIt
first
,
InputIt
last
)
{
return
data_
->
insert
(
first
,
last
);
}
void
Clear
()
{
data_
->
clear
();
}
void
CopyFrom
(
const
_MapField_
&
other
)
{
*
data_
=
*
other
.
data_
;
}
void
CopyFrom
(
const
_MapField_
&
other
)
{
*
data_
=
*
other
.
data_
;
}
_MapField_
&
operator
=
(
const
_MapField_
&
other
)
{
CopyFrom
(
other
);
...
...
@@ -89,7 +89,7 @@ class _MapField_ {
std
::
shared_ptr
<
std
::
map
<
Key
,
T
>>
data_
;
};
}
}
}
// namespace cfg
}
// namespace oneflow
#endif // ONEFLOW_CFG_MAP_FIELD_H_
tools/cfg/include/oneflow/cfg/message.h
浏览文件 @
cf253718
...
...
@@ -9,14 +9,12 @@ namespace google {
namespace
protobuf
{
class
Message
;
}
}
}
// namespace google
namespace
oneflow
{
namespace
cfg
{
class
Message
{
public:
Message
()
=
default
;
...
...
@@ -68,18 +66,17 @@ class Message {
}
virtual
int
FieldNumber4FieldName
(
const
std
::
string
&
field_name
)
const
=
0
;
virtual
bool
FieldDefined4FieldNumber
(
int
field_number
)
const
=
0
;
virtual
bool
FieldDefined4FieldNumber
(
int
field_number
)
const
=
0
;
virtual
const
std
::
set
<
std
::
type_index
>&
ValidTypeIndices4FieldNumber
(
int
field_number
)
const
=
0
;
virtual
const
void
*
FieldPtr4FieldNumber
(
int
field_number
)
const
=
0
;
virtual
void
*
MutableFieldPtr4FieldNumber
(
int
field_number
)
{
return
nullptr
;
}
using
PbMessage
=
::
google
::
protobuf
::
Message
;
virtual
void
ToProto
(
PbMessage
*
)
const
=
0
;
virtual
void
InitFromProto
(
const
PbMessage
&
)
{};
virtual
void
InitFromProto
(
const
PbMessage
&
){};
};
}
}
}
// namespace cfg
}
// namespace oneflow
#endif // ONEFLOW_CFG_CFG_MESSAGE_H_
tools/cfg/include/oneflow/cfg/pybind_module_registry.h
浏览文件 @
cf253718
...
...
@@ -40,26 +40,26 @@ class Pybind11ModuleRegistry {
const
std
::
function
<
void
(
pybind11
::
module
&
)
>&
BuildModule
);
};
}
// namespace cfg
}
// namespace cfg
}
// namespace oneflow
}
// namespace oneflow
#define ONEFLOW_CFG_PYBIND11_MODULE(module_path, m) \
static void OneflowCfgPythonModule##__LINE__(pybind11::module& m, \
::oneflow::cfg::Pybind11Context* ctx); \
namespace { \
void OneflowCfgPythonModule(pybind11::module& m) { \
::oneflow::cfg::Pybind11Context ctx; \
OneflowCfgPythonModule##__LINE__(m, &ctx); \
} \
struct CfgRegistryInit { \
CfgRegistryInit() { \
::oneflow::cfg::Pybind11ModuleRegistry().Register(module_path, &OneflowCfgPythonModule); \
} \
}; \
CfgRegistryInit cfg_registry_init; \
} \
static void OneflowCfgPythonModule##__LINE__(pybind11::module& m, \
::oneflow::cfg::Pybind11Context* ctx)
#define ONEFLOW_CFG_PYBIND11_MODULE(module_path, m) \
static void OneflowCfgPythonModule##__LINE__(pybind11::module& m, ::oneflow::cfg::Pybind11Context* ctx); \
namespace { \
void OneflowCfgPythonModule(pybind11::module& m) { \
::oneflow::cfg::Pybind11Context ctx; \
OneflowCfgPythonModule##__LINE__(m, &ctx); \
} \
struct CfgRegistryInit { \
CfgRegistryInit() { \
::oneflow::cfg::Pybind11ModuleRegistry() \
.Register(module_path, &OneflowCfgPythonModule); \
} \
}; \
CfgRegistryInit cfg_registry_init; \
} \
static void OneflowCfgPythonModule##__LINE__(pybind11::module& m, ::oneflow::cfg::Pybind11Context* ctx)
#endif // CFG_PYBIND_REGISTRY_H_
#endif // CFG_PYBIND_REGISTRY_H_
tools/cfg/include/oneflow/cfg/repeated_field.h
浏览文件 @
cf253718
...
...
@@ -11,8 +11,7 @@ namespace cfg {
template
<
typename
T
>
class
_ConstRepeatedField_
{
public:
static_assert
(
std
::
is_nothrow_move_constructible
<
T
>::
value
,
""
);
public:
using
value_type
=
typename
std
::
vector
<
T
>::
value_type
;
using
size_type
=
typename
std
::
vector
<
T
>::
size_type
;
using
difference_type
=
typename
std
::
vector
<
T
>::
difference_type
;
...
...
@@ -21,10 +20,11 @@ class _ConstRepeatedField_ {
using
const_iterator
=
typename
std
::
vector
<
T
>::
const_iterator
;
using
const_reverse_iterator
=
typename
std
::
vector
<
T
>::
const_reverse_iterator
;
_ConstRepeatedField_
()
:
data_
(
std
::
make_shared
<
std
::
vector
<
T
>>
())
{}
_ConstRepeatedField_
(
const
std
::
shared_ptr
<
std
::
vector
<
T
>>&
data
)
:
data_
(
data
)
{}
_ConstRepeatedField_
()
:
data_
(
std
::
make_shared
<
std
::
vector
<
T
>>
())
{}
_ConstRepeatedField_
(
const
std
::
shared_ptr
<
std
::
vector
<
T
>>&
data
)
:
data_
(
data
)
{}
template
<
typename
InputIt
>
_ConstRepeatedField_
(
InputIt
begin
,
InputIt
end
)
:
data_
(
std
::
make_shared
<
std
::
vector
<
T
>>
(
begin
,
end
))
{}
_ConstRepeatedField_
(
InputIt
begin
,
InputIt
end
)
:
data_
(
std
::
make_shared
<
std
::
vector
<
T
>>
(
begin
,
end
))
{}
virtual
~
_ConstRepeatedField_
()
=
default
;
const_iterator
begin
()
const
noexcept
{
return
data_
->
begin
();
}
...
...
@@ -51,18 +51,20 @@ class _ConstRepeatedField_ {
return
std
::
make_shared
<
_ConstRepeatedField_
>
(
__SharedPtr__
());
}
bool
operator
==
(
const
_ConstRepeatedField_
&
other
)
const
{
return
*
__SharedPtr__
()
==
*
other
.
__SharedPtr__
();}
bool
operator
<
(
const
_ConstRepeatedField_
&
other
)
const
{
return
*
__SharedPtr__
()
<
*
other
.
__SharedPtr__
();}
bool
operator
==
(
const
_ConstRepeatedField_
&
other
)
const
{
return
*
__SharedPtr__
()
==
*
other
.
__SharedPtr__
();
}
bool
operator
<
(
const
_ConstRepeatedField_
&
other
)
const
{
return
*
__SharedPtr__
()
<
*
other
.
__SharedPtr__
();
}
protected:
protected:
std
::
shared_ptr
<
std
::
vector
<
T
>>
data_
;
};
template
<
typename
T
>
class
_RepeatedField_
:
public
_ConstRepeatedField_
<
T
>
{
class
_RepeatedField_
:
public
_ConstRepeatedField_
<
T
>
{
public:
static_assert
(
std
::
is_nothrow_move_constructible
<
T
>::
value
,
""
);
using
reference
=
typename
std
::
vector
<
T
>::
reference
;
using
pointer
=
typename
std
::
vector
<
T
>::
pointer
;
using
iterator
=
typename
std
::
vector
<
T
>::
iterator
;
...
...
@@ -81,18 +83,14 @@ class _RepeatedField_: public _ConstRepeatedField_<T>{
using
_ConstRepeatedField_
<
T
>::
data_
;
using
_ConstRepeatedField_
<
T
>::
__SharedPtr__
;
_RepeatedField_
()
:
_ConstRepeatedField_
<
T
>::
_ConstRepeatedField_
()
{}
_RepeatedField_
(
const
std
::
shared_ptr
<
std
::
vector
<
T
>>&
data
)
:
_ConstRepeatedField_
<
T
>
(
data
)
{}
_RepeatedField_
(
const
_RepeatedField_
&
other
)
{
CopyFrom
(
other
);
}
_RepeatedField_
(
const
_ConstRepeatedField_
<
T
>&
other
)
{
CopyFrom
(
other
);
}
_RepeatedField_
()
:
_ConstRepeatedField_
<
T
>::
_ConstRepeatedField_
()
{}
_RepeatedField_
(
const
std
::
shared_ptr
<
std
::
vector
<
T
>>&
data
)
:
_ConstRepeatedField_
<
T
>
(
data
)
{}
_RepeatedField_
(
const
_RepeatedField_
&
other
)
{
CopyFrom
(
other
);
}
_RepeatedField_
(
const
_ConstRepeatedField_
<
T
>&
other
)
{
CopyFrom
(
other
);
}
_RepeatedField_
(
_RepeatedField_
&&
)
=
default
;
template
<
typename
InputIt
>
_RepeatedField_
(
InputIt
begin
,
InputIt
end
)
:
_ConstRepeatedField_
<
T
>
(
begin
,
end
)
{}
_RepeatedField_
(
InputIt
begin
,
InputIt
end
)
:
_ConstRepeatedField_
<
T
>
(
begin
,
end
)
{}
~
_RepeatedField_
()
=
default
;
iterator
begin
()
noexcept
{
return
data_
->
begin
();
}
...
...
@@ -114,9 +112,7 @@ class _RepeatedField_: public _ConstRepeatedField_<T>{
return
Mutable
(
index
)
->
__SharedMutable__
();
}
std
::
shared_ptr
<
T
>
__SharedAdd__
()
{
return
Add
()
->
__SharedMutable__
();
}
std
::
shared_ptr
<
T
>
__SharedAdd__
()
{
return
Add
()
->
__SharedMutable__
();
}
pointer
Mutable
(
size_type
pos
)
{
return
&
data_
->
at
(
pos
);
}
...
...
@@ -137,22 +133,16 @@ class _RepeatedField_: public _ConstRepeatedField_<T>{
}
}
void
CopyFrom
(
const
_ConstRepeatedField_
<
T
>&
other
)
{
CopyFrom
(
other
);
}
void
CopyFrom
(
const
_ConstRepeatedField_
<
T
>&
other
)
{
CopyFrom
(
other
);
}
_RepeatedField_
&
operator
=
(
const
_RepeatedField_
&
other
)
{
CopyFrom
(
other
);
return
*
this
;
}
void
Set
(
size_type
pos
,
const
T
&
elem
)
{
data_
->
at
(
pos
)
=
elem
;
}
void
Set
(
size_type
pos
,
const
T
&
elem
)
{
data_
->
at
(
pos
)
=
elem
;
}
void
Add
(
const
T
&
elem
)
{
data_
->
push_back
(
std
::
move
(
elem
));
}
void
Add
(
const
T
&
elem
)
{
data_
->
push_back
(
std
::
move
(
elem
));
}
pointer
Add
()
{
data_
->
push_back
(
T
());
...
...
@@ -160,7 +150,7 @@ class _RepeatedField_: public _ConstRepeatedField_<T>{
}
};
}
}
}
// namespace cfg
}
// namespace oneflow
#endif // ONEFLOW_CFG_REPEATED_FIELD_H_
tools/cfg/include/oneflow/cfg/shared_pair_iterator.h
浏览文件 @
cf253718
...
...
@@ -19,14 +19,11 @@ class _SharedPairIterator_ {
using
pointer
=
std
::
unique_ptr
<
value_type
>
;
using
reference
=
value_type
;
_SharedPairIterator_
(
DataIter
data_iter
)
:
data_iter_
(
data_iter
)
{}
_SharedPairIterator_
(
DataIter
data_iter
)
:
data_iter_
(
data_iter
)
{}
// const methods
bool
operator
==
(
const
_SharedPairIterator_
&
rhs
)
const
{
return
data_iter_
==
rhs
.
data_iter_
;
}
bool
operator
==
(
const
_SharedPairIterator_
&
rhs
)
const
{
return
data_iter_
==
rhs
.
data_iter_
;
}
bool
operator
!=
(
const
_SharedPairIterator_
&
rhs
)
const
{
return
!
(
*
this
==
rhs
);
}
...
...
tools/cfg/pybind_module_registry.cpp
浏览文件 @
cf253718
...
...
@@ -15,7 +15,7 @@ SubModuleMap* GetSubModuleMap() {
return
&
sub_module_map
;
}
}
}
// namespace
std
::
set
<
std
::
type_index
>*
Pybind11Context
::
GetRegisteredTypeIndices
()
{
static
std
::
set
<
std
::
type_index
>
registered_type_indices
;
...
...
@@ -23,7 +23,7 @@ std::set<std::type_index>* Pybind11Context::GetRegisteredTypeIndices() {
}
void
Pybind11ModuleRegistry
::
Register
(
std
::
string
module_path
,
std
::
function
<
void
(
pybind11
::
module
&
)
>
BuildModule
)
{
std
::
function
<
void
(
pybind11
::
module
&
)
>
BuildModule
)
{
(
*
GetSubModuleMap
())[
module_path
].
emplace_back
(
BuildModule
);
}
...
...
@@ -51,6 +51,6 @@ void Pybind11ModuleRegistry::BuildSubModule(
}
}
}
// namespace cfg
}
// namespace cfg
}
// namespace oneflow
}
// namespace oneflow
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录