Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
780663c9
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
780663c9
编写于
5月 12, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(api_cache): lock api cache for thread safety
GitOrigin-RevId: 8a244677c3d2d8b2a7eaafb51c3ac13e2dfc55d6
上级
d4615f91
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
147 addition
and
39 deletion
+147
-39
dnn/src/common/api_cache.h
dnn/src/common/api_cache.h
+142
-34
dnn/src/cuda/api_cache.h
dnn/src/cuda/api_cache.h
+5
-5
未找到文件。
dnn/src/common/api_cache.h
浏览文件 @
780663c9
...
...
@@ -12,19 +12,79 @@
#pragma once
#include <atomic>
#include <cstring>
#include <memory>
#include <mutex>
#include <tuple>
#include <unordered_map>
#include "megdnn/thin/function.h"
#include "./utils.h"
namespace
megdnn
{
template
<
typename
...
TArgs
>
class
FunctionCache
{
// https://jfdube.wordpress.com/2014/01/03/implementing-a-recursive-read-write-spinlock/
class
RWSpin
{
public:
class
Lock
{
private:
RWSpin
*
m_spin
;
void
(
RWSpin
::*
m_lock
)(
void
);
void
(
RWSpin
::*
m_unlock
)(
void
);
public:
Lock
(
RWSpin
*
spin
,
decltype
(
m_lock
)
lock
,
decltype
(
m_unlock
)
unlock
)
:
m_spin
{
spin
},
m_lock
{
lock
},
m_unlock
{
unlock
}
{}
void
lock
()
{
(
m_spin
->*
m_lock
)();
}
void
unlock
()
{
(
m_spin
->*
m_unlock
)();
}
};
private:
std
::
atomic
<
uint32_t
>
m_atomic
{
0
};
static
constexpr
uint32_t
sm_reader_mask
=
0x7FFFFFFF
;
static
constexpr
uint32_t
sm_writer_mask
=
0x80000000
;
void
_reader_lock
()
{
uint32_t
expected
=
m_atomic
;
do
{
expected
&=
sm_reader_mask
;
}
while
(
!
m_atomic
.
compare_exchange_strong
(
expected
,
expected
+
1
));
}
void
_reader_unlock
()
{
m_atomic
--
;
}
void
_writer_lock
()
{
uint32_t
expected
=
m_atomic
;
do
{
expected
&=
sm_reader_mask
;
}
while
(
!
m_atomic
.
compare_exchange_strong
(
expected
,
expected
|
sm_writer_mask
));
while
(
m_atomic
.
load
()
!=
sm_writer_mask
)
;
}
void
_writer_unlock
()
{
// assert m_atomic == sm_writer_mask
m_atomic
=
0
;
}
public:
Lock
reader
()
{
return
{
this
,
&
RWSpin
::
_reader_lock
,
&
RWSpin
::
_reader_unlock
};
}
Lock
writer
()
{
return
{
this
,
&
RWSpin
::
_writer_lock
,
&
RWSpin
::
_writer_unlock
};
}
};
template
<
typename
TSignature
>
class
FunctionCache
;
template
<
typename
TRet
,
typename
...
TArgs
>
class
FunctionCache
<
TRet
(
TArgs
...)
>
{
public:
using
key_t
=
std
::
string
;
using
value_t
=
std
::
string
;
using
value_t
=
TRet
;
using
key_mapper_t
=
thin_function
<
key_t
(
TArgs
...)
>
;
using
value_mapper_t
=
thin_function
<
value_t
(
TArgs
...)
>
;
using
storage_t
=
std
::
unordered_map
<
key_t
,
value_t
>
;
...
...
@@ -33,13 +93,31 @@ public:
key_mapper_t
key_mapper
;
value_mapper_t
value_mapper
;
value_t
operator
()(
TArgs
...
args
)
{
RWSpin
spin
;
public:
TRet
operator
()(
TArgs
...
args
)
{
key_t
key
=
key_mapper
(
args
...);
if
(
storage
.
count
(
key
)
==
0
)
{
storage
[
key
]
=
value_mapper
(
std
::
forward
<
TArgs
>
(
args
)...);
}
auto
reader_lock
=
spin
.
reader
();
auto
writer_lock
=
spin
.
writer
();
{
MEGDNN_LOCK_GUARD
(
reader_lock
);
auto
iter
=
storage
.
find
(
key
);
if
(
iter
!=
storage
.
end
())
{
return
iter
->
second
;
}
}
// RWSpin doesn't support upgrade
{
MEGDNN_LOCK_GUARD
(
writer_lock
);
if
(
storage
.
count
(
key
)
!=
0
)
{
return
storage
[
key
];
}
value_t
ret
=
value_mapper
(
std
::
forward
<
TArgs
>
(
args
)...);
storage
[
key
]
=
ret
;
return
ret
;
}
}
};
// FIFO
...
...
@@ -63,10 +141,8 @@ public:
"type should be trivially copyable"
);
m_buffer
.
append
(
reinterpret_cast
<
const
char
*>
(
&
value
),
sizeof
(
T
));
}
std
::
string
take
()
{
return
std
::
move
(
m_buffer
);
}
void
set
(
std
::
string
new_buf
)
{
std
::
string
take
()
{
return
std
::
move
(
m_buffer
);
}
void
reset
(
std
::
string
new_buf
)
{
m_cursor
=
0
;
m_buffer
=
new_buf
;
}
...
...
@@ -74,26 +150,32 @@ public:
struct
Empty
{};
template
<
typename
...
TParams
>
class
ParamBundle
{
private:
template
<
std
::
size_t
N
,
std
::
size_t
...
Seq
>
static
std
::
index_sequence
<
N
+
Seq
...
>
add_all
(
// in: seq[1, 2, ..., m]
// out: seq[N+1, N+2, ... N+m]
template
<
std
::
size_t
N
,
std
::
size_t
...
Seq
>
static
std
::
index_sequence
<
N
+
Seq
...
>
inc_index_sequence
(
std
::
index_sequence
<
Seq
...
>
)
{
return
{};
}
}
template
<
typename
...
TParams
>
class
ParamBundle
{
private:
// out: Min, Min+1, ..., Max
template
<
std
::
size_t
Min
,
std
::
size_t
Max
>
using
make_index_range
=
decltype
(
add_all
<
Min
>
(
std
::
make_index_sequence
<
Max
-
Min
>
()));
using
make_index_range
=
decltype
(
inc_index_sequence
<
Min
>
(
std
::
make_index_sequence
<
Max
-
Min
>
()));
// store params in a tuple
using
storage_t
=
std
::
tuple
<
typename
std
::
remove_reference_t
<
TParams
>
...
>
;
storage_t
m_storage
;
// deconstruct tuple and call functor
template
<
typename
TFunctor
,
size_t
...
Indices
>
auto
call_helper
(
TFunctor
functor
,
std
::
index_sequence
<
Indices
...
>
)
{
return
functor
(
std
::
get
<
Indices
>
(
m_storage
).
value
...);
}
template
<
size_t
Index
,
size_t
...
Indices
,
typename
TPrev
>
auto
serialize_helper
(
StringSerializer
&
ser
,
TPrev
&&
prev
,
std
::
index_sequence
<
Index
,
Indices
...
>
)
{
...
...
@@ -101,9 +183,11 @@ private:
std
::
get
<
Index
>
(
m_storage
).
serialize
(
ser
,
prev
),
std
::
index_sequence
<
Indices
...
>
());
}
template
<
typename
TPrev
>
auto
serialize_helper
(
StringSerializer
&
ser
,
TPrev
&&
prev
,
std
::
index_sequence
<>
)
{}
template
<
size_t
Index
,
size_t
...
Indices
,
typename
TPrev
>
auto
deserialize_helper
(
StringSerializer
&
ser
,
TPrev
&&
prev
,
std
::
index_sequence
<
Index
,
Indices
...
>
)
{
...
...
@@ -111,9 +195,11 @@ private:
ser
,
std
::
get
<
Index
>
(
m_storage
).
deserialize
(
ser
,
prev
),
std
::
index_sequence
<
Indices
...
>
());
}
template
<
typename
TPrev
>
auto
deserialize_helper
(
StringSerializer
&
ser
,
TPrev
&&
prev
,
std
::
index_sequence
<>
)
{}
template
<
size_t
Index
,
size_t
...
Indices
,
typename
TArg
,
typename
...
TArgs
>
void
set_values_helper
(
std
::
index_sequence
<
Index
,
Indices
...
>
,
TArg
&&
arg
,
TArgs
&&
...
args
)
{
...
...
@@ -121,6 +207,7 @@ private:
set_values_helper
(
std
::
index_sequence
<
Indices
...
>
(),
std
::
forward
<
TArgs
>
(
args
)...);
}
template
<
size_t
...
Indices
>
void
set_values_helper
(
std
::
index_sequence
<
Indices
...
>
)
{
static_assert
(
sizeof
...(
Indices
)
==
0
,
"redundant indices"
);
...
...
@@ -132,24 +219,25 @@ public:
return
call_helper
(
std
::
forward
<
TFunctor
>
(
functor
),
std
::
make_index_sequence
<
sizeof
...(
TParams
)
>
());
}
// recursively store params into ser
template
<
size_t
NBegin
,
size_t
NEnd
>
void
serialize_params
(
StringSerializer
&
ser
)
{
static_assert
(
NEnd
>=
NBegin
,
"invalid range"
);
serialize_helper
(
ser
,
Empty
{},
add_all
<
NBegin
>
(
std
::
make_index_sequence
<
NEnd
-
NBegin
>
()));
serialize_helper
(
ser
,
Empty
{},
make_index_range
<
NBegin
,
NEnd
>
());
}
// recursively load params from ser
template
<
size_t
NBegin
,
size_t
NEnd
>
void
deserialize_params
(
StringSerializer
&
ser
)
{
static_assert
(
NEnd
>=
NBegin
,
"invalid range"
);
deserialize_helper
(
ser
,
Empty
{},
add_all
<
NBegin
>
(
std
::
make_index_sequence
<
NEnd
-
NBegin
>
()));
deserialize_helper
(
ser
,
Empty
{},
make_index_range
<
NBegin
,
NEnd
>
());
}
// recursively set params into m_storage
template
<
size_t
NBegin
,
size_t
NEnd
,
typename
...
TArgs
>
void
set_values
(
TArgs
&&
...
args
)
{
set_values_helper
(
add_all
<
NBegin
>
(
std
::
make_index_sequence
<
NEnd
-
NBegin
>
()),
set_values_helper
(
make_index_range
<
NBegin
,
NEnd
>
(),
std
::
forward
<
TArgs
>
(
args
)...);
}
};
...
...
@@ -158,10 +246,12 @@ template <typename T>
class
Param
{
public:
T
value
;
Empty
serialize
(
StringSerializer
&
ser
,
Empty
)
{
ser
.
write_plain
(
value
);
return
Empty
{};
}
Empty
deserialize
(
StringSerializer
&
ser
,
Empty
)
{
value
=
ser
.
read_plain
<
T
>
();
return
Empty
{};
...
...
@@ -172,42 +262,54 @@ template <typename TRet = Param<Empty>, typename TInputs = std::tuple<>,
typename
TOutputs
=
std
::
tuple
<>>
class
FunctionCacheBuilder
{
private:
// decl value with type of tuple-of-args
static
auto
declargs
()
->
decltype
(
std
::
tuple_cat
(
std
::
declval
<
TInputs
>
(),
std
::
declval
<
TOutputs
>
()))
{
return
{};
}
template
<
size_t
...
Indices
>
static
auto
declfunction_helper
(
std
::
index_sequence
<
Indices
...
>
)
->
thin_function
<
decltype
(
std
::
declval
<
TRet
>
().
value
)(
decltype
(
std
::
get
<
Indices
>
(
declargs
()).
value
)...)
>
{
return
{};
}
// decl value with type of original function
static
auto
declfunction
()
{
return
declfunction_helper
(
std
::
make_index_sequence
<
std
::
tuple_size
<
TInputs
>::
value
+
std
::
tuple_size
<
TOutputs
>::
value
>
());
}
template
<
size_t
...
Indices
>
static
auto
declbundle_helper
(
std
::
index_sequence
<
Indices
...
>
)
->
ParamBundle
<
decltype
(
std
::
get
<
Indices
>
(
declargs
()))...
>
{
return
{};
}
// decl value with type of bundle-of-args
static
auto
declbundle
()
{
return
declbundle_helper
(
std
::
make_index_sequence
<
std
::
tuple_size
<
TInputs
>::
value
+
std
::
tuple_size
<
TOutputs
>::
value
>
());
}
// type of original function
using
function_t
=
decltype
(
declfunction
());
// type of bundle-of-args
using
bundle_t
=
decltype
(
declbundle
());
public:
// declare new return type, cannot be override
template
<
typename
TNewRet
>
auto
ret
()
{
static_assert
(
std
::
is_same
<
TRet
,
Param
<
Empty
>>::
value
,
"return value redefinition"
);
return
FunctionCacheBuilder
<
TNewRet
,
TInputs
,
TOutputs
>
{};
}
// declare new input
template
<
typename
TNewInput
>
auto
input
()
{
using
TNewInputs
=
decltype
(
...
...
@@ -215,6 +317,7 @@ public:
std
::
make_tuple
(
std
::
declval
<
TNewInput
>
())));
return
FunctionCacheBuilder
<
TRet
,
TNewInputs
,
TOutputs
>
{};
}
// declare new output
template
<
typename
TNewOutput
>
auto
output
()
{
using
TNewOutputs
=
decltype
(
...
...
@@ -222,17 +325,20 @@ public:
std
::
make_tuple
(
std
::
declval
<
TNewOutput
>
())));
return
FunctionCacheBuilder
<
TRet
,
TInputs
,
TNewOutputs
>
{};
}
// summary
template
<
typename
TFunctor
>
function_t
build
(
TFunctor
func
)
{
FunctionCache
<
bundle_t
>
cache
;
cache
.
key_mapper
=
[](
bundle_t
bundle
)
{
auto
cache
=
std
::
make_shared
<
FunctionCache
<
std
::
string
(
bundle_t
)
>>
();
// bundle -> ser(in args)
cache
->
key_mapper
=
[](
bundle_t
bundle
)
{
StringSerializer
ser
;
bundle
.
template
serialize_params
<
0
,
std
::
tuple_size
<
TInputs
>
::
value
>
(
ser
);
return
ser
.
take
();
};
cache
.
value_mapper
=
[
=
](
bundle_t
bundle
)
{
// bundle -> ser(out args)
cache
->
value_mapper
=
[
=
](
bundle_t
bundle
)
{
StringSerializer
ser
;
TRet
ret
;
ret
.
value
=
bundle
.
call_by
(
func
);
...
...
@@ -253,7 +359,7 @@ public:
"args count mismatch"
);
bundle
.
template
set_values
<
0
,
sizeof
...(
args
)>(
std
::
forward
<
decltype
(
args
)
>
(
args
)...);
ser
.
set
(
cache
(
bundle
));
ser
.
reset
((
*
cache
)
(
bundle
));
ret
.
deserialize
(
ser
,
Empty
{});
constexpr
size_t
n_inputs
=
std
::
tuple_size
<
TInputs
>::
value
;
constexpr
size_t
n_outputs
=
std
::
tuple_size
<
TOutputs
>::
value
;
...
...
@@ -278,6 +384,7 @@ public:
}
};
// like RefParam but return *value while ser and deser. Working with ArrayParam
template
<
typename
T
>
class
RefArraySizeParam
{
public:
...
...
@@ -291,6 +398,7 @@ public:
}
};
// accept array length from previous param. Working with RefArraySizeParam
template
<
typename
TSize
,
typename
TItem
>
class
ArrayParam
{
public:
...
...
dnn/src/cuda/api_cache.h
浏览文件 @
780663c9
...
...
@@ -20,7 +20,7 @@ class CudnnConvDescParam {
public:
cudnnConvolutionDescriptor_t
value
;
Empty
serialize
(
StringSerializer
&
ser
,
Empty
)
{
constexpr
int
nbDims
=
MEGDNN_MAX_NDIM
;
int
nbDims
=
MEGDNN_MAX_NDIM
;
int
padA
[
MEGDNN_MAX_NDIM
];
int
strideA
[
MEGDNN_MAX_NDIM
];
int
dilationA
[
MEGDNN_MAX_NDIM
];
...
...
@@ -59,7 +59,7 @@ class CudnnTensorDescParam {
public:
cudnnTensorDescriptor_t
value
;
Empty
serialize
(
StringSerializer
&
ser
,
Empty
)
{
constexpr
int
nbDims
=
MEGDNN_MAX_NDIM
;
int
nbDims
=
MEGDNN_MAX_NDIM
;
cudnnDataType_t
dataType
;
int
dimA
[
MEGDNN_MAX_NDIM
];
int
strideA
[
MEGDNN_MAX_NDIM
];
...
...
@@ -74,7 +74,7 @@ public:
return
Empty
{};
}
Empty
deserialize
(
StringSerializer
&
ser
,
Empty
)
{
constexpr
int
nbDims
=
MEGDNN_MAX_NDIM
;
int
nbDims
=
MEGDNN_MAX_NDIM
;
cudnnDataType_t
dataType
;
int
dimA
[
MEGDNN_MAX_NDIM
];
int
strideA
[
MEGDNN_MAX_NDIM
];
...
...
@@ -92,7 +92,7 @@ class CudnnFilterDescParam {
public:
cudnnFilterDescriptor_t
value
;
Empty
serialize
(
StringSerializer
&
ser
,
Empty
)
{
constexpr
int
nbDims
=
MEGDNN_MAX_NDIM
;
int
nbDims
=
MEGDNN_MAX_NDIM
;
cudnnDataType_t
dataType
;
cudnnTensorFormat_t
format
;
int
filterDimA
[
MEGDNN_MAX_NDIM
];
...
...
@@ -107,7 +107,7 @@ public:
return
Empty
{};
}
Empty
deserialize
(
StringSerializer
&
ser
,
Empty
)
{
constexpr
int
nbDims
=
MEGDNN_MAX_NDIM
;
int
nbDims
=
MEGDNN_MAX_NDIM
;
cudnnDataType_t
dataType
;
cudnnTensorFormat_t
format
;
int
filterDimA
[
MEGDNN_MAX_NDIM
];
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录