Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
c90e0b54
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
c90e0b54
编写于
12月 27, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(arm): optimize arm uint16 relayout with n=4
GitOrigin-RevId: 5779c6b9c15aa52447e32f8d95d1b845c6d21e18
上级
202b4071
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
81 addition
and
0 deletion
+81
-0
dnn/src/aarch64/relayout/opr_impl.cpp
dnn/src/aarch64/relayout/opr_impl.cpp
+69
-0
dnn/test/aarch64/relayout.cpp
dnn/test/aarch64/relayout.cpp
+12
-0
未找到文件。
dnn/src/aarch64/relayout/opr_impl.cpp
浏览文件 @
c90e0b54
...
...
@@ -305,6 +305,64 @@ static inline void trans_8x8_u16(
vst1q_u16
(
dst_ptr
+
7
*
dst_step
,
row_7
);
}
static
inline
void
trans_8x4_u16
(
const
void
*
src
,
void
*
dst
,
const
size_t
src_step
,
const
size_t
dst_step
)
{
uint16_t
*
src_ptr
=
(
uint16_t
*
)
src
;
uint16_t
*
dst_ptr
=
(
uint16_t
*
)
dst
;
uint16x4_t
src0
=
vld1_u16
(
src_ptr
+
0
*
src_step
);
// A0A1A2A3
uint16x4_t
src1
=
vld1_u16
(
src_ptr
+
1
*
src_step
);
// B0B1B2B3
uint16x4_t
src2
=
vld1_u16
(
src_ptr
+
2
*
src_step
);
// C0C1C2C3
uint16x4_t
src3
=
vld1_u16
(
src_ptr
+
3
*
src_step
);
// D0D1D2D3
uint16x4_t
src4
=
vld1_u16
(
src_ptr
+
4
*
src_step
);
// E0E1E2E3
uint16x4_t
src5
=
vld1_u16
(
src_ptr
+
5
*
src_step
);
// F0F1F2F3
uint16x4_t
src6
=
vld1_u16
(
src_ptr
+
6
*
src_step
);
// G0G1G2G3
uint16x4_t
src7
=
vld1_u16
(
src_ptr
+
7
*
src_step
);
// H0H1H2H3
uint16x4_t
ab_low
=
vzip1_u16
(
src0
,
src1
);
// A0B0A1B1
uint16x4_t
ab_high
=
vzip2_u16
(
src0
,
src1
);
// A2B2A3B3
uint16x4_t
cd_low
=
vzip1_u16
(
src2
,
src3
);
// C0D0C1D1
uint16x4_t
cd_high
=
vzip2_u16
(
src2
,
src3
);
// C2D2C3D3
uint16x4_t
ef_low
=
vzip1_u16
(
src4
,
src5
);
// E0F0E1F1
uint16x4_t
ef_high
=
vzip2_u16
(
src4
,
src5
);
// E2F2E3F3
uint16x4_t
gh_low
=
vzip1_u16
(
src6
,
src7
);
// G0H0G1H1
uint16x4_t
gh_high
=
vzip2_u16
(
src6
,
src7
);
// G2H2G3H3
uint16x4_t
abcd_0
=
vreinterpret_u16_u32
(
vzip1_u32
(
vreinterpret_u32_u16
(
ab_low
),
vreinterpret_u32_u16
(
cd_low
)));
// A0B0C0D0
uint16x4_t
abcd_1
=
vreinterpret_u16_u32
(
vzip2_u32
(
vreinterpret_u32_u16
(
ab_low
),
vreinterpret_u32_u16
(
cd_low
)));
// A1B1C1D1
uint16x4_t
abcd_2
=
vreinterpret_u16_u32
(
vzip1_u32
(
vreinterpret_u32_u16
(
ab_high
),
vreinterpret_u32_u16
(
cd_high
)));
// A2B2C2D2
uint16x4_t
abcd_3
=
vreinterpret_u16_u32
(
vzip2_u32
(
vreinterpret_u32_u16
(
ab_high
),
vreinterpret_u32_u16
(
cd_high
)));
// A3B3C3D3
uint16x4_t
efgh_0
=
vreinterpret_u16_u32
(
vzip1_u32
(
vreinterpret_u32_u16
(
ef_low
),
vreinterpret_u32_u16
(
gh_low
)));
// E0F0G0H0
uint16x4_t
efgh_1
=
vreinterpret_u16_u32
(
vzip2_u32
(
vreinterpret_u32_u16
(
ef_low
),
vreinterpret_u32_u16
(
gh_low
)));
// E1F1G1H1
uint16x4_t
efgh_2
=
vreinterpret_u16_u32
(
vzip1_u32
(
vreinterpret_u32_u16
(
ef_high
),
vreinterpret_u32_u16
(
gh_high
)));
// E2F2G2H2
uint16x4_t
efgh_3
=
vreinterpret_u16_u32
(
vzip2_u32
(
vreinterpret_u32_u16
(
ef_high
),
vreinterpret_u32_u16
(
gh_high
)));
// E3F3G3H3
uint16x8_t
row_0
=
vcombine_u16
(
abcd_0
,
efgh_0
);
uint16x8_t
row_1
=
vcombine_u16
(
abcd_1
,
efgh_1
);
uint16x8_t
row_2
=
vcombine_u16
(
abcd_2
,
efgh_2
);
uint16x8_t
row_3
=
vcombine_u16
(
abcd_3
,
efgh_3
);
vst1q_u16
(
dst_ptr
+
0
*
dst_step
,
row_0
);
vst1q_u16
(
dst_ptr
+
1
*
dst_step
,
row_1
);
vst1q_u16
(
dst_ptr
+
2
*
dst_step
,
row_2
);
vst1q_u16
(
dst_ptr
+
3
*
dst_step
,
row_3
);
}
}
// anonymous namespace
namespace
megdnn
{
...
...
@@ -346,6 +404,17 @@ void transpose_block<Transpose2Byte>(
trans_8x8_u16
(
src
,
dst
,
src_stride
,
dst_stride
);
}
template
<
>
void
transpose_block
<
Transpose2Byte
>
(
const
Transpose2Byte
*
src
,
Transpose2Byte
*
dst
,
const
size_t
src_stride
,
const
size_t
dst_stride
,
size_t
block_h
,
size_t
block_w
)
{
if
(
block_h
==
8
&&
block_w
==
4
)
{
trans_8x4_u16
(
src
,
dst
,
src_stride
,
dst_stride
);
}
else
{
transpose_block_fallback
(
src
,
dst
,
src_stride
,
dst_stride
,
block_h
,
block_w
);
}
}
}
// namespace transpose_fallback
}
// namespace relayout
}
// namespace megdnn
...
...
dnn/test/aarch64/relayout.cpp
浏览文件 @
c90e0b54
...
...
@@ -67,6 +67,18 @@ TEST_F(AARCH64, RelayoutBig) {
checker
.
execl
({
src
,
dst
});
}
TEST_F
(
AARCH64
,
RelayoutSplict
)
{
Checker
<
Relayout
>
checker
(
handle
());
ConsecutiveRNG
rng
;
checker
.
set_rng
(
0
,
&
rng
);
int
m
=
4
;
for
(
int
n
:
{
4
,
28
})
{
TensorLayout
src
({(
size_t
)
m
,
(
size_t
)
n
},
{
1
,
m
},
dtype
::
Uint16
());
TensorLayout
dst
({(
size_t
)
m
,
(
size_t
)
n
},
{
n
,
1
},
dtype
::
Uint16
());
checker
.
execl
({
src
,
dst
});
}
}
TEST_F
(
AARCH64
,
RelayoutRecord
)
{
TaskRecordChecker
<
Relayout
>
checker
(
0
);
std
::
vector
<::
megdnn
::
DType
>
dtype_vec
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录