diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index a91ead240416e8aba3259fa785f6dfda596be44e..3699b1e8ae9d8f813439eaeaa760c4a9f6e100a0 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -51,8 +51,8 @@ __global__ void KeMaxPoolForward(const int nthreads, for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { if (maxval < inputData[h * width + w]) { - maxval = inputData[h * width + w]; max_index = h * width + w; + maxval = inputData[max_index]; } } } diff --git a/paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp b/paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp index 44fc2b91ec334ff1618328e44a5e8bd44d82c4b9..16438886df94cab9d29d05924bb047e6c7f1f6fa 100644 --- a/paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp +++ b/paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp @@ -105,15 +105,13 @@ TEST(Layer, maxPoolingWithMaskOutputLayerFwd) { maskMat->setData(maskData); doOneMaxPoolingWithMaskOutputTest( inputMat, "max-pool-with-mask", useGpu, maskMat); - /* - #ifdef PADDLE_WITH_CUDA - useGpu = true; - inputMat = Matrix::create(1, 25, false, useGpu); - maskMat = Matrix::create(1, 4, false, useGpu); - inputMat->copyFrom(inputData, 25); - maskMat->copyFrom(maskData, 4); - doOneMaxPoolingWithMaskOutputTest( - inputMat, "max-pool-with-mask", useGpu, maskMat); - #endif - */ +#ifdef PADDLE_WITH_CUDA + useGpu = true; + inputMat = Matrix::create(1, 25, false, useGpu); + maskMat = Matrix::create(1, 4, false, useGpu); + inputMat->copyFrom(inputData, 25); + maskMat->copyFrom(maskData, 4); + doOneMaxPoolingWithMaskOutputTest( + inputMat, "max-pool-with-mask", useGpu, maskMat); +#endif } diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 743922cd9bd65b311cf47659be83f55712b8d5ac..41ee5089677f2565dfd16b1bca7885db5583d910 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -2021,7 +2021,7 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, int wstart = pw * strideW - paddingW; int wend = std::min(wstart + sizeX, imgSizeW); wstart = std::max(wstart, 0); - if (maskMatP == NULL) { + if (maskData == NULL) { for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { outData[ph * outputW + pw] = std::max( @@ -2044,7 +2044,7 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, inputData += inLength; outData += outLength; - if (maskMatP != NULL) maskData += outLength; + if (maskData != NULL) maskData += outLength; } } } diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index f7ab7a5ca0a369e976c783fdc793abd69a683c3c..e21071f5b020156713f31aa4fc7266aed5cd9cfa 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -2701,7 +2701,7 @@ def img_pool_layer(input, assert type(pool_type) in [AvgPooling, MaxPooling, MaxWithMaskPooling, CudnnAvgPooling, CudnnMaxPooling], \ - "only (Cudnn)AvgPooling, (Cudnn)MaxPooling MaxWithMaskPooling are supported" + "only (Cudnn)AvgPooling, (Cudnn)MaxPooling, MaxWithMaskPooling are supported" type_name = pool_type.name + '-projection' \ if (