Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2be6ceda
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2be6ceda
编写于
1月 14, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative/utils): add serval utils
GitOrigin-RevId: f401663ae3641d8a6467cf4d10cba17a1d3f4553
上级
e7c2ed11
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
1291 addition
and
73 deletion
+1291
-73
imperative/src/impl/profiler/states.h
imperative/src/impl/profiler/states.h
+1
-1
imperative/src/include/megbrain/imperative/profiler.h
imperative/src/include/megbrain/imperative/profiler.h
+4
-68
imperative/src/include/megbrain/imperative/utils/allocator.h
imperative/src/include/megbrain/imperative/utils/allocator.h
+71
-0
imperative/src/include/megbrain/imperative/utils/any.h
imperative/src/include/megbrain/imperative/utils/any.h
+70
-0
imperative/src/include/megbrain/imperative/utils/box.h
imperative/src/include/megbrain/imperative/utils/box.h
+96
-0
imperative/src/include/megbrain/imperative/utils/helper.h
imperative/src/include/megbrain/imperative/utils/helper.h
+40
-0
imperative/src/include/megbrain/imperative/utils/intrusive_list.h
...ve/src/include/megbrain/imperative/utils/intrusive_list.h
+245
-0
imperative/src/include/megbrain/imperative/utils/local_ptr.h
imperative/src/include/megbrain/imperative/utils/local_ptr.h
+285
-0
imperative/src/include/megbrain/imperative/utils/map.h
imperative/src/include/megbrain/imperative/utils/map.h
+157
-0
imperative/src/include/megbrain/imperative/utils/mempool.h
imperative/src/include/megbrain/imperative/utils/mempool.h
+70
-0
imperative/src/include/megbrain/imperative/utils/span.h
imperative/src/include/megbrain/imperative/utils/span.h
+69
-0
imperative/src/include/megbrain/imperative/utils/to_string.h
imperative/src/include/megbrain/imperative/utils/to_string.h
+49
-0
imperative/src/include/megbrain/imperative/utils/value_shape.h
...ative/src/include/megbrain/imperative/utils/value_shape.h
+104
-0
imperative/src/include/megbrain/imperative/utils/visit.h
imperative/src/include/megbrain/imperative/utils/visit.h
+26
-0
imperative/src/test/profiler.cpp
imperative/src/test/profiler.cpp
+4
-4
未找到文件。
imperative/src/impl/profiler/states.h
浏览文件 @
2be6ceda
...
@@ -160,7 +160,7 @@ private:
...
@@ -160,7 +160,7 @@ private:
template
<
typename
TItem
>
template
<
typename
TItem
>
void
register_converter
()
{
void
register_converter
()
{
m_table
[
typeid
(
TItem
)]
=
[](
const
any_t
&
input
)
{
m_table
[
typeid
(
TItem
)]
=
[](
const
any_t
&
input
)
{
return
variant_t
(
*
input
.
as
<
TItem
>
());
return
variant_t
(
input
.
cast
<
TItem
>
());
};
};
}
}
...
...
imperative/src/include/megbrain/imperative/profiler.h
浏览文件 @
2be6ceda
...
@@ -11,7 +11,6 @@
...
@@ -11,7 +11,6 @@
#pragma once
#pragma once
#include <any>
#include <bitset>
#include <bitset>
#include <chrono>
#include <chrono>
#include <deque>
#include <deque>
...
@@ -28,6 +27,7 @@
...
@@ -28,6 +27,7 @@
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/physical_tensor.h"
#include "megbrain/imperative/physical_tensor.h"
#include "megbrain/imperative/utils/any.h"
namespace
mgb
{
namespace
mgb
{
namespace
imperative
{
namespace
imperative
{
...
@@ -51,48 +51,6 @@ public:
...
@@ -51,48 +51,6 @@ public:
static
std
::
shared_ptr
<
CompNode
::
Event
>
record_device
(
CompNode
device
);
static
std
::
shared_ptr
<
CompNode
::
Event
>
record_device
(
CompNode
device
);
};
};
class
AnyPtr
{
public:
struct
Deleter
{
void
*
object
;
void
(
*
method
)(
void
*
,
void
*
);
void
operator
()(
void
*
ptr
)
{
method
(
object
,
ptr
);
}
};
private:
using
holder_t
=
std
::
unique_ptr
<
void
,
Deleter
>
;
const
std
::
type_info
*
m_type
=
nullptr
;
holder_t
m_holder
=
nullptr
;
public:
AnyPtr
()
=
default
;
template
<
typename
T
,
typename
=
std
::
enable_if_t
<!
std
::
is_same_v
<
std
::
decay_t
<
T
>,
AnyPtr
>>>
explicit
AnyPtr
(
T
*
value
,
Deleter
deleter
)
{
m_type
=
&
typeid
(
T
);
m_holder
=
{
value
,
deleter
};
}
template
<
typename
T
>
T
*
as
()
{
mgb_assert
(
is_exactly
<
T
>
(),
"type mismatch"
);
return
reinterpret_cast
<
T
*>
(
m_holder
.
get
());
}
template
<
typename
T
>
const
T
*
as
()
const
{
mgb_assert
(
is_exactly
<
T
>
(),
"type mismatch"
);
return
reinterpret_cast
<
const
T
*>
(
m_holder
.
get
());
}
template
<
typename
T
>
bool
is_exactly
()
const
{
return
std
::
type_index
{
typeid
(
T
)}
==
std
::
type_index
{
*
m_type
};
}
const
std
::
type_info
&
type
()
const
{
return
*
m_type
;
}
bool
operator
==
(
std
::
nullptr_t
nptr
)
const
{
return
m_holder
==
nullptr
;
}
operator
bool
()
const
{
return
m_holder
!=
nullptr
;
}
};
class
Profiler
{
class
Profiler
{
public:
public:
struct
Record
{
struct
Record
{
...
@@ -128,7 +86,6 @@ private:
...
@@ -128,7 +86,6 @@ private:
std
::
thread
::
id
m_thread_id
;
std
::
thread
::
id
m_thread_id
;
std
::
vector
<
Record
>
m_records
;
std
::
vector
<
Record
>
m_records
;
std
::
atomic
<
Status
>
m_status
=
Running
;
std
::
atomic
<
Status
>
m_status
=
Running
;
std
::
unordered_map
<
std
::
type_index
,
AnyPtr
>
m_mem_pools
;
static
std
::
vector
<
entry_t
>
sm_records
;
static
std
::
vector
<
entry_t
>
sm_records
;
static
options_t
sm_profile_options
;
static
options_t
sm_profile_options
;
...
@@ -161,42 +118,21 @@ public:
...
@@ -161,42 +118,21 @@ public:
return
*
tm_profiler
;
return
*
tm_profiler
;
}
}
template
<
typename
T
>
static
MemPool
<
T
>&
get_mem_pool
()
{
thread_local
MemPool
<
T
>*
t_pool
=
nullptr
;
if
(
t_pool
==
nullptr
)
{
auto
&
pool
=
get_instance
().
m_mem_pools
[
typeid
(
MemPool
<
T
>
)];
if
(
pool
==
nullptr
)
{
pool
=
AnyPtr
(
new
MemPool
<
T
>
(),
{
nullptr
,
[](
void
*
,
void
*
ptr
)
{
delete
reinterpret_cast
<
MemPool
<
T
>*>
(
ptr
);
}});
}
t_pool
=
pool
.
as
<
MemPool
<
T
>>
();
}
return
*
t_pool
;
}
static
uint64_t
next_id
()
{
return
sm_last_id
++
;
}
static
uint64_t
next_id
()
{
return
sm_last_id
++
;
}
template
<
typename
T
,
typename
...
TArgs
>
template
<
typename
T
,
typename
...
TArgs
>
static
uint64_t
record
(
TArgs
&&
...
args
)
{
static
uint64_t
record
(
TArgs
&&
...
args
)
{
auto
&
profiler
=
get_instance
();
auto
&
profiler
=
get_instance
();
auto
&
mem_pool
=
get_mem_pool
<
T
>
();
//
auto& mem_pool = get_mem_pool<T>();
if
constexpr
(
sm_debug
)
{
if
constexpr
(
sm_debug
)
{
Status
expected
=
Running
;
Status
expected
=
Running
;
mgb_assert
(
profiler
.
m_status
.
compare_exchange_strong
(
expected
,
Recording
));
mgb_assert
(
profiler
.
m_status
.
compare_exchange_strong
(
expected
,
Recording
));
}
}
uint64_t
id
=
next_id
();
uint64_t
id
=
next_id
();
profiler
::
Time
time
=
sm_timer
.
record_host
();
profiler
::
Time
time
=
sm_timer
.
record_host
();
auto
deleter
=
[](
void
*
obj
,
void
*
ptr
)
{
reinterpret_cast
<
MemPool
<
T
>*>
(
obj
)
->
free
(
reinterpret_cast
<
T
*>
(
ptr
));
};
profiler
.
m_records
.
emplace_back
(
profiler
.
m_records
.
emplace_back
(
id
,
profiler
.
m_thread_id
,
time
,
id
,
profiler
.
m_thread_id
,
time
,
AnyPtr
{
mem_pool
.
alloc
(
T
{
std
::
forward
<
TArgs
>
(
args
)...}),
AnyPtr
::
make
<
T
>
(
T
{
std
::
forward
<
TArgs
&&>
(
args
)...}));
{
&
mem_pool
,
deleter
}});
if
constexpr
(
sm_debug
)
{
if
constexpr
(
sm_debug
)
{
Status
expected
=
Recording
;
Status
expected
=
Recording
;
mgb_assert
(
profiler
.
m_status
.
compare_exchange_strong
(
expected
,
Running
));
mgb_assert
(
profiler
.
m_status
.
compare_exchange_strong
(
expected
,
Running
));
...
@@ -241,7 +177,7 @@ public:
...
@@ -241,7 +177,7 @@ public:
bundle
.
options
=
get_options
();
bundle
.
options
=
get_options
();
bundle
.
start_at
=
sm_start_at
;
bundle
.
start_at
=
sm_start_at
;
bundle
.
thread_dict
=
get_thread_dict
();
bundle
.
thread_dict
=
get_thread_dict
();
return
std
::
move
(
bundle
)
;
return
bundle
;
}
}
static
option_t
get_option
(
std
::
string
key
,
option_t
default_val
)
{
static
option_t
get_option
(
std
::
string
key
,
option_t
default_val
)
{
...
...
imperative/src/include/megbrain/imperative/utils/allocator.h
0 → 100644
浏览文件 @
2be6ceda
/**
* \file imperative/src/include/megbrain/imperative/utils/allocator.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <typeindex>
#include "megbrain/utils/mempool.h"
#include "megbrain/utils/metahelper.h"
namespace
mgb
::
imperative
{
template
<
typename
T
>
class
Allocator
{
public:
using
pointer
=
T
*
;
using
const_pointer
=
const
T
*
;
using
void_pointer
=
void
*
;
using
const_void_pointer
=
const
void
*
;
using
value_type
=
T
;
using
size_type
=
std
::
size_t
;
using
diffenence_type
=
std
::
ptrdiff_t
;
using
pool_type
=
MemPoolStorage
;
private:
pool_type
*
m_pool
=
nullptr
;
public:
Allocator
(
pool_type
*
pool
)
:
m_pool
(
pool
)
{}
T
*
allocate
(
size_type
n
)
{
mgb_assert
(
n
==
1
);
return
m_pool
->
alloc
(
sizeof
(
T
));
}
void
deallocate
(
pointer
*
p
,
size_type
n
)
{
mgb_assert
(
n
==
1
);
m_pool
->
free
(
p
);
}
bool
operator
==
(
const
Allocator
&
rhs
)
const
{
return
m_pool
==
rhs
.
m_pool
;
}
bool
operator
!=
(
const
Allocator
&
rhs
)
const
{
return
m_pool
!=
rhs
.
m_pool
;
}
};
template
<
typename
T
>
class
ThreadLocalAllocatorAdapter
{
public:
using
value_type
=
T
;
using
size_type
=
std
::
size_t
;
using
pointer
=
T
*
;
public:
T
*
allocate
(
size_type
n
)
{
mgb_assert
(
false
);
}
void
deallocate
(
pointer
*
p
,
size_type
n
)
{
mgb_assert
(
false
);
}
bool
operator
==
(
const
ThreadLocalAllocatorAdapter
&
rhs
)
const
{
return
true
;
}
bool
operator
!=
(
const
ThreadLocalAllocatorAdapter
&
rhs
)
const
{
return
false
;
}
};
}
// namespace mgb::imperative
\ No newline at end of file
imperative/src/include/megbrain/imperative/utils/any.h
0 → 100644
浏览文件 @
2be6ceda
/**
* \file imperative/src/include/megbrain/imperative/utils/any.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <typeindex>
#include "megbrain/imperative/utils/local_ptr.h"
namespace
mgb
::
imperative
{
class
AnyMixinBase
{
private:
const
std
::
type_info
*
m_type
=
nullptr
;
public:
AnyMixinBase
()
=
default
;
const
std
::
type_info
&
type
()
const
{
return
*
m_type
;
}
friend
class
AnyPtr
;
};
template
<
typename
T
>
class
AnyMixin
:
public
AnyMixinBase
,
public
T
{
public:
AnyMixin
(
T
&&
val
)
:
T
(
std
::
move
(
val
))
{}
};
class
AnyPtr
{
public:
using
storage_t
=
LocalPtr
<
AnyMixinBase
>
;
private:
storage_t
m_storage
;
public:
const
std
::
type_info
&
type
()
const
{
return
m_storage
->
type
();
}
template
<
typename
T
>
const
T
&
cast
()
const
{
mgb_assert
(
is_exactly
<
T
>
(),
"type mismatch"
);
return
*
static_cast
<
const
AnyMixin
<
T
>*>
(
m_storage
.
get
());
}
template
<
typename
T
>
bool
is_exactly
()
const
{
return
std
::
type_index
{
typeid
(
T
)}
==
std
::
type_index
{
type
()};
}
bool
operator
==
(
std
::
nullptr_t
nptr
)
const
{
return
m_storage
==
nullptr
;
}
bool
operator
!=
(
std
::
nullptr_t
nptr
)
const
{
return
m_storage
!=
nullptr
;
}
operator
bool
()
const
{
return
m_storage
!=
nullptr
;
}
template
<
typename
T
,
typename
...
TArgs
>
static
AnyPtr
make
(
TArgs
&&
...
args
)
{
AnyPtr
ret
;
ret
.
m_storage
=
LocalPtr
<
AnyMixinBase
>::
make
<
AnyMixin
<
T
>>
(
std
::
forward
<
TArgs
&&>
(
args
)...);
ret
.
m_storage
->
m_type
=
&
typeid
(
T
);
return
ret
;
}
};
}
// namespace mgb::imperative
imperative/src/include/megbrain/imperative/utils/box.h
0 → 100644
浏览文件 @
2be6ceda
/**
* \file imperative/src/include/megbrain/imperative/utils/visit.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <chrono>
#include <future>
#include <vector>
#include "megbrain/utils/metahelper.h"
#include "megbrain/utils/small_vector.h"
namespace
mgb
::
imperative
{
class
BoxBase
:
public
NonCopyableObj
{
public:
virtual
void
reset
()
=
0
;
virtual
void
set_exception
(
std
::
exception_ptr
exc
)
=
0
;
virtual
bool
try_set_exception
(
std
::
exception_ptr
exc
)
=
0
;
};
/**
* \brief An reusable promise
*
* \tparam T type of value
*/
template
<
typename
T
>
class
Box
final
:
public
BoxBase
{
private:
std
::
promise
<
T
>
m_promise
;
std
::
shared_future
<
T
>
m_future
;
std
::
mutex
m_mutex
;
bool
m_value_set
;
bool
m_exception_set
;
public:
Box
()
{
reset
();
}
const
T
&
get_value
()
{
return
m_future
.
get
();
}
T
take_value
()
{
T
value
=
m_future
.
get
();
reset
();
return
value
;
}
void
set_value
(
T
value
)
{
MGB_LOCK_GUARD
(
m_mutex
);
m_promise
.
set_value
(
std
::
move
(
value
));
m_value_set
=
true
;
}
bool
try_set_value
(
T
value
)
{
MGB_LOCK_GUARD
(
m_mutex
);
if
(
m_exception_set
)
{
return
false
;
}
m_promise
.
set_value
(
std
::
move
(
value
));
m_value_set
=
true
;
return
true
;
}
void
set_exception
(
std
::
exception_ptr
exc
)
override
{
MGB_LOCK_GUARD
(
m_mutex
);
m_promise
.
set_exception
(
exc
);
m_exception_set
=
true
;
}
bool
try_set_exception
(
std
::
exception_ptr
exc
)
override
{
MGB_LOCK_GUARD
(
m_mutex
);
if
(
m_value_set
)
{
return
false
;
}
m_promise
.
set_exception
(
exc
);
m_exception_set
=
true
;
return
true
;
}
void
reset
()
override
{
MGB_LOCK_GUARD
(
m_mutex
);
m_promise
=
{};
m_future
=
m_promise
.
get_future
();
m_value_set
=
false
;
m_exception_set
=
false
;
}
/**
* \brief make an empty box
*
* \return std::shared_ptr<Box>
*/
static
std
::
shared_ptr
<
Box
>
make
()
{
return
std
::
make_shared
<
Box
>
();
}
};
}
// namespace mgb::imperative
imperative/src/include/megbrain/imperative/utils/helper.h
0 → 100644
浏览文件 @
2be6ceda
/**
* \file imperative/src/include/megbrain/imperative/utils/span.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <iomanip>
#include <memory>
#include <sstream>
namespace
mgb
{
namespace
imperative
{
template
<
typename
T
>
class
CleanupGuard
{
private:
T
m_callback
;
public:
explicit
CleanupGuard
(
T
cb
)
:
m_callback
{
std
::
move
(
cb
)}
{}
~
CleanupGuard
()
{
m_callback
();
}
};
inline
std
::
string
quoted
(
std
::
string
str
)
{
std
::
stringstream
ss
;
ss
<<
std
::
quoted
(
str
);
return
ss
.
str
();
}
}
// namespace imperative
}
// namespace mgb
\ No newline at end of file
imperative/src/include/megbrain/imperative/utils/intrusive_list.h
0 → 100644
浏览文件 @
2be6ceda
/**
* \file imperative/src/include/megbrain/imperative/utils/intrusive_list.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/utils/metahelper.h"
namespace
mgb
::
imperative
::
utils
::
intrusive_list
{
// copy policy
struct
after_t
{};
struct
before_t
{};
struct
disable_t
{};
template
<
typename
T
>
struct
Tail
;
// invariant: next->prev == this
template
<
typename
T
>
struct
Head
{
Tail
<
T
>*
next
;
Head
(
Tail
<
T
>*
node
=
nullptr
)
:
next
(
node
)
{}
Head
(
const
Head
<
T
>&
)
=
delete
;
Head
<
T
>&
operator
=
(
const
Head
<
T
>&
)
=
delete
;
Head
(
Head
<
T
>&&
rhs
)
:
next
(
rhs
.
next
)
{
rhs
.
next
=
nullptr
;
if
(
next
)
{
next
->
prev
=
this
;
}
}
Head
<
T
>&
operator
=
(
Head
<
T
>&&
rhs
)
{
mgb_assert
(
!
next
);
next
=
rhs
.
next
;
rhs
.
next
=
nullptr
;
if
(
next
)
{
next
->
prev
=
this
;
}
return
*
this
;
}
~
Head
()
{
if
(
next
)
{
next
->
prev
=
nullptr
;
}
}
};
// invariant: prev->next == this
template
<
typename
T
>
struct
Tail
{
Head
<
T
>*
prev
;
Tail
(
Head
<
T
>*
node
=
nullptr
)
:
prev
(
node
)
{}
Tail
(
const
Tail
<
T
>&
)
=
delete
;
Tail
<
T
>&
operator
=
(
const
Tail
<
T
>&
)
=
delete
;
Tail
(
Tail
<
T
>&&
rhs
)
:
prev
(
rhs
.
prev
)
{
rhs
.
prev
=
nullptr
;
if
(
prev
)
{
prev
->
next
=
this
;
}
}
Tail
<
T
>&
operator
=
(
Tail
<
T
>&&
rhs
)
{
mgb_assert
(
!
prev
);
prev
=
rhs
.
prev
;
rhs
.
prev
=
nullptr
;
if
(
prev
)
{
prev
->
next
=
this
;
}
return
*
this
;
}
~
Tail
()
{
if
(
prev
)
{
prev
->
next
=
nullptr
;
}
}
};
template
<
typename
T
,
typename
policy
>
struct
Node
;
template
<
typename
T
>
class
Iterator
{
T
*
ptr
;
void
inc
()
{
ptr
=
static_cast
<
T
*>
(
ptr
->
Head
<
T
>::
next
);
}
void
dec
()
{
ptr
=
static_cast
<
T
*>
(
ptr
->
Head
<
T
>::
prev
);
}
public:
Iterator
(
Head
<
T
>&
head
)
:
ptr
(
static_cast
<
T
*>
(
head
.
next
))
{}
Iterator
(
Tail
<
T
>&
tail
)
:
ptr
(
static_cast
<
T
*>
(
tail
.
prev
))
{}
template
<
typename
policy
>
Iterator
(
Node
<
T
,
policy
>&
node
)
:
ptr
(
static_cast
<
T
*>
(
&
node
))
{}
T
&
operator
*
()
{
return
*
static_cast
<
T
*>
(
ptr
);
}
T
*
operator
->
()
{
return
static_cast
<
T
*>
(
ptr
);
}
operator
bool
()
{
return
ptr
;
}
bool
operator
==
(
const
Iterator
<
T
>&
rhs
)
{
return
ptr
==
rhs
.
ptr
;
}
Iterator
&
operator
++
()
{
inc
();
return
*
this
;
}
Iterator
&
operator
--
()
{
dec
();
return
*
this
;
}
Iterator
operator
++
(
int
)
{
auto
ret
=
*
this
;
inc
();
return
ret
;
}
Iterator
operator
--
(
int
)
{
auto
ret
=
*
this
;
dec
();
return
ret
;
}
};
// Node in a doubly linked list. Unlike std::list, nodes are not owned by a container.
// Instead, nodes may join or leave a list freely.
// NOTE: Derived classes have to explicitly declare copy / assignment as default,
// otherwise the compiler generated version would use the const T& signature,
// which is deleted.
template
<
typename
T
=
void
,
typename
policy
=
disable_t
>
struct
Node
:
Tail
<
std
::
conditional_t
<
std
::
is_same_v
<
T
,
void
>
,
Node
<
T
,
policy
>
,
T
>>
,
Head
<
std
::
conditional_t
<
std
::
is_same_v
<
T
,
void
>
,
Node
<
T
,
policy
>
,
T
>>
{
private:
using
this_t
=
Node
<
T
,
policy
>
;
using
U
=
std
::
conditional_t
<
std
::
is_same_v
<
T
,
void
>
,
this_t
,
T
>
;
public:
using
head_t
=
Head
<
U
>
;
using
tail_t
=
Tail
<
U
>
;
using
head_t
::
next
;
using
tail_t
::
prev
;
Node
()
=
default
;
Node
(
const
this_t
&
)
=
delete
;
this_t
&
operator
=
(
const
this_t
&
)
=
delete
;
//! constructed node is inserted after the input node
Node
(
after_t
,
head_t
&
node
)
:
tail_t
(
&
node
),
head_t
(
node
.
next
)
{
node
.
next
=
this
;
if
(
next
)
{
next
->
prev
=
this
;
}
}
//! constructed node is inserted before the input node
Node
(
before_t
,
tail_t
&
node
)
:
head_t
(
&
node
),
tail_t
(
node
.
prev
)
{
node
.
prev
=
this
;
if
(
prev
)
{
prev
->
next
=
this
;
}
}
Node
(
this_t
&&
rhs
)
:
tail_t
(
rhs
.
prev
),
head_t
(
rhs
.
next
)
{
rhs
.
prev
=
nullptr
;
rhs
.
next
=
nullptr
;
if
(
prev
)
{
prev
->
next
=
this
;
}
if
(
next
)
{
next
->
prev
=
this
;
}
}
Node
&
operator
=
(
this_t
&&
rhs
)
{
unlink
();
prev
=
rhs
.
prev
;
next
=
rhs
.
next
;
rhs
.
prev
=
nullptr
;
rhs
.
next
=
nullptr
;
if
(
prev
)
{
prev
->
next
=
this
;
}
if
(
next
)
{
next
->
prev
=
this
;
}
return
*
this
;
}
template
<
typename
p
=
policy
,
typename
=
std
::
enable_if_t
<
std
::
is_same_v
<
p
,
before_t
>
||
std
::
is_same_v
<
p
,
after_t
>
,
void
>>
Node
(
this_t
&
rhs
)
:
Node
(
policy
{},
rhs
)
{}
template
<
typename
p
=
policy
,
typename
=
std
::
enable_if_t
<
std
::
is_same_v
<
p
,
before_t
>
||
std
::
is_same_v
<
p
,
after_t
>
,
void
>>
this_t
&
operator
=
(
this_t
&
rhs
)
{
insert
(
policy
{},
rhs
);
return
*
this
;
}
void
unlink
()
{
if
(
prev
)
{
prev
->
next
=
next
;
}
if
(
next
)
{
next
->
prev
=
prev
;
}
prev
=
nullptr
;
next
=
nullptr
;
}
//! this node is unlinked from its list and inserted after the input node
void
insert
(
after_t
,
head_t
&
node
)
{
unlink
();
prev
=
&
node
;
next
=
node
.
next
;
node
.
next
=
this
;
if
(
next
)
{
next
->
prev
=
this
;
}
}
//! this node is unlinked from its list and inserted before the input node
void
insert
(
before_t
,
tail_t
&
node
)
{
unlink
();
next
=
&
node
;
prev
=
node
.
prev
;
node
.
prev
=
this
;
if
(
prev
)
{
prev
->
next
=
this
;
}
}
void
insert_before
(
tail_t
&
node
)
{
insert
(
before_t
{},
node
);
}
void
insert_after
(
head_t
&
node
)
{
insert
(
after_t
{},
node
);
}
~
Node
()
{
unlink
();
}
};
}
// namespace mgb::imperative::utils::intrusive_list
imperative/src/include/megbrain/imperative/utils/local_ptr.h
0 → 100644
浏览文件 @
2be6ceda
/**
* \file imperative/src/include/megbrain/imperative/utils/local_ptr.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <optional>
#include "megbrain/imperative/utils/mempool.h"
#include "megbrain/utils/metahelper.h"
namespace
mgb
::
imperative
{
template
<
typename
T
>
class
LocalPtrStorage
:
public
NonCopyableObj
{
private:
size_t
m_ref_count
=
0
;
size_t
m_weak_count
=
0
;
T
*
m_pointer
=
nullptr
;
void
(
*
reset
)(
LocalPtrStorage
*
)
=
nullptr
;
void
(
*
free
)(
LocalPtrStorage
*
)
=
nullptr
;
void
inc_ref
()
{
m_ref_count
++
;
}
void
dec_ref
()
{
m_ref_count
--
;
if
(
m_ref_count
==
0
)
{
reset
(
this
);
m_pointer
=
nullptr
;
reset
=
nullptr
;
if
(
m_weak_count
==
0
)
{
free
(
this
);
// dead
}
}
}
void
inc_weak_ref
()
{
m_weak_count
++
;
}
void
dec_weak_ref
()
{
m_weak_count
--
;
if
((
m_weak_count
+
m_ref_count
)
==
0
)
{
free
(
this
);
// dead
}
}
template
<
typename
U
>
friend
class
LocalPtr
;
template
<
typename
U
>
friend
class
LocalWeakPtr
;
public:
};
template
<
typename
T
,
typename
TDerived
>
class
LocalPtrStorgeImpl
:
public
LocalPtrStorage
<
T
>
{
private:
std
::
optional
<
TDerived
>
m_value
;
void
*
m_pool
=
nullptr
;
template
<
typename
U
>
friend
class
LocalPtr
;
template
<
typename
U
>
friend
class
LocalWeakPtr
;
};
template
<
typename
T
>
class
LocalWeakPtr
;
/**
* \brief thread-unsafe smart pointer
*
* \tparam T type of value
*/
template
<
typename
T
>
class
LocalPtr
{
public:
using
storage_t
=
LocalPtrStorage
<
T
>
;
using
pool_t
=
MemPool
<
storage_t
>
;
using
weak_type
=
LocalWeakPtr
<
T
>
;
private:
storage_t
*
m_storage
=
nullptr
;
void
emplace
(
storage_t
*
ptr
)
{
if
(
ptr
)
{
ptr
->
inc_ref
();
m_storage
=
ptr
;
}
}
LocalPtr
(
storage_t
*
ptr
)
{
emplace
(
ptr
);
}
public:
LocalPtr
()
=
default
;
LocalPtr
(
const
LocalPtr
&
rhs
)
{
(
*
this
)
=
rhs
;
}
LocalPtr
(
LocalPtr
&&
rhs
)
{
(
*
this
)
=
std
::
move
(
rhs
);
}
LocalPtr
&
operator
=
(
const
LocalPtr
&
rhs
)
{
if
(
this
==
&
rhs
)
{
return
*
this
;
}
auto
storage
=
rhs
.
m_storage
;
if
(
storage
)
{
storage
->
inc_ref
();
}
if
(
m_storage
)
{
m_storage
->
dec_ref
();
// rhs.m_storage may be invalid here
}
m_storage
=
storage
;
return
*
this
;
}
LocalPtr
&
operator
=
(
LocalPtr
&&
rhs
)
{
if
(
this
==
&
rhs
)
{
return
*
this
;
}
std
::
swap
(
m_storage
,
rhs
.
m_storage
);
rhs
.
reset
();
return
*
this
;
}
bool
operator
==
(
const
LocalPtr
&
rhs
)
const
{
return
m_storage
==
rhs
.
m_storage
;
}
bool
operator
!=
(
const
LocalPtr
&
rhs
)
const
{
return
m_storage
!=
rhs
.
m_storage
;
}
size_t
hash
()
const
{
return
reinterpret_cast
<
uintptr_t
>
(
m_storage
);
}
~
LocalPtr
()
{
reset
();
}
/**
* \brief Construct an instance of TDerived and return an LocalPtr
*
* There is an memory pool for each (T, TDerived) pair
*
* \tparam TDerived type of concrete instance, should be subclass of T
* \tparam TArgs
* \param args constructor arguments
* \return LocalPtr points to the instance
*/
template
<
typename
TDerived
=
T
,
typename
...
TArgs
>
static
LocalPtr
make
(
TArgs
&&
...
args
)
{
static_assert
(
std
::
is_base_of_v
<
T
,
TDerived
>
);
using
storage_impl_t
=
LocalPtrStorgeImpl
<
T
,
TDerived
>
;
constexpr
auto
normalize_size
=
[](
size_t
size
)
{
size_t
normalized_size
=
64
;
while
(
normalized_size
<
size
)
{
normalized_size
*=
2
;
}
return
normalized_size
;
};
using
raw_storage_t
=
std
::
aligned_storage_t
<
normalize_size
(
sizeof
(
storage_impl_t
))
>
;
static_assert
(
alignof
(
raw_storage_t
)
%
alignof
(
storage_impl_t
)
==
0
);
static_assert
(
sizeof
(
raw_storage_t
)
>=
sizeof
(
storage_impl_t
));
using
pool_t
=
MemPool
<
raw_storage_t
>
;
pool_t
&
pool
=
MemPoolUtils
<
raw_storage_t
>::
get_thread_local
();
auto
*
raw_storage
=
pool
.
alloc_raw
();
auto
*
storage
=
reinterpret_cast
<
storage_impl_t
*>
(
raw_storage
);
new
(
storage
)
storage_impl_t
();
storage
->
m_value
.
emplace
(
std
::
forward
<
TArgs
&&>
(
args
)...);
storage
->
m_pointer
=
&*
storage
->
m_value
;
storage
->
reset
=
[](
storage_t
*
storage
)
{
auto
*
storage_impl
=
static_cast
<
storage_impl_t
*>
(
storage
);
storage_impl
->
m_value
.
reset
();
storage_impl
->
m_pointer
=
nullptr
;
};
storage
->
free
=
[](
storage_t
*
storage_base
)
{
auto
*
storage
=
static_cast
<
storage_impl_t
*>
(
storage_base
);
auto
*
pool
=
reinterpret_cast
<
pool_t
*>
(
storage
->
m_pool
);
storage
->
m_pool
=
nullptr
;
storage
->~
storage_impl_t
();
auto
*
raw_storage
=
reinterpret_cast
<
raw_storage_t
*>
(
storage
);
pool
->
free_raw
(
raw_storage
);
};
storage
->
m_pool
=
&
pool
;
return
{(
storage_t
*
)
storage
};
}
T
&
operator
*
()
const
{
return
*
get
();
}
T
*
get
()
const
{
if
((
!
m_storage
)
||
!
m_storage
->
m_pointer
)
{
return
nullptr
;
}
return
m_storage
->
m_pointer
;
}
T
*
operator
->
()
const
{
return
get
();
}
size_t
ref_count
()
const
{
return
m_storage
->
m_ref_count
;
}
bool
unique
()
const
{
return
ref_count
()
==
1
;
}
void
reset
()
{
if
(
m_storage
)
{
m_storage
->
dec_ref
();
m_storage
=
nullptr
;
}
}
operator
bool
()
const
{
return
bool
(
m_storage
);
}
bool
operator
==
(
std
::
nullptr_t
nptr
)
const
{
return
m_storage
==
nullptr
;
}
bool
operator
!=
(
std
::
nullptr_t
nptr
)
const
{
return
m_storage
!=
nullptr
;
}
template
<
typename
U
>
friend
class
LocalWeakPtr
;
};
template
<
typename
T
>
class
LocalWeakPtr
{
public:
using
storage_t
=
LocalPtrStorage
<
T
>
;
private:
storage_t
*
m_storage
=
nullptr
;
void
emplace
(
storage_t
*
ptr
)
{
if
(
ptr
)
{
ptr
->
inc_weak_ref
();
m_storage
=
ptr
;
}
}
public:
LocalWeakPtr
()
=
default
;
LocalWeakPtr
(
const
LocalPtr
<
T
>&
rhs
)
{
emplace
(
rhs
.
m_storage
);
}
LocalWeakPtr
(
const
LocalWeakPtr
&
rhs
)
{
(
*
this
)
=
rhs
;
}
LocalWeakPtr
(
LocalWeakPtr
&&
rhs
)
{
(
*
this
)
=
std
::
move
(
rhs
);
}
LocalWeakPtr
&
operator
=
(
const
LocalWeakPtr
&
rhs
)
{
if
(
this
==
&
rhs
)
{
return
*
this
;
}
reset
();
emplace
(
rhs
.
m_storage
);
return
*
this
;
}
LocalWeakPtr
&
operator
=
(
LocalWeakPtr
&&
rhs
)
{
if
(
this
==
&
rhs
)
{
return
*
this
;
}
std
::
swap
(
m_storage
,
rhs
.
m_storage
);
rhs
.
reset
();
return
*
this
;
}
~
LocalWeakPtr
()
{
reset
();
}
void
reset
()
{
if
(
m_storage
)
{
m_storage
->
dec_weak_ref
();
m_storage
=
nullptr
;
}
}
LocalPtr
<
T
>
lock
()
const
{
if
(
m_storage
&&
m_storage
->
m_ref_count
)
{
return
{
m_storage
};
}
return
{};
}
bool
operator
==
(
const
LocalWeakPtr
&
rhs
)
const
{
return
m_storage
==
rhs
.
m_storage
;
}
bool
operator
!=
(
const
LocalWeakPtr
&
rhs
)
const
{
return
m_storage
!=
rhs
.
m_storage
;
}
size_t
hash
()
const
{
return
reinterpret_cast
<
uintptr_t
>
(
m_storage
);
}
};
template
<
typename
T
,
typename
TDerived
,
typename
...
TArgs
>
LocalPtr
<
T
>
make_local
(
TArgs
&&
...
args
)
{
return
LocalPtr
<
T
>::
template
make
<
TDerived
>(
std
::
forward
<
TArgs
&&>
(
args
)...);
}
}
// namespace mgb::imperative
imperative/src/include/megbrain/imperative/utils/map.h
0 → 100644
浏览文件 @
2be6ceda
/**
* \file imperative/src/include/megbrain/imperative/utils/map.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <optional>
#include "megbrain/utils/metahelper.h"
namespace
mgb
::
imperative
{
/**
* \brief an hash map optimized for weak pointer as key
*
* Keys were scanned automatically, so values referenced by invalid keys whould be
* released soon
*
* \tparam TKey key type, requires(bool(key.lock()))
* \tparam TValue value type
*/
template
<
typename
TKey
,
typename
TValue
>
class
WeakKeyMap
:
public
NonCopyableObj
{
public:
using
storage_t
=
std
::
unordered_map
<
TKey
,
TValue
>
;
private:
storage_t
m_storage
;
typename
storage_t
::
iterator
m_cursor
=
m_storage
.
begin
();
/**
* \brief select a key and verify that whether it is invalid. If yes, erase it
*
*/
void
_step
()
{
if
(
m_cursor
==
m_storage
.
end
())
{
m_cursor
=
m_storage
.
begin
();
return
;
}
auto
key
=
m_cursor
->
first
;
if
(
!
key
.
lock
())
{
m_cursor
=
m_storage
.
erase
(
m_cursor
);
}
else
{
++
m_cursor
;
}
}
public:
size_t
count
(
TKey
key
)
{
_step
();
_step
();
return
m_storage
.
count
(
key
);
}
TValue
&
at
(
TKey
key
)
const
{
return
m_storage
.
at
(
key
);
}
TValue
&
at
(
TKey
key
)
{
_step
();
_step
();
return
m_storage
.
at
(
key
);
}
TValue
&
operator
[](
TKey
key
)
{
_step
();
_step
();
if
(
m_storage
.
count
(
key
))
{
return
m_storage
.
at
(
key
);
}
else
{
size_t
bucket_count
=
m_storage
.
bucket_count
();
TValue
&
result
=
m_storage
[
key
];
if
(
bucket_count
!=
m_storage
.
bucket_count
())
{
m_cursor
=
m_storage
.
begin
();
}
return
result
;
}
}
std
::
optional
<
TValue
>
try_get
(
TKey
key
)
const
{
auto
iter
=
m_storage
.
find
(
key
);
if
(
iter
==
m_storage
.
end
())
{
return
{};
}
return
{
iter
->
second
};
}
std
::
optional
<
TValue
>
try_get
(
TKey
key
)
{
_step
();
_step
();
return
((
const
WeakKeyMap
*
)
this
)
->
try_get
(
std
::
move
(
key
));
}
};
template
<
typename
TKey
,
typename
TValue
>
class
WeakValueMap
:
public
NonCopyableObj
{
public:
using
storage_t
=
std
::
unordered_map
<
TKey
,
TValue
>
;
private:
storage_t
m_storage
;
typename
storage_t
::
iterator
m_cursor
=
m_storage
.
begin
();
/**
* \brief select a key and verify that whether it is invalid. If yes, erase it
*
*/
void
_step
()
{
if
(
m_cursor
==
m_storage
.
end
())
{
m_cursor
=
m_storage
.
begin
();
return
;
}
auto
value
=
m_cursor
->
second
;
if
(
!
value
.
lock
())
{
m_cursor
=
m_storage
.
erase
(
m_cursor
);
}
else
{
++
m_cursor
;
}
}
public:
size_t
count
(
TKey
key
)
{
_step
();
_step
();
return
m_storage
.
count
(
key
);
}
TValue
&
at
(
TKey
key
)
const
{
return
m_storage
.
at
(
key
);
}
TValue
&
at
(
TKey
key
)
{
_step
();
_step
();
return
m_storage
.
at
(
key
);
}
TValue
&
operator
[](
TKey
key
)
{
_step
();
_step
();
if
(
m_storage
.
count
(
key
))
{
return
m_storage
.
at
(
key
);
}
else
{
size_t
bucket_count
=
m_storage
.
bucket_count
();
TValue
&
result
=
m_storage
[
key
];
if
(
bucket_count
!=
m_storage
.
bucket_count
())
{
m_cursor
=
m_storage
.
begin
();
}
return
result
;
}
}
};
}
// namespace mgb::imperative
\ No newline at end of file
imperative/src/include/megbrain/imperative/utils/mempool.h
0 → 100644
浏览文件 @
2be6ceda
/**
* \file imperative/src/include/megbrain/imperative/utils/mempool.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <mutex>
#include <thread>
#include <unordered_map>
#include "megbrain/utils/mempool.h"
#include "megbrain/utils/metahelper.h"
namespace
mgb
::
imperative
{
template
<
typename
T
>
class
MemPoolUtils
{
private:
static
std
::
mutex
sm_mutex
;
static
std
::
unordered_map
<
std
::
thread
::
id
,
std
::
unique_ptr
<
MemPool
<
T
>>>
sm_instances
;
static
thread_local
MemPool
<
T
>*
tm_instance
;
static
MemPool
<
T
>*
sm_instance
;
public:
static
MemPool
<
T
>&
get_thread_local
()
{
if
(
!
tm_instance
)
{
MGB_LOCK_GUARD
(
sm_mutex
);
auto
&
instance
=
sm_instances
[
std
::
this_thread
::
get_id
()];
if
(
!
instance
)
{
// thread id may be duplicated
instance
=
std
::
make_unique
<
MemPool
<
T
>>
();
}
tm_instance
=
instance
.
get
();
}
return
*
tm_instance
;
}
static
MemPool
<
T
>&
get_static
()
{
if
(
!
sm_instance
)
{
MGB_LOCK_GUARD
(
sm_mutex
);
auto
&
instance
=
sm_instances
[{}];
if
(
!
instance
)
{
// double check
instance
=
std
::
make_unique
<
MemPool
<
T
>>
();
sm_instance
=
instance
.
get
();
}
mgb_assert
(
sm_instance
);
}
}
};
template
<
typename
T
>
std
::
mutex
MemPoolUtils
<
T
>::
sm_mutex
;
template
<
typename
T
>
std
::
unordered_map
<
std
::
thread
::
id
,
std
::
unique_ptr
<
MemPool
<
T
>>>
MemPoolUtils
<
T
>::
sm_instances
;
template
<
typename
T
>
thread_local
MemPool
<
T
>*
MemPoolUtils
<
T
>::
tm_instance
;
template
<
typename
T
>
MemPool
<
T
>*
MemPoolUtils
<
T
>::
sm_instance
;
}
// namespace mgb::imperative
\ No newline at end of file
imperative/src/include/megbrain/imperative/utils/span.h
0 → 100644
浏览文件 @
2be6ceda
/**
* \file imperative/src/include/megbrain/imperative/utils/span.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <array>
#include <vector>
#include "megbrain/utils/small_vector.h"
namespace
mgb
::
imperative
{
/**
* \brief wrapper for c-style array
*
* \tparam T value type
*/
template
<
typename
T
>
class
Span
{
private:
const
T
*
m_begin
=
nullptr
;
const
T
*
m_end
=
nullptr
;
public:
Span
()
{}
Span
(
const
T
*
begin
,
const
T
*
end
)
:
m_begin
{
begin
},
m_end
{
end
}
{}
Span
(
const
T
*
begin
,
size_t
size
)
:
Span
(
begin
,
begin
+
size
)
{}
template
<
typename
TContainer
>
Span
(
TContainer
&
container
)
:
Span
(
container
.
data
(),
container
.
size
())
{}
const
T
*
begin
()
const
{
return
m_begin
;
}
const
T
*
end
()
const
{
return
m_end
;
}
const
T
*
data
()
const
{
return
m_begin
;
}
size_t
size
()
const
{
return
m_end
-
m_begin
;
}
template
<
typename
TContainer
>
TContainer
copy_into
()
{
return
TContainer
(
m_begin
,
m_end
);
}
const
T
&
operator
[](
size_t
idx
)
const
{
return
m_begin
[
idx
];
}
const
T
&
at
(
size_t
idx
)
const
{
return
m_begin
[
idx
];
}
const
T
&
item
()
const
{
mgb_assert
(
m_end
-
m_begin
==
1
,
"size mismatch: %zu vs %zu"
,
(
m_end
-
m_begin
),
(
size_t
)
1
);
return
m_begin
[
0
];
}
template
<
size_t
N
>
const
std
::
array
<
T
,
N
>&
as_array
()
{
mgb_assert
(
m_end
-
m_begin
==
N
,
"size mismatch: %zu vs %zu"
,
(
m_end
-
m_begin
),
N
);
return
*
reinterpret_cast
<
const
std
::
array
<
T
,
N
>*>
(
m_begin
);
}
Span
sub
(
size_t
begin
,
size_t
length
)
{
mgb_assert
(
begin
+
length
<=
m_end
-
m_begin
);
return
{
m_begin
+
begin
,
length
};
}
};
}
// namespace mgb::imperative
imperative/src/include/megbrain/imperative/utils/to_string.h
浏览文件 @
2be6ceda
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include <tuple>
#include <tuple>
#include <type_traits>
#include <type_traits>
#include "megbrain/imperative/utils/span.h"
#include "megbrain/tensor.h"
#include "megbrain/tensor.h"
#include "megbrain/utils/small_vector.h"
#include "megbrain/utils/small_vector.h"
...
@@ -59,6 +60,22 @@ struct ToStringTrait<SmallVector<T, N>> {
...
@@ -59,6 +60,22 @@ struct ToStringTrait<SmallVector<T, N>> {
}
}
};
};
template
<
typename
T
>
struct
ToStringTrait
<
std
::
vector
<
T
>>
{
std
::
string
operator
()(
const
std
::
vector
<
T
>&
v
)
const
{
if
(
v
.
empty
())
{
return
"[]"
;
}
std
::
string
result
=
"["
;
result
+=
to_string
(
v
[
0
]);
for
(
size_t
i
=
1
;
i
<
v
.
size
();
++
i
)
{
result
+=
", "
;
result
+=
to_string
(
v
[
i
]);
}
return
result
+
"]"
;
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
ToStringTrait
<
std
::
shared_ptr
<
T
>>
{
struct
ToStringTrait
<
std
::
shared_ptr
<
T
>>
{
std
::
string
operator
()(
const
std
::
shared_ptr
<
T
>&
sp
)
const
{
std
::
string
operator
()(
const
std
::
shared_ptr
<
T
>&
sp
)
const
{
...
@@ -115,4 +132,36 @@ struct ToStringTrait<CompNode> {
...
@@ -115,4 +132,36 @@ struct ToStringTrait<CompNode> {
std
::
string
operator
()(
CompNode
device
)
const
{
return
device
.
to_string
();
}
std
::
string
operator
()(
CompNode
device
)
const
{
return
device
.
to_string
();
}
};
};
inline
std
::
string
string_join
(
Span
<
std
::
string
>
span
,
char
delimiter
=
','
)
{
std
::
string
buffer
=
"["
;
for
(
size_t
i
=
1
;
i
<
span
.
size
();
++
i
)
{
if
(
i
)
{
buffer
.
push_back
(
delimiter
);
}
buffer
.
append
(
span
[
0
]);
}
return
buffer
+
"]"
;
}
template
<
typename
T
>
struct
ToStringTrait
<
Span
<
T
>>
{
std
::
string
operator
()(
Span
<
T
>
span
)
const
{
if
(
span
.
size
()
==
0
)
{
return
"[]"
;
}
std
::
string
result
=
"["
;
result
+=
to_string
(
span
[
0
]);
for
(
size_t
i
=
1
;
i
<
span
.
size
();
++
i
)
{
result
+=
", "
;
result
+=
to_string
(
span
[
i
]);
}
return
result
+
"]"
;
}
};
template
<
>
struct
ToStringTrait
<
std
::
type_info
>
{
std
::
string
operator
()(
const
std
::
type_info
&
info
)
const
{
return
info
.
name
();
}
};
}
// namespace mgb::imperative
}
// namespace mgb::imperative
imperative/src/include/megbrain/imperative/utils/value_shape.h
0 → 100644
浏览文件 @
2be6ceda
/**
* \file imperative/src/include/megbrain/imperative/utils/visit.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <vector>
#include "megbrain/imperative/utils/span.h"
#include "megbrain/tensor.h"
namespace
mgb
::
imperative
{
/**
* \brief like TensorShape, but allow real scalar shape.
*
*/
struct
ValueShape
{
size_t
shape
[
TensorShape
::
MAX_NDIM
];
int
ndim
=
0
;
ValueShape
()
=
default
;
ValueShape
(
std
::
initializer_list
<
size_t
>
dims
)
{
for
(
auto
&&
dim
:
dims
)
{
shape
[
ndim
++
]
=
dim
;
}
}
ValueShape
(
Span
<
size_t
>
dims
)
{
for
(
auto
&&
dim
:
dims
)
{
shape
[
ndim
++
]
=
dim
;
}
}
size_t
&
operator
[](
int
axis
)
{
return
shape
[
axis
];
}
size_t
operator
[](
int
axis
)
const
{
return
shape
[
axis
];
}
size_t
at
(
int
axis
)
const
{
mgb_assert
(
axis
<
ndim
);
return
shape
[
axis
];
}
size_t
total_nr_elems
()
const
{
size_t
prod
=
1
;
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
{
prod
*=
shape
[
i
];
}
return
prod
;
}
bool
is_scalar
()
const
{
return
ndim
==
0
;
}
std
::
string
to_string
()
const
{
std
::
string
buffer
=
"{"
;
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
if
(
i
)
{
buffer
.
append
(
","
);
}
buffer
.
append
(
std
::
to_string
(
shape
[
i
]));
}
buffer
.
append
(
"}"
);
return
buffer
;
}
static
ValueShape
from
(
TensorShape
tensor_shape
)
{
mgb_assert
(
tensor_shape
.
ndim
);
return
Span
<
size_t
>
{
tensor_shape
.
shape
,
tensor_shape
.
ndim
};
}
TensorShape
as_tensor_shape
()
const
{
mgb_assert
(
ndim
!=
0
);
TensorShape
ret
;
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
ret
.
shape
[
i
]
=
shape
[
i
];
}
ret
.
ndim
=
ndim
;
return
ret
;
}
bool
operator
==
(
const
ValueShape
&
rhs
)
const
{
if
(
ndim
!=
rhs
.
ndim
)
{
return
false
;
}
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
if
(
shape
[
i
]
!=
rhs
.
shape
[
i
])
{
return
false
;
}
}
return
true
;
}
};
static_assert
(
sizeof
(
size_t
)
>=
sizeof
(
int
));
static_assert
(
TensorShape
::
MAX_NDIM
==
7
);
static_assert
(
sizeof
(
ValueShape
)
<=
sizeof
(
size_t
)
*
8
);
}
// namespace mgb::imperative
\ No newline at end of file
imperative/src/include/megbrain/imperative/utils/visit.h
0 → 100644
浏览文件 @
2be6ceda
/**
* \file imperative/src/include/megbrain/imperative/utils/visit.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <vector>
#include "megbrain/utils/small_vector.h"
namespace
mgb
::
imperative
{
template
<
typename
...
TVisitors
>
class
Visitor
:
public
TVisitors
...
{
public:
using
TVisitors
::
operator
()...;
};
}
// namespace mgb::imperative
imperative/src/test/profiler.cpp
浏览文件 @
2be6ceda
...
@@ -28,10 +28,10 @@ TEST(TestProfiler, ImperativeLogProfile) {
...
@@ -28,10 +28,10 @@ TEST(TestProfiler, ImperativeLogProfile) {
auto
results
=
imperative
::
Profiler
::
collect
();
auto
results
=
imperative
::
Profiler
::
collect
();
imperative
::
Profiler
::
stop_profile
();
imperative
::
Profiler
::
stop_profile
();
mgb_assert
(
results
.
entries
.
size
()
==
2
);
mgb_assert
(
results
.
entries
.
size
()
==
2
);
auto
*
event_start
=
results
.
entries
[
0
].
data
.
as
<
profiler
::
CustomEvent
>
();
auto
&
event_start
=
results
.
entries
[
0
].
data
.
cast
<
profiler
::
CustomEvent
>
();
auto
*
event_finish
=
results
.
entries
[
1
].
data
.
as
<
profiler
::
CustomFinishEvent
>
();
auto
&
event_finish
=
results
.
entries
[
1
].
data
.
cast
<
profiler
::
CustomFinishEvent
>
();
mgb_assert
(
event_start
&&
event_start
->
title
==
"XXX"
);
mgb_assert
(
event_start
.
title
==
"XXX"
);
mgb_assert
(
event_finish
&&
event_finish
->
title
==
"XXX"
);
mgb_assert
(
event_finish
.
title
==
"XXX"
);
mgb_assert
(
results
.
entries
[
0
].
time
<
results
.
entries
[
1
].
time
);
mgb_assert
(
results
.
entries
[
0
].
time
<
results
.
entries
[
1
].
time
);
mgb_assert
(
results
.
entries
[
0
].
id
<
results
.
entries
[
1
].
id
);
mgb_assert
(
results
.
entries
[
0
].
id
<
results
.
entries
[
1
].
id
);
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录