Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
9d5c5c07
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
9d5c5c07
编写于
6月 19, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/naive): workspacebundle support 2D
GitOrigin-RevId: 4408bb9e1d2ced2f9e16cc6784882f89be4a3cf2
上级
f268e0f8
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
104 addition
and
16 deletion
+104
-16
dnn/src/common/utils.cpp
dnn/src/common/utils.cpp
+53
-9
dnn/src/common/utils.h
dnn/src/common/utils.h
+33
-7
dnn/test/common/test_basic_types.cpp
dnn/test/common/test_basic_types.cpp
+18
-0
未找到文件。
dnn/src/common/utils.cpp
浏览文件 @
9d5c5c07
...
...
@@ -156,15 +156,42 @@ void megdnn::infer_conv_shape2d(size_t ih, size_t iw, size_t fh, size_t fw,
WorkspaceBundle
::
WorkspaceBundle
(
void
*
ptr
,
SmallVector
<
size_t
>
sizes_in_bytes
,
size_t
align_in_bytes
)
:
m_ptr
(
ptr
),
m_sizes
(
std
::
move
(
sizes_in_bytes
)),
m_align_in_bytes
(
align_in_bytes
)
{
m_aligned_sizes
.
reserve
(
m_sizes
.
size
());
for
(
auto
size
:
m_sizes
)
{
m_sizes
.
push_back
(
sizes_in_bytes
);
size_t
reduce_size
=
0
_z
;
m_reduce_num
.
push_back
(
0
_z
);
for
(
auto
size
:
m_sizes
[
0
])
{
auto
aligned_size
=
size
;
if
(
size
%
m_align_in_bytes
!=
0
)
{
aligned_size
+=
m_align_in_bytes
-
size
%
m_align_in_bytes
;
}
m_aligned_sizes
.
push_back
(
aligned_size
);
m_reduce_sizes
.
push_back
(
reduce_size
);
reduce_size
+=
aligned_size
;
}
}
WorkspaceBundle
::
WorkspaceBundle
(
SmallVector
<
SmallVector
<
size_t
>>
vector_sizes_in_bytes
,
void
*
ptr
,
size_t
align_in_bytes
)
:
m_ptr
(
ptr
),
m_sizes
(
vector_sizes_in_bytes
),
m_align_in_bytes
(
align_in_bytes
)
{
size_t
nr_workspace
=
0
_z
;
size_t
reduce_size
=
0
_z
;
for
(
auto
sizes_in_bytes
:
vector_sizes_in_bytes
)
{
m_reduce_num
.
push_back
(
nr_workspace
);
for
(
auto
size
:
sizes_in_bytes
)
{
auto
aligned_size
=
size
;
if
(
size
%
m_align_in_bytes
!=
0
)
{
aligned_size
+=
m_align_in_bytes
-
size
%
m_align_in_bytes
;
}
m_aligned_sizes
.
push_back
(
aligned_size
);
m_reduce_sizes
.
push_back
(
reduce_size
);
reduce_size
+=
aligned_size
;
nr_workspace
++
;
}
}
}
...
...
@@ -172,22 +199,39 @@ void* WorkspaceBundle::ptr() const {
return
m_ptr
;
}
void
*
WorkspaceBundle
::
get
(
size_t
i
)
const
{
void
*
WorkspaceBundle
::
get
(
size_t
dim1
,
size_t
dim0
)
const
{
megdnn_assert
(
dim1
<
m_sizes
.
size
(),
"dim1 is out of range"
);
megdnn_assert
(
dim0
<
m_sizes
[
dim1
].
size
(),
"dim0 is out of range"
);
auto
addr
=
reinterpret_cast
<
uintptr_t
>
(
m_ptr
);
if
(
addr
%
m_align_in_bytes
!=
0
)
addr
+=
m_align_in_bytes
-
addr
%
m_align_in_bytes
;
for
(
size_t
j
=
0
;
j
<
i
;
++
j
)
{
addr
+=
m_aligned_sizes
[
j
];
}
size_t
index
=
m_reduce_num
[
dim1
]
+
dim0
;
addr
+=
m_reduce_sizes
[
index
];
return
reinterpret_cast
<
void
*>
(
addr
);
}
void
*
WorkspaceBundle
::
get
(
size_t
dim0
)
const
{
megdnn_assert
(
dim0
<
m_aligned_sizes
.
size
(),
"dim0 is out of range"
);
auto
addr
=
reinterpret_cast
<
uintptr_t
>
(
m_ptr
);
if
(
addr
%
m_align_in_bytes
!=
0
)
addr
+=
m_align_in_bytes
-
addr
%
m_align_in_bytes
;
addr
+=
m_reduce_sizes
[
dim0
];
return
reinterpret_cast
<
void
*>
(
addr
);
}
size_t
WorkspaceBundle
::
nr_workspace
()
const
{
return
m_sizes
.
size
();
return
m_aligned_sizes
.
size
();
}
size_t
WorkspaceBundle
::
get_size
(
size_t
dim1
,
size_t
dim0
)
const
{
megdnn_assert
(
dim1
<
m_sizes
.
size
(),
"dim1 is out of range"
);
megdnn_assert
(
dim0
<
m_sizes
[
dim1
].
size
(),
"dim0 is out of range"
);
return
m_sizes
[
dim1
][
dim0
];
}
size_t
WorkspaceBundle
::
get_size
(
size_t
i
)
const
{
return
m_sizes
[
i
];
size_t
WorkspaceBundle
::
get_size
(
size_t
dim0
)
const
{
megdnn_assert
(
dim0
<
m_aligned_sizes
.
size
(),
"dim0 is out of range"
);
return
m_sizes
[
0
][
dim0
];
}
void
WorkspaceBundle
::
set
(
void
*
ptr
)
{
...
...
dnn/src/common/utils.h
浏览文件 @
9d5c5c07
...
...
@@ -194,8 +194,15 @@ std::unique_ptr<T> make_unique(Args&&... args) {
*/
class
WorkspaceBundle
{
public:
WorkspaceBundle
(
void
*
ptr
,
SmallVector
<
size_t
>
sizes_in_bytes
,
WorkspaceBundle
(
void
*
ptr
=
nullptr
,
SmallVector
<
size_t
>
sizes_in_bytes
=
{},
size_t
align_in_bytes
=
512
);
/**
* construct 2D workspace buldle
*/
WorkspaceBundle
(
SmallVector
<
SmallVector
<
size_t
>>
vector_sizes_in_bytes
,
void
*
ptr
,
size_t
align_in_bytes
=
512
);
/**
* \returns raw workspace ptr.
*
...
...
@@ -204,26 +211,45 @@ public:
*/
void
*
ptr
()
const
;
/**
* \returns the
i-th
workspace ptr (aligned)
* \returns the
2D [dim1, dim0]
workspace ptr (aligned)
*/
void
*
get
(
size_t
i
)
const
;
void
*
get
(
size_t
dim1
,
size_t
dim0
)
const
;
/**
* \returns the 1D [dim0] workspace ptr (aligned)
*/
void
*
get
(
size_t
dim0
)
const
;
/**
* \returns total size taking into account paddings to solve alignment
* issue.
*/
size_t
total_size_in_bytes
()
const
;
size_t
get_size
(
size_t
i
)
const
;
/**
* \return the 2D [dim1, dim0] workspace size
*/
size_t
get_size
(
size_t
dim1
,
size_t
dim0
)
const
;
/**
* \return the 1D [dim0] workspace size
*/
size_t
get_size
(
size_t
dim0
)
const
;
size_t
nr_workspace
()
const
;
void
set
(
void
*
ptr
);
Workspace
get_workspace
(
size_t
i
)
const
{
return
{
static_cast
<
dt_byte
*>
(
get
(
i
)),
get_size
(
i
)};
Workspace
get_workspace
(
size_t
dim1
,
size_t
dim0
)
const
{
return
{
static_cast
<
dt_byte
*>
(
get
(
dim1
,
dim0
)),
get_size
(
dim1
,
dim0
)};
}
Workspace
get_workspace
(
size_t
dim0
)
const
{
return
{
static_cast
<
dt_byte
*>
(
get
(
dim0
)),
get_size
(
dim0
)};
}
private:
void
*
m_ptr
;
SmallVector
<
size_t
>
m_sizes
;
SmallVector
<
SmallVector
<
size_t
>
>
m_sizes
;
SmallVector
<
size_t
>
m_aligned_sizes
;
//! all workspace size prefix sum
SmallVector
<
size_t
>
m_reduce_sizes
;
//! dim1 workspace number prefix sum
SmallVector
<
size_t
>
m_reduce_num
;
size_t
m_align_in_bytes
;
};
...
...
dnn/test/common/test_basic_types.cpp
浏览文件 @
9d5c5c07
...
...
@@ -11,6 +11,7 @@
#include "megdnn/basic_types.h"
#include "megdnn/tensor_format.h"
#include "src/common/utils.h"
// clang-format off
#include "test/common/utils.h"
...
...
@@ -278,4 +279,21 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_W) {
}
}
TEST
(
MISC
,
WORKSPACE_BUNDLE
)
{
WorkspaceBundle
bundle
{
{{
100
,
200
},
{
435
,
234
,
143
},
{
422
,
1325
,
728
}},
nullptr
,
64
};
bundle
.
set
(
reinterpret_cast
<
void
*>
(
82l
));
ASSERT_EQ
(
bundle
.
get
(
0
),
reinterpret_cast
<
void
*>
(
128l
));
void
*
dst
=
reinterpret_cast
<
void
*>
(
128
+
round_up
(
100
,
64
));
ASSERT_EQ
(
bundle
.
get
(
0
,
1
),
dst
);
dst
=
reinterpret_cast
<
void
*>
(
128
+
round_up
(
100
,
64
)
+
round_up
(
200
,
64
)
+
round_up
(
435
,
64
));
ASSERT_EQ
(
bundle
.
get
(
1
,
1
),
dst
);
dst
=
reinterpret_cast
<
void
*>
(
128l
+
round_up
(
100
,
64
)
+
round_up
(
200
,
64
)
+
round_up
(
435
,
64
)
+
round_up
(
234
,
64
)
+
round_up
(
143
,
64
)
+
round_up
(
422
,
64
)
+
round_up
(
1325
,
64
));
ASSERT_EQ
(
bundle
.
get
(
2
,
2
),
dst
);
}
// vim: syntax=cpp.doxygen
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录