提交 26769ad7 编写于 作者: H hjchen2

Fix pooling3x3 bug

上级 afa836d9
...@@ -60,12 +60,18 @@ void Pooling<P>::operator()(const framework::Tensor &input, ...@@ -60,12 +60,18 @@ void Pooling<P>::operator()(const framework::Tensor &input,
wstart = std::max(wstart, 0); wstart = std::max(wstart, 0);
PoolingVal<P> val; PoolingVal<P> val;
// std::cout << "output[" << ph * output_width + pw << "]:"
// << std::endl;
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
val += input_ptr[h * input_width + w]; val += input_ptr[h * input_width + w];
// std::cout << "input[" << h << "][" << w << "] = "
// << input_ptr[h * input_width + w] << std::endl;
} }
} }
output_data[ph * output_width + pw] = val.Value(); output_ptr[ph * output_width + pw] = val.Value();
// std::cout << "output[" << ph * output_width + pw << "] = "
// << val.Value() << std::endl;
} }
} }
} }
......
...@@ -35,7 +35,7 @@ struct PoolingVal { ...@@ -35,7 +35,7 @@ struct PoolingVal {
float val; float val;
int count; int count;
PoolingVal() { PoolingVal() {
val = std::numeric_limits<float>::min(); val = -std::numeric_limits<float>::max();
count = 0; count = 0;
} }
inline PoolingVal<P> &operator+=(const float &x) { inline PoolingVal<P> &operator+=(const float &x) {
......
...@@ -161,9 +161,9 @@ struct Pooling3x3<P, 1> { ...@@ -161,9 +161,9 @@ struct Pooling3x3<P, 1> {
x0.val[0] = vld1q_f32(input_ptr0); x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4); x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]); x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
...@@ -172,9 +172,9 @@ struct Pooling3x3<P, 1> { ...@@ -172,9 +172,9 @@ struct Pooling3x3<P, 1> {
x0.val[0] = vld1q_f32(input_ptr1); x0.val[0] = vld1q_f32(input_ptr1);
x0.val[1] = vld1q_f32(input_ptr1 + 4); x0.val[1] = vld1q_f32(input_ptr1 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]); x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y1.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); y1.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
...@@ -185,9 +185,9 @@ struct Pooling3x3<P, 1> { ...@@ -185,9 +185,9 @@ struct Pooling3x3<P, 1> {
x0.val[0] = vld1q_f32(input_ptr2); x0.val[0] = vld1q_f32(input_ptr2);
x0.val[1] = vld1q_f32(input_ptr2 + 4); x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]); x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y2.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); y2.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
...@@ -204,9 +204,9 @@ struct Pooling3x3<P, 1> { ...@@ -204,9 +204,9 @@ struct Pooling3x3<P, 1> {
x0.val[0] = vld1q_f32(input_ptr3); x0.val[0] = vld1q_f32(input_ptr3);
x0.val[1] = vld1q_f32(input_ptr3 + 4); x0.val[1] = vld1q_f32(input_ptr3 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]); x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
...@@ -223,9 +223,9 @@ struct Pooling3x3<P, 1> { ...@@ -223,9 +223,9 @@ struct Pooling3x3<P, 1> {
x0.val[0] = vld1q_f32(input_ptr4); x0.val[0] = vld1q_f32(input_ptr4);
x0.val[1] = vld1q_f32(input_ptr4 + 4); x0.val[1] = vld1q_f32(input_ptr4 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]); x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
...@@ -242,9 +242,9 @@ struct Pooling3x3<P, 1> { ...@@ -242,9 +242,9 @@ struct Pooling3x3<P, 1> {
x0.val[0] = vld1q_f32(input_ptr5); x0.val[0] = vld1q_f32(input_ptr5);
x0.val[1] = vld1q_f32(input_ptr5 + 4); x0.val[1] = vld1q_f32(input_ptr5 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]); x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
...@@ -270,12 +270,14 @@ struct Pooling3x3<P, 1> { ...@@ -270,12 +270,14 @@ struct Pooling3x3<P, 1> {
// remain w // remain w
if (remain >= 4) { if (remain >= 4) {
x0.val[0] = vld1q_f32(input_ptr0); x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[0] = vld1q_f32(input_ptr1); x0.val[0] = vld1q_f32(input_ptr1);
x0.val[1] = vld1q_f32(input_ptr1 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
...@@ -283,6 +285,7 @@ struct Pooling3x3<P, 1> { ...@@ -283,6 +285,7 @@ struct Pooling3x3<P, 1> {
y0.val[0] = vPoolPreq_f32<P>(y1.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(y1.val[0], y0.val[0]);
x0.val[0] = vld1q_f32(input_ptr2); x0.val[0] = vld1q_f32(input_ptr2);
x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
...@@ -293,6 +296,7 @@ struct Pooling3x3<P, 1> { ...@@ -293,6 +296,7 @@ struct Pooling3x3<P, 1> {
vst1q_f32(output_ptr0, y0.val[0]); vst1q_f32(output_ptr0, y0.val[0]);
x0.val[0] = vld1q_f32(input_ptr3); x0.val[0] = vld1q_f32(input_ptr3);
x0.val[1] = vld1q_f32(input_ptr3 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
...@@ -303,6 +307,7 @@ struct Pooling3x3<P, 1> { ...@@ -303,6 +307,7 @@ struct Pooling3x3<P, 1> {
vst1q_f32(output_ptr1, y1.val[0]); vst1q_f32(output_ptr1, y1.val[0]);
x0.val[0] = vld1q_f32(input_ptr4); x0.val[0] = vld1q_f32(input_ptr4);
x0.val[1] = vld1q_f32(input_ptr4 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
...@@ -313,6 +318,7 @@ struct Pooling3x3<P, 1> { ...@@ -313,6 +318,7 @@ struct Pooling3x3<P, 1> {
vst1q_f32(output_ptr2, y2.val[0]); vst1q_f32(output_ptr2, y2.val[0]);
x0.val[0] = vld1q_f32(input_ptr5); x0.val[0] = vld1q_f32(input_ptr5);
x0.val[1] = vld1q_f32(input_ptr5 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
...@@ -372,9 +378,9 @@ struct Pooling3x3<P, 1> { ...@@ -372,9 +378,9 @@ struct Pooling3x3<P, 1> {
x0.val[0] = vld1q_f32(input_ptr0); x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4); x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]); x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
...@@ -383,22 +389,26 @@ struct Pooling3x3<P, 1> { ...@@ -383,22 +389,26 @@ struct Pooling3x3<P, 1> {
x0.val[0] = vld1q_f32(input_ptr1); x0.val[0] = vld1q_f32(input_ptr1);
x0.val[1] = vld1q_f32(input_ptr1 + 4); x0.val[1] = vld1q_f32(input_ptr1 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]); x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]); y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
x0.val[0] = vld1q_f32(input_ptr2); x0.val[0] = vld1q_f32(input_ptr2);
x0.val[1] = vld1q_f32(input_ptr2 + 4); x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]); x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]); y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0]); y0.val[0] = vPoolPostq_f32<P>(y0.val[0]);
...@@ -414,21 +424,26 @@ struct Pooling3x3<P, 1> { ...@@ -414,21 +424,26 @@ struct Pooling3x3<P, 1> {
// remain w // remain w
if (remain >= 4) { if (remain >= 4) {
x0.val[0] = vld1q_f32(input_ptr0); x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[0] = vld1q_f32(input_ptr1); x0.val[0] = vld1q_f32(input_ptr1);
x0.val[1] = vld1q_f32(input_ptr1 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
x0.val[0] = vld1q_f32(input_ptr2); x0.val[0] = vld1q_f32(input_ptr2);
x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1); x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2); x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0]); y0.val[0] = vPoolPostq_f32<P>(y0.val[0]);
vst1q_f32(output_ptr0, y0.val[0]); vst1q_f32(output_ptr0, y0.val[0]);
...@@ -540,6 +555,8 @@ struct Pooling3x3<P, 2> { ...@@ -540,6 +555,8 @@ struct Pooling3x3<P, 2> {
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]); x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]); y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
...@@ -616,6 +633,7 @@ struct Pooling3x3<P, 2> { ...@@ -616,6 +633,7 @@ struct Pooling3x3<P, 2> {
input_ptr3 += 12; input_ptr3 += 12;
input_ptr4 += 12; input_ptr4 += 12;
input_ptr5 += 12; input_ptr5 += 12;
input_ptr6 += 12;
output_ptr0 += 6; output_ptr0 += 6;
output_ptr1 += 6; output_ptr1 += 6;
output_ptr2 += 6; output_ptr2 += 6;
...@@ -632,6 +650,7 @@ struct Pooling3x3<P, 2> { ...@@ -632,6 +650,7 @@ struct Pooling3x3<P, 2> {
x1.val[0] = vdupq_n_f32(input_ptr1[8]); x1.val[0] = vdupq_n_f32(input_ptr1[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
x0 = vld2q_f32(input_ptr2); x0 = vld2q_f32(input_ptr2);
...@@ -681,6 +700,7 @@ struct Pooling3x3<P, 2> { ...@@ -681,6 +700,7 @@ struct Pooling3x3<P, 2> {
input_ptr3 += 8; input_ptr3 += 8;
input_ptr4 += 8; input_ptr4 += 8;
input_ptr5 += 8; input_ptr5 += 8;
input_ptr6 += 8;
output_ptr0 += 4; output_ptr0 += 4;
output_ptr1 += 4; output_ptr1 += 4;
output_ptr2 += 4; output_ptr2 += 4;
...@@ -738,6 +758,8 @@ struct Pooling3x3<P, 2> { ...@@ -738,6 +758,8 @@ struct Pooling3x3<P, 2> {
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1); x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]); x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]); y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
...@@ -773,6 +795,7 @@ struct Pooling3x3<P, 2> { ...@@ -773,6 +795,7 @@ struct Pooling3x3<P, 2> {
x1.val[0] = vdupq_n_f32(input_ptr1[8]); x1.val[0] = vdupq_n_f32(input_ptr1[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1); x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]); x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]); y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
x0 = vld2q_f32(input_ptr2); x0 = vld2q_f32(input_ptr2);
......
...@@ -21,20 +21,7 @@ namespace paddle_mobile { ...@@ -21,20 +21,7 @@ namespace paddle_mobile {
namespace math = operators::math; namespace math = operators::math;
static int PoolOutputSize(int input_size, int filter_size, int padding, template <int PoolType, int Kernel, int Pad, int Stride>
int stride, bool ceil_mode) {
int output_size;
if (!ceil_mode) {
output_size = (input_size - filter_size + 2 * padding) / stride + 1;
} else {
output_size =
(input_size - filter_size + 2 * padding + stride - 1) / stride + 1;
}
return output_size;
}
template <typename T, int CeilMode, int PoolType, int Kernel, int Pad,
int Stride>
int TestPoolOp(int in_channels, int in_height, int in_width) { int TestPoolOp(int in_channels, int in_height, int in_width) {
int kernel_h = Kernel; int kernel_h = Kernel;
int kernel_w = Kernel; int kernel_w = Kernel;
...@@ -42,7 +29,6 @@ int TestPoolOp(int in_channels, int in_height, int in_width) { ...@@ -42,7 +29,6 @@ int TestPoolOp(int in_channels, int in_height, int in_width) {
int pad_w = Pad; int pad_w = Pad;
int stride_h = Stride; int stride_h = Stride;
int stride_w = Stride; int stride_w = Stride;
bool ceil_mode = CeilMode != 0;
std::string pooling_type = (PoolType == 0 ? "max" : "avg"); std::string pooling_type = (PoolType == 0 ? "max" : "avg");
int batch_size = 1; int batch_size = 1;
...@@ -53,14 +39,6 @@ int TestPoolOp(int in_channels, int in_height, int in_width) { ...@@ -53,14 +39,6 @@ int TestPoolOp(int in_channels, int in_height, int in_width) {
framework::DDim input_shape = framework::DDim input_shape =
framework::make_ddim({batch_size, input_c, input_h, input_w}); framework::make_ddim({batch_size, input_c, input_h, input_w});
std::vector<int64_t> output_shape_v({batch_size, input_c});
output_shape_v.push_back(
PoolOutputSize(input_h, kernel_h, pad_h, stride_h, ceil_mode));
output_shape_v.push_back(
PoolOutputSize(input_w, kernel_w, pad_w, stride_w, ceil_mode));
framework::DDim output_shape = framework::make_ddim(output_shape_v);
VariableNameMap inputs; VariableNameMap inputs;
VariableNameMap outputs; VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>(); auto scope = std::make_shared<framework::Scope>();
...@@ -69,7 +47,11 @@ int TestPoolOp(int in_channels, int in_height, int in_width) { ...@@ -69,7 +47,11 @@ int TestPoolOp(int in_channels, int in_height, int in_width) {
auto input_var = scope.get()->Var("input"); auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>(); auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<T>(input, input_shape, -127, 127); SetupTensor<float>(input, input_shape, -127, 127);
// for (int i = 0; i < input->numel(); ++i) {
// DLOG << "input[" << i << "] = " << input->data<float>()[i];
// }
auto output_var = scope.get()->Var("output"); auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs; framework::AttributeMap attrs;
...@@ -86,8 +68,9 @@ int TestPoolOp(int in_channels, int in_height, int in_width) { ...@@ -86,8 +68,9 @@ int TestPoolOp(int in_channels, int in_height, int in_width) {
op->Init(); op->Init();
op->Run(); op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp; framework::Tensor output_cmp;
output_cmp.mutable_data<T>(output_shape); output_cmp.mutable_data<float>(output->dims());
if (pooling_type == "avg") { if (pooling_type == "avg") {
math::Pooling<Avg>()(*input, std::vector<int>{kernel_h, kernel_w}, math::Pooling<Avg>()(*input, std::vector<int>{kernel_h, kernel_w},
...@@ -100,13 +83,19 @@ int TestPoolOp(int in_channels, int in_height, int in_width) { ...@@ -100,13 +83,19 @@ int TestPoolOp(int in_channels, int in_height, int in_width) {
} }
// compare results // compare results
auto output = output_var->template Get<framework::LoDTensor>(); const float *output_data = output->data<float>();
const T *output_data = output->data<T>(); float *output_cmp_data = output_cmp.data<float>();
T *output_cmp_data = output_cmp.data<T>();
for (int i = 0; i < output->numel(); ++i) { for (int i = 0; i < output->numel(); ++i) {
PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i], float gap = output_data[i] - output_cmp_data[i];
"output[%d] = %d, output_cmp[%d] = %d", i, // PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i],
output_data[i], i, output_cmp_data[i]); // "output[%d] = %d, output_cmp[%d] = %d", i,
// output_data[i], i, output_cmp_data[i]);
if (gap > 1e-5 && std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
exit(1);
}
} }
delete op; delete op;
return 0; return 0;
...@@ -127,34 +116,80 @@ int main(int argc, char *argv[]) { ...@@ -127,34 +116,80 @@ int main(int argc, char *argv[]) {
int in_channels = atoi(argv[1]); int in_channels = atoi(argv[1]);
int in_height = atoi(argv[2]); int in_height = atoi(argv[2]);
int in_width = atoi(argv[3]); int in_width = atoi(argv[3]);
// kernel = 3, pad = 1, stride = 1 LOG(paddle_mobile::kLOG_INFO)
LOG(paddle_mobile::kLOG_INFO) << "float, pooling_type=max, kernel=3, pad=0, stride=1";
<< "float, ceil_mode=false, pooling_type=max, kernel=3, pad=1, stride=1"; paddle_mobile::TestPoolOp<0, 3, 0, 1>(in_channels, in_height, in_width);
paddle_mobile::TestPoolOp<float, 0, 0, 3, 1, 1>(in_channels, in_height, LOG(paddle_mobile::kLOG_INFO)
in_width); << "float, pooling_type=max, kernel=3, pad=1, stride=1";
// kernel = 3, pad = 0, stride = 2 paddle_mobile::TestPoolOp<0, 3, 1, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=max, kernel=3, pad=0, stride=2"; << "float, pooling_type=max, kernel=3, pad=2, stride=1";
paddle_mobile::TestPoolOp<float, 0, 0, 3, 0, 2>(in_channels, in_height, paddle_mobile::TestPoolOp<0, 3, 2, 1>(in_channels, in_height, in_width);
in_width); LOG(paddle_mobile::kLOG_INFO)
// kernel = 5, pad = 0, stride = 1 << "float, pooling_type=max, kernel=3, pad=5, stride=1";
LOG(paddle_mobile::kLOG_INFO) paddle_mobile::TestPoolOp<0, 3, 5, 1>(in_channels, in_height, in_width);
<< "float, ceil_mode=false, pooling_type=avg, kernel=5, pad=0, stride=1";
paddle_mobile::TestPoolOp<float, 0, 1, 5, 0, 1>(in_channels, in_height, LOG(paddle_mobile::kLOG_INFO)
in_width); << "float, pooling_type=avg, kernel=3, pad=0, stride=1";
// kernel = 5, pad = 0, stride = 2 paddle_mobile::TestPoolOp<1, 3, 0, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=avg, kernel=5, pad=0, stride=1"; << "float, pooling_type=avg, kernel=3, pad=1, stride=1";
paddle_mobile::TestPoolOp<float, 0, 1, 5, 0, 2>(in_channels, in_height, paddle_mobile::TestPoolOp<1, 3, 1, 1>(in_channels, in_height, in_width);
in_width); LOG(paddle_mobile::kLOG_INFO)
// kernel = 7, pad = 0, stride = 1 << "float, pooling_type=avg, kernel=3, pad=2, stride=1";
LOG(paddle_mobile::kLOG_INFO) paddle_mobile::TestPoolOp<1, 3, 2, 1>(in_channels, in_height, in_width);
<< "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=1"; LOG(paddle_mobile::kLOG_INFO)
paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 1>(in_channels, in_height, << "float, pooling_type=avg, kernel=3, pad=5, stride=1";
in_width); paddle_mobile::TestPoolOp<1, 3, 5, 1>(in_channels, in_height, in_width);
// kernel = 7, pad = 0, stride = 4
LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=4"; << "float, pooling_type=max, kernel=3, pad=0, stride=2";
paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 4>(in_channels, in_height, paddle_mobile::TestPoolOp<0, 3, 0, 2>(in_channels, in_height, in_width);
in_width); LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=3, pad=1, stride=2";
paddle_mobile::TestPoolOp<0, 3, 1, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=3, pad=2, stride=2";
paddle_mobile::TestPoolOp<0, 3, 2, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=3, pad=5, stride=2";
paddle_mobile::TestPoolOp<0, 3, 5, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=3, pad=0, stride=2";
paddle_mobile::TestPoolOp<1, 3, 0, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=3, pad=1, stride=2";
paddle_mobile::TestPoolOp<1, 3, 1, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=3, pad=2, stride=2";
paddle_mobile::TestPoolOp<1, 3, 2, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=3, pad=5, stride=2";
paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width);
// // kernel = 5, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO)
// << "float, ceil_mode=false, pooling_type=avg, kernel=5, pad=0,
// stride=1";
// paddle_mobile::TestPoolOp<float, 0, 1, 5, 0, 1>(in_channels, in_height,
// in_width);
// // kernel = 5, pad = 0, stride = 2
// LOG(paddle_mobile::kLOG_INFO)
// << "float, ceil_mode=false, pooling_type=avg, kernel=5, pad=0,
// stride=1";
// paddle_mobile::TestPoolOp<float, 0, 1, 5, 0, 2>(in_channels, in_height,
// in_width);
// // kernel = 7, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO)
// << "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0,
// stride=1";
// paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 1>(in_channels, in_height,
// in_width);
// // kernel = 7, pad = 0, stride = 4
// LOG(paddle_mobile::kLOG_INFO)
// << "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0,
// stride=4";
// paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 4>(in_channels, in_height,
// in_width);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册