未验证 提交 f36dcf99 编写于 作者: S suiyang 提交者: GitHub

Merge pull request #1220 from Eclipsess/develop

fix #1219 fix pool max 2x2 bug
......@@ -76,7 +76,7 @@ void PoolCompute(const PoolParam<CPU> &param) {
}
}
} else if (0 && ksize[0] == 2 && ksize[0] == ksize[1] && strides[0] == 2 &&
} else if (ksize[0] == 2 && ksize[0] == ksize[1] && strides[0] == 2 &&
strides[0] == strides[1] && paddings[0] == paddings[1] &&
paddings[1] == 0) {
#if __ARM_NEON
......
......@@ -58,7 +58,7 @@ void Pool2x2Maxs2p0(vector<int> strides, vector<int> paddings,
const float *in_ptr1 = input_data + i * input_batch_stride +
c * input_channel_stride + ph * input_width;
const float *in_ptr2 = in_ptr1 + input_width;
if (ph + 1 >= input_height) {
if (ph != input_height && ph + 1 >= input_height) {
in_ptr2 = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * input_width));
memset(static_cast<void *>(const_cast<float *>(in_ptr2)), -FLT_MAX,
......@@ -122,19 +122,30 @@ void Pool2x2Maxs2p0(vector<int> strides, vector<int> paddings,
#endif
if (_w2 != 0) {
in_ptr1 += 16 * w1 + 4 * w2;
in_ptr2 += 16 * w1 + 4 * w2;
out_ptr += 8 * w1 + 2 * w2;
in_ptr1 = input_data + i * input_batch_stride +
c * input_channel_stride + ph * input_width + 16 * w1 +
4 * w2;
in_ptr2 = in_ptr1 + input_width;
out_ptr = output_data + i * output_batch_stride +
c * output_channel_stride + ph / 2 * output_width + 8 * w1 +
2 * w2;
if (_w2 == 1) {
*out_ptr = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
} else if (_w2 == 2) {
float temp = (*in_ptr1++ > *in_ptr2++) ? *in_ptr1++ : *in_ptr2++;
float temp = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
in_ptr1++;
in_ptr2++;
float temp1 = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
*out_ptr = (temp > temp1) ? temp : temp1;
} else if (_w2 == 3) {
float temp = (*in_ptr1++ > *in_ptr2++) ? *in_ptr1++ : *in_ptr2++;
float temp1 = (*in_ptr1++ > *in_ptr2++) ? *in_ptr1++ : *in_ptr2++;
*out_ptr++ = (temp > temp1) ? temp : temp1;
float temp = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
in_ptr1++;
in_ptr2++;
float temp1 = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
in_ptr1++;
in_ptr2++;
*out_ptr = (temp > temp1) ? temp : temp1;
out_ptr++;
*out_ptr = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
}
}
......@@ -173,7 +184,7 @@ void Pool2x2Avgs2p0(vector<int> strides, vector<int> paddings,
int w2 = _w1 / 4;
int _w2 = _w1 % 4;
float quarter = 1 / 4;
float quarter = 0.25;
for (int i = 0; i < batch_size; ++i) {
for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < input_height; ph += 2) {
......@@ -250,25 +261,32 @@ void Pool2x2Avgs2p0(vector<int> strides, vector<int> paddings,
#endif
if (_w2 != 0) {
in_ptr1 += 16 * w1 + 4 * w2;
in_ptr2 += 16 * w1 + 4 * w2;
out_ptr += 8 * w1 + 2 * w2;
in_ptr1 = input_data + i * input_batch_stride +
c * input_channel_stride + ph * input_width + 16 * w1 +
4 * w2;
in_ptr2 = in_ptr1 + input_width;
out_ptr = output_data + i * output_batch_stride +
c * output_channel_stride + ph / 2 * output_width + 8 * w1 +
2 * w2;
if (_w2 == 1) {
*out_ptr = 0.5 * (*in_ptr1 + *in_ptr2);
} else if (_w2 == 2) {
float temp = 0;
temp += *in_ptr1++;
temp += *in_ptr2++;
temp += *in_ptr1;
temp += *in_ptr2;
*out_ptr = 0.5 * temp;
in_ptr1++;
in_ptr2++;
temp += *in_ptr1;
temp += *in_ptr2;
*out_ptr = 0.25 * temp;
} else if (_w2 == 3) {
float temp = 0;
temp += *in_ptr1++;
temp += *in_ptr2++;
temp += *in_ptr1++;
temp += *in_ptr2++;
*out_ptr++ = 0.5 * temp;
*out_ptr = 0.25 * temp;
out_ptr++;
*out_ptr = 0.5 * (*in_ptr1 + *in_ptr2);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册