提交 f7eb7352 编写于 作者: H hjchen2

Change 'val * (1.f / count)' to 'val / count' to fix average pooling calculation precision

上级 32917513
......@@ -25,6 +25,7 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
void ConvAddBasic(const FusionConvAddParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
......@@ -106,9 +107,9 @@ void ConvAddBasic(const FusionConvAddParam<CPU> &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(1), false, biase_data);
math::matmul<float, float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(1), false, biase_data);
}
}
}
......
......@@ -25,15 +25,15 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <typename P, typename S>
template <typename Itype, typename Otype>
void ConvAddReluCompute(const FusionConvAddReluParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor bias = *param.Bias();
int32_t axis = param.Axis();
S *bias_data = bias.data<S>();
Otype *bias_data = bias.data<Otype>();
Tensor *output = param.Output();
output->mutable_data<P>();
output->mutable_data<Otype>();
float alpha = 1.0f;
float beta = 1.0f;
......@@ -64,7 +64,7 @@ void ConvAddReluCompute(const FusionConvAddReluParam<CPU> &param) {
Tensor col;
Tensor col_matrix;
if (is_expand) {
col.mutable_data<P>(col_shape);
col.mutable_data<Itype>(col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
......@@ -83,8 +83,8 @@ void ConvAddReluCompute(const FusionConvAddReluParam<CPU> &param) {
int32_t in_step = static_cast<int32_t>(input->dims()[1]) / groups;
int32_t out_step = static_cast<int32_t>(output->dims()[1]) / groups;
math::Vol2ColFunctor<CPU, P> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, P> im2col;
math::Vol2ColFunctor<CPU, Itype> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, Itype> im2col;
for (int32_t i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
......@@ -112,8 +112,8 @@ void ConvAddReluCompute(const FusionConvAddReluParam<CPU> &param) {
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul(filter_slice, false, col_matrix, false, alpha, &out_slice,
beta, true, bias_data);
math::matmul<Itype, Otype>(filter_slice, false, col_matrix, false, alpha,
&out_slice, beta, true, bias_data);
}
}
}
......
......@@ -106,9 +106,10 @@ inline void GemmConv(const ConvParam<CPU> &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice, static_cast<float>(0),
false, static_cast<Otype *>(nullptr));
math::matmul<Itype, Otype>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0), false,
static_cast<Otype *>(nullptr));
}
}
}
......
......@@ -93,8 +93,8 @@ void ConvTransposeCompute(const ConvTransposeParam<CPU> &param) {
Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
Tensor out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step);
math::matmul(filter_slice, true, in_slice, false, static_cast<P>(1.0),
&col_matrix, static_cast<P>(0.0));
math::matmul<P, P>(filter_slice, true, in_slice, false,
static_cast<P>(1.0), &col_matrix, static_cast<P>(0.0));
if (data_dim == 2U) {
col2im(col, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
......
......@@ -23,20 +23,16 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <typename P, typename S>
template <typename Itype, typename Otype>
void FusionFcCompute(const FusionFcParam<CPU> &param) {
const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY();
Tensor *input_z = param.InputZ();
S *input_z_data = input_z->data<S>();
Otype *input_z_data = input_z->data<Otype>();
int axis = param.Axis();
Tensor *out = param.Out();
// int m = out->dims()[0];
// int n = out->dims()[1];
auto *out_data = out->mutable_data<P>();
auto *out_data = out->mutable_data<Itype>();
float alpha = 1.0f;
float beta = 1.0f;
const Tensor x_matrix =
input_x->dims().size() > 2
? framework::ReshapeToMatrix(*input_x, param.XNumColDims())
......@@ -59,11 +55,11 @@ void FusionFcCompute(const FusionFcParam<CPU> &param) {
// bias_data的维度和out的第二个维度一致
int64_t classes = input_z->numel();
for (int i = 0; i < out_dim[0]; i++) {
memory::Copy(out_data + i * classes, input_z_data, sizeof(float) * classes);
memory::Copy(out_data + i * classes, input_z_data, sizeof(Otype) * classes);
}
math::matmul<float>(x_matrix, false, y_matrix, false, alpha, out, beta,
false);
math::matmul<Itype, Otype>(x_matrix, false, y_matrix, false,
static_cast<float>(1), out, static_cast<float>(1),
false);
}
} // namespace operators
......
......@@ -73,14 +73,14 @@ void MulCompute(const MulParam<CPU> &param) {
}
if (param.InputX()->type() == typeid(int8_t)) {
out->mutable_data<int32_t>();
math::matmul<float, int32_t>(x_matrix, false, y_matrix, false,
static_cast<float>(1), out,
static_cast<float>(0));
math::matmul<int8_t, int32_t>(x_matrix, false, y_matrix, false,
static_cast<float>(1), out,
static_cast<float>(0));
} else {
out->mutable_data<float>();
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(0));
math::matmul<float, float>(x_matrix, false, y_matrix, false,
static_cast<float>(1), out,
static_cast<float>(0));
}
if (out_dim.size() != 2) {
out->Resize(out_dim);
......
......@@ -41,10 +41,10 @@ void set_constant(framework::Tensor *tensor, float value) {
}
template <>
void matmul<float>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu,
float *bias) {
void matmul<float, float>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
float alpha, framework::Tensor *matrix_out,
float beta, bool relu, float *bias) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
......
......@@ -24,24 +24,24 @@ namespace math {
void set_constant(framework::Tensor *tensor, float value);
template <typename T>
template <typename Itype, typename Otype>
void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu = false,
float *bias = nullptr);
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu = false,
Otype *bias = nullptr);
template <typename T, typename S>
template <typename Itype, typename Otype>
void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu = false,
S *bias = nullptr, bool addOnRow = false);
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu, Otype *bias,
bool addOnRow);
template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu,
framework::Tensor *new_scale, framework::Tensor *new_bias,
int group, float *bias = nullptr);
int group, T *bias = nullptr);
void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
......
......@@ -22,10 +22,11 @@ namespace operators {
namespace math {
template <>
void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu, int32_t *bias,
bool addOnRow) {
void matmul<int8_t, int32_t>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
float alpha, framework::Tensor *matrix_out,
float beta, bool relu, int32_t *bias,
bool addOnRow) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
......@@ -93,6 +94,16 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a,
#endif
}
}
template <>
void matmul<int8_t, int32_t>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
float alpha, framework::Tensor *matrix_out,
float beta, bool relu, int32_t *bias) {
matmul<int8_t, int32_t>(matrix_a, trans_a, matrix_b, trans_b, alpha,
matrix_out, beta, relu, bias, false);
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -53,7 +53,7 @@ struct PoolingVal<Avg> {
++count;
return *this;
}
inline float Value() { return (count > 0) ? val * (1.f / count) : 0.f; }
inline float Value() { return (count > 0) ? val / count : 0.f; }
};
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
......
......@@ -288,360 +288,363 @@ struct Pooling3x3<P, 1> {
int valid_w_end = valid_w_start + valid_w;
float avg = 1.f / 9;
#pragma omp parallel for
for (int c = 0; c < output->dims()[1]; ++c) {
const float *input_ptr = input_data + c * image_size;
float *output_ptr = output_data + c * out_image_size;
// top
for (int h = 0; h < valid_h_start; ++h) {
Pooling3x3NormalRow<P, 1>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
}
// left
for (int w = 0; w < valid_w_start; ++w) {
Pooling3x3ValidCol<P, 1>(input_ptr, valid_h_start, valid_h_end, w,
input_h, input_w, padding_h, padding_w,
output_w, output_ptr);
}
// right
for (int w = valid_w_end; w < output_w; ++w) {
Pooling3x3ValidCol<P, 1>(input_ptr, valid_h_start, valid_h_end, w,
input_h, input_w, padding_h, padding_w,
output_w, output_ptr);
}
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
Pooling3x3NormalRow<P, 1>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
}
// valid
int output_w_tiles = valid_w / 6;
int output_w_remain = valid_w - output_w_tiles * 6;
for (int h = valid_h_start; h < valid_h_end - 3; h += 4) {
const float *input_ptr0 = input_ptr + (h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
const float *input_ptr3 = input_ptr2 + input_w;
const float *input_ptr4 = input_ptr3 + input_w;
const float *input_ptr5 = input_ptr4 + input_w;
float *output_ptr0 = output_ptr + h * output_w + valid_w_start;
float *output_ptr1 = output_ptr0 + output_w;
float *output_ptr2 = output_ptr1 + output_w;
float *output_ptr3 = output_ptr2 + output_w;
int remain = output_w_remain;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
float32x4x2_t x0, x1, x2;
float32x4x2_t y0, y1, y2;
float32x4_t post = vdupq_n_f32(1.f / 9);
for (int loop = 0; loop < output_w_tiles; ++loop) {
x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
x0.val[0] = vld1q_f32(input_ptr1);
x0.val[1] = vld1q_f32(input_ptr1 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y1.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y1.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(y1.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(y1.val[1], y0.val[1]);
x0.val[0] = vld1q_f32(input_ptr2);
x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y2.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y2.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y1.val[0] = vPoolPreq_f32<P>(y2.val[0], y1.val[0]);
y1.val[1] = vPoolPreq_f32<P>(y2.val[1], y1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(y2.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(y2.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr0, y0.val[0]);
vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1]));
x0.val[0] = vld1q_f32(input_ptr3);
x0.val[1] = vld1q_f32(input_ptr3 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y1.val[0] = vPoolPreq_f32<P>(y0.val[0], y1.val[0]);
y1.val[1] = vPoolPreq_f32<P>(y0.val[1], y1.val[1]);
y2.val[0] = vPoolPreq_f32<P>(y0.val[0], y2.val[0]);
y2.val[1] = vPoolPreq_f32<P>(y0.val[1], y2.val[1]);
y1.val[0] = vPoolPostq_f32<P>(y1.val[0], post);
y1.val[1] = vPoolPostq_f32<P>(y1.val[1], post);
vst1q_f32(output_ptr1, y1.val[0]);
vst1_f32(output_ptr1 + 4, vget_low_f32(y1.val[1]));
x0.val[0] = vld1q_f32(input_ptr4);
x0.val[1] = vld1q_f32(input_ptr4 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y2.val[0] = vPoolPreq_f32<P>(x0.val[0], y2.val[0]);
y2.val[1] = vPoolPreq_f32<P>(x0.val[1], y2.val[1]);
y2.val[0] = vPoolPostq_f32<P>(y2.val[0], post);
y2.val[1] = vPoolPostq_f32<P>(y2.val[1], post);
vst1q_f32(output_ptr2, y2.val[0]);
vst1_f32(output_ptr2 + 4, vget_low_f32(y2.val[1]));
x0.val[0] = vld1q_f32(input_ptr5);
x0.val[1] = vld1q_f32(input_ptr5 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr3, y0.val[0]);
vst1_f32(output_ptr3 + 4, vget_low_f32(y0.val[1]));
input_ptr0 += 6;
input_ptr1 += 6;
input_ptr2 += 6;
input_ptr3 += 6;
input_ptr4 += 6;
input_ptr5 += 6;
output_ptr0 += 6;
output_ptr1 += 6;
output_ptr2 += 6;
output_ptr3 += 6;
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < output->dims()[0]; ++batch) {
for (int c = 0; c < output->dims()[1]; ++c) {
int channel = batch * output->dims()[1] + c;
const float *input_ptr = input_data + channel * image_size;
float *output_ptr = output_data + channel * out_image_size;
// top
for (int h = 0; h < valid_h_start; ++h) {
Pooling3x3NormalRow<P, 1>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
}
// remain w
if (remain >= 4) {
x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[0] = vld1q_f32(input_ptr1);
x0.val[1] = vld1q_f32(input_ptr1 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
y1.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(y1.val[0], y0.val[0]);
x0.val[0] = vld1q_f32(input_ptr2);
x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
y2.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y1.val[0] = vPoolPreq_f32<P>(y2.val[0], y1.val[0]);
y0.val[0] = vPoolPreq_f32<P>(y2.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr0, y0.val[0]);
x0.val[0] = vld1q_f32(input_ptr3);
x0.val[1] = vld1q_f32(input_ptr3 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y1.val[0] = vPoolPreq_f32<P>(y0.val[0], y1.val[0]);
y2.val[0] = vPoolPreq_f32<P>(y0.val[0], y2.val[0]);
y1.val[0] = vPoolPostq_f32<P>(y1.val[0], post);
vst1q_f32(output_ptr1, y1.val[0]);
x0.val[0] = vld1q_f32(input_ptr4);
x0.val[1] = vld1q_f32(input_ptr4 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y2.val[0] = vPoolPreq_f32<P>(x0.val[0], y2.val[0]);
y2.val[0] = vPoolPostq_f32<P>(y2.val[0], post);
vst1q_f32(output_ptr2, y2.val[0]);
x0.val[0] = vld1q_f32(input_ptr5);
x0.val[1] = vld1q_f32(input_ptr5 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr3, y0.val[0]);
input_ptr0 += 4;
input_ptr1 += 4;
input_ptr2 += 4;
input_ptr3 += 4;
input_ptr4 += 4;
input_ptr5 += 4;
output_ptr0 += 4;
output_ptr1 += 4;
output_ptr2 += 4;
output_ptr3 += 4;
remain -= 4;
// left
for (int w = 0; w < valid_w_start; ++w) {
Pooling3x3ValidCol<P, 1>(input_ptr, valid_h_start, valid_h_end, w,
input_h, input_w, padding_h, padding_w,
output_w, output_ptr);
}
#endif // __ARM_NEON__
for (int r = 0; r < remain; ++r) {
float m0 = PoolPre<P>(input_ptr0[r], input_ptr0[r + 1]);
m0 = PoolPre<P>(m0, input_ptr0[r + 2]);
float m1 = PoolPre<P>(input_ptr1[r], input_ptr1[r + 1]);
m1 = PoolPre<P>(m1, input_ptr1[r + 2]);
float m2 = PoolPre<P>(input_ptr2[r], input_ptr2[r + 1]);
m2 = PoolPre<P>(m2, input_ptr2[r + 2]);
float m3 = PoolPre<P>(input_ptr3[r], input_ptr3[r + 1]);
m3 = PoolPre<P>(m3, input_ptr3[r + 2]);
float m4 = PoolPre<P>(input_ptr4[r], input_ptr4[r + 1]);
m4 = PoolPre<P>(m4, input_ptr4[r + 2]);
float m5 = PoolPre<P>(input_ptr5[r], input_ptr5[r + 1]);
m5 = PoolPre<P>(m5, input_ptr5[r + 2]);
m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2);
m1 = PoolPre<P>(PoolPre<P>(m1, m2), m3);
m2 = PoolPre<P>(PoolPre<P>(m2, m3), m4);
m3 = PoolPre<P>(PoolPre<P>(m3, m4), m5);
output_ptr0[r] = PoolPost<P>(m0, avg);
output_ptr1[r] = PoolPost<P>(m1, avg);
output_ptr2[r] = PoolPost<P>(m2, avg);
output_ptr3[r] = PoolPost<P>(m3, avg);
// right
for (int w = valid_w_end; w < output_w; ++w) {
Pooling3x3ValidCol<P, 1>(input_ptr, valid_h_start, valid_h_end, w,
input_h, input_w, padding_h, padding_w,
output_w, output_ptr);
}
}
// remain h
int start_h = valid_h_start + (valid_h & 0xFFFC);
for (int h = start_h; h < valid_h_end; ++h) {
const float *input_ptr0 = input_ptr + (h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
float *output_ptr0 = output_ptr + h * output_w + valid_w_start;
int remain = output_w_remain;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
float32x4x2_t x0, x1, x2, y0;
float32x4_t post = vdupq_n_f32(1.f / 9);
for (int loop = 0; loop < output_w_tiles; ++loop) {
x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
x0.val[0] = vld1q_f32(input_ptr1);
x0.val[1] = vld1q_f32(input_ptr1 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
x0.val[0] = vld1q_f32(input_ptr2);
x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr0, y0.val[0]);
vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1]));
input_ptr0 += 6;
input_ptr1 += 6;
input_ptr2 += 6;
output_ptr0 += 6;
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
Pooling3x3NormalRow<P, 1>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
}
// remain w
if (remain >= 4) {
x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[0] = vld1q_f32(input_ptr1);
x0.val[1] = vld1q_f32(input_ptr1 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
x0.val[0] = vld1q_f32(input_ptr2);
x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr0, y0.val[0]);
input_ptr0 += 4;
input_ptr1 += 4;
input_ptr2 += 4;
output_ptr0 += 4;
remain -= 4;
// valid
int output_w_tiles = valid_w / 6;
int output_w_remain = valid_w - output_w_tiles * 6;
for (int h = valid_h_start; h < valid_h_end - 3; h += 4) {
const float *input_ptr0 = input_ptr + (h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
const float *input_ptr3 = input_ptr2 + input_w;
const float *input_ptr4 = input_ptr3 + input_w;
const float *input_ptr5 = input_ptr4 + input_w;
float *output_ptr0 = output_ptr + h * output_w + valid_w_start;
float *output_ptr1 = output_ptr0 + output_w;
float *output_ptr2 = output_ptr1 + output_w;
float *output_ptr3 = output_ptr2 + output_w;
int remain = output_w_remain;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
float32x4x2_t x0, x1, x2;
float32x4x2_t y0, y1, y2;
float32x4_t post = vdupq_n_f32(1.f / 9);
for (int loop = 0; loop < output_w_tiles; ++loop) {
x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
x0.val[0] = vld1q_f32(input_ptr1);
x0.val[1] = vld1q_f32(input_ptr1 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y1.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y1.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(y1.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(y1.val[1], y0.val[1]);
x0.val[0] = vld1q_f32(input_ptr2);
x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y2.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y2.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y1.val[0] = vPoolPreq_f32<P>(y2.val[0], y1.val[0]);
y1.val[1] = vPoolPreq_f32<P>(y2.val[1], y1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(y2.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(y2.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr0, y0.val[0]);
vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1]));
x0.val[0] = vld1q_f32(input_ptr3);
x0.val[1] = vld1q_f32(input_ptr3 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y1.val[0] = vPoolPreq_f32<P>(y0.val[0], y1.val[0]);
y1.val[1] = vPoolPreq_f32<P>(y0.val[1], y1.val[1]);
y2.val[0] = vPoolPreq_f32<P>(y0.val[0], y2.val[0]);
y2.val[1] = vPoolPreq_f32<P>(y0.val[1], y2.val[1]);
y1.val[0] = vPoolPostq_f32<P>(y1.val[0], post);
y1.val[1] = vPoolPostq_f32<P>(y1.val[1], post);
vst1q_f32(output_ptr1, y1.val[0]);
vst1_f32(output_ptr1 + 4, vget_low_f32(y1.val[1]));
x0.val[0] = vld1q_f32(input_ptr4);
x0.val[1] = vld1q_f32(input_ptr4 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y2.val[0] = vPoolPreq_f32<P>(x0.val[0], y2.val[0]);
y2.val[1] = vPoolPreq_f32<P>(x0.val[1], y2.val[1]);
y2.val[0] = vPoolPostq_f32<P>(y2.val[0], post);
y2.val[1] = vPoolPostq_f32<P>(y2.val[1], post);
vst1q_f32(output_ptr2, y2.val[0]);
vst1_f32(output_ptr2 + 4, vget_low_f32(y2.val[1]));
x0.val[0] = vld1q_f32(input_ptr5);
x0.val[1] = vld1q_f32(input_ptr5 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr3, y0.val[0]);
vst1_f32(output_ptr3 + 4, vget_low_f32(y0.val[1]));
input_ptr0 += 6;
input_ptr1 += 6;
input_ptr2 += 6;
input_ptr3 += 6;
input_ptr4 += 6;
input_ptr5 += 6;
output_ptr0 += 6;
output_ptr1 += 6;
output_ptr2 += 6;
output_ptr3 += 6;
}
// remain width
if (remain >= 4) {
x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[0] = vld1q_f32(input_ptr1);
x0.val[1] = vld1q_f32(input_ptr1 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
y1.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(y1.val[0], y0.val[0]);
x0.val[0] = vld1q_f32(input_ptr2);
x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
y2.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y1.val[0] = vPoolPreq_f32<P>(y2.val[0], y1.val[0]);
y0.val[0] = vPoolPreq_f32<P>(y2.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr0, y0.val[0]);
x0.val[0] = vld1q_f32(input_ptr3);
x0.val[1] = vld1q_f32(input_ptr3 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y1.val[0] = vPoolPreq_f32<P>(y0.val[0], y1.val[0]);
y2.val[0] = vPoolPreq_f32<P>(y0.val[0], y2.val[0]);
y1.val[0] = vPoolPostq_f32<P>(y1.val[0], post);
vst1q_f32(output_ptr1, y1.val[0]);
x0.val[0] = vld1q_f32(input_ptr4);
x0.val[1] = vld1q_f32(input_ptr4 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y2.val[0] = vPoolPreq_f32<P>(x0.val[0], y2.val[0]);
y2.val[0] = vPoolPostq_f32<P>(y2.val[0], post);
vst1q_f32(output_ptr2, y2.val[0]);
x0.val[0] = vld1q_f32(input_ptr5);
x0.val[1] = vld1q_f32(input_ptr5 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr3, y0.val[0]);
input_ptr0 += 4;
input_ptr1 += 4;
input_ptr2 += 4;
input_ptr3 += 4;
input_ptr4 += 4;
input_ptr5 += 4;
output_ptr0 += 4;
output_ptr1 += 4;
output_ptr2 += 4;
output_ptr3 += 4;
remain -= 4;
}
#endif // __ARM_NEON__
for (int r = 0; r < remain; ++r) {
float m0 = PoolPre<P>(input_ptr0[r], input_ptr0[r + 1]);
m0 = PoolPre<P>(m0, input_ptr0[r + 2]);
float m1 = PoolPre<P>(input_ptr1[r], input_ptr1[r + 1]);
m1 = PoolPre<P>(m1, input_ptr1[r + 2]);
float m2 = PoolPre<P>(input_ptr2[r], input_ptr2[r + 1]);
m2 = PoolPre<P>(m2, input_ptr2[r + 2]);
float m3 = PoolPre<P>(input_ptr3[r], input_ptr3[r + 1]);
m3 = PoolPre<P>(m3, input_ptr3[r + 2]);
float m4 = PoolPre<P>(input_ptr4[r], input_ptr4[r + 1]);
m4 = PoolPre<P>(m4, input_ptr4[r + 2]);
float m5 = PoolPre<P>(input_ptr5[r], input_ptr5[r + 1]);
m5 = PoolPre<P>(m5, input_ptr5[r + 2]);
m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2);
m1 = PoolPre<P>(PoolPre<P>(m1, m2), m3);
m2 = PoolPre<P>(PoolPre<P>(m2, m3), m4);
m3 = PoolPre<P>(PoolPre<P>(m3, m4), m5);
output_ptr0[r] = PoolPost<P>(m0, avg);
output_ptr1[r] = PoolPost<P>(m1, avg);
output_ptr2[r] = PoolPost<P>(m2, avg);
output_ptr3[r] = PoolPost<P>(m3, avg);
}
}
// remain height
int start_h = valid_h_start + (valid_h & 0xFFFC);
for (int h = start_h; h < valid_h_end; ++h) {
const float *input_ptr0 = input_ptr + (h - padding_h) * input_w;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
float *output_ptr0 = output_ptr + h * output_w + valid_w_start;
int remain = output_w_remain;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
float32x4x2_t x0, x1, x2, y0;
float32x4_t post = vdupq_n_f32(1.f / 9);
for (int loop = 0; loop < output_w_tiles; ++loop) {
x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
x0.val[0] = vld1q_f32(input_ptr1);
x0.val[1] = vld1q_f32(input_ptr1 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
x0.val[0] = vld1q_f32(input_ptr2);
x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x1.val[1] = vextq_f32(x0.val[1], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x2.val[1] = vextq_f32(x0.val[1], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr0, y0.val[0]);
vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1]));
input_ptr0 += 6;
input_ptr1 += 6;
input_ptr2 += 6;
output_ptr0 += 6;
}
// remain width
if (remain >= 4) {
x0.val[0] = vld1q_f32(input_ptr0);
x0.val[1] = vld1q_f32(input_ptr0 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[0] = vld1q_f32(input_ptr1);
x0.val[1] = vld1q_f32(input_ptr1 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
x0.val[0] = vld1q_f32(input_ptr2);
x0.val[1] = vld1q_f32(input_ptr2 + 4);
x1.val[0] = vextq_f32(x0.val[0], x0.val[1], 1);
x2.val[0] = vextq_f32(x0.val[0], x0.val[1], 2);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x1.val[0]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr0, y0.val[0]);
input_ptr0 += 4;
input_ptr1 += 4;
input_ptr2 += 4;
output_ptr0 += 4;
remain -= 4;
}
#endif // __ARM_NEON__
for (int r = 0; r < remain; ++r) {
float m0 = PoolPre<P>(input_ptr0[r], input_ptr0[r + 1]);
m0 = PoolPre<P>(m0, input_ptr0[r + 2]);
float m1 = PoolPre<P>(input_ptr1[r], input_ptr1[r + 1]);
m1 = PoolPre<P>(m1, input_ptr1[r + 2]);
float m2 = PoolPre<P>(input_ptr2[r], input_ptr2[r + 1]);
m2 = PoolPre<P>(m2, input_ptr2[r + 2]);
m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2);
output_ptr0[r] = PoolPost<P>(m0, avg);
for (int r = 0; r < remain; ++r) {
float m0 = PoolPre<P>(input_ptr0[r], input_ptr0[r + 1]);
m0 = PoolPre<P>(m0, input_ptr0[r + 2]);
float m1 = PoolPre<P>(input_ptr1[r], input_ptr1[r + 1]);
m1 = PoolPre<P>(m1, input_ptr1[r + 2]);
float m2 = PoolPre<P>(input_ptr2[r], input_ptr2[r + 1]);
m2 = PoolPre<P>(m2, input_ptr2[r + 2]);
m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2);
output_ptr0[r] = PoolPost<P>(m0, avg);
}
}
}
}
......@@ -671,339 +674,342 @@ struct Pooling3x3<P, 2> {
int valid_w_end = valid_w_start + valid_w;
float avg = 1.f / 9;
#pragma omp parallel for
for (int c = 0; c < output->dims()[1]; ++c) {
const float *input_ptr = input_data + c * image_size;
float *output_ptr = output_data + c * out_image_size;
// top
for (int h = 0; h < valid_h_start; ++h) {
Pooling3x3NormalRow<P, 2>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
}
// left
for (int w = 0; w < valid_w_start; ++w) {
Pooling3x3ValidCol<P, 2>(input_ptr, valid_h_start, valid_h_end, w,
input_h, input_w, padding_h, padding_w,
output_w, output_ptr);
}
// right
for (int w = valid_w_end; w < output_w; ++w) {
Pooling3x3ValidCol<P, 2>(input_ptr, valid_h_start, valid_h_end, w,
input_h, input_w, padding_h, padding_w,
output_w, output_ptr);
}
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
Pooling3x3NormalRow<P, 2>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
}
// valid
int input_w_start = 2 * valid_w_start - padding_w;
int output_w_tiles = valid_w / 6;
int output_w_remain = valid_w - output_w_tiles * 6;
for (int h = valid_h_start; h < valid_h_end - 2; h += 3) {
size_t offset = (2 * h - padding_h) * input_w + input_w_start;
const float *input_ptr0 = input_ptr + offset;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
const float *input_ptr3 = input_ptr2 + input_w;
const float *input_ptr4 = input_ptr3 + input_w;
const float *input_ptr5 = input_ptr4 + input_w;
const float *input_ptr6 = input_ptr5 + input_w;
float *output_ptr0 = output_ptr + h * output_w + valid_w_start;
float *output_ptr1 = output_ptr0 + output_w;
float *output_ptr2 = output_ptr1 + output_w;
int remain = output_w_remain;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
float32x4x2_t x0, x1, x2;
float32x4x2_t y0, y1, y2;
float32x4_t post = vdupq_n_f32(1.f / 9);
for (int loop = 0; loop < output_w_tiles; ++loop) {
x0 = vld2q_f32(input_ptr0);
x1 = vld2q_f32(input_ptr0 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
x0 = vld2q_f32(input_ptr1);
x1 = vld2q_f32(input_ptr1 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
x0 = vld2q_f32(input_ptr2);
x1 = vld2q_f32(input_ptr2 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
y1.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y1.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(y1.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(y1.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr0, y0.val[0]);
vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1]));
x0 = vld2q_f32(input_ptr3);
x1 = vld2q_f32(input_ptr3 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y1.val[0] = vPoolPreq_f32<P>(x0.val[0], y1.val[0]);
y1.val[1] = vPoolPreq_f32<P>(x0.val[1], y1.val[1]);
x0 = vld2q_f32(input_ptr4);
x1 = vld2q_f32(input_ptr4 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y1.val[0] = vPoolPreq_f32<P>(y0.val[0], y1.val[0]);
y1.val[1] = vPoolPreq_f32<P>(y0.val[1], y1.val[1]);
y1.val[0] = vPoolPostq_f32<P>(y1.val[0], post);
y1.val[1] = vPoolPostq_f32<P>(y1.val[1], post);
vst1q_f32(output_ptr1, y1.val[0]);
vst1_f32(output_ptr1 + 4, vget_low_f32(y1.val[1]));
x0 = vld2q_f32(input_ptr5);
x1 = vld2q_f32(input_ptr5 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
x0 = vld2q_f32(input_ptr6);
x1 = vld2q_f32(input_ptr6 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr2, y0.val[0]);
vst1_f32(output_ptr2 + 4, vget_low_f32(y0.val[1]));
input_ptr0 += 12;
input_ptr1 += 12;
input_ptr2 += 12;
input_ptr3 += 12;
input_ptr4 += 12;
input_ptr5 += 12;
input_ptr6 += 12;
output_ptr0 += 6;
output_ptr1 += 6;
output_ptr2 += 6;
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < output->dims()[0]; ++batch) {
for (int c = 0; c < output->dims()[1]; ++c) {
int channel = batch * output->dims()[1] + c;
const float *input_ptr = input_data + channel * image_size;
float *output_ptr = output_data + channel * out_image_size;
// top
for (int h = 0; h < valid_h_start; ++h) {
Pooling3x3NormalRow<P, 2>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
}
// remain w
if (remain >= 4) {
x0 = vld2q_f32(input_ptr0);
x1.val[0] = vdupq_n_f32(input_ptr0[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0 = vld2q_f32(input_ptr1);
x1.val[0] = vdupq_n_f32(input_ptr1[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
x0 = vld2q_f32(input_ptr2);
x1.val[0] = vdupq_n_f32(input_ptr2[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
y1.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(y1.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr0, y0.val[0]);
x0 = vld2q_f32(input_ptr3);
x1.val[0] = vdupq_n_f32(input_ptr3[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y1.val[0] = vPoolPreq_f32<P>(x0.val[0], y1.val[0]);
x0 = vld2q_f32(input_ptr4);
x1.val[0] = vdupq_n_f32(input_ptr4[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y1.val[0] = vPoolPreq_f32<P>(y0.val[0], y1.val[0]);
y1.val[0] = vPoolPostq_f32<P>(y1.val[0], post);
vst1q_f32(output_ptr1, y1.val[0]);
x0 = vld2q_f32(input_ptr5);
x1.val[0] = vdupq_n_f32(input_ptr5[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
x0 = vld2q_f32(input_ptr6);
x1.val[0] = vdupq_n_f32(input_ptr6[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr2, y0.val[0]);
input_ptr0 += 8;
input_ptr1 += 8;
input_ptr2 += 8;
input_ptr3 += 8;
input_ptr4 += 8;
input_ptr5 += 8;
input_ptr6 += 8;
output_ptr0 += 4;
output_ptr1 += 4;
output_ptr2 += 4;
remain -= 4;
// left
for (int w = 0; w < valid_w_start; ++w) {
Pooling3x3ValidCol<P, 2>(input_ptr, valid_h_start, valid_h_end, w,
input_h, input_w, padding_h, padding_w,
output_w, output_ptr);
}
#endif // __ARM_NEON__
for (int r = 0; r < remain; ++r) {
float m0 = PoolPre<P>(input_ptr0[2 * r], input_ptr0[2 * r + 1]);
m0 = PoolPre<P>(m0, input_ptr0[2 * r + 2]);
float m1 = PoolPre<P>(input_ptr1[2 * r], input_ptr1[2 * r + 1]);
m1 = PoolPre<P>(m1, input_ptr1[2 * r + 2]);
float m2 = PoolPre<P>(input_ptr2[2 * r], input_ptr2[2 * r + 1]);
m2 = PoolPre<P>(m2, input_ptr2[2 * r + 2]);
float m3 = PoolPre<P>(input_ptr3[2 * r], input_ptr3[2 * r + 1]);
m3 = PoolPre<P>(m3, input_ptr3[2 * r + 2]);
float m4 = PoolPre<P>(input_ptr4[2 * r], input_ptr4[2 * r + 1]);
m4 = PoolPre<P>(m4, input_ptr4[2 * r + 2]);
float m5 = PoolPre<P>(input_ptr5[2 * r], input_ptr5[2 * r + 1]);
m5 = PoolPre<P>(m5, input_ptr5[2 * r + 2]);
float m6 = PoolPre<P>(input_ptr6[2 * r], input_ptr6[2 * r + 1]);
m6 = PoolPre<P>(m6, input_ptr6[2 * r + 2]);
m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2);
m1 = PoolPre<P>(PoolPre<P>(m2, m3), m4);
m2 = PoolPre<P>(PoolPre<P>(m4, m5), m6);
output_ptr0[r] = PoolPost<P>(m0, avg);
output_ptr1[r] = PoolPost<P>(m1, avg);
output_ptr2[r] = PoolPost<P>(m2, avg);
// right
for (int w = valid_w_end; w < output_w; ++w) {
Pooling3x3ValidCol<P, 2>(input_ptr, valid_h_start, valid_h_end, w,
input_h, input_w, padding_h, padding_w,
output_w, output_ptr);
}
}
// remain h
int start_h = valid_h_start + valid_h / 3 * 3;
for (int h = start_h; h < valid_h_end; ++h) {
size_t offset = (2 * h - padding_h) * input_w + input_w_start;
const float *input_ptr0 = input_ptr + offset;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
float *output_ptr0 = output_ptr + h * output_w + valid_w_start;
int remain = output_w_remain;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
float32x4x2_t x0, x1, x2, y0;
float32x4_t post = vdupq_n_f32(1.f / 9);
for (int loop = 0; loop < output_w_tiles; ++loop) {
x0 = vld2q_f32(input_ptr0);
x1 = vld2q_f32(input_ptr0 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
x0 = vld2q_f32(input_ptr1);
x1 = vld2q_f32(input_ptr1 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
x0 = vld2q_f32(input_ptr2);
x1 = vld2q_f32(input_ptr2 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr0, y0.val[0]);
vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1]));
input_ptr0 += 12;
input_ptr1 += 12;
input_ptr2 += 12;
output_ptr0 += 6;
// bottom
for (int h = valid_h_end; h < output_h; ++h) {
Pooling3x3NormalRow<P, 2>(input_ptr, h, input_h, input_w, padding_h,
padding_w, output_w, output_ptr);
}
// remain w
if (remain >= 4) {
x0 = vld2q_f32(input_ptr0);
x1.val[0] = vdupq_n_f32(input_ptr0[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0 = vld2q_f32(input_ptr1);
x1.val[0] = vdupq_n_f32(input_ptr1[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
x0 = vld2q_f32(input_ptr2);
x1.val[0] = vdupq_n_f32(input_ptr2[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr0, y0.val[0]);
input_ptr0 += 8;
input_ptr1 += 8;
input_ptr2 += 8;
output_ptr0 += 4;
remain -= 4;
// valid
int input_w_start = 2 * valid_w_start - padding_w;
int output_w_tiles = valid_w / 6;
int output_w_remain = valid_w - output_w_tiles * 6;
for (int h = valid_h_start; h < valid_h_end - 2; h += 3) {
size_t offset = (2 * h - padding_h) * input_w + input_w_start;
const float *input_ptr0 = input_ptr + offset;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
const float *input_ptr3 = input_ptr2 + input_w;
const float *input_ptr4 = input_ptr3 + input_w;
const float *input_ptr5 = input_ptr4 + input_w;
const float *input_ptr6 = input_ptr5 + input_w;
float *output_ptr0 = output_ptr + h * output_w + valid_w_start;
float *output_ptr1 = output_ptr0 + output_w;
float *output_ptr2 = output_ptr1 + output_w;
int remain = output_w_remain;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
float32x4x2_t x0, x1, x2;
float32x4x2_t y0, y1, y2;
float32x4_t post = vdupq_n_f32(1.f / 9);
for (int loop = 0; loop < output_w_tiles; ++loop) {
x0 = vld2q_f32(input_ptr0);
x1 = vld2q_f32(input_ptr0 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
x0 = vld2q_f32(input_ptr1);
x1 = vld2q_f32(input_ptr1 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
x0 = vld2q_f32(input_ptr2);
x1 = vld2q_f32(input_ptr2 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
y1.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y1.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(y1.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(y1.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr0, y0.val[0]);
vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1]));
x0 = vld2q_f32(input_ptr3);
x1 = vld2q_f32(input_ptr3 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y1.val[0] = vPoolPreq_f32<P>(x0.val[0], y1.val[0]);
y1.val[1] = vPoolPreq_f32<P>(x0.val[1], y1.val[1]);
x0 = vld2q_f32(input_ptr4);
x1 = vld2q_f32(input_ptr4 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y1.val[0] = vPoolPreq_f32<P>(y0.val[0], y1.val[0]);
y1.val[1] = vPoolPreq_f32<P>(y0.val[1], y1.val[1]);
y1.val[0] = vPoolPostq_f32<P>(y1.val[0], post);
y1.val[1] = vPoolPostq_f32<P>(y1.val[1], post);
vst1q_f32(output_ptr1, y1.val[0]);
vst1_f32(output_ptr1 + 4, vget_low_f32(y1.val[1]));
x0 = vld2q_f32(input_ptr5);
x1 = vld2q_f32(input_ptr5 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
x0 = vld2q_f32(input_ptr6);
x1 = vld2q_f32(input_ptr6 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr2, y0.val[0]);
vst1_f32(output_ptr2 + 4, vget_low_f32(y0.val[1]));
input_ptr0 += 12;
input_ptr1 += 12;
input_ptr2 += 12;
input_ptr3 += 12;
input_ptr4 += 12;
input_ptr5 += 12;
input_ptr6 += 12;
output_ptr0 += 6;
output_ptr1 += 6;
output_ptr2 += 6;
}
// remain width
if (remain >= 4) {
x0 = vld2q_f32(input_ptr0);
x1.val[0] = vdupq_n_f32(input_ptr0[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0 = vld2q_f32(input_ptr1);
x1.val[0] = vdupq_n_f32(input_ptr1[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
x0 = vld2q_f32(input_ptr2);
x1.val[0] = vdupq_n_f32(input_ptr2[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
y1.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(y1.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr0, y0.val[0]);
x0 = vld2q_f32(input_ptr3);
x1.val[0] = vdupq_n_f32(input_ptr3[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y1.val[0] = vPoolPreq_f32<P>(x0.val[0], y1.val[0]);
x0 = vld2q_f32(input_ptr4);
x1.val[0] = vdupq_n_f32(input_ptr4[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y1.val[0] = vPoolPreq_f32<P>(y0.val[0], y1.val[0]);
y1.val[0] = vPoolPostq_f32<P>(y1.val[0], post);
vst1q_f32(output_ptr1, y1.val[0]);
x0 = vld2q_f32(input_ptr5);
x1.val[0] = vdupq_n_f32(input_ptr5[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
x0 = vld2q_f32(input_ptr6);
x1.val[0] = vdupq_n_f32(input_ptr6[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr2, y0.val[0]);
input_ptr0 += 8;
input_ptr1 += 8;
input_ptr2 += 8;
input_ptr3 += 8;
input_ptr4 += 8;
input_ptr5 += 8;
input_ptr6 += 8;
output_ptr0 += 4;
output_ptr1 += 4;
output_ptr2 += 4;
remain -= 4;
}
#endif // __ARM_NEON__
for (int r = 0; r < remain; ++r) {
float m0 = PoolPre<P>(input_ptr0[2 * r], input_ptr0[2 * r + 1]);
m0 = PoolPre<P>(m0, input_ptr0[2 * r + 2]);
float m1 = PoolPre<P>(input_ptr1[2 * r], input_ptr1[2 * r + 1]);
m1 = PoolPre<P>(m1, input_ptr1[2 * r + 2]);
float m2 = PoolPre<P>(input_ptr2[2 * r], input_ptr2[2 * r + 1]);
m2 = PoolPre<P>(m2, input_ptr2[2 * r + 2]);
float m3 = PoolPre<P>(input_ptr3[2 * r], input_ptr3[2 * r + 1]);
m3 = PoolPre<P>(m3, input_ptr3[2 * r + 2]);
float m4 = PoolPre<P>(input_ptr4[2 * r], input_ptr4[2 * r + 1]);
m4 = PoolPre<P>(m4, input_ptr4[2 * r + 2]);
float m5 = PoolPre<P>(input_ptr5[2 * r], input_ptr5[2 * r + 1]);
m5 = PoolPre<P>(m5, input_ptr5[2 * r + 2]);
float m6 = PoolPre<P>(input_ptr6[2 * r], input_ptr6[2 * r + 1]);
m6 = PoolPre<P>(m6, input_ptr6[2 * r + 2]);
m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2);
m1 = PoolPre<P>(PoolPre<P>(m2, m3), m4);
m2 = PoolPre<P>(PoolPre<P>(m4, m5), m6);
output_ptr0[r] = PoolPost<P>(m0, avg);
output_ptr1[r] = PoolPost<P>(m1, avg);
output_ptr2[r] = PoolPost<P>(m2, avg);
}
}
// remain height
int start_h = valid_h_start + valid_h / 3 * 3;
for (int h = start_h; h < valid_h_end; ++h) {
size_t offset = (2 * h - padding_h) * input_w + input_w_start;
const float *input_ptr0 = input_ptr + offset;
const float *input_ptr1 = input_ptr0 + input_w;
const float *input_ptr2 = input_ptr1 + input_w;
float *output_ptr0 = output_ptr + h * output_w + valid_w_start;
int remain = output_w_remain;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
float32x4x2_t x0, x1, x2, y0;
float32x4_t post = vdupq_n_f32(1.f / 9);
for (int loop = 0; loop < output_w_tiles; ++loop) {
x0 = vld2q_f32(input_ptr0);
x1 = vld2q_f32(input_ptr0 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
x0 = vld2q_f32(input_ptr1);
x1 = vld2q_f32(input_ptr1 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
x0 = vld2q_f32(input_ptr2);
x1 = vld2q_f32(input_ptr2 + 8);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x2.val[1] = vextq_f32(x1.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[1] = vPoolPreq_f32<P>(x1.val[0], x1.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0.val[1] = vPoolPreq_f32<P>(x0.val[1], x2.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[1] = vPoolPreq_f32<P>(x0.val[1], y0.val[1]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
y0.val[1] = vPoolPostq_f32<P>(y0.val[1], post);
vst1q_f32(output_ptr0, y0.val[0]);
vst1_f32(output_ptr0 + 4, vget_low_f32(y0.val[1]));
input_ptr0 += 12;
input_ptr1 += 12;
input_ptr2 += 12;
output_ptr0 += 6;
}
// remain width
if (remain >= 4) {
x0 = vld2q_f32(input_ptr0);
x1.val[0] = vdupq_n_f32(input_ptr0[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
x0 = vld2q_f32(input_ptr1);
x1.val[0] = vdupq_n_f32(input_ptr1[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
x0 = vld2q_f32(input_ptr2);
x1.val[0] = vdupq_n_f32(input_ptr2[8]);
x2.val[0] = vextq_f32(x0.val[0], x1.val[0], 1);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x0.val[1]);
x0.val[0] = vPoolPreq_f32<P>(x0.val[0], x2.val[0]);
y0.val[0] = vPoolPreq_f32<P>(x0.val[0], y0.val[0]);
y0.val[0] = vPoolPostq_f32<P>(y0.val[0], post);
vst1q_f32(output_ptr0, y0.val[0]);
input_ptr0 += 8;
input_ptr1 += 8;
input_ptr2 += 8;
output_ptr0 += 4;
remain -= 4;
}
#endif // __ARM_NEON__
for (int r = 0; r < remain; ++r) {
float m0 = PoolPre<P>(input_ptr0[2 * r], input_ptr0[2 * r + 1]);
m0 = PoolPre<P>(m0, input_ptr0[2 * r + 2]);
float m1 = PoolPre<P>(input_ptr1[2 * r], input_ptr1[2 * r + 1]);
m1 = PoolPre<P>(m1, input_ptr1[2 * r + 2]);
float m2 = PoolPre<P>(input_ptr2[2 * r], input_ptr2[2 * r + 1]);
m2 = PoolPre<P>(m2, input_ptr2[2 * r + 2]);
m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2);
output_ptr0[r] = PoolPost<P>(m0, avg);
for (int r = 0; r < remain; ++r) {
float m0 = PoolPre<P>(input_ptr0[2 * r], input_ptr0[2 * r + 1]);
m0 = PoolPre<P>(m0, input_ptr0[2 * r + 2]);
float m1 = PoolPre<P>(input_ptr1[2 * r], input_ptr1[2 * r + 1]);
m1 = PoolPre<P>(m1, input_ptr1[2 * r + 2]);
float m2 = PoolPre<P>(input_ptr2[2 * r], input_ptr2[2 * r + 1]);
m2 = PoolPre<P>(m2, input_ptr2[2 * r + 2]);
m0 = PoolPre<P>(PoolPre<P>(m0, m1), m2);
output_ptr0[r] = PoolPost<P>(m0, avg);
}
}
}
}
......
......@@ -73,14 +73,14 @@ int main() {
// float
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<float>(
paddle_mobile::operators::math::matmul<float, float>(
aa, false, bb, false, static_cast<float>(1), &cc, static_cast<float>(0),
false, nullptr);
}
auto time_start0 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<float>(
paddle_mobile::operators::math::matmul<float, float>(
aa, false, bb, false, static_cast<float>(1), &cc, static_cast<float>(0),
false, nullptr);
}
......@@ -91,14 +91,14 @@ int main() {
// int8_t without bias
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<float, int32_t>(
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0));
}
auto time_start1 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<float, int32_t>(
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0));
}
......@@ -109,13 +109,13 @@ int main() {
// int8_t with bias, column element wise add
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul(
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), false, bias_data_col, false);
}
auto time_start2 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul(
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), false, bias_data_col, false);
}
......@@ -126,13 +126,13 @@ int main() {
// int8_t with bias, row element wise add
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul(
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), false, bias_data_row, true);
}
auto time_start3 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul(
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), false, bias_data_row, true);
}
......@@ -143,13 +143,13 @@ int main() {
// int8_t with bias&relu
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul(
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), true, bias_data_col, false);
}
auto time_start4 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul(
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), true, bias_data_col, false);
}
......
......@@ -59,7 +59,8 @@ int TestPoolOp(int in_channels, int in_height, int in_width) {
attrs["ksize"].Set<vector<int>>(std::vector<int>({kernel_h, kernel_w}));
attrs["strides"].Set<vector<int>>(std::vector<int>({stride_h, stride_w}));
attrs["paddings"].Set<vector<int>>(std::vector<int>({pad_h, pad_w}));
attrs["ceil_mode"].Set<bool>(false);
attrs["ceil_mode"].Set<bool>(true);
// attrs["ceil_mode"].Set<bool>(false);
attrs["global_pooling"].Set<bool>(false);
auto *op = new operators::PoolOp<CPU, float>("pool2d", inputs, outputs, attrs,
......@@ -116,57 +117,57 @@ int main(int argc, char *argv[]) {
int in_channels = atoi(argv[1]);
int in_height = atoi(argv[2]);
int in_width = atoi(argv[3]);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=3, pad=0, stride=1";
paddle_mobile::TestPoolOp<0, 3, 0, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=3, pad=1, stride=1";
paddle_mobile::TestPoolOp<0, 3, 1, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=3, pad=2, stride=1";
paddle_mobile::TestPoolOp<0, 3, 2, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=3, pad=5, stride=1";
paddle_mobile::TestPoolOp<0, 3, 5, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=3, pad=0, stride=1";
paddle_mobile::TestPoolOp<1, 3, 0, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=3, pad=1, stride=1";
paddle_mobile::TestPoolOp<1, 3, 1, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=3, pad=2, stride=1";
paddle_mobile::TestPoolOp<1, 3, 2, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=3, pad=5, stride=1";
paddle_mobile::TestPoolOp<1, 3, 5, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=3, pad=0, stride=1";
// paddle_mobile::TestPoolOp<0, 3, 0, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=3, pad=1, stride=1";
// paddle_mobile::TestPoolOp<0, 3, 1, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=3, pad=2, stride=1";
// paddle_mobile::TestPoolOp<0, 3, 2, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=3, pad=5, stride=1";
// paddle_mobile::TestPoolOp<0, 3, 5, 1>(in_channels, in_height, in_width);
//
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=3, pad=0, stride=1";
// paddle_mobile::TestPoolOp<1, 3, 0, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=3, pad=1, stride=1";
// paddle_mobile::TestPoolOp<1, 3, 1, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=3, pad=2, stride=1";
// paddle_mobile::TestPoolOp<1, 3, 2, 1>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=3, pad=5, stride=1";
// paddle_mobile::TestPoolOp<1, 3, 5, 1>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=3, pad=0, stride=2";
paddle_mobile::TestPoolOp<0, 3, 0, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=3, pad=1, stride=2";
paddle_mobile::TestPoolOp<0, 3, 1, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=3, pad=2, stride=2";
paddle_mobile::TestPoolOp<0, 3, 2, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=max, kernel=3, pad=5, stride=2";
paddle_mobile::TestPoolOp<0, 3, 5, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=3, pad=0, stride=2";
paddle_mobile::TestPoolOp<1, 3, 0, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=3, pad=1, stride=2";
paddle_mobile::TestPoolOp<1, 3, 1, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=3, pad=2, stride=2";
paddle_mobile::TestPoolOp<1, 3, 2, 2>(in_channels, in_height, in_width);
LOG(paddle_mobile::kLOG_INFO)
<< "float, pooling_type=avg, kernel=3, pad=5, stride=2";
paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=3, pad=1, stride=2";
// paddle_mobile::TestPoolOp<0, 3, 1, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=3, pad=2, stride=2";
// paddle_mobile::TestPoolOp<0, 3, 2, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=max, kernel=3, pad=5, stride=2";
// paddle_mobile::TestPoolOp<0, 3, 5, 2>(in_channels, in_height, in_width);
//
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=3, pad=0, stride=2";
// paddle_mobile::TestPoolOp<1, 3, 0, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=3, pad=1, stride=2";
// paddle_mobile::TestPoolOp<1, 3, 1, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=3, pad=2, stride=2";
// paddle_mobile::TestPoolOp<1, 3, 2, 2>(in_channels, in_height, in_width);
// LOG(paddle_mobile::kLOG_INFO)
// << "float, pooling_type=avg, kernel=3, pad=5, stride=2";
// paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width);
// // kernel = 5, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册