未验证 提交 f36dcf99 编写于 作者: S suiyang 提交者: GitHub

Merge pull request #1220 from Eclipsess/develop

fix #1219 fix pool max 2x2 bug
...@@ -76,7 +76,7 @@ void PoolCompute(const PoolParam<CPU> &param) { ...@@ -76,7 +76,7 @@ void PoolCompute(const PoolParam<CPU> &param) {
} }
} }
} else if (0 && ksize[0] == 2 && ksize[0] == ksize[1] && strides[0] == 2 && } else if (ksize[0] == 2 && ksize[0] == ksize[1] && strides[0] == 2 &&
strides[0] == strides[1] && paddings[0] == paddings[1] && strides[0] == strides[1] && paddings[0] == paddings[1] &&
paddings[1] == 0) { paddings[1] == 0) {
#if __ARM_NEON #if __ARM_NEON
......
...@@ -58,7 +58,7 @@ void Pool2x2Maxs2p0(vector<int> strides, vector<int> paddings, ...@@ -58,7 +58,7 @@ void Pool2x2Maxs2p0(vector<int> strides, vector<int> paddings,
const float *in_ptr1 = input_data + i * input_batch_stride + const float *in_ptr1 = input_data + i * input_batch_stride +
c * input_channel_stride + ph * input_width; c * input_channel_stride + ph * input_width;
const float *in_ptr2 = in_ptr1 + input_width; const float *in_ptr2 = in_ptr1 + input_width;
if (ph + 1 >= input_height) { if (ph != input_height && ph + 1 >= input_height) {
in_ptr2 = static_cast<float *>( in_ptr2 = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * input_width)); paddle_mobile::memory::Alloc(sizeof(float) * input_width));
memset(static_cast<void *>(const_cast<float *>(in_ptr2)), -FLT_MAX, memset(static_cast<void *>(const_cast<float *>(in_ptr2)), -FLT_MAX,
...@@ -122,19 +122,30 @@ void Pool2x2Maxs2p0(vector<int> strides, vector<int> paddings, ...@@ -122,19 +122,30 @@ void Pool2x2Maxs2p0(vector<int> strides, vector<int> paddings,
#endif #endif
if (_w2 != 0) { if (_w2 != 0) {
in_ptr1 += 16 * w1 + 4 * w2; in_ptr1 = input_data + i * input_batch_stride +
in_ptr2 += 16 * w1 + 4 * w2; c * input_channel_stride + ph * input_width + 16 * w1 +
out_ptr += 8 * w1 + 2 * w2; 4 * w2;
in_ptr2 = in_ptr1 + input_width;
out_ptr = output_data + i * output_batch_stride +
c * output_channel_stride + ph / 2 * output_width + 8 * w1 +
2 * w2;
if (_w2 == 1) { if (_w2 == 1) {
*out_ptr = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2; *out_ptr = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
} else if (_w2 == 2) { } else if (_w2 == 2) {
float temp = (*in_ptr1++ > *in_ptr2++) ? *in_ptr1++ : *in_ptr2++; float temp = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
in_ptr1++;
in_ptr2++;
float temp1 = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2; float temp1 = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
*out_ptr = (temp > temp1) ? temp : temp1; *out_ptr = (temp > temp1) ? temp : temp1;
} else if (_w2 == 3) { } else if (_w2 == 3) {
float temp = (*in_ptr1++ > *in_ptr2++) ? *in_ptr1++ : *in_ptr2++; float temp = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
float temp1 = (*in_ptr1++ > *in_ptr2++) ? *in_ptr1++ : *in_ptr2++; in_ptr1++;
*out_ptr++ = (temp > temp1) ? temp : temp1; in_ptr2++;
float temp1 = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
in_ptr1++;
in_ptr2++;
*out_ptr = (temp > temp1) ? temp : temp1;
out_ptr++;
*out_ptr = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2; *out_ptr = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
} }
} }
...@@ -173,7 +184,7 @@ void Pool2x2Avgs2p0(vector<int> strides, vector<int> paddings, ...@@ -173,7 +184,7 @@ void Pool2x2Avgs2p0(vector<int> strides, vector<int> paddings,
int w2 = _w1 / 4; int w2 = _w1 / 4;
int _w2 = _w1 % 4; int _w2 = _w1 % 4;
float quarter = 1 / 4; float quarter = 0.25;
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < input_height; ph += 2) { for (int ph = 0; ph < input_height; ph += 2) {
...@@ -250,25 +261,32 @@ void Pool2x2Avgs2p0(vector<int> strides, vector<int> paddings, ...@@ -250,25 +261,32 @@ void Pool2x2Avgs2p0(vector<int> strides, vector<int> paddings,
#endif #endif
if (_w2 != 0) { if (_w2 != 0) {
in_ptr1 += 16 * w1 + 4 * w2; in_ptr1 = input_data + i * input_batch_stride +
in_ptr2 += 16 * w1 + 4 * w2; c * input_channel_stride + ph * input_width + 16 * w1 +
out_ptr += 8 * w1 + 2 * w2; 4 * w2;
in_ptr2 = in_ptr1 + input_width;
out_ptr = output_data + i * output_batch_stride +
c * output_channel_stride + ph / 2 * output_width + 8 * w1 +
2 * w2;
if (_w2 == 1) { if (_w2 == 1) {
*out_ptr = 0.5 * (*in_ptr1 + *in_ptr2); *out_ptr = 0.5 * (*in_ptr1 + *in_ptr2);
} else if (_w2 == 2) { } else if (_w2 == 2) {
float temp = 0; float temp = 0;
temp += *in_ptr1++;
temp += *in_ptr2++;
temp += *in_ptr1; temp += *in_ptr1;
temp += *in_ptr2; temp += *in_ptr2;
*out_ptr = 0.5 * temp; in_ptr1++;
in_ptr2++;
temp += *in_ptr1;
temp += *in_ptr2;
*out_ptr = 0.25 * temp;
} else if (_w2 == 3) { } else if (_w2 == 3) {
float temp = 0; float temp = 0;
temp += *in_ptr1++; temp += *in_ptr1++;
temp += *in_ptr2++; temp += *in_ptr2++;
temp += *in_ptr1++; temp += *in_ptr1++;
temp += *in_ptr2++; temp += *in_ptr2++;
*out_ptr++ = 0.5 * temp; *out_ptr = 0.25 * temp;
out_ptr++;
*out_ptr = 0.5 * (*in_ptr1 + *in_ptr2); *out_ptr = 0.5 * (*in_ptr1 + *in_ptr2);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册