Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
684c07d7
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
684c07d7
编写于
5月 17, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(api_cache): fix serialization for conv_desc
GitOrigin-RevId: 95dbc9c685cced46dd910997bd585363c392ccbd
上级
780663c9
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
77 addition
and
81 deletion
+77
-81
dnn/src/common/api_cache.h
dnn/src/common/api_cache.h
+37
-27
dnn/src/cuda/api_cache.h
dnn/src/cuda/api_cache.h
+34
-49
dnn/src/cuda/handle.cpp
dnn/src/cuda/handle.cpp
+6
-5
未找到文件。
dnn/src/common/api_cache.h
浏览文件 @
684c07d7
...
...
@@ -131,12 +131,18 @@ public:
T
read_plain
()
{
static_assert
(
std
::
is_trivially_copyable
<
T
>::
value
,
"invalid type"
);
T
ret
;
memcpy
(
&
ret
,
m_buffer
.
data
()
+
m_cursor
,
sizeof
(
T
));
std
::
memcpy
(
&
ret
,
m_buffer
.
data
()
+
m_cursor
,
sizeof
(
T
));
m_cursor
+=
sizeof
(
T
);
return
ret
;
}
template
<
typename
T
>
void
write_plain
(
T
value
)
{
void
read_plain
(
T
*
dest
)
{
static_assert
(
std
::
is_trivially_copyable
<
T
>::
value
,
"invalid type"
);
std
::
memcpy
(
dest
,
m_buffer
.
data
()
+
m_cursor
,
sizeof
(
T
));
m_cursor
+=
sizeof
(
T
);
}
template
<
typename
T
>
void
write_plain
(
const
T
&
value
)
{
static_assert
(
std
::
is_trivially_copyable
<
T
>::
value
,
"type should be trivially copyable"
);
m_buffer
.
append
(
reinterpret_cast
<
const
char
*>
(
&
value
),
sizeof
(
T
));
...
...
@@ -144,7 +150,7 @@ public:
std
::
string
take
()
{
return
std
::
move
(
m_buffer
);
}
void
reset
(
std
::
string
new_buf
)
{
m_cursor
=
0
;
m_buffer
=
new_buf
;
m_buffer
=
std
::
move
(
new_buf
)
;
}
};
...
...
@@ -153,7 +159,7 @@ struct Empty {};
// in: seq[1, 2, ..., m]
// out: seq[N+1, N+2, ... N+m]
template
<
std
::
size_t
N
,
std
::
size_t
...
Seq
>
static
std
::
index_sequence
<
N
+
Seq
...
>
inc_index_sequence
(
inline
std
::
index_sequence
<
N
+
Seq
...
>
inc_index_sequence
(
std
::
index_sequence
<
Seq
...
>
)
{
return
{};
}
...
...
@@ -172,7 +178,7 @@ private:
// deconstruct tuple and call functor
template
<
typename
TFunctor
,
size_t
...
Indices
>
auto
call_helper
(
TFunctor
functor
,
std
::
index_sequence
<
Indices
...
>
)
{
auto
call_helper
(
TFunctor
&&
functor
,
std
::
index_sequence
<
Indices
...
>
)
{
return
functor
(
std
::
get
<
Indices
>
(
m_storage
).
value
...);
}
...
...
@@ -203,7 +209,7 @@ private:
template
<
size_t
Index
,
size_t
...
Indices
,
typename
TArg
,
typename
...
TArgs
>
void
set_values_helper
(
std
::
index_sequence
<
Index
,
Indices
...
>
,
TArg
&&
arg
,
TArgs
&&
...
args
)
{
std
::
get
<
Index
>
(
m_storage
).
value
=
arg
;
std
::
get
<
Index
>
(
m_storage
).
value
=
std
::
forward
<
TArg
>
(
arg
)
;
set_values_helper
(
std
::
index_sequence
<
Indices
...
>
(),
std
::
forward
<
TArgs
>
(
args
)...);
}
...
...
@@ -253,7 +259,7 @@ public:
}
Empty
deserialize
(
StringSerializer
&
ser
,
Empty
)
{
value
=
ser
.
read_plain
<
T
>
(
);
ser
.
read_plain
(
&
value
);
return
Empty
{};
}
};
...
...
@@ -285,7 +291,8 @@ private:
template
<
size_t
...
Indices
>
static
auto
declbundle_helper
(
std
::
index_sequence
<
Indices
...
>
)
->
ParamBundle
<
decltype
(
std
::
get
<
Indices
>
(
declargs
()))...
>
{
->
ParamBundle
<
std
::
remove_reference_t
<
decltype
(
std
::
get
<
Indices
>
(
declargs
()))
>
...
>
{
return
{};
}
...
...
@@ -312,9 +319,11 @@ public:
// declare new input
template
<
typename
TNewInput
>
auto
input
()
{
using
TNewInputs
=
decltype
(
std
::
tuple_cat
(
std
::
declval
<
TInputs
>
(),
std
::
make_tuple
(
std
::
declval
<
TNewInput
>
())));
static_assert
(
std
::
tuple_size
<
TOutputs
>::
value
==
0
,
"input arg cannot be declared after output"
);
using
TNewInputs
=
decltype
(
std
::
tuple_cat
(
std
::
declval
<
TInputs
>
(),
std
::
declval
<
std
::
tuple
<
TNewInput
>>
()));
return
FunctionCacheBuilder
<
TRet
,
TNewInputs
,
TOutputs
>
{};
}
// declare new output
...
...
@@ -322,31 +331,29 @@ public:
auto
output
()
{
using
TNewOutputs
=
decltype
(
std
::
tuple_cat
(
std
::
declval
<
TOutputs
>
(),
std
::
make_tuple
(
std
::
declval
<
TNewOutput
>
()
)));
std
::
declval
<
std
::
tuple
<
TNewOutput
>>
(
)));
return
FunctionCacheBuilder
<
TRet
,
TInputs
,
TNewOutputs
>
{};
}
// summary
template
<
typename
TFunctor
>
function_t
build
(
TFunctor
func
)
{
function_t
build
(
TFunctor
&&
func
)
{
constexpr
size_t
n_inputs
=
std
::
tuple_size
<
TInputs
>::
value
;
constexpr
size_t
n_outputs
=
std
::
tuple_size
<
TOutputs
>::
value
;
auto
cache
=
std
::
make_shared
<
FunctionCache
<
std
::
string
(
bundle_t
)
>>
();
// bundle -> ser(in args)
cache
->
key_mapper
=
[](
bundle_t
bundle
)
{
StringSerializer
ser
;
bundle
.
template
serialize_params
<
0
,
std
::
tuple_size
<
TInputs
>
::
value
>
(
ser
);
bundle
.
template
serialize_params
<
0
,
n_inputs
>(
ser
);
return
ser
.
take
();
};
// bundle -> ser(out args)
cache
->
value_mapper
=
[
=
](
bundle_t
bundle
)
{
cache
->
value_mapper
=
[
func
](
bundle_t
bundle
)
{
StringSerializer
ser
;
TRet
ret
;
ret
.
value
=
bundle
.
call_by
(
func
);
ret
.
serialize
(
ser
,
Empty
{});
bundle
.
template
serialize_params
<
std
::
tuple_size
<
TInputs
>
::
value
,
std
::
tuple_size
<
TInputs
>::
value
+
std
::
tuple_size
<
TOutputs
>::
value
>
(
ser
);
bundle
.
template
serialize_params
<
n_inputs
,
n_inputs
+
n_outputs
>(
ser
);
return
ser
.
take
();
};
return
[
=
](
auto
&&
...
args
)
mutable
{
...
...
@@ -361,8 +368,6 @@ public:
std
::
forward
<
decltype
(
args
)
>
(
args
)...);
ser
.
reset
((
*
cache
)(
bundle
));
ret
.
deserialize
(
ser
,
Empty
{});
constexpr
size_t
n_inputs
=
std
::
tuple_size
<
TInputs
>::
value
;
constexpr
size_t
n_outputs
=
std
::
tuple_size
<
TOutputs
>::
value
;
bundle
.
template
deserialize_params
<
n_inputs
,
n_inputs
+
n_outputs
>(
ser
);
return
ret
.
value
;
...
...
@@ -394,7 +399,8 @@ public:
return
*
value
;
}
T
deserialize
(
StringSerializer
&
ser
,
Empty
)
{
return
*
value
=
ser
.
read_plain
<
T
>
();
ser
.
read_plain
(
value
);
return
*
value
;
}
};
...
...
@@ -402,16 +408,20 @@ public:
template
<
typename
TSize
,
typename
TItem
>
class
ArrayParam
{
public:
TItem
*
value
;
decltype
(
std
::
declval
<
TItem
>
().
value
)
*
value
;
Empty
serialize
(
StringSerializer
&
ser
,
TSize
size
)
{
TItem
param
;
for
(
TSize
i
=
0
;
i
<
size
;
++
i
)
{
ser
.
write_plain
(
value
[
i
]);
param
.
value
=
value
[
i
];
param
.
serialize
(
ser
,
Empty
{});
}
return
Empty
{};
}
Empty
deserialize
(
StringSerializer
&
ser
,
TSize
size
)
{
TItem
param
;
for
(
TSize
i
=
0
;
i
<
size
;
++
i
)
{
value
[
i
]
=
ser
.
read_plain
<
TItem
>
();
param
.
deserialize
(
ser
,
Empty
{});
value
[
i
]
=
param
.
value
;
}
return
Empty
{};
}
...
...
dnn/src/cuda/api_cache.h
浏览文件 @
684c07d7
...
...
@@ -20,14 +20,16 @@ class CudnnConvDescParam {
public:
cudnnConvolutionDescriptor_t
value
;
Empty
serialize
(
StringSerializer
&
ser
,
Empty
)
{
int
nbDims
=
MEGDNN_MAX_NDIM
;
int
padA
[
MEGDNN_MAX_NDIM
];
int
strideA
[
MEGDNN_MAX_NDIM
];
int
dilationA
[
MEGDNN_MAX_NDIM
];
constexpr
int
maxNbDims
=
CUDNN_DIM_MAX
-
2
;
int
nbDims
=
maxNbDims
;
int
padA
[
maxNbDims
];
int
strideA
[
maxNbDims
];
int
dilationA
[
maxNbDims
];
cudnnConvolutionMode_t
mode
;
cudnnDataType_t
computeType
;
cudnnGetConvolutionNdDescriptor
(
value
,
nbDims
,
&
nbDims
,
padA
,
strideA
,
dilationA
,
&
mode
,
&
computeType
);
cudnnGetConvolutionNdDescriptor
(
value
,
maxNbDims
,
&
nbDims
,
padA
,
strideA
,
dilationA
,
&
mode
,
&
computeType
);
ser
.
write_plain
(
nbDims
);
for
(
int
i
=
0
;
i
<
nbDims
;
++
i
)
{
ser
.
write_plain
(
padA
[
i
]);
...
...
@@ -38,23 +40,8 @@ public:
ser
.
write_plain
(
computeType
);
return
Empty
{};
}
Empty
deserialize
(
StringSerializer
&
ser
,
Empty
)
{
int
ndim
=
ser
.
read_plain
<
int
>
();
int
padA
[
MEGDNN_MAX_NDIM
];
int
strideA
[
MEGDNN_MAX_NDIM
];
int
dilationA
[
MEGDNN_MAX_NDIM
];
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
{
padA
[
i
]
=
ser
.
read_plain
<
int
>
();
strideA
[
i
]
=
ser
.
read_plain
<
int
>
();
dilationA
[
i
]
=
ser
.
read_plain
<
int
>
();
}
cudnnConvolutionMode_t
mode
=
ser
.
read_plain
<
cudnnConvolutionMode_t
>
();
cudnnDataType_t
computeType
=
ser
.
read_plain
<
cudnnDataType_t
>
();
cudnnSetConvolutionNdDescriptor
(
value
,
ndim
,
padA
,
strideA
,
dilationA
,
mode
,
computeType
);
return
Empty
{};
}
};
class
CudnnTensorDescParam
{
public:
cudnnTensorDescriptor_t
value
;
...
...
@@ -63,8 +50,8 @@ public:
cudnnDataType_t
dataType
;
int
dimA
[
MEGDNN_MAX_NDIM
];
int
strideA
[
MEGDNN_MAX_NDIM
];
cudnnGetTensorNdDescriptor
(
value
,
nbDims
,
&
dataType
,
&
nbDims
,
dimA
,
strideA
);
cudnnGetTensorNdDescriptor
(
value
,
MEGDNN_MAX_NDIM
,
&
dataType
,
&
nbDims
,
dimA
,
strideA
);
ser
.
write_plain
(
nbDims
);
for
(
int
i
=
0
;
i
<
nbDims
;
++
i
)
{
ser
.
write_plain
(
dimA
[
i
]);
...
...
@@ -73,21 +60,8 @@ public:
ser
.
write_plain
(
dataType
);
return
Empty
{};
}
Empty
deserialize
(
StringSerializer
&
ser
,
Empty
)
{
int
nbDims
=
MEGDNN_MAX_NDIM
;
cudnnDataType_t
dataType
;
int
dimA
[
MEGDNN_MAX_NDIM
];
int
strideA
[
MEGDNN_MAX_NDIM
];
nbDims
=
ser
.
read_plain
<
int
>
();
for
(
int
i
=
0
;
i
<
nbDims
;
++
i
)
{
dimA
[
i
]
=
ser
.
read_plain
<
int
>
();
strideA
[
i
]
=
ser
.
read_plain
<
int
>
();
}
dataType
=
ser
.
read_plain
<
cudnnDataType_t
>
();
cudnnSetTensorNdDescriptor
(
value
,
dataType
,
nbDims
,
dimA
,
strideA
);
return
Empty
{};
}
};
class
CudnnFilterDescParam
{
public:
cudnnFilterDescriptor_t
value
;
...
...
@@ -106,18 +80,29 @@ public:
ser
.
write_plain
(
format
);
return
Empty
{};
}
};
template
<
typename
T
>
class
CudnnConvAlgoPerfParam
{
public:
T
value
;
Empty
serialize
(
StringSerializer
&
ser
,
Empty
)
{
ser
.
write_plain
(
value
.
algo
);
ser
.
write_plain
(
value
.
status
);
ser
.
write_plain
(
value
.
time
);
ser
.
write_plain
(
value
.
memory
);
ser
.
write_plain
(
value
.
determinism
);
ser
.
write_plain
(
value
.
mathType
);
return
Empty
{};
}
Empty
deserialize
(
StringSerializer
&
ser
,
Empty
)
{
int
nbDims
=
MEGDNN_MAX_NDIM
;
cudnnDataType_t
dataType
;
cudnnTensorFormat_t
format
;
int
filterDimA
[
MEGDNN_MAX_NDIM
];
nbDims
=
ser
.
read_plain
<
int
>
();
for
(
int
i
=
0
;
i
<
nbDims
;
++
i
)
{
filterDimA
[
i
]
=
ser
.
read_plain
<
int
>
();
}
dataType
=
ser
.
read_plain
<
cudnnDataType_t
>
();
format
=
ser
.
read_plain
<
cudnnTensorFormat_t
>
();
cudnnSetFilterNdDescriptor
(
value
,
dataType
,
format
,
nbDims
,
filterDimA
);
ser
.
read_plain
(
&
value
.
algo
);
ser
.
read_plain
(
&
value
.
status
);
ser
.
read_plain
(
&
value
.
time
);
ser
.
read_plain
(
&
value
.
memory
);
ser
.
read_plain
(
&
value
.
determinism
);
ser
.
read_plain
(
&
value
.
mathType
);
return
Empty
{};
}
};
...
...
dnn/src/cuda/handle.cpp
浏览文件 @
684c07d7
...
...
@@ -165,7 +165,8 @@ HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) {
.
input
<
CudnnTensorDescParam
>
()
.
input
<
Param
<
int
>>
()
.
output
<
RefArraySizeParam
<
int
>>
()
.
output
<
ArrayParam
<
int
,
cudnnConvolutionFwdAlgoPerf_t
>>
()
.
output
<
ArrayParam
<
int
,
Param
<
cudnnConvolutionFwdAlgoPerf_t
>>>
()
.
ret
<
Param
<
cudnnStatus_t
>>
()
.
build
(
&
cudnnGetConvolutionForwardAlgorithm_v7
);
GetConvolutionForwardAlgorithmMaxCount
=
...
...
@@ -196,8 +197,8 @@ HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) {
.
input
<
CudnnTensorDescParam
>
()
.
input
<
Param
<
int
>>
()
.
output
<
RefArraySizeParam
<
int
>>
()
.
output
<
ArrayParam
<
int
,
cudnnConvolutionBwdDataAlgoPerf_t
>>
()
.
output
<
ArrayParam
<
int
,
Param
<
cudnnConvolutionBwdDataAlgoPerf_t
>
>>
()
.
ret
<
Param
<
cudnnStatus_t
>>
()
.
build
(
&
cudnnGetConvolutionBackwardDataAlgorithm_v7
);
GetConvolutionBackwardDataAlgorithmMaxCount
=
...
...
@@ -228,8 +229,8 @@ HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) {
.
input
<
CudnnFilterDescParam
>
()
.
input
<
Param
<
int
>>
()
.
output
<
RefArraySizeParam
<
int
>>
()
.
output
<
ArrayParam
<
int
,
cudnnConvolutionBwdFilterAlgoPerf_t
>>
()
.
output
<
ArrayParam
<
int
,
Param
<
cudnnConvolutionBwdFilterAlgoPerf_t
>
>>
()
.
ret
<
Param
<
cudnnStatus_t
>>
()
.
build
(
&
cudnnGetConvolutionBackwardFilterAlgorithm_v7
);
GetConvolutionBackwardFilterAlgorithmMaxCount
=
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录