diff --git a/src/operators/math/pooling.cpp b/src/operators/math/pooling.cpp index 3bb6c27c5af2cf3537ffeb17647ec9cc23fb1355..b4aba52b9b0aaab613a8669f05b3bb7ece70c933 100644 --- a/src/operators/math/pooling.cpp +++ b/src/operators/math/pooling.cpp @@ -60,12 +60,18 @@ void Pooling
::operator()(const framework::Tensor &input, wstart = std::max(wstart, 0); PoolingVal
val;
+ // std::cout << "output[" << ph * output_width + pw << "]:"
+ // << std::endl;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
val += input_ptr[h * input_width + w];
+ // std::cout << "input[" << h << "][" << w << "] = "
+ // << input_ptr[h * input_width + w] << std::endl;
}
}
- output_data[ph * output_width + pw] = val.Value();
+ output_ptr[ph * output_width + pw] = val.Value();
+ // std::cout << "output[" << ph * output_width + pw << "] = "
+ // << val.Value() << std::endl;
}
}
}
diff --git a/src/operators/math/pooling.h b/src/operators/math/pooling.h
index 9407270a474f0852a2a2d314026cf8b218d34584..9fcfdf811f7ea19f9b6466a8527eccaeab3cf5a8 100644
--- a/src/operators/math/pooling.h
+++ b/src/operators/math/pooling.h
@@ -35,7 +35,7 @@ struct PoolingVal {
float val;
int count;
PoolingVal() {
- val = std::numeric_limits &operator+=(const float &x) {
diff --git a/src/operators/math/pooling3x3.cpp b/src/operators/math/pooling3x3.cpp
index 3918001d76800c1636d569a923468e5ea236873e..8f1defa6cbb222e49257b4400aea3ac4faf40cbd 100644
--- a/src/operators/math/pooling3x3.cpp
+++ b/src/operators/math/pooling3x3.cpp
@@ -161,9 +161,9 @@ struct Pooling3x3 {
x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
- x1.val[1] = vextq_f32(x0.val[0], x0.val[1], 1);
+ x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
- x2.val[1] = vextq_f32(x0.val[0], x0.val[1], 2);
+ x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32 (x0.val[1], x1.val[1]);
y0.val[0] = vPoolPreq_f32 (x0.val[0], x2.val[0]);
@@ -172,9 +172,9 @@ struct Pooling3x3 {
x0.val[0] = vld1q_f32(input_ptr1);
x0.val[1] = vld1q_f32(input_ptr1 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
- x1.val[1] = vextq_f32(x0.val[0], x0.val[1], 1);
+ x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
- x2.val[1] = vextq_f32(x0.val[0], x0.val[1], 2);
+ x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32 (x0.val[1], x1.val[1]);
y1.val[0] = vPoolPreq_f32 (x0.val[0], x2.val[0]);
@@ -185,9 +185,9 @@ struct Pooling3x3 {
x0.val[0] = vld1q_f32(input_ptr2);
x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
- x1.val[1] = vextq_f32(x0.val[0], x0.val[1], 1);
+ x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
- x2.val[1] = vextq_f32(x0.val[0], x0.val[1], 2);
+ x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32 (x0.val[1], x1.val[1]);
y2.val[0] = vPoolPreq_f32 (x0.val[0], x2.val[0]);
@@ -204,9 +204,9 @@ struct Pooling3x3 {
x0.val[0] = vld1q_f32(input_ptr3);
x0.val[1] = vld1q_f32(input_ptr3 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
- x1.val[1] = vextq_f32(x0.val[0], x0.val[1], 1);
+ x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
- x2.val[1] = vextq_f32(x0.val[0], x0.val[1], 2);
+ x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32 (x0.val[1], x1.val[1]);
y0.val[0] = vPoolPreq_f32 (x0.val[0], x2.val[0]);
@@ -223,9 +223,9 @@ struct Pooling3x3 {
x0.val[0] = vld1q_f32(input_ptr4);
x0.val[1] = vld1q_f32(input_ptr4 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
- x1.val[1] = vextq_f32(x0.val[0], x0.val[1], 1);
+ x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
- x2.val[1] = vextq_f32(x0.val[0], x0.val[1], 2);
+ x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32 (x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x2.val[0]);
@@ -242,9 +242,9 @@ struct Pooling3x3 {
x0.val[0] = vld1q_f32(input_ptr5);
x0.val[1] = vld1q_f32(input_ptr5 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
- x1.val[1] = vextq_f32(x0.val[0], x0.val[1], 1);
+ x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
- x2.val[1] = vextq_f32(x0.val[0], x0.val[1], 2);
+ x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32 (x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x2.val[0]);
@@ -270,12 +270,14 @@ struct Pooling3x3 {
// remain w
if (remain >= 4) {
x0.val[0] = vld1q_f32(input_ptr0);
+ x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x1.val[0]);
y0.val[0] = vPoolPreq_f32 (x0.val[0], x2.val[0]);
x0.val[0] = vld1q_f32(input_ptr1);
+ x0.val[1] = vld1q_f32(input_ptr1 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x1.val[0]);
@@ -283,6 +285,7 @@ struct Pooling3x3 {
y0.val[0] = vPoolPreq_f32 (y1.val[0], y0.val[0]);
x0.val[0] = vld1q_f32(input_ptr2);
+ x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x1.val[0]);
@@ -293,6 +296,7 @@ struct Pooling3x3 {
vst1q_f32(output_ptr0, y0.val[0]);
x0.val[0] = vld1q_f32(input_ptr3);
+ x0.val[1] = vld1q_f32(input_ptr3 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x1.val[0]);
@@ -303,6 +307,7 @@ struct Pooling3x3 {
vst1q_f32(output_ptr1, y1.val[0]);
x0.val[0] = vld1q_f32(input_ptr4);
+ x0.val[1] = vld1q_f32(input_ptr4 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x1.val[0]);
@@ -313,6 +318,7 @@ struct Pooling3x3 {
vst1q_f32(output_ptr2, y2.val[0]);
x0.val[0] = vld1q_f32(input_ptr5);
+ x0.val[1] = vld1q_f32(input_ptr5 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x1.val[0]);
@@ -372,9 +378,9 @@ struct Pooling3x3 {
x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
- x1.val[1] = vextq_f32(x0.val[0], x0.val[1], 1);
+ x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
- x2.val[1] = vextq_f32(x0.val[0], x0.val[1], 2);
+ x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32 (x0.val[1], x1.val[1]);
y0.val[0] = vPoolPreq_f32 (x0.val[0], x2.val[0]);
@@ -383,22 +389,26 @@ struct Pooling3x3 {
x0.val[0] = vld1q_f32(input_ptr1);
x0.val[1] = vld1q_f32(input_ptr1 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
- x1.val[1] = vextq_f32(x0.val[0], x0.val[1], 1);
+ x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
- x2.val[1] = vextq_f32(x0.val[0], x0.val[1], 2);
+ x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32 (x0.val[1], x1.val[1]);
+ x0.val[0] = vPoolPreq_f32 (x0.val[0], x2.val[0]);
+ x0.val[1] = vPoolPreq_f32 (x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32 (x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32 (x0.val[1], y0.val[1]);
x0.val[0] = vld1q_f32(input_ptr2);
x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
- x1.val[1] = vextq_f32(x0.val[0], x0.val[1], 1);
+ x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
- x2.val[1] = vextq_f32(x0.val[0], x0.val[1], 2);
+ x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32 (x0.val[1], x1.val[1]);
+ x0.val[0] = vPoolPreq_f32 (x0.val[0], x2.val[0]);
+ x0.val[1] = vPoolPreq_f32 (x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32 (x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32 (x0.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32 (y0.val[0]);
@@ -414,21 +424,26 @@ struct Pooling3x3 {
// remain w
if (remain >= 4) {
x0.val[0] = vld1q_f32(input_ptr0);
+ x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x1.val[0]);
y0.val[0] = vPoolPreq_f32 (x0.val[0], x2.val[0]);
x0.val[0] = vld1q_f32(input_ptr1);
+ x0.val[1] = vld1q_f32(input_ptr1 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x1.val[0]);
+ x0.val[0] = vPoolPreq_f32 (x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32 (x0.val[0], y0.val[0]);
x0.val[0] = vld1q_f32(input_ptr2);
+ x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x1.val[0]);
+ x0.val[0] = vPoolPreq_f32 (x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32 (x0.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32 (y0.val[0]);
vst1q_f32(output_ptr0, y0.val[0]);
@@ -540,6 +555,8 @@ struct Pooling3x3 {
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32 (x1.val[0], x1.val[1]);
+ x0.val[0] = vPoolPreq_f32 (x0.val[0], x2.val[0]);
+ x0.val[1] = vPoolPreq_f32 (x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32 (x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32 (x0.val[1], y0.val[1]);
@@ -616,6 +633,7 @@ struct Pooling3x3 {
input_ptr3 += 12;
input_ptr4 += 12;
input_ptr5 += 12;
+ input_ptr6 += 12;
output_ptr0 += 6;
output_ptr1 += 6;
output_ptr2 += 6;
@@ -632,6 +650,7 @@ struct Pooling3x3 {
x1.val[0] = vdupq_n_f32(input_ptr1[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x0.val[1]);
+ x0.val[0] = vPoolPreq_f32 (x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32 (x0.val[0], y0.val[0]);
x0 = vld2q_f32(input_ptr2);
@@ -681,6 +700,7 @@ struct Pooling3x3 {
input_ptr3 += 8;
input_ptr4 += 8;
input_ptr5 += 8;
+ input_ptr6 += 8;
output_ptr0 += 4;
output_ptr1 += 4;
output_ptr2 += 4;
@@ -738,6 +758,8 @@ struct Pooling3x3 {
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32 (x1.val[0], x1.val[1]);
+ x0.val[0] = vPoolPreq_f32 (x0.val[0], x2.val[0]);
+ x0.val[1] = vPoolPreq_f32 (x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32 (x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32 (x0.val[1], y0.val[1]);
@@ -773,6 +795,7 @@ struct Pooling3x3 {
x1.val[0] = vdupq_n_f32(input_ptr1[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32 (x0.val[0], x0.val[1]);
+ x0.val[0] = vPoolPreq_f32 (x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32 (x0.val[0], y0.val[0]);
x0 = vld2q_f32(input_ptr2);
diff --git a/test/operators/test_pool_op.cpp b/test/operators/test_pool_op.cpp
index ae5ff9d3f719c259308950db79706dcf54bc62c2..b38123a7d623d614fb04077daf351cad7d1aa8f2 100644
--- a/test/operators/test_pool_op.cpp
+++ b/test/operators/test_pool_op.cpp
@@ -21,20 +21,7 @@ namespace paddle_mobile {
namespace math = operators::math;
-static int PoolOutputSize(int input_size, int filter_size, int padding,
- int stride, bool ceil_mode) {
- int output_size;
- if (!ceil_mode) {
- output_size = (input_size - filter_size + 2 * padding) / stride + 1;
- } else {
- output_size =
- (input_size - filter_size + 2 * padding + stride - 1) / stride + 1;
- }
- return output_size;
-}
-
-template