Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
b3efb72b
Mace
项目概览
Xiaomi
/
Mace
通知
106
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b3efb72b
编写于
5月 22, 2018
作者:
B
Bin Li
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize armv7 armv8 conv1x7 and conv7x1
上级
cb70c32a
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
689 addition
and
125 deletion
+689
-125
mace/kernels/arm/conv_2d_neon.h
mace/kernels/arm/conv_2d_neon.h
+35
-0
mace/kernels/arm/conv_2d_neon_15x1.cc
mace/kernels/arm/conv_2d_neon_15x1.cc
+2
-2
mace/kernels/arm/conv_2d_neon_1x7.cc
mace/kernels/arm/conv_2d_neon_1x7.cc
+256
-0
mace/kernels/arm/conv_2d_neon_3x3.cc
mace/kernels/arm/conv_2d_neon_3x3.cc
+11
-35
mace/kernels/arm/conv_2d_neon_5x5.cc
mace/kernels/arm/conv_2d_neon_5x5.cc
+8
-32
mace/kernels/arm/conv_2d_neon_7x1.cc
mace/kernels/arm/conv_2d_neon_7x1.cc
+297
-0
mace/kernels/arm/conv_2d_neon_7x7.cc
mace/kernels/arm/conv_2d_neon_7x7.cc
+24
-48
mace/kernels/conv_2d.h
mace/kernels/conv_2d.h
+21
-1
mace/kernels/gemm.cc
mace/kernels/gemm.cc
+1
-2
mace/ops/conv_2d_benchmark.cc
mace/ops/conv_2d_benchmark.cc
+4
-0
mace/ops/conv_2d_test.cc
mace/ops/conv_2d_test.cc
+30
-5
未找到文件。
mace/kernels/arm/conv_2d_neon.h
浏览文件 @
b3efb72b
...
...
@@ -47,6 +47,18 @@ extern void Conv2dNeonK5x5S1(const float *input,
const
index_t
*
out_shape
,
float
*
output
);
extern
void
Conv2dNeonK1x7S1
(
const
float
*
input
,
const
float
*
filter
,
const
index_t
*
in_shape
,
const
index_t
*
out_shape
,
float
*
output
);
extern
void
Conv2dNeonK7x1S1
(
const
float
*
input
,
const
float
*
filter
,
const
index_t
*
in_shape
,
const
index_t
*
out_shape
,
float
*
output
);
extern
void
Conv2dNeonK7x7S1
(
const
float
*
input
,
const
float
*
filter
,
const
index_t
*
in_shape
,
...
...
@@ -77,6 +89,29 @@ extern void Conv2dNeonK15x1S1(const float *input,
const
index_t
*
out_shape
,
float
*
output
);
// calculate one output channel and one input channel
inline
void
Conv2dCPUKHxKWCalc
(
const
float
*
in_ptr
,
const
float
*
filter_ptr
,
const
index_t
in_width
,
const
index_t
filter_height
,
const
index_t
filter_width
,
const
index_t
out_height
,
const
index_t
out_width
,
float
*
out_ptr
,
const
int
stride
)
{
for
(
index_t
h
=
0
;
h
<
out_height
;
++
h
)
{
for
(
index_t
w
=
0
;
w
<
out_width
;
++
w
)
{
for
(
int
i
=
0
;
i
<
filter_height
;
++
i
)
{
for
(
int
j
=
0
;
j
<
filter_width
;
++
j
)
{
out_ptr
[
h
*
out_width
+
w
]
+=
in_ptr
[(
h
*
stride
+
i
)
*
in_width
+
(
w
*
stride
+
j
)]
*
filter_ptr
[
i
*
filter_width
+
j
];
}
}
}
}
}
}
// namespace kernels
}
// namespace mace
...
...
mace/kernels/arm/conv_2d_neon_15x1.cc
浏览文件 @
b3efb72b
...
...
@@ -76,7 +76,7 @@ void Conv2dNeonK15x1S1(const float *input,
input
+
b
*
in_batch_size
+
c
*
in_image_size
;
const
float
*
filter_ptr
=
filter
+
m
*
in_channels
*
15
+
c
*
15
;
#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__)
/* load filter (1 outch x
1 height x 4
width) */
/* load filter (1 outch x
4 height x 1
width) */
float32x4_t
vf0
,
vf1
,
vf2
,
vf3
;
vf0
=
vld1q_f32
(
filter_ptr
);
vf1
=
vld1q_f32
(
filter_ptr
+
4
);
...
...
@@ -87,7 +87,7 @@ void Conv2dNeonK15x1S1(const float *input,
for
(
index_t
wt
=
0
;
wt
<
tile_width
&&
w
+
wt
<
out_width
;
++
wt
)
{
// load output
index_t
out_offset
=
h
*
out_width
+
w
+
wt
;
// output (1 outch x
1 height x 4
width): vo_outch_height
// output (1 outch x
4 height x 1
width): vo_outch_height
float32x4_t
vo
=
{
out_ptr_base
[
out_offset
],
out_ptr_base
[
out_offset
+
out_width
],
out_ptr_base
[
out_offset
+
2
*
out_width
],
...
...
mace/kernels/arm/conv_2d_neon_1x7.cc
0 → 100644
浏览文件 @
b3efb72b
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#include "mace/kernels/arm/conv_2d_neon.h"
namespace
mace
{
namespace
kernels
{
// Ho = 1, Wo = 4, Co = 4
void
Conv2dNeonK1x7S1
(
const
float
*
input
,
const
float
*
filter
,
const
index_t
*
in_shape
,
const
index_t
*
out_shape
,
float
*
output
)
{
const
index_t
in_image_size
=
in_shape
[
2
]
*
in_shape
[
3
];
const
index_t
out_image_size
=
out_shape
[
2
]
*
out_shape
[
3
];
const
index_t
in_batch_size
=
in_shape
[
1
]
*
in_image_size
;
const
index_t
out_batch_size
=
out_shape
[
1
]
*
out_image_size
;
#pragma omp parallel for collapse(2)
for
(
index_t
b
=
0
;
b
<
out_shape
[
0
];
++
b
)
{
for
(
index_t
m
=
0
;
m
<
out_shape
[
1
];
m
+=
4
)
{
const
index_t
out_channels
=
out_shape
[
1
];
const
index_t
out_height
=
out_shape
[
2
];
const
index_t
out_width
=
out_shape
[
3
];
const
index_t
in_channels
=
in_shape
[
1
];
const
index_t
in_width
=
in_shape
[
3
];
if
(
m
+
3
<
out_channels
)
{
float
*
out_ptr0_base
=
output
+
b
*
out_batch_size
+
m
*
out_image_size
;
#if defined(MACE_ENABLE_NEON)
float
*
out_ptr1_base
=
output
+
b
*
out_batch_size
+
(
m
+
1
)
*
out_image_size
;
float
*
out_ptr2_base
=
output
+
b
*
out_batch_size
+
(
m
+
2
)
*
out_image_size
;
float
*
out_ptr3_base
=
output
+
b
*
out_batch_size
+
(
m
+
3
)
*
out_image_size
;
#endif
for
(
index_t
c
=
0
;
c
<
in_channels
;
++
c
)
{
const
float
*
in_ptr_base
=
input
+
b
*
in_batch_size
+
c
*
in_image_size
;
const
float
*
filter_ptr0
=
filter
+
m
*
in_channels
*
7
+
c
*
7
;
#if defined(MACE_ENABLE_NEON)
const
float
*
filter_ptr1
=
filter
+
(
m
+
1
)
*
in_channels
*
7
+
c
*
7
;
const
float
*
filter_ptr2
=
filter
+
(
m
+
2
)
*
in_channels
*
7
+
c
*
7
;
const
float
*
filter_ptr3
=
filter
+
(
m
+
3
)
*
in_channels
*
7
+
c
*
7
;
/* load filter (4 outch x 1 height x 4 width) */
float32x4_t
vf00
,
vf01
;
float32x4_t
vf10
,
vf11
;
float32x4_t
vf20
,
vf21
;
float32x4_t
vf30
,
vf31
;
vf00
=
vld1q_f32
(
filter_ptr0
);
vf01
=
vld1q_f32
(
filter_ptr0
+
3
);
vf10
=
vld1q_f32
(
filter_ptr1
);
vf11
=
vld1q_f32
(
filter_ptr1
+
3
);
vf20
=
vld1q_f32
(
filter_ptr2
);
vf21
=
vld1q_f32
(
filter_ptr2
+
3
);
vf30
=
vld1q_f32
(
filter_ptr3
);
vf31
=
vld1q_f32
(
filter_ptr3
+
3
);
for
(
index_t
h
=
0
;
h
<
out_height
;
++
h
)
{
for
(
index_t
w
=
0
;
w
+
3
<
out_width
;
w
+=
4
)
{
// output (4 outch x 1 height x 4 width): vo_outch_height
float32x4_t
vo0
,
vo1
,
vo2
,
vo3
;
// load output
index_t
out_offset
=
h
*
out_width
+
w
;
vo0
=
vld1q_f32
(
out_ptr0_base
+
out_offset
);
vo1
=
vld1q_f32
(
out_ptr1_base
+
out_offset
);
vo2
=
vld1q_f32
(
out_ptr2_base
+
out_offset
);
vo3
=
vld1q_f32
(
out_ptr3_base
+
out_offset
);
// input (3 slide)
float32x4_t
vi0
,
vi1
,
vi2
,
vi3
,
vi4
,
vi5
,
vi6
,
vi8
;
// input offset
index_t
in_offset
=
h
*
in_width
+
w
;
// load input
vi0
=
vld1q_f32
(
in_ptr_base
+
in_offset
);
vi4
=
vld1q_f32
(
in_ptr_base
+
in_offset
+
4
);
vi8
=
vld1q_f32
(
in_ptr_base
+
in_offset
+
8
);
vi1
=
vextq_f32
(
vi0
,
vi4
,
1
);
vi2
=
vextq_f32
(
vi0
,
vi4
,
2
);
vi3
=
vextq_f32
(
vi0
,
vi4
,
3
);
vi5
=
vextq_f32
(
vi4
,
vi8
,
1
);
vi6
=
vextq_f32
(
vi4
,
vi8
,
2
);
#if defined(__aarch64__)
/* outch 0 */
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi0
,
vf00
,
0
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi1
,
vf00
,
1
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi2
,
vf00
,
2
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi3
,
vf00
,
3
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi4
,
vf01
,
1
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi5
,
vf01
,
2
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi6
,
vf01
,
3
);
/* outch 1 */
vo1
=
vfmaq_laneq_f32
(
vo1
,
vi0
,
vf10
,
0
);
vo1
=
vfmaq_laneq_f32
(
vo1
,
vi1
,
vf10
,
1
);
vo1
=
vfmaq_laneq_f32
(
vo1
,
vi2
,
vf10
,
2
);
vo1
=
vfmaq_laneq_f32
(
vo1
,
vi3
,
vf10
,
3
);
vo1
=
vfmaq_laneq_f32
(
vo1
,
vi4
,
vf11
,
1
);
vo1
=
vfmaq_laneq_f32
(
vo1
,
vi5
,
vf11
,
2
);
vo1
=
vfmaq_laneq_f32
(
vo1
,
vi6
,
vf11
,
3
);
/* outch 2 */
vo2
=
vfmaq_laneq_f32
(
vo2
,
vi0
,
vf20
,
0
);
vo2
=
vfmaq_laneq_f32
(
vo2
,
vi1
,
vf20
,
1
);
vo2
=
vfmaq_laneq_f32
(
vo2
,
vi2
,
vf20
,
2
);
vo2
=
vfmaq_laneq_f32
(
vo2
,
vi3
,
vf20
,
3
);
vo2
=
vfmaq_laneq_f32
(
vo2
,
vi4
,
vf21
,
1
);
vo2
=
vfmaq_laneq_f32
(
vo2
,
vi5
,
vf21
,
2
);
vo2
=
vfmaq_laneq_f32
(
vo2
,
vi6
,
vf21
,
3
);
/* outch 3 */
vo3
=
vfmaq_laneq_f32
(
vo3
,
vi0
,
vf30
,
0
);
vo3
=
vfmaq_laneq_f32
(
vo3
,
vi1
,
vf30
,
1
);
vo3
=
vfmaq_laneq_f32
(
vo3
,
vi2
,
vf30
,
2
);
vo3
=
vfmaq_laneq_f32
(
vo3
,
vi3
,
vf30
,
3
);
vo3
=
vfmaq_laneq_f32
(
vo3
,
vi4
,
vf31
,
1
);
vo3
=
vfmaq_laneq_f32
(
vo3
,
vi5
,
vf31
,
2
);
vo3
=
vfmaq_laneq_f32
(
vo3
,
vi6
,
vf31
,
3
);
#else
/* outch 0 */
vo0
=
vmlaq_lane_f32
(
vo0
,
vi0
,
vget_low_f32
(
vf00
),
0
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi1
,
vget_low_f32
(
vf00
),
1
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi2
,
vget_high_f32
(
vf00
),
0
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi3
,
vget_high_f32
(
vf00
),
1
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi4
,
vget_low_f32
(
vf01
),
1
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi5
,
vget_high_f32
(
vf01
),
0
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi6
,
vget_high_f32
(
vf01
),
1
);
/* outch 1 */
vo1
=
vmlaq_lane_f32
(
vo1
,
vi0
,
vget_low_f32
(
vf10
),
0
);
vo1
=
vmlaq_lane_f32
(
vo1
,
vi1
,
vget_low_f32
(
vf10
),
1
);
vo1
=
vmlaq_lane_f32
(
vo1
,
vi2
,
vget_high_f32
(
vf10
),
0
);
vo1
=
vmlaq_lane_f32
(
vo1
,
vi3
,
vget_high_f32
(
vf10
),
1
);
vo1
=
vmlaq_lane_f32
(
vo1
,
vi4
,
vget_low_f32
(
vf11
),
1
);
vo1
=
vmlaq_lane_f32
(
vo1
,
vi5
,
vget_high_f32
(
vf11
),
0
);
vo1
=
vmlaq_lane_f32
(
vo1
,
vi6
,
vget_high_f32
(
vf11
),
1
);
/* outch 2 */
vo2
=
vmlaq_lane_f32
(
vo2
,
vi0
,
vget_low_f32
(
vf20
),
0
);
vo2
=
vmlaq_lane_f32
(
vo2
,
vi1
,
vget_low_f32
(
vf20
),
1
);
vo2
=
vmlaq_lane_f32
(
vo2
,
vi2
,
vget_high_f32
(
vf20
),
0
);
vo2
=
vmlaq_lane_f32
(
vo2
,
vi3
,
vget_high_f32
(
vf20
),
1
);
vo2
=
vmlaq_lane_f32
(
vo2
,
vi4
,
vget_low_f32
(
vf21
),
1
);
vo2
=
vmlaq_lane_f32
(
vo2
,
vi5
,
vget_high_f32
(
vf21
),
0
);
vo2
=
vmlaq_lane_f32
(
vo2
,
vi6
,
vget_high_f32
(
vf21
),
1
);
/* outch 3 */
vo3
=
vmlaq_lane_f32
(
vo3
,
vi0
,
vget_low_f32
(
vf30
),
0
);
vo3
=
vmlaq_lane_f32
(
vo3
,
vi1
,
vget_low_f32
(
vf30
),
1
);
vo3
=
vmlaq_lane_f32
(
vo3
,
vi2
,
vget_high_f32
(
vf30
),
0
);
vo3
=
vmlaq_lane_f32
(
vo3
,
vi3
,
vget_high_f32
(
vf30
),
1
);
vo3
=
vmlaq_lane_f32
(
vo3
,
vi4
,
vget_low_f32
(
vf31
),
1
);
vo3
=
vmlaq_lane_f32
(
vo3
,
vi5
,
vget_high_f32
(
vf31
),
0
);
vo3
=
vmlaq_lane_f32
(
vo3
,
vi6
,
vget_high_f32
(
vf31
),
1
);
#endif
vst1q_f32
(
out_ptr0_base
+
out_offset
,
vo0
);
vst1q_f32
(
out_ptr1_base
+
out_offset
,
vo1
);
vst1q_f32
(
out_ptr2_base
+
out_offset
,
vo2
);
vst1q_f32
(
out_ptr3_base
+
out_offset
,
vo3
);
}
// w
}
// h
#else
for
(
index_t
oc
=
0
;
oc
<
4
;
++
oc
)
{
Conv2dCPUKHxKWCalc
(
in_ptr_base
,
filter_ptr0
+
oc
*
in_channels
*
7
,
in_width
,
1
,
7
,
out_height
,
out_width
,
out_ptr0_base
+
oc
*
out_image_size
,
1
);
}
#endif
}
// c
}
else
{
for
(
index_t
mm
=
m
;
mm
<
out_channels
;
++
mm
)
{
float
*
out_ptr0_base
=
output
+
b
*
out_batch_size
+
mm
*
out_image_size
;
for
(
index_t
c
=
0
;
c
<
in_channels
;
++
c
)
{
const
float
*
in_ptr_base
=
input
+
b
*
in_batch_size
+
c
*
in_image_size
;
const
float
*
filter_ptr0
=
filter
+
mm
*
in_channels
*
7
+
c
*
7
;
#if defined(MACE_ENABLE_NEON)
/* load filter (1 outch x 1 height x 4 width) */
float32x4_t
vf00
,
vf01
;
vf00
=
vld1q_f32
(
filter_ptr0
);
vf01
=
vld1q_f32
(
filter_ptr0
+
3
);
for
(
index_t
h
=
0
;
h
<
out_height
;
++
h
)
{
for
(
index_t
w
=
0
;
w
+
3
<
out_width
;
w
+=
4
)
{
// output (1 outch x 1 height x 4 width): vo_outch_height
float32x4_t
vo0
;
// load output
index_t
out_offset
=
h
*
out_width
+
w
;
vo0
=
vld1q_f32
(
out_ptr0_base
+
out_offset
);
// input (3 slide)
float32x4_t
vi0
,
vi1
,
vi2
,
vi3
,
vi4
,
vi5
,
vi6
,
vi8
;
// input offset
index_t
in_offset
=
h
*
in_width
+
w
;
// load input
vi0
=
vld1q_f32
(
in_ptr_base
+
in_offset
);
vi4
=
vld1q_f32
(
in_ptr_base
+
in_offset
+
4
);
vi8
=
vld1q_f32
(
in_ptr_base
+
in_offset
+
8
);
vi1
=
vextq_f32
(
vi0
,
vi4
,
1
);
vi2
=
vextq_f32
(
vi0
,
vi4
,
2
);
vi3
=
vextq_f32
(
vi0
,
vi4
,
3
);
vi5
=
vextq_f32
(
vi4
,
vi8
,
1
);
vi6
=
vextq_f32
(
vi4
,
vi8
,
2
);
#if defined(__aarch64__)
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi0
,
vf00
,
0
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi1
,
vf00
,
1
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi2
,
vf00
,
2
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi3
,
vf00
,
3
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi4
,
vf01
,
1
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi5
,
vf01
,
2
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi6
,
vf01
,
3
);
#else
vo0
=
vmlaq_lane_f32
(
vo0
,
vi0
,
vget_low_f32
(
vf00
),
0
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi1
,
vget_low_f32
(
vf00
),
1
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi2
,
vget_high_f32
(
vf00
),
0
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi3
,
vget_high_f32
(
vf00
),
1
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi4
,
vget_low_f32
(
vf01
),
1
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi5
,
vget_high_f32
(
vf01
),
0
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi6
,
vget_high_f32
(
vf01
),
1
);
#endif
vst1q_f32
(
out_ptr0_base
+
out_offset
,
vo0
);
}
// w
}
// h
#else
Conv2dCPUKHxKWCalc
(
in_ptr_base
,
filter_ptr0
,
in_width
,
1
,
7
,
out_height
,
out_width
,
out_ptr0_base
,
1
);
#endif
}
// c
}
}
// if
}
// m
}
// b
}
}
// namespace kernels
}
// namespace mace
mace/kernels/arm/conv_2d_neon_3x3.cc
浏览文件 @
b3efb72b
...
...
@@ -300,19 +300,11 @@ void Conv2dNeonK3x3S1(const float *input,
out_ptr1
+=
out_width
;
}
// h
#else
for
(
index_t
io
=
0
;
io
<
2
;
++
io
)
{
for
(
index_t
ih
=
0
;
ih
<
out_height
;
++
ih
)
{
for
(
index_t
iw
=
0
;
iw
<
out_width
;
++
iw
)
{
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
for
(
int
j
=
0
;
j
<
3
;
++
j
)
{
out_ptr0
[
io
*
out_image_size
+
ih
*
out_width
+
iw
]
+=
in_ptr0
[(
ih
+
i
)
*
in_width
+
(
iw
+
j
)]
*
filter_ptr0
[
io
*
in_channels
*
9
+
i
*
3
+
j
];
}
}
}
}
}
// for
for
(
index_t
oc
=
0
;
oc
<
2
;
++
oc
)
{
Conv2dCPUKHxKWCalc
(
in_ptr0
,
filter_ptr0
+
oc
*
in_channels
*
9
,
in_width
,
3
,
3
,
out_height
,
out_width
,
out_ptr0_base
+
oc
*
out_image_size
,
1
);
}
#endif
}
// c
}
else
{
...
...
@@ -501,17 +493,9 @@ void Conv2dNeonK3x3S1(const float *input,
out_ptr0
+=
out_width
;
}
// h
#else
for
(
index_t
ih
=
0
;
ih
<
out_height
;
++
ih
)
{
for
(
index_t
iw
=
0
;
iw
<
out_width
;
++
iw
)
{
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
for
(
int
j
=
0
;
j
<
3
;
++
j
)
{
out_ptr0
[
ih
*
out_width
+
iw
]
+=
in_ptr0
[(
ih
+
i
)
*
in_width
+
(
iw
+
j
)]
*
filter_ptr0
[
i
*
3
+
j
];
}
}
}
}
Conv2dCPUKHxKWCalc
(
in_ptr0
,
filter_ptr0
,
in_width
,
3
,
3
,
out_height
,
out_width
,
out_ptr0_base
,
1
);
#endif
}
// c
}
// mm
...
...
@@ -666,17 +650,9 @@ void Conv2dNeonK3x3S2(const float *input,
}
// w
}
// h
#else
for
(
index_t
ih
=
0
;
ih
<
out_height
;
++
ih
)
{
for
(
index_t
iw
=
0
;
iw
<
out_width
;
++
iw
)
{
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
for
(
int
j
=
0
;
j
<
3
;
++
j
)
{
out_base
[
ih
*
out_width
+
iw
]
+=
in_base
[(
ih
*
2
+
i
)
*
in_width
+
(
iw
*
2
+
j
)]
*
filter_ptr
[
i
*
3
+
j
];
}
}
}
}
Conv2dCPUKHxKWCalc
(
in_base
,
filter_ptr
,
in_width
,
3
,
3
,
out_height
,
out_width
,
out_base
,
2
);
#endif
}
// c
}
// m
...
...
mace/kernels/arm/conv_2d_neon_5x5.cc
浏览文件 @
b3efb72b
...
...
@@ -76,30 +76,6 @@ namespace kernels {
vo0 = vmlaq_lane_f32(vo0, vi3, vget_high_f32(vf00), 1); \
vo0 = vmlaq_lane_f32(vo0, vi4, vf01, 1);
inline
void
Conv2dCPUK5x5Calc
(
const
float
*
in_ptr_base
,
const
float
*
filter_ptr0
,
const
index_t
in_width
,
const
index_t
in_channels
,
const
index_t
out_height
,
const
index_t
out_width
,
const
index_t
out_image_size
,
float
*
out_ptr0_base
,
const
index_t
io
,
const
int
stride
)
{
for
(
index_t
ih
=
0
;
ih
<
out_height
;
++
ih
)
{
for
(
index_t
iw
=
0
;
iw
<
out_width
;
++
iw
)
{
for
(
int
i
=
0
;
i
<
5
;
++
i
)
{
for
(
int
j
=
0
;
j
<
5
;
++
j
)
{
out_ptr0_base
[
io
*
out_image_size
+
ih
*
out_width
+
iw
]
+=
in_ptr_base
[(
ih
*
stride
+
i
)
*
in_width
+
(
iw
*
stride
+
j
)]
*
filter_ptr0
[
io
*
in_channels
*
25
+
i
*
5
+
j
];
}
}
}
}
}
// Ho = 1, Wo = 4, Co = 4
void
Conv2dNeonK5x5S1
(
const
float
*
input
,
const
float
*
filter
,
...
...
@@ -183,11 +159,11 @@ void Conv2dNeonK5x5S1(const float *input,
}
// w
}
// h
#else
for
(
index_t
io
=
0
;
io
<
4
;
++
io
)
{
Conv2dCPUK
5x5Calc
(
in_ptr_base
,
filter_ptr0
,
in_width
,
in_channels
,
out_height
,
out_width
,
out_image_size
,
out_ptr0_base
,
io
,
1
);
}
// for
for
(
index_t
oc
=
0
;
oc
<
4
;
++
oc
)
{
Conv2dCPUK
HxKWCalc
(
in_ptr_base
,
filter_ptr0
+
oc
*
in_channels
*
25
,
in_width
,
5
,
5
,
out_height
,
out_width
,
out_ptr0_base
+
oc
*
out_image_size
,
1
);
}
#endif
}
// c
}
else
{
...
...
@@ -229,9 +205,9 @@ void Conv2dNeonK5x5S1(const float *input,
}
// w
}
// h
#else
Conv2dCPUK
5x5Calc
(
in_ptr_base
,
filter_ptr0
,
in_width
,
in_channels
,
out_height
,
out_width
,
out_image_size
,
out_ptr0_base
,
0
,
1
);
Conv2dCPUK
HxKWCalc
(
in_ptr_base
,
filter_ptr0
,
in_width
,
5
,
5
,
out_height
,
out_width
,
out_ptr0_base
,
1
);
#endif
}
// c
}
// mm
...
...
mace/kernels/arm/conv_2d_neon_7x1.cc
0 → 100644
浏览文件 @
b3efb72b
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#include "mace/kernels/arm/conv_2d_neon.h"
namespace
mace
{
namespace
kernels
{
// Ho = 4, Wo = 1, Co = 4
void
Conv2dNeonK7x1S1
(
const
float
*
input
,
const
float
*
filter
,
const
index_t
*
in_shape
,
const
index_t
*
out_shape
,
float
*
output
)
{
const
index_t
in_image_size
=
in_shape
[
2
]
*
in_shape
[
3
];
const
index_t
out_image_size
=
out_shape
[
2
]
*
out_shape
[
3
];
const
index_t
in_batch_size
=
in_shape
[
1
]
*
in_image_size
;
const
index_t
out_batch_size
=
out_shape
[
1
]
*
out_image_size
;
#pragma omp parallel for collapse(2)
for
(
index_t
b
=
0
;
b
<
out_shape
[
0
];
++
b
)
{
for
(
index_t
m
=
0
;
m
<
out_shape
[
1
];
m
+=
4
)
{
const
index_t
out_channels
=
out_shape
[
1
];
const
index_t
out_height
=
out_shape
[
2
];
const
index_t
out_width
=
out_shape
[
3
];
const
index_t
in_channels
=
in_shape
[
1
];
const
index_t
in_width
=
in_shape
[
3
];
if
(
m
+
3
<
out_channels
)
{
float
*
out_ptr0_base
=
output
+
b
*
out_batch_size
+
m
*
out_image_size
;
#if defined(MACE_ENABLE_NEON)
float
*
out_ptr1_base
=
output
+
b
*
out_batch_size
+
(
m
+
1
)
*
out_image_size
;
float
*
out_ptr2_base
=
output
+
b
*
out_batch_size
+
(
m
+
2
)
*
out_image_size
;
float
*
out_ptr3_base
=
output
+
b
*
out_batch_size
+
(
m
+
3
)
*
out_image_size
;
#endif
for
(
index_t
c
=
0
;
c
<
in_channels
;
++
c
)
{
const
float
*
in_ptr_base
=
input
+
b
*
in_batch_size
+
c
*
in_image_size
;
const
float
*
filter_ptr0
=
filter
+
m
*
in_channels
*
7
+
c
*
7
;
#if defined(MACE_ENABLE_NEON)
const
float
*
filter_ptr1
=
filter
+
(
m
+
1
)
*
in_channels
*
7
+
c
*
7
;
const
float
*
filter_ptr2
=
filter
+
(
m
+
2
)
*
in_channels
*
7
+
c
*
7
;
const
float
*
filter_ptr3
=
filter
+
(
m
+
3
)
*
in_channels
*
7
+
c
*
7
;
/* load filter (4 outch x 4 height x 1 width) */
float32x4_t
vf00
,
vf01
;
float32x4_t
vf10
,
vf11
;
float32x4_t
vf20
,
vf21
;
float32x4_t
vf30
,
vf31
;
vf00
=
vld1q_f32
(
filter_ptr0
);
vf01
=
vld1q_f32
(
filter_ptr0
+
3
);
vf10
=
vld1q_f32
(
filter_ptr1
);
vf11
=
vld1q_f32
(
filter_ptr1
+
3
);
vf20
=
vld1q_f32
(
filter_ptr2
);
vf21
=
vld1q_f32
(
filter_ptr2
+
3
);
vf30
=
vld1q_f32
(
filter_ptr3
);
vf31
=
vld1q_f32
(
filter_ptr3
+
3
);
for
(
index_t
h
=
0
;
h
+
3
<
out_height
;
h
+=
4
)
{
for
(
index_t
w
=
0
;
w
<
out_width
;
++
w
)
{
// load output
index_t
out_offset
=
h
*
out_width
+
w
;
// output (4 outch x 4 height x 1 width): vo_outch_height
float32x4_t
vo0
=
{
out_ptr0_base
[
out_offset
],
out_ptr0_base
[
out_offset
+
out_width
],
out_ptr0_base
[
out_offset
+
2
*
out_width
],
out_ptr0_base
[
out_offset
+
3
*
out_width
]};
float32x4_t
vo1
=
{
out_ptr1_base
[
out_offset
],
out_ptr1_base
[
out_offset
+
out_width
],
out_ptr1_base
[
out_offset
+
2
*
out_width
],
out_ptr1_base
[
out_offset
+
3
*
out_width
]};
float32x4_t
vo2
=
{
out_ptr2_base
[
out_offset
],
out_ptr2_base
[
out_offset
+
out_width
],
out_ptr2_base
[
out_offset
+
2
*
out_width
],
out_ptr2_base
[
out_offset
+
3
*
out_width
]};
float32x4_t
vo3
=
{
out_ptr3_base
[
out_offset
],
out_ptr3_base
[
out_offset
+
out_width
],
out_ptr3_base
[
out_offset
+
2
*
out_width
],
out_ptr3_base
[
out_offset
+
3
*
out_width
]};
// input offset
index_t
in_offset
=
h
*
in_width
+
w
;
// input (3 slide)
float32x4_t
vi0
=
{
in_ptr_base
[
in_offset
],
in_ptr_base
[
in_offset
+
in_width
],
in_ptr_base
[
in_offset
+
2
*
in_width
],
in_ptr_base
[
in_offset
+
3
*
in_width
]};
float32x4_t
vi4
=
{
in_ptr_base
[
in_offset
+
4
*
in_width
],
in_ptr_base
[
in_offset
+
5
*
in_width
],
in_ptr_base
[
in_offset
+
6
*
in_width
],
in_ptr_base
[
in_offset
+
7
*
in_width
]};
float32x4_t
vi8
=
{
in_ptr_base
[
in_offset
+
8
*
in_width
],
in_ptr_base
[
in_offset
+
9
*
in_width
]};
float32x4_t
vi1
=
vextq_f32
(
vi0
,
vi4
,
1
);
float32x4_t
vi2
=
vextq_f32
(
vi0
,
vi4
,
2
);
float32x4_t
vi3
=
vextq_f32
(
vi0
,
vi4
,
3
);
float32x4_t
vi5
=
vextq_f32
(
vi4
,
vi8
,
1
);
float32x4_t
vi6
=
vextq_f32
(
vi4
,
vi8
,
2
);
#if defined(__aarch64__)
/* outch 0 */
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi0
,
vf00
,
0
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi1
,
vf00
,
1
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi2
,
vf00
,
2
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi3
,
vf00
,
3
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi4
,
vf01
,
1
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi5
,
vf01
,
2
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi6
,
vf01
,
3
);
/* outch 1 */
vo1
=
vfmaq_laneq_f32
(
vo1
,
vi0
,
vf10
,
0
);
vo1
=
vfmaq_laneq_f32
(
vo1
,
vi1
,
vf10
,
1
);
vo1
=
vfmaq_laneq_f32
(
vo1
,
vi2
,
vf10
,
2
);
vo1
=
vfmaq_laneq_f32
(
vo1
,
vi3
,
vf10
,
3
);
vo1
=
vfmaq_laneq_f32
(
vo1
,
vi4
,
vf11
,
1
);
vo1
=
vfmaq_laneq_f32
(
vo1
,
vi5
,
vf11
,
2
);
vo1
=
vfmaq_laneq_f32
(
vo1
,
vi6
,
vf11
,
3
);
/* outch 2 */
vo2
=
vfmaq_laneq_f32
(
vo2
,
vi0
,
vf20
,
0
);
vo2
=
vfmaq_laneq_f32
(
vo2
,
vi1
,
vf20
,
1
);
vo2
=
vfmaq_laneq_f32
(
vo2
,
vi2
,
vf20
,
2
);
vo2
=
vfmaq_laneq_f32
(
vo2
,
vi3
,
vf20
,
3
);
vo2
=
vfmaq_laneq_f32
(
vo2
,
vi4
,
vf21
,
1
);
vo2
=
vfmaq_laneq_f32
(
vo2
,
vi5
,
vf21
,
2
);
vo2
=
vfmaq_laneq_f32
(
vo2
,
vi6
,
vf21
,
3
);
/* outch 3 */
vo3
=
vfmaq_laneq_f32
(
vo3
,
vi0
,
vf30
,
0
);
vo3
=
vfmaq_laneq_f32
(
vo3
,
vi1
,
vf30
,
1
);
vo3
=
vfmaq_laneq_f32
(
vo3
,
vi2
,
vf30
,
2
);
vo3
=
vfmaq_laneq_f32
(
vo3
,
vi3
,
vf30
,
3
);
vo3
=
vfmaq_laneq_f32
(
vo3
,
vi4
,
vf31
,
1
);
vo3
=
vfmaq_laneq_f32
(
vo3
,
vi5
,
vf31
,
2
);
vo3
=
vfmaq_laneq_f32
(
vo3
,
vi6
,
vf31
,
3
);
#else
/* outch 0 */
vo0
=
vmlaq_lane_f32
(
vo0
,
vi0
,
vget_low_f32
(
vf00
),
0
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi1
,
vget_low_f32
(
vf00
),
1
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi2
,
vget_high_f32
(
vf00
),
0
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi3
,
vget_high_f32
(
vf00
),
1
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi4
,
vget_low_f32
(
vf01
),
1
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi5
,
vget_high_f32
(
vf01
),
0
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi6
,
vget_high_f32
(
vf01
),
1
);
/* outch 1 */
vo1
=
vmlaq_lane_f32
(
vo1
,
vi0
,
vget_low_f32
(
vf10
),
0
);
vo1
=
vmlaq_lane_f32
(
vo1
,
vi1
,
vget_low_f32
(
vf10
),
1
);
vo1
=
vmlaq_lane_f32
(
vo1
,
vi2
,
vget_high_f32
(
vf10
),
0
);
vo1
=
vmlaq_lane_f32
(
vo1
,
vi3
,
vget_high_f32
(
vf10
),
1
);
vo1
=
vmlaq_lane_f32
(
vo1
,
vi4
,
vget_low_f32
(
vf11
),
1
);
vo1
=
vmlaq_lane_f32
(
vo1
,
vi5
,
vget_high_f32
(
vf11
),
0
);
vo1
=
vmlaq_lane_f32
(
vo1
,
vi6
,
vget_high_f32
(
vf11
),
1
);
/* outch 2 */
vo2
=
vmlaq_lane_f32
(
vo2
,
vi0
,
vget_low_f32
(
vf20
),
0
);
vo2
=
vmlaq_lane_f32
(
vo2
,
vi1
,
vget_low_f32
(
vf20
),
1
);
vo2
=
vmlaq_lane_f32
(
vo2
,
vi2
,
vget_high_f32
(
vf20
),
0
);
vo2
=
vmlaq_lane_f32
(
vo2
,
vi3
,
vget_high_f32
(
vf20
),
1
);
vo2
=
vmlaq_lane_f32
(
vo2
,
vi4
,
vget_low_f32
(
vf21
),
1
);
vo2
=
vmlaq_lane_f32
(
vo2
,
vi5
,
vget_high_f32
(
vf21
),
0
);
vo2
=
vmlaq_lane_f32
(
vo2
,
vi6
,
vget_high_f32
(
vf21
),
1
);
/* outch 3 */
vo3
=
vmlaq_lane_f32
(
vo3
,
vi0
,
vget_low_f32
(
vf30
),
0
);
vo3
=
vmlaq_lane_f32
(
vo3
,
vi1
,
vget_low_f32
(
vf30
),
1
);
vo3
=
vmlaq_lane_f32
(
vo3
,
vi2
,
vget_high_f32
(
vf30
),
0
);
vo3
=
vmlaq_lane_f32
(
vo3
,
vi3
,
vget_high_f32
(
vf30
),
1
);
vo3
=
vmlaq_lane_f32
(
vo3
,
vi4
,
vget_low_f32
(
vf31
),
1
);
vo3
=
vmlaq_lane_f32
(
vo3
,
vi5
,
vget_high_f32
(
vf31
),
0
);
vo3
=
vmlaq_lane_f32
(
vo3
,
vi6
,
vget_high_f32
(
vf31
),
1
);
#endif
out_ptr0_base
[
out_offset
]
=
vo0
[
0
];
out_ptr0_base
[
out_offset
+
out_width
]
=
vo0
[
1
];
out_ptr0_base
[
out_offset
+
2
*
out_width
]
=
vo0
[
2
];
out_ptr0_base
[
out_offset
+
3
*
out_width
]
=
vo0
[
3
];
out_ptr1_base
[
out_offset
]
=
vo1
[
0
];
out_ptr1_base
[
out_offset
+
out_width
]
=
vo1
[
1
];
out_ptr1_base
[
out_offset
+
2
*
out_width
]
=
vo1
[
2
];
out_ptr1_base
[
out_offset
+
3
*
out_width
]
=
vo1
[
3
];
out_ptr2_base
[
out_offset
]
=
vo2
[
0
];
out_ptr2_base
[
out_offset
+
out_width
]
=
vo2
[
1
];
out_ptr2_base
[
out_offset
+
2
*
out_width
]
=
vo2
[
2
];
out_ptr2_base
[
out_offset
+
3
*
out_width
]
=
vo2
[
3
];
out_ptr3_base
[
out_offset
]
=
vo3
[
0
];
out_ptr3_base
[
out_offset
+
out_width
]
=
vo3
[
1
];
out_ptr3_base
[
out_offset
+
2
*
out_width
]
=
vo3
[
2
];
out_ptr3_base
[
out_offset
+
3
*
out_width
]
=
vo3
[
3
];
}
// w
}
// h
#else
for
(
index_t
oc
=
0
;
oc
<
4
;
++
oc
)
{
Conv2dCPUKHxKWCalc
(
in_ptr_base
,
filter_ptr0
+
oc
*
in_channels
*
7
,
in_width
,
7
,
1
,
out_height
,
out_width
,
out_ptr0_base
+
oc
*
out_image_size
,
1
);
}
#endif
}
// c
}
else
{
for
(
index_t
mm
=
m
;
mm
<
out_channels
;
++
mm
)
{
float
*
out_ptr0_base
=
output
+
b
*
out_batch_size
+
mm
*
out_image_size
;
for
(
index_t
c
=
0
;
c
<
in_channels
;
++
c
)
{
const
float
*
in_ptr_base
=
input
+
b
*
in_batch_size
+
c
*
in_image_size
;
const
float
*
filter_ptr0
=
filter
+
mm
*
in_channels
*
7
+
c
*
7
;
#if defined(MACE_ENABLE_NEON)
/* load filter (1 outch x 4 height x 1 width) */
float32x4_t
vf00
,
vf01
;
vf00
=
vld1q_f32
(
filter_ptr0
);
vf01
=
vld1q_f32
(
filter_ptr0
+
3
);
for
(
index_t
h
=
0
;
h
+
3
<
out_height
;
h
+=
4
)
{
for
(
index_t
w
=
0
;
w
<
out_width
;
++
w
)
{
// load output
index_t
out_offset
=
h
*
out_width
+
w
;
// output (1 outch x 4 height x 1 width): vo_outch_height
float32x4_t
vo0
=
{
out_ptr0_base
[
out_offset
],
out_ptr0_base
[
out_offset
+
out_width
],
out_ptr0_base
[
out_offset
+
2
*
out_width
],
out_ptr0_base
[
out_offset
+
3
*
out_width
]};
// input offset
index_t
in_offset
=
h
*
in_width
+
w
;
// input (3 slide)
float32x4_t
vi0
=
{
in_ptr_base
[
in_offset
],
in_ptr_base
[
in_offset
+
in_width
],
in_ptr_base
[
in_offset
+
2
*
in_width
],
in_ptr_base
[
in_offset
+
3
*
in_width
]};
float32x4_t
vi4
=
{
in_ptr_base
[
in_offset
+
4
*
in_width
],
in_ptr_base
[
in_offset
+
5
*
in_width
],
in_ptr_base
[
in_offset
+
6
*
in_width
],
in_ptr_base
[
in_offset
+
7
*
in_width
]};
float32x4_t
vi8
=
{
in_ptr_base
[
in_offset
+
8
*
in_width
],
in_ptr_base
[
in_offset
+
9
*
in_width
],
in_ptr_base
[
in_offset
+
10
*
in_width
],
in_ptr_base
[
in_offset
+
11
*
in_width
]};
float32x4_t
vi1
=
vextq_f32
(
vi0
,
vi4
,
1
);
float32x4_t
vi2
=
vextq_f32
(
vi0
,
vi4
,
2
);
float32x4_t
vi3
=
vextq_f32
(
vi0
,
vi4
,
3
);
float32x4_t
vi5
=
vextq_f32
(
vi4
,
vi8
,
1
);
float32x4_t
vi6
=
vextq_f32
(
vi4
,
vi8
,
2
);
#if defined(__aarch64__)
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi0
,
vf00
,
0
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi1
,
vf00
,
1
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi2
,
vf00
,
2
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi3
,
vf00
,
3
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi4
,
vf01
,
1
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi5
,
vf01
,
2
);
vo0
=
vfmaq_laneq_f32
(
vo0
,
vi6
,
vf01
,
3
);
#else
vo0
=
vmlaq_lane_f32
(
vo0
,
vi0
,
vget_low_f32
(
vf00
),
0
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi1
,
vget_low_f32
(
vf00
),
1
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi2
,
vget_high_f32
(
vf00
),
0
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi3
,
vget_high_f32
(
vf00
),
1
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi4
,
vget_low_f32
(
vf01
),
1
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi5
,
vget_high_f32
(
vf01
),
0
);
vo0
=
vmlaq_lane_f32
(
vo0
,
vi6
,
vget_high_f32
(
vf01
),
1
);
#endif
out_ptr0_base
[
out_offset
]
=
vo0
[
0
];
out_ptr0_base
[
out_offset
+
out_width
]
=
vo0
[
1
];
out_ptr0_base
[
out_offset
+
2
*
out_width
]
=
vo0
[
2
];
out_ptr0_base
[
out_offset
+
3
*
out_width
]
=
vo0
[
3
];
}
// w
}
// h
#else
Conv2dCPUKHxKWCalc
(
in_ptr_base
,
filter_ptr0
,
in_width
,
7
,
1
,
out_height
,
out_width
,
out_ptr0_base
,
1
);
#endif
}
// c
}
}
// if
}
// m
}
// b
}
}
// namespace kernels
}
// namespace mace
mace/kernels/arm/conv_2d_neon_7x7.cc
浏览文件 @
b3efb72b
...
...
@@ -153,30 +153,6 @@ namespace kernels {
vo0 = vmlaq_lane_f32(vo0, vi5, vget_high_f32(vf01), 0); \
vo0 = vmlaq_lane_f32(vo0, vi6, vget_high_f32(vf01), 1);
inline
void
Conv2dCPUK7x7Calc
(
const
float
*
in_ptr_base
,
const
float
*
filter_ptr0
,
const
index_t
in_width
,
const
index_t
in_channels
,
const
index_t
out_height
,
const
index_t
out_width
,
const
index_t
out_image_size
,
float
*
out_ptr0_base
,
const
index_t
io
,
const
int
stride
)
{
for
(
index_t
ih
=
0
;
ih
<
out_height
;
++
ih
)
{
for
(
index_t
iw
=
0
;
iw
<
out_width
;
++
iw
)
{
for
(
int
i
=
0
;
i
<
7
;
++
i
)
{
for
(
int
j
=
0
;
j
<
7
;
++
j
)
{
out_ptr0_base
[
io
*
out_image_size
+
ih
*
out_width
+
iw
]
+=
in_ptr_base
[(
ih
*
stride
+
i
)
*
in_width
+
(
iw
*
stride
+
j
)]
*
filter_ptr0
[
io
*
in_channels
*
49
+
i
*
7
+
j
];
}
}
}
}
}
// Ho = 1, Wo = 4, Co = 4
void
Conv2dNeonK7x7S1
(
const
float
*
input
,
const
float
*
filter
,
...
...
@@ -268,11 +244,11 @@ void Conv2dNeonK7x7S1(const float *input,
}
// w
}
// h
#else
for
(
index_t
io
=
0
;
io
<
4
;
++
io
)
{
Conv2dCPUK
7x7Calc
(
in_ptr_base
,
filter_ptr0
,
in_width
,
in_channels
,
out_height
,
out_width
,
out_image_size
,
out_ptr0_base
,
io
,
1
);
}
// for
for
(
index_t
oc
=
0
;
oc
<
4
;
++
oc
)
{
Conv2dCPUK
HxKWCalc
(
in_ptr_base
,
filter_ptr0
+
oc
*
in_channels
*
49
,
in_width
,
7
,
7
,
out_height
,
out_width
,
out_ptr0_base
+
oc
*
out_image_size
,
1
);
}
#endif
}
// c
}
else
{
...
...
@@ -322,9 +298,9 @@ void Conv2dNeonK7x7S1(const float *input,
}
// w
}
// h
#else
Conv2dCPUK
7x7Calc
(
in_ptr_base
,
filter_ptr0
,
in_width
,
in_channels
,
out_height
,
out_width
,
out_image_size
,
out_ptr0_base
,
0
,
1
);
Conv2dCPUK
HxKWCalc
(
in_ptr_base
,
filter_ptr0
,
in_width
,
7
,
7
,
out_height
,
out_width
,
out_ptr0_base
,
1
);
#endif
}
// c
}
// mm
...
...
@@ -429,11 +405,11 @@ void Conv2dNeonK7x7S2(const float *input,
}
// w
}
// h
#else
for
(
index_t
io
=
0
;
io
<
4
;
++
io
)
{
Conv2dCPUK
7x7Calc
(
in_ptr_base
,
filter_ptr0
,
in_width
,
in_channels
,
out_height
,
out_width
,
out_image_size
,
out_ptr0_base
,
io
,
2
);
}
// for
for
(
index_t
oc
=
0
;
oc
<
4
;
++
oc
)
{
Conv2dCPUK
HxKWCalc
(
in_ptr_base
,
filter_ptr0
+
oc
*
in_channels
*
49
,
in_width
,
7
,
7
,
out_height
,
out_width
,
out_ptr0_base
+
oc
*
out_image_size
,
2
);
}
#endif
}
// c
}
else
{
...
...
@@ -488,9 +464,9 @@ void Conv2dNeonK7x7S2(const float *input,
}
// w
}
// h
#else
Conv2dCPUK
7x7Calc
(
in_ptr_base
,
filter_ptr0
,
in_width
,
in_channels
,
out_height
,
out_width
,
out_image_size
,
out_ptr0_base
,
0
,
2
);
Conv2dCPUK
HxKWCalc
(
in_ptr_base
,
filter_ptr0
,
in_width
,
7
,
7
,
out_height
,
out_width
,
out_ptr0_base
,
2
);
#endif
}
// c
}
// mm
...
...
@@ -595,11 +571,11 @@ void Conv2dNeonK7x7S3(const float *input,
}
// w
}
// h
#else
for
(
index_t
io
=
0
;
io
<
4
;
++
io
)
{
Conv2dCPUK
7x7Calc
(
in_ptr_base
,
filter_ptr0
,
in_width
,
in_channels
,
out_height
,
out_width
,
out_image_size
,
out_ptr0_base
,
io
,
3
);
}
// for
for
(
index_t
oc
=
0
;
oc
<
4
;
++
oc
)
{
Conv2dCPUK
HxKWCalc
(
in_ptr_base
,
filter_ptr0
+
oc
*
in_channels
*
49
,
in_width
,
7
,
7
,
out_height
,
out_width
,
out_ptr0_base
+
oc
*
out_image_size
,
3
);
}
#endif
}
// c
}
else
{
...
...
@@ -654,9 +630,9 @@ void Conv2dNeonK7x7S3(const float *input,
}
// w
}
// h
#else
Conv2dCPUK
7x7Calc
(
in_ptr_base
,
filter_ptr0
,
in_width
,
in_channels
,
out_height
,
out_width
,
out_image_size
,
out_ptr0_base
,
0
,
3
);
Conv2dCPUK
HxKWCalc
(
in_ptr_base
,
filter_ptr0
,
in_width
,
7
,
7
,
out_height
,
out_width
,
out_ptr0_base
,
3
);
#endif
}
// c
}
// mm
...
...
mace/kernels/conv_2d.h
浏览文件 @
b3efb72b
...
...
@@ -357,6 +357,10 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
&&
stride_h
==
1
&&
stride_w
==
1
&&
dilation_h
==
1
&&
dilation_w
==
1
;
bool
use_neon_5x5_s1
=
filter_h
==
5
&&
filter_w
==
5
&&
stride_h
==
1
&&
stride_w
==
1
&&
dilation_h
==
1
&&
dilation_w
==
1
;
bool
use_neon_1x7_s1
=
filter_h
==
1
&&
filter_w
==
7
&&
stride_h
==
1
&&
stride_w
==
1
&&
dilation_h
==
1
&&
dilation_w
==
1
;
bool
use_neon_7x1_s1
=
filter_h
==
7
&&
filter_w
==
1
&&
stride_h
==
1
&&
stride_w
==
1
&&
dilation_h
==
1
&&
dilation_w
==
1
;
bool
use_neon_7x7_s1
=
filter_h
==
7
&&
filter_w
==
7
&&
stride_h
==
1
&&
stride_w
==
1
&&
dilation_h
==
1
&&
dilation_w
==
1
;
bool
use_neon_7x7_s2
=
filter_h
==
7
&&
filter_w
==
7
...
...
@@ -414,7 +418,7 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
}
else
if
(
use_neon_3x3_s1
)
{
tile_h
=
2
;
tile_w
=
4
;
}
else
if
(
use_neon_15x1_s1
)
{
}
else
if
(
use_neon_
7x1_s1
||
use_neon_
15x1_s1
)
{
tile_h
=
4
;
tile_w
=
1
;
}
else
{
...
...
@@ -566,6 +570,22 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
extra_output_shape
,
pad_output
);
};
}
else
if
(
use_neon_1x7_s1
)
{
conv_func
=
[
=
](
const
float
*
pad_input
,
float
*
pad_output
)
{
Conv2dNeonK1x7S1
(
pad_input
,
filter_data
,
extra_input_shape
,
extra_output_shape
,
pad_output
);
};
}
else
if
(
use_neon_7x1_s1
)
{
conv_func
=
[
=
](
const
float
*
pad_input
,
float
*
pad_output
)
{
Conv2dNeonK7x1S1
(
pad_input
,
filter_data
,
extra_input_shape
,
extra_output_shape
,
pad_output
);
};
}
else
if
(
use_neon_7x7_s1
)
{
conv_func
=
[
=
](
const
float
*
pad_input
,
float
*
pad_output
)
{
Conv2dNeonK7x7S1
(
pad_input
,
...
...
mace/kernels/gemm.cc
浏览文件 @
b3efb72b
...
...
@@ -388,8 +388,7 @@ inline void GemmTile(const float *A,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v26"
"v25"
);
w
=
(
width
>>
2
)
<<
2
;
...
...
mace/ops/conv_2d_benchmark.cc
浏览文件 @
b3efb72b
...
...
@@ -168,6 +168,10 @@ BM_CONV_2D(1, 1024, 7, 7, 1, 1, 1, 1, SAME, 1024);
BM_CONV_2D
(
64
,
32
,
34
,
34
,
3
,
3
,
1
,
1
,
VALID
,
32
);
BM_CONV_2D
(
1
,
32
,
34
,
34
,
3
,
3
,
1
,
1
,
VALID
,
32
);
BM_CONV_2D
(
1
,
192
,
17
,
17
,
1
,
7
,
1
,
1
,
SAME
,
192
);
BM_CONV_2D
(
1
,
192
,
17
,
17
,
7
,
1
,
1
,
1
,
SAME
,
192
);
BM_CONV_2D
(
1
,
160
,
17
,
17
,
7
,
1
,
1
,
1
,
SAME
,
192
);
BM_CONV_2D
(
1
,
32
,
256
,
256
,
1
,
15
,
1
,
1
,
SAME
,
2
);
BM_CONV_2D
(
1
,
32
,
256
,
256
,
15
,
1
,
1
,
1
,
SAME
,
2
);
BM_CONV_2D
(
1
,
64
,
64
,
64
,
15
,
1
,
1
,
1
,
SAME
,
2
);
...
...
mace/ops/conv_2d_test.cc
浏览文件 @
b3efb72b
...
...
@@ -776,6 +776,36 @@ TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv3x3S12) {
{
1
,
1
});
}
TEST_F
(
Conv2dOpTest
,
OPENCLHalfAlignedConv5x5S12
)
{
TestHalfComplexConvNxNS12
<
DeviceType
::
GPU
>
({
32
,
32
},
{
5
,
5
,
3
,
64
},
{
1
,
1
});
TestHalfComplexConvNxNS12
<
DeviceType
::
GPU
>
({
32
,
32
},
{
5
,
5
,
3
,
63
},
{
1
,
1
});
}
TEST_F
(
Conv2dOpTest
,
OPENCLHalfAlignedConv1x7S1
)
{
TestHalfComplexConvNxNS12
<
DeviceType
::
GPU
>
({
17
,
17
},
{
1
,
7
,
192
,
192
},
{
1
,
1
});
TestHalfComplexConvNxNS12
<
DeviceType
::
GPU
>
({
17
,
17
},
{
1
,
7
,
192
,
191
},
{
1
,
1
});
}
TEST_F
(
Conv2dOpTest
,
OPENCLHalfAlignedConv7x1S1
)
{
TestHalfComplexConvNxNS12
<
DeviceType
::
GPU
>
({
17
,
17
},
{
7
,
1
,
192
,
192
},
{
1
,
1
});
TestHalfComplexConvNxNS12
<
DeviceType
::
GPU
>
({
17
,
17
},
{
7
,
1
,
160
,
192
},
{
1
,
1
});
TestHalfComplexConvNxNS12
<
DeviceType
::
GPU
>
({
17
,
17
},
{
7
,
1
,
160
,
191
},
{
1
,
1
});
}
TEST_F
(
Conv2dOpTest
,
OPENCLHalfAlignedConv7x7S12
)
{
TestHalfComplexConvNxNS12
<
DeviceType
::
GPU
>
({
32
,
32
},
{
7
,
7
,
3
,
64
},
{
1
,
1
});
TestHalfComplexConvNxNS12
<
DeviceType
::
GPU
>
({
32
,
32
},
{
7
,
7
,
3
,
63
},
{
1
,
1
});
}
TEST_F
(
Conv2dOpTest
,
OPENCLHalfAlignedConv15x1S12
)
{
TestHalfComplexConvNxNS12
<
DeviceType
::
GPU
>
({
32
,
32
},
{
15
,
1
,
256
,
2
},
{
1
,
1
});
...
...
@@ -792,11 +822,6 @@ TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv1x15S12) {
{
1
,
1
});
}
TEST_F
(
Conv2dOpTest
,
OPENCLHalfAlignedConv7x75S12
)
{
TestHalfComplexConvNxNS12
<
DeviceType
::
GPU
>
({
32
,
32
},
{
7
,
7
,
3
,
64
},
{
1
,
1
});
}
TEST_F
(
Conv2dOpTest
,
OPENCLHalfUnalignedConv1x1S12
)
{
TestHalfComplexConvNxNS12
<
DeviceType
::
GPU
>
({
107
,
113
},
{
1
,
1
,
5
,
7
},
{
1
,
1
});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录