diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 0e84cb37392839d112448b0b3c12b042e7df838e..bcd6dfe1fda6b1243007b0c26a6e0087eedcc10c 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -2157,26 +2157,20 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, int wend = wstart + sizeX; wstart = wstart < 0 ? 0 : wstart; wend = wend < (int)imgSizeW ? wend : (int)imgSizeW; - if (maskData == NULL) { - real tmp = -(real)FLT_MAX; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - tmp = tmp < inputData[h * imgSizeW + w] - ? inputData[h * imgSizeW + w] - : tmp; - } - } - outData[ph * outputW + pw] = tmp; - } else { - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - if (outData[ph * outputW + pw] < inputData[h * imgSizeW + w]) { - outData[ph * outputW + pw] = inputData[h * imgSizeW + w]; - maskData[ph * outputW + pw] = h * imgSizeW + w; - } + + real maxval = -(real)FLT_MAX; + int max_index = -1; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (maxval < inputData[h * imgSizeW + w]) { + maxval = inputData[h * imgSizeW + w]; + max_index = h * imgSizeW + w; } } } + + outData[ph * outputW + pw] = maxval; + if (maskData != NULL) maskData[ph * outputW + pw] = max_index; } } // compute offset