Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
58a9e04b
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
58a9e04b
编写于
9月 12, 2018
作者:
Z
zhangyang0701
提交者:
GitHub
9月 12, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #952 from yangfei963158659/develop
repair bug of pool3x3
上级
cf65f22b
48cfa1ea
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
403 addition
and
224 deletion
+403
-224
src/operators/math/pool_3x3.cpp
src/operators/math/pool_3x3.cpp
+403
-224
未找到文件。
src/operators/math/pool_3x3.cpp
浏览文件 @
58a9e04b
...
@@ -31,251 +31,428 @@ using std::min;
...
@@ -31,251 +31,428 @@ using std::min;
using
std
::
vector
;
using
std
::
vector
;
void
Pool3x3Avgs1p1
(
const
Tensor
*
input
,
Tensor
*
output
)
{
void
Pool3x3Avgs1p1
(
const
Tensor
*
input
,
Tensor
*
output
)
{
#if __ARM_NEON
#if __ARM_NEON
const
int
batch_size
=
input
->
dims
()[
0
];
const
int
batch_size
=
static_cast
<
int
>
(
input
->
dims
()[
0
]);
const
int
input_channel
=
static_cast
<
int
>
(
input
->
dims
()[
1
]);
const
int
h_in
=
input
->
dims
()[
2
];
const
int
input_height
=
static_cast
<
int
>
(
input
->
dims
()[
2
]);
const
int
input_width
=
static_cast
<
int
>
(
input
->
dims
()[
3
]);
const
int
output_height
=
static_cast
<
int
>
(
output
->
dims
()[
2
]);
const
int
output_width
=
static_cast
<
int
>
(
output
->
dims
()[
3
]);
const
int
w_in
=
input
->
dims
()[
3
]
;
const
int
hxw
=
input_height
*
input_width
;
const
int
output_channels
=
output
->
dims
()[
1
];
const
int
l
=
input_height
;
const
int
h_out
=
output
->
dims
()[
2
];
const
int
w_out
=
output
->
dims
()[
3
];
const
int
outputdata_channel_stride
=
h_out
*
w_out
;
const
int
inputdata_channel_stride
=
h_in
*
w_in
;
const
int
input_batch_stride
=
output_channels
*
inputdata_channel_stride
;
const
int
output_batch_stride
=
output_channels
*
outputdata_channel_stride
;
float
*
out_data
=
output
->
data
<
float
>
();
const
float
*
input_data
=
input
->
data
<
float
>
();
const
float
coef
=
1.0
/
9.0
;
const
float
coef
=
1.0
/
9.0
;
for
(
int
k
=
0
;
k
<
batch_size
;
++
k
)
{
const
float
coef1
=
1.0
/
6.0
;
#pragma omp parallel for
const
float
coef2
=
1.0
/
4.0
;
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
const
float
*
input_seg
=
input_data
+
c
*
inputdata_channel_stride
;
float
*
output_seg
=
out_data
+
c
*
outputdata_channel_stride
;
// four corner point
output_seg
[
0
]
=
(
input_seg
[
0
]
+
input_seg
[
1
]
+
input_seg
[
w_in
]
+
input_seg
[
w_in
+
1
])
*
coef
;
output_seg
[
w_out
-
1
]
=
(
input_seg
[
w_in
-
2
]
+
input_seg
[
w_in
-
1
]
+
input_seg
[
w_in
*
2
-
2
]
+
input_seg
[
2
*
w_in
-
1
])
*
coef
;
output_seg
[(
h_out
-
1
)
*
w_out
]
=
(
input_seg
[(
h_in
-
2
)
*
w_in
]
+
input_seg
[(
h_in
-
2
)
*
w_in
+
1
]
+
input_seg
[(
h_in
-
1
)
*
w_in
]
+
input_seg
[(
h_in
-
1
)
*
w_in
+
1
])
*
coef
;
output_seg
[
h_out
*
w_out
-
1
]
=
(
input_seg
[
h_in
*
w_in
-
1
]
+
input_seg
[
h_in
*
w_in
-
2
]
+
input_seg
[(
h_in
-
1
)
*
w_in
-
1
]
+
input_seg
[(
h_in
-
1
)
*
w_in
-
2
])
*
coef
;
// left side & right side
for
(
int
i
=
1
;
i
<
h_in
-
1
;
++
i
)
{
output_seg
[
i
*
w_out
]
=
(
input_seg
[
i
*
w_in
-
w_in
]
+
input_seg
[
i
*
w_in
-
w_in
+
1
]
+
input_seg
[
i
*
w_in
]
+
input_seg
[
i
*
w_in
+
1
]
+
input_seg
[
i
*
w_in
+
w_in
]
+
input_seg
[
i
*
w_in
+
w_in
+
1
])
*
coef
;
output_seg
[
i
*
w_out
+
w_out
-
1
]
=
(
input_seg
[
i
*
w_in
-
w_in
+
w_in
-
2
]
+
input_seg
[
i
*
w_in
-
w_in
+
1
+
w_in
-
2
]
+
input_seg
[
i
*
w_in
+
w_in
-
2
]
+
input_seg
[
i
*
w_in
+
1
+
w_in
-
2
]
+
input_seg
[
i
*
w_in
+
w_in
+
w_in
-
2
]
+
input_seg
[
i
*
w_in
+
w_in
+
1
+
w_in
-
2
])
*
coef
;
}
// top 1 row & bottom 1 row
const
float
*
input_tmp
=
input_seg
;
float32x4_t
in0
,
in1
,
in2
,
in3
,
in4
,
in5
,
in6
,
in7
,
tmp0
,
tmp1
,
tmp2
,
tmp3
,
tmp4
,
tmp5
,
sum
,
out0
;
float32x4_t
v_coef
=
vdupq_n_f32
(
coef
);
float32x4_t
v_coef
=
vdupq_n_f32
(
coef
);
in0
=
vld1q_f32
(
input_tmp
);
float32x4_t
v_coef1
=
vdupq_n_f32
(
coef1
);
in2
=
vld1q_f32
(
input_tmp
+
w_in
);
const
float
*
input_tmp_end
=
input_tmp
+
(
h_in
-
2
)
*
w_in
;
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
in4
=
vld1q_f32
(
input_tmp_end
);
#pragma omp parallel for
in6
=
vld1q_f32
(
input_tmp_end
+
w_in
);
for
(
int
c
=
0
;
c
<
input_channel
;
c
++
)
{
int
c_mid
=
w_out
-
2
;
const
float
*
input_data
=
input
->
data
<
float
>
()
+
c
*
hxw
;
auto
output_ptr
=
output_seg
+
1
;
float
*
output_data
=
output
->
data
<
float
>
()
+
c
*
hxw
;
for
(;
c_mid
>
3
;
c_mid
-=
4
)
{
in1
=
vld1q_f32
(
input_tmp
+
4
);
for
(
int
i
=
1
;
i
<
output_height
-
1
;
i
++
)
{
in3
=
vld1q_f32
(
input_tmp
+
w_in
+
4
);
float
*
output_ptr
;
float32x4_t
in0
,
in1
,
in2
,
in3
,
in4
,
in5
,
tmp0
,
tmp1
,
tmp2
,
tmp3
,
tmp4
,
tmp5
,
out0
;
for
(
int
m
=
1
;
m
<
output_width
-
4
;
m
+=
4
)
{
output_ptr
=
output_data
+
i
*
output_width
+
m
;
in0
=
vld1q_f32
(
input_data
+
(
i
-
1
)
*
input_width
+
m
-
1
);
in1
=
vld1q_f32
(
input_data
+
(
i
-
1
)
*
input_width
+
m
+
3
);
in2
=
vld1q_f32
(
input_data
+
i
*
input_width
+
m
-
1
);
in3
=
vld1q_f32
(
input_data
+
i
*
input_width
+
m
+
3
);
in4
=
vld1q_f32
(
input_data
+
(
i
+
1
)
*
input_width
+
m
-
1
);
in5
=
vld1q_f32
(
input_data
+
(
i
+
1
)
*
input_width
+
m
+
3
);
tmp0
=
vextq_f32
(
in0
,
in1
,
1
);
tmp0
=
vextq_f32
(
in0
,
in1
,
1
);
tmp1
=
vextq_f32
(
in0
,
in1
,
2
);
tmp1
=
vextq_f32
(
in0
,
in1
,
2
);
tmp2
=
vextq_f32
(
in2
,
in3
,
1
);
tmp2
=
vextq_f32
(
in2
,
in3
,
1
);
tmp3
=
vextq_f32
(
in2
,
in3
,
2
);
tmp3
=
vextq_f32
(
in2
,
in3
,
2
);
tmp4
=
vextq_f32
(
in4
,
in5
,
1
);
tmp5
=
vextq_f32
(
in4
,
in5
,
2
);
sum
=
vaddq_f32
(
in0
,
tmp0
);
out0
=
in0
;
sum
=
vaddq_f32
(
sum
,
tmp1
);
out0
=
vaddq_f32
(
out0
,
tmp0
);
sum
=
vaddq_f32
(
sum
,
in2
);
out0
=
vaddq_f32
(
out0
,
tmp1
);
sum
=
vaddq_f32
(
sum
,
tmp2
);
out0
=
vaddq_f32
(
out0
,
in2
);
sum
=
vaddq_f32
(
sum
,
tmp3
);
out0
=
vaddq_f32
(
out0
,
tmp2
);
out0
=
vaddq_f32
(
out0
,
tmp3
);
vst1q_f32
(
output_ptr
,
vmulq_f32
(
sum
,
v_coef
));
out0
=
vaddq_f32
(
out0
,
in4
);
out0
=
vaddq_f32
(
out0
,
tmp4
);
in5
=
vld1q_f32
(
input_tmp_end
+
4
);
out0
=
vaddq_f32
(
out0
,
tmp5
);
in7
=
vld1q_f32
(
input_tmp_end
+
w_in
+
4
);
vst1q_f32
(
output_ptr
,
vmulq_f32
(
out0
,
v_coef
));
tmp0
=
vextq_f32
(
in4
,
in5
,
1
);
}
tmp1
=
vextq_f32
(
in4
,
in5
,
2
);
int
m
;
tmp2
=
vextq_f32
(
in6
,
in7
,
1
);
for
(
m
=
1
;
(
m
+
3
)
<
output_width
-
1
;
m
=
m
+
4
)
{
tmp3
=
vextq_f32
(
in6
,
in7
,
2
);
}
sum
=
vaddq_f32
(
in0
,
tmp0
);
for
(
int
j
=
m
;
j
<
output_width
-
1
;
j
++
)
{
sum
=
vaddq_f32
(
sum
,
tmp1
);
output_data
[
i
*
output_width
+
j
]
=
sum
=
vaddq_f32
(
sum
,
in2
);
input_data
[(
i
-
1
)
*
input_width
+
j
-
1
]
+
sum
=
vaddq_f32
(
sum
,
tmp2
);
input_data
[(
i
-
1
)
*
input_width
+
j
]
+
sum
=
vaddq_f32
(
sum
,
tmp3
);
input_data
[(
i
-
1
)
*
input_width
+
j
+
1
]
+
input_data
[(
i
)
*
input_width
+
j
-
1
]
+
vst1q_f32
(
output_ptr
+
(
h_out
-
1
)
*
w_out
,
vmulq_f32
(
sum
,
v_coef
));
input_data
[(
i
)
*
input_width
+
j
]
+
input_data
[(
i
)
*
input_width
+
j
+
1
]
+
// can optimize to each 8 stride.
input_data
[(
i
+
1
)
*
input_width
+
j
-
1
]
+
input_tmp
+=
4
;
input_data
[(
i
+
1
)
*
input_width
+
j
]
+
input_tmp_end
+=
4
;
input_data
[(
i
+
1
)
*
input_width
+
j
+
1
];
output_ptr
+=
4
;
output_data
[
i
*
output_width
+
j
]
=
in0
=
in1
;
output_data
[
i
*
output_width
+
j
]
*
coef
;
in2
=
in3
;
}
in4
=
in5
;
}
in6
=
in7
;
}
output_data
[
0
]
=
// top right remain
input_data
[
0
]
+
input_data
[
1
]
+
input_data
[
l
]
+
input_data
[
l
+
1
];
float32x4_t
pad0
=
vdupq_n_f32
(
input_seg
[
w_in
-
1
]);
output_data
[
l
-
1
]
=
input_data
[
l
-
2
]
+
input_data
[
l
-
1
]
+
float32x4_t
pad1
=
vdupq_n_f32
(
input_seg
[
2
*
w_in
-
1
]);
input_data
[
2
*
l
-
2
]
+
input_data
[
2
*
l
-
1
];
output_data
[(
l
-
1
)
*
l
]
=
tmp0
=
vextq_f32
(
in0
,
pad0
,
1
);
input_data
[(
l
-
2
)
*
l
]
+
input_data
[(
l
-
2
)
*
l
+
1
]
+
tmp1
=
vextq_f32
(
in0
,
pad0
,
2
);
input_data
[(
l
-
1
)
*
l
]
+
input_data
[(
l
-
1
)
*
l
+
1
];
tmp2
=
vextq_f32
(
in2
,
pad1
,
2
);
output_data
[
l
*
l
-
1
]
=
input_data
[(
l
-
2
)
*
(
l
+
1
)]
+
tmp3
=
vextq_f32
(
in2
,
pad1
,
2
);
input_data
[(
l
-
2
)
*
(
l
+
1
)
+
1
]
+
input_data
[
l
*
l
-
2
]
+
input_data
[
l
*
l
-
1
];
sum
=
vaddq_f32
(
in0
,
tmp0
);
output_data
[
0
]
=
output_data
[
0
]
*
coef2
;
sum
=
vaddq_f32
(
sum
,
tmp1
);
output_data
[
l
-
1
]
=
output_data
[
l
-
1
]
*
coef2
;
sum
=
vaddq_f32
(
sum
,
in2
);
output_data
[(
l
-
1
)
*
l
]
=
output_data
[(
l
-
1
)
*
l
]
*
coef2
;
sum
=
vaddq_f32
(
sum
,
tmp2
);
output_data
[
l
*
l
-
1
]
=
output_data
[
l
*
l
-
1
]
*
coef2
;
sum
=
vaddq_f32
(
sum
,
tmp3
);
out0
=
vmulq_f32
(
sum
,
v_coef
);
for
(
int
i
=
1
;
i
<
l
-
1
;
++
i
)
{
output_data
[
i
*
l
]
=
input_data
[
i
*
l
-
l
]
+
input_data
[
i
*
l
-
l
+
1
]
+
input_data
[
i
*
l
]
+
input_data
[
i
*
l
+
1
]
+
input_data
[
i
*
l
+
l
]
+
input_data
[
i
*
l
+
l
+
1
];
output_data
[
i
*
l
+
l
-
1
]
=
input_data
[
i
*
l
+
l
-
1
-
l
-
1
]
+
input_data
[
i
*
l
+
l
-
1
-
l
]
+
input_data
[
i
*
l
+
l
-
1
-
1
]
+
input_data
[
i
*
l
+
l
-
1
]
+
input_data
[
i
*
l
+
l
-
1
+
l
-
1
]
+
input_data
[
i
*
l
+
l
-
1
+
l
];
output_data
[
i
*
l
]
=
output_data
[
i
*
l
]
*
coef1
;
output_data
[
i
*
l
+
l
-
1
]
=
output_data
[
i
*
l
+
l
-
1
]
*
coef1
;
}
int
m
;
for
(
m
=
1
;
m
<
output_width
-
4
;
m
+=
4
)
{
float
*
output_ptr
=
output_data
+
m
;
float32x4_t
in0
,
in1
,
in2
,
in3
,
tmp0
,
tmp1
,
tmp2
,
tmp3
,
out0
;
in0
=
vld1q_f32
(
input_data
+
m
-
1
);
in1
=
vld1q_f32
(
input_data
+
m
+
3
);
in2
=
vld1q_f32
(
input_data
+
input_width
+
m
-
1
);
in3
=
vld1q_f32
(
input_data
+
input_width
+
m
+
3
);
tmp0
=
vextq_f32
(
in0
,
in1
,
1
);
tmp1
=
vextq_f32
(
in0
,
in1
,
2
);
tmp2
=
vextq_f32
(
in2
,
in3
,
1
);
tmp3
=
vextq_f32
(
in2
,
in3
,
2
);
out0
=
in0
;
out0
=
vaddq_f32
(
out0
,
tmp0
);
out0
=
vaddq_f32
(
out0
,
tmp1
);
out0
=
vaddq_f32
(
out0
,
in2
);
out0
=
vaddq_f32
(
out0
,
tmp2
);
out0
=
vaddq_f32
(
out0
,
tmp3
);
for
(
int
i
=
0
;
i
<
c_mid
;
++
i
)
{
vst1q_f32
(
output_ptr
,
vmulq_f32
(
out0
,
v_coef1
));
if
(
i
==
0
)
{
vst1q_lane_f32
(
output_ptr
+
i
,
out0
,
0
);
}
if
(
i
==
1
)
{
vst1q_lane_f32
(
output_ptr
+
i
,
out0
,
1
);
}
if
(
i
==
2
)
{
vst1q_lane_f32
(
output_ptr
+
i
,
out0
,
2
);
}
}
}
// bottom_right remain
for
(
m
=
1
;
(
m
+
3
)
<
output_width
-
1
;
m
+=
4
)
{
float32x4_t
pad2
=
vdupq_n_f32
(
input_seg
[(
h_in
-
1
)
*
w_in
-
1
]);
float32x4_t
pad3
=
vdupq_n_f32
(
input_seg
[
h_in
*
w_in
-
1
]);
tmp0
=
vextq_f32
(
in4
,
pad2
,
1
);
tmp1
=
vextq_f32
(
in4
,
pad2
,
2
);
tmp2
=
vextq_f32
(
in6
,
pad3
,
2
);
tmp3
=
vextq_f32
(
in6
,
pad3
,
2
);
sum
=
vaddq_f32
(
in4
,
tmp0
);
sum
=
vaddq_f32
(
sum
,
tmp1
);
sum
=
vaddq_f32
(
sum
,
in6
);
sum
=
vaddq_f32
(
sum
,
tmp2
);
sum
=
vaddq_f32
(
sum
,
tmp3
);
out0
=
vmulq_f32
(
sum
,
v_coef
);
for
(
int
i
=
0
;
i
<
c_mid
;
++
i
)
{
if
(
i
==
0
)
{
vst1q_lane_f32
(
output_ptr
+
(
h_out
-
1
)
*
w_out
+
i
,
out0
,
0
);
}
}
if
(
i
==
1
)
{
for
(
int
j
=
m
;
j
<
output_width
-
1
;
j
++
)
{
vst1q_lane_f32
(
output_ptr
+
(
h_out
-
1
)
*
w_out
+
i
,
out0
,
1
);
output_data
[
j
]
=
input_data
[
j
-
1
]
+
input_data
[
j
]
+
input_data
[
j
+
1
]
+
}
input_data
[
input_width
+
j
-
1
]
+
if
(
i
==
2
)
{
input_data
[
input_width
+
j
]
+
vst1q_lane_f32
(
output_ptr
+
(
h_out
-
1
)
*
w_out
+
i
,
out0
,
2
);
input_data
[
input_width
+
j
+
1
];
output_data
[
j
]
=
output_data
[
j
]
*
coef1
;
}
}
}
// mid
for
(
int
j
=
0
;
j
<
h_out
-
2
;
++
j
)
{
output_ptr
=
output_seg
+
w_out
*
(
j
+
1
)
+
1
;
input_tmp
=
input_seg
+
j
*
w_in
;
in0
=
vld1q_f32
(
input_tmp
);
for
(
m
=
1
;
m
<
output_width
-
4
;
m
+=
4
)
{
in2
=
vld1q_f32
(
input_tmp
+
w_in
);
float
*
output_ptr
=
in4
=
vld1q_f32
(
input_tmp
+
2
*
w_in
);
output_data
+
(
output_height
-
1
)
*
output_width
+
m
;
c_mid
=
w_out
-
2
;
for
(;
c_mid
>
3
;
c_mid
-=
4
)
{
in1
=
vld1q_f32
(
input_tmp
+
4
);
in3
=
vld1q_f32
(
input_tmp
+
w_in
+
4
);
in5
=
vld1q_f32
(
input_tmp
+
2
*
w_in
+
4
);
float32x4_t
in0
,
in1
,
in2
,
in3
,
tmp0
,
tmp1
,
tmp2
,
tmp3
,
out0
;
in0
=
vld1q_f32
(
input_data
+
(
output_height
-
2
)
*
input_width
+
m
-
1
);
in1
=
vld1q_f32
(
input_data
+
(
output_height
-
2
)
*
input_width
+
m
+
3
);
in2
=
vld1q_f32
(
input_data
+
(
output_height
-
1
)
*
input_width
+
m
-
1
);
in3
=
vld1q_f32
(
input_data
+
(
output_height
-
1
)
*
input_width
+
m
+
3
);
tmp0
=
vextq_f32
(
in0
,
in1
,
1
);
tmp0
=
vextq_f32
(
in0
,
in1
,
1
);
tmp1
=
vextq_f32
(
in0
,
in1
,
2
);
tmp1
=
vextq_f32
(
in0
,
in1
,
2
);
tmp2
=
vextq_f32
(
in2
,
in3
,
1
);
tmp2
=
vextq_f32
(
in2
,
in3
,
1
);
tmp3
=
vextq_f32
(
in2
,
in3
,
2
);
tmp3
=
vextq_f32
(
in2
,
in3
,
2
);
tmp4
=
vextq_f32
(
in4
,
in5
,
1
);
out0
=
in0
;
tmp5
=
vextq_f32
(
in4
,
in5
,
2
);
out0
=
vaddq_f32
(
out0
,
tmp0
);
out0
=
vaddq_f32
(
out0
,
tmp1
);
sum
=
vaddq_f32
(
in0
,
tmp0
);
out0
=
vaddq_f32
(
out0
,
in2
);
sum
=
vaddq_f32
(
sum
,
tmp1
);
out0
=
vaddq_f32
(
out0
,
tmp2
);
sum
=
vaddq_f32
(
sum
,
in2
);
out0
=
vaddq_f32
(
out0
,
tmp3
);
sum
=
vaddq_f32
(
sum
,
tmp2
);
sum
=
vaddq_f32
(
sum
,
tmp3
);
vst1q_f32
(
output_ptr
,
vmulq_f32
(
out0
,
v_coef1
));
sum
=
vaddq_f32
(
sum
,
in4
);
}
sum
=
vaddq_f32
(
sum
,
tmp4
);
for
(
m
=
1
;
(
m
+
3
)
<
output_width
-
1
;
m
=
m
+
4
)
{
sum
=
vaddq_f32
(
sum
,
tmp5
);
}
for
(
int
j
=
m
;
j
<
output_width
-
1
;
j
++
)
{
out0
=
vmulq_f32
(
sum
,
v_coef
);
output_data
[(
output_height
-
1
)
*
input_width
+
j
]
=
vst1q_f32
(
output_ptr
,
out0
);
input_data
[(
output_height
-
2
)
*
input_width
+
j
-
1
]
+
output_ptr
+=
4
;
input_data
[(
output_height
-
2
)
*
input_width
+
j
]
+
input_tmp
+=
4
;
input_data
[(
output_height
-
2
)
*
input_width
+
j
+
1
]
+
in0
=
in1
;
input_data
[(
output_height
-
1
)
*
input_width
+
j
-
1
]
+
in2
=
in3
;
input_data
[(
output_height
-
1
)
*
input_width
+
j
]
+
in4
=
in5
;
input_data
[(
output_height
-
1
)
*
input_width
+
j
+
1
];
}
output_data
[(
output_height
-
1
)
*
output_width
+
j
]
=
// mid remain
output_data
[(
output_height
-
1
)
*
output_width
+
j
]
*
coef1
;
float32x4_t
pad0
=
vdupq_n_f32
(
input_seg
[(
j
+
1
)
*
w_in
-
1
]);
}
float32x4_t
pad1
=
vdupq_n_f32
(
input_seg
[(
j
+
2
)
*
w_in
-
1
]);
}
float32x4_t
pad2
=
vdupq_n_f32
(
input_seg
[(
j
+
2
)
*
w_in
-
1
]);
}
tmp0
=
vextq_f32
(
in0
,
pad0
,
1
);
// const int batch_size = input->dims()[0];
tmp1
=
vextq_f32
(
in0
,
pad0
,
2
);
//
tmp2
=
vextq_f32
(
in2
,
pad1
,
1
);
// const int h_in = input->dims()[2];
tmp3
=
vextq_f32
(
in2
,
pad1
,
2
);
//
tmp4
=
vextq_f32
(
in4
,
pad2
,
1
);
// const int w_in = input->dims()[3];
tmp5
=
vextq_f32
(
in4
,
pad2
,
2
);
//
// const int output_channels = output->dims()[1];
sum
=
vaddq_f32
(
in0
,
tmp0
);
//
sum
=
vaddq_f32
(
sum
,
tmp1
);
// const int h_out = output->dims()[2];
sum
=
vaddq_f32
(
sum
,
in2
);
// const int w_out = output->dims()[3];
sum
=
vaddq_f32
(
sum
,
tmp2
);
// const int outputdata_channel_stride = h_out * w_out;
sum
=
vaddq_f32
(
sum
,
tmp3
);
// const int inputdata_channel_stride = h_in * w_in;
sum
=
vaddq_f32
(
sum
,
in4
);
// const int input_batch_stride = output_channels * inputdata_channel_stride;
sum
=
vaddq_f32
(
sum
,
tmp4
);
// const int output_batch_stride = output_channels *
sum
=
vaddq_f32
(
sum
,
tmp5
);
// outputdata_channel_stride; float *out_data = output->data<float>(); const
out0
=
vmulq_f32
(
sum
,
v_coef
);
// float *input_data = input->data<float>();
//
for
(
int
i
=
0
;
i
<
c_mid
;
++
i
)
{
// const float coef = 1.0 / 9.0;
if
(
i
==
0
)
{
// for (int k = 0; k < batch_size; ++k) {
vst1q_lane_f32
(
output_ptr
+
i
,
out0
,
0
);
//#pragma omp parallel for
}
// for (int c = 0; c < output_channels; ++c) {
if
(
i
==
1
)
{
// const float *input_seg = input_data + c * inputdata_channel_stride;
vst1q_lane_f32
(
output_ptr
+
i
,
out0
,
1
);
// float *output_seg = out_data + c * outputdata_channel_stride;
}
// // four corner point
if
(
i
==
2
)
{
// output_seg[0] = (input_seg[0] + input_seg[1] + input_seg[w_in] +
vst1q_lane_f32
(
output_ptr
+
i
,
out0
,
2
);
// input_seg[w_in + 1]) *
}
// coef;
}
// output_seg[w_out - 1] =
}
// (input_seg[w_in - 2] + input_seg[w_in - 1] + input_seg[w_in * 2 -
// input_data += inputdata_channel_stride;
// 2] +
// out_data += outputdata_channel_stride;
// input_seg[2 * w_in - 1]) *
}
// coef;
input_data
+=
input_batch_stride
;
// output_seg[(h_out - 1) * w_out] =
out_data
+=
output_batch_stride
;
// (input_seg[(h_in - 2) * w_in] + input_seg[(h_in - 2) * w_in + 1] +
}
// input_seg[(h_in - 1) * w_in] + input_seg[(h_in - 1) * w_in + 1])
// *
// coef;
// output_seg[h_out * w_out - 1] =
// (input_seg[h_in * w_in - 1] + input_seg[h_in * w_in - 2] +
// input_seg[(h_in - 1) * w_in - 1] +
// input_seg[(h_in - 1) * w_in - 2]) *
// coef;
// // left side & right side
// for (int i = 1; i < h_in - 1; ++i) {
// output_seg[i * w_out] =
// (input_seg[i * w_in - w_in] + input_seg[i * w_in - w_in + 1] +
// input_seg[i * w_in] + input_seg[i * w_in + 1] +
// input_seg[i * w_in + w_in] + input_seg[i * w_in + w_in + 1]) *
// coef;
// output_seg[i * w_out + w_out - 1] =
// (input_seg[i * w_in - w_in + w_in - 2] +
// input_seg[i * w_in - w_in + 1 + w_in - 2] +
// input_seg[i * w_in + w_in - 2] +
// input_seg[i * w_in + 1 + w_in - 2] +
// input_seg[i * w_in + w_in + w_in - 2] +
// input_seg[i * w_in + w_in + 1 + w_in - 2]) *
// coef;
// }
// // top 1 row & bottom 1 row
// const float *input_tmp = input_seg;
//
// float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
// tmp3, tmp4, tmp5, sum, out0;
// float32x4_t v_coef = vdupq_n_f32(coef);
// in0 = vld1q_f32(input_tmp);
// in2 = vld1q_f32(input_tmp + w_in);
// const float *input_tmp_end = input_tmp + (h_in - 2) * w_in;
// in4 = vld1q_f32(input_tmp_end);
// in6 = vld1q_f32(input_tmp_end + w_in);
// int c_mid = w_out - 2;
// auto output_ptr = output_seg + 1;
// for (; c_mid > 3; c_mid -= 4) {
// in1 = vld1q_f32(input_tmp + 4);
// in3 = vld1q_f32(input_tmp + w_in + 4);
//
// tmp0 = vextq_f32(in0, in1, 1);
// tmp1 = vextq_f32(in0, in1, 2);
//
// tmp2 = vextq_f32(in2, in3, 1);
// tmp3 = vextq_f32(in2, in3, 2);
//
// sum = vaddq_f32(in0, tmp0);
// sum = vaddq_f32(sum, tmp1);
// sum = vaddq_f32(sum, in2);
// sum = vaddq_f32(sum, tmp2);
// sum = vaddq_f32(sum, tmp3);
//
// vst1q_f32(output_ptr, vmulq_f32(sum, v_coef));
//
// in5 = vld1q_f32(input_tmp_end + 4);
// in7 = vld1q_f32(input_tmp_end + w_in + 4);
//
// tmp0 = vextq_f32(in4, in5, 1);
// tmp1 = vextq_f32(in4, in5, 2);
// tmp2 = vextq_f32(in6, in7, 1);
// tmp3 = vextq_f32(in6, in7, 2);
//
// sum = vaddq_f32(in0, tmp0);
// sum = vaddq_f32(sum, tmp1);
// sum = vaddq_f32(sum, in2);
// sum = vaddq_f32(sum, tmp2);
// sum = vaddq_f32(sum, tmp3);
//
// vst1q_f32(output_ptr + (h_out - 1) * w_out, vmulq_f32(sum, v_coef));
//
// // can optimize to each 8 stride.
// input_tmp += 4;
// input_tmp_end += 4;
// output_ptr += 4;
// in0 = in1;
// in2 = in3;
// in4 = in5;
// in6 = in7;
// }
// // top right remain
// float32x4_t pad0 = vdupq_n_f32(input_seg[w_in - 1]);
// float32x4_t pad1 = vdupq_n_f32(input_seg[2 * w_in - 1]);
//
// tmp0 = vextq_f32(in0, pad0, 1);
// tmp1 = vextq_f32(in0, pad0, 2);
// tmp2 = vextq_f32(in2, pad1, 2);
// tmp3 = vextq_f32(in2, pad1, 2);
//
// sum = vaddq_f32(in0, tmp0);
// sum = vaddq_f32(sum, tmp1);
// sum = vaddq_f32(sum, in2);
// sum = vaddq_f32(sum, tmp2);
// sum = vaddq_f32(sum, tmp3);
// out0 = vmulq_f32(sum, v_coef);
//
// for (int i = 0; i < c_mid; ++i) {
// if (i == 0) {
// vst1q_lane_f32(output_ptr + i, out0, 0);
// }
// if (i == 1) {
// vst1q_lane_f32(output_ptr + i, out0, 1);
// }
// if (i == 2) {
// vst1q_lane_f32(output_ptr + i, out0, 2);
// }
// }
//
// // bottom_right remain
// float32x4_t pad2 = vdupq_n_f32(input_seg[(h_in - 1) * w_in - 1]);
// float32x4_t pad3 = vdupq_n_f32(input_seg[h_in * w_in - 1]);
//
// tmp0 = vextq_f32(in4, pad2, 1);
// tmp1 = vextq_f32(in4, pad2, 2);
// tmp2 = vextq_f32(in6, pad3, 2);
// tmp3 = vextq_f32(in6, pad3, 2);
//
// sum = vaddq_f32(in4, tmp0);
// sum = vaddq_f32(sum, tmp1);
// sum = vaddq_f32(sum, in6);
// sum = vaddq_f32(sum, tmp2);
// sum = vaddq_f32(sum, tmp3);
// out0 = vmulq_f32(sum, v_coef);
//
// for (int i = 0; i < c_mid; ++i) {
// if (i == 0) {
// vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 0);
// }
// if (i == 1) {
// vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 1);
// }
// if (i == 2) {
// vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 2);
// }
// }
// // mid
// for (int j = 0; j < h_out - 2; ++j) {
// output_ptr = output_seg + w_out * (j + 1) + 1;
// input_tmp = input_seg + j * w_in;
//
// in0 = vld1q_f32(input_tmp);
// in2 = vld1q_f32(input_tmp + w_in);
// in4 = vld1q_f32(input_tmp + 2 * w_in);
// c_mid = w_out - 2;
// for (; c_mid > 3; c_mid -= 4) {
// in1 = vld1q_f32(input_tmp + 4);
// in3 = vld1q_f32(input_tmp + w_in + 4);
// in5 = vld1q_f32(input_tmp + 2 * w_in + 4);
//
// tmp0 = vextq_f32(in0, in1, 1);
// tmp1 = vextq_f32(in0, in1, 2);
// tmp2 = vextq_f32(in2, in3, 1);
// tmp3 = vextq_f32(in2, in3, 2);
// tmp4 = vextq_f32(in4, in5, 1);
// tmp5 = vextq_f32(in4, in5, 2);
//
// sum = vaddq_f32(in0, tmp0);
// sum = vaddq_f32(sum, tmp1);
// sum = vaddq_f32(sum, in2);
// sum = vaddq_f32(sum, tmp2);
// sum = vaddq_f32(sum, tmp3);
// sum = vaddq_f32(sum, in4);
// sum = vaddq_f32(sum, tmp4);
// sum = vaddq_f32(sum, tmp5);
//
// out0 = vmulq_f32(sum, v_coef);
// vst1q_f32(output_ptr, out0);
// output_ptr += 4;
// input_tmp += 4;
// in0 = in1;
// in2 = in3;
// in4 = in5;
// }
// // mid remain
// float32x4_t pad0 = vdupq_n_f32(input_seg[(j + 1) * w_in - 1]);
// float32x4_t pad1 = vdupq_n_f32(input_seg[(j + 2) * w_in - 1]);
// float32x4_t pad2 = vdupq_n_f32(input_seg[(j + 2) * w_in - 1]);
//
// tmp0 = vextq_f32(in0, pad0, 1);
// tmp1 = vextq_f32(in0, pad0, 2);
// tmp2 = vextq_f32(in2, pad1, 1);
// tmp3 = vextq_f32(in2, pad1, 2);
// tmp4 = vextq_f32(in4, pad2, 1);
// tmp5 = vextq_f32(in4, pad2, 2);
//
// sum = vaddq_f32(in0, tmp0);
// sum = vaddq_f32(sum, tmp1);
// sum = vaddq_f32(sum, in2);
// sum = vaddq_f32(sum, tmp2);
// sum = vaddq_f32(sum, tmp3);
// sum = vaddq_f32(sum, in4);
// sum = vaddq_f32(sum, tmp4);
// sum = vaddq_f32(sum, tmp5);
// out0 = vmulq_f32(sum, v_coef);
//
// for (int i = 0; i < c_mid; ++i) {
// if (i == 0) {
// vst1q_lane_f32(output_ptr + i, out0, 0);
// }
// if (i == 1) {
// vst1q_lane_f32(output_ptr + i, out0, 1);
// }
// if (i == 2) {
// vst1q_lane_f32(output_ptr + i, out0, 2);
// }
// }
// }
// // input_data += inputdata_channel_stride;
// // out_data += outputdata_channel_stride;
// }
// input_data += input_batch_stride;
// out_data += output_batch_stride;
// }
#endif
#endif
}
}
...
@@ -662,6 +839,7 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input,
...
@@ -662,6 +839,7 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input,
wstart
=
max
(
wstart
,
0
);
wstart
=
max
(
wstart
,
0
);
hend
=
min
(
hend
,
input_height
);
hend
=
min
(
hend
,
input_height
);
wend
=
min
(
wend
,
input_width
);
wend
=
min
(
wend
,
input_width
);
const
float
*
pos1
=
input_seg
+
hstart
*
input_width
+
wstart
;
const
float
*
pos1
=
input_seg
+
hstart
*
input_width
+
wstart
;
const
float
*
pos2
=
input_seg
+
(
hstart
+
1
)
*
input_width
+
wstart
;
const
float
*
pos2
=
input_seg
+
(
hstart
+
1
)
*
input_width
+
wstart
;
const
float
*
pos3
=
input_seg
+
(
hstart
+
2
)
*
input_width
+
wstart
;
const
float
*
pos3
=
input_seg
+
(
hstart
+
2
)
*
input_width
+
wstart
;
...
@@ -674,7 +852,8 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input,
...
@@ -674,7 +852,8 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input,
sum
+=
input_seg
[
h
*
input_width
+
w
];
sum
+=
input_seg
[
h
*
input_width
+
w
];
}
}
}
}
output_seg
[
ph
*
output_width
+
pw
]
=
sum
/
9.0
;
output_seg
[
ph
*
output_width
+
pw
]
=
sum
/
((
hend
-
hstart
)
*
(
wend
-
wstart
)
*
1.0
);
}
else
{
}
else
{
#if __aarch64__
#if __aarch64__
#else
#else
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录