diff --git a/paddle/cuda/src/hl_cuda_cnn.cu b/paddle/cuda/src/hl_cuda_cnn.cu index 0992286f360fb8be22e3c35b632e4b7163036277..1516accaae17fbeff4f4e48584940ec3e9873897 100644 --- a/paddle/cuda/src/hl_cuda_cnn.cu +++ b/paddle/cuda/src/hl_cuda_cnn.cu @@ -381,57 +381,45 @@ void hl_avgpool_backward(const int frameCnt, const real* outGrad, CHECK_SYNC("hl_avgpool_backward failed"); } -__global__ void KeCMRNormFillScale(size_t nthreads, const real* in, +__global__ void KeCMRNormFillScale(size_t imageSize, const real* in, real* scale, size_t channels, size_t height, size_t width, size_t size, real alpha) { - size_t index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < nthreads) { - // find out the local offset - size_t w = index % width; - size_t h = (index / width) % height; - size_t n = index / width / height; - size_t offset = (n * channels * height + h) * width + w; - size_t step = height * width; + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < imageSize) { + const int w = idx % width; + const int h = (idx / width) % height; + const int n = idx / width / height; + const int offset = (n * channels * height + h) * width + w; + in += offset; scale += offset; - size_t head = 0; - size_t pre_pad = (size - 1) / 2; - size_t post_pad = size - pre_pad - 1; - real accum_scale = 0; - // fill the scale at [n, :, h, w] - // accumulate values - while (head < post_pad) { - accum_scale += in[head * step] * in[head * step]; - ++head; - } - // until we reach size, nothing needs to be subtracted - while (head < size) { - accum_scale += in[head * step] * in[head * step]; - scale[(head - post_pad) * step] = 1. + accum_scale * alpha; - ++head; - } - // both add and subtract - while (head < channels) { - accum_scale += in[head * step] * in[head * step]; - accum_scale -= in[(head - size) * step] * in[(head - size) * step]; - scale[(head - post_pad) * step] = 1. + accum_scale * alpha; - ++head; - } - // subtract only - while (head < channels + post_pad) { - accum_scale -= in[(head - size) * step] * in[(head - size) * step]; - scale[(head - post_pad) * step] = 1. + accum_scale * alpha; - ++head; + const int step = height * width; + const int pre_pad = (size - 1) / 2; + const int post_pad = size - pre_pad - 1; + + real accum = 0; + int index = 0; + while (index < channels + post_pad) { + if (index < channels) { + accum += in[index * step] * in[index * step]; + } + if (index >= size) { + accum -= in[(index - size) * step] * in[(index - size) * step]; + } + if (index >= post_pad) { + scale[(index - post_pad) * step] = 1. + accum * alpha; + } + ++index; } } } - __global__ void KeCMRNormOutput(size_t nthreads, const real* in, - const real* scale, real negative_beta, - real* out) { - size_t index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < nthreads) { +__global__ void KeCMRNormOutput(size_t inputSize, const real* in, + const real* scale, real negative_beta, + real* out) { + const int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < inputSize) { out[index] = in[index] * pow(scale[index], negative_beta); } } @@ -440,84 +428,60 @@ void hl_CMRNorm_forward(size_t frameCnt, const real* in, real* scale, real* out, size_t channels, size_t height, size_t width, size_t sizeX, real alpha, real beta) { - size_t threadsNum = frameCnt * height * width; - size_t blocksX = (threadsNum + 1024 - 1) / 1024; - size_t blocksY = 1; - dim3 threads(1024, 1); - dim3 grid(blocksX, blocksY); - - KeCMRNormFillScale<<>> - (threadsNum, in, scale, channels, height, width, sizeX, alpha); - - threadsNum = frameCnt * height * width *channels; - blocksX = (threadsNum + 1024 -1) / 1024; - dim3 threads2(1024, 1); - dim3 grid2(blocksX, blocksY); - KeCMRNormOutput<<>> - (threadsNum, in, scale, beta, out); + size_t imageSize = frameCnt * height * width; + int blockSize = 1024; + int gridSize = (imageSize + 1024 - 1) / 1024; + KeCMRNormFillScale<<>> + (imageSize, in, scale, channels, height, width, sizeX, alpha); + + size_t inputSize = frameCnt * height * width *channels; + blockSize = 1024; + gridSize = (inputSize + 1024 - 1) / 1024; + KeCMRNormOutput<<>> + (inputSize, in, scale, beta, out); CHECK_SYNC("hl_CMRNorm_forward"); } -__global__ void KeCMRNormDiff(size_t nthreads, const real* bottom_data, +__global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, const real* top_data, const real* scale, const real* top_diff, size_t channels, size_t height, size_t width, size_t size, real negative_beta, real cache_ratio, real* bottom_diff ) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - if (index < nthreads) { - // find out the local offset - size_t w = index % width; - size_t h = (index / width) % height; - size_t n = index / width / height; - size_t offset = (n * channels * height + h) * width + w; - size_t step = height * width; + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < imageSize) { + const int w = idx % width; + const int h = (idx / width) % height; + const int n = idx / width / height; + const int offset = (n * channels * height + h) * width + w; bottom_data += offset; top_data += offset; scale += offset; top_diff += offset; bottom_diff += offset; - int head = 0; - int pre_pad = size - (size + 1) / 2; - int post_pad = size - pre_pad - 1; - real accum_ratio = 0; - // accumulate values - while (head < post_pad) { - accum_ratio += top_diff[head * step] * - top_data[head * step] / scale[head * step]; - ++head; - } - // until we reach size, nothing needs to be subtracted - while (head < size) { - accum_ratio += top_diff[head * step] * - top_data[head * step] / scale[head * step]; - bottom_diff[(head - post_pad) * step] += - top_diff[(head - post_pad) * step] * - pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio * - bottom_data[(head - post_pad) * step] * accum_ratio; - ++head; - } - // both add and subtract - while (head < channels) { - accum_ratio += top_diff[head * step] * top_data[head * step] / - scale[head * step]; - accum_ratio -= top_diff[(head - size) * step] * - top_data[(head - size) * step] / scale[(head - size) * step]; - bottom_diff[(head - post_pad) * step] += - top_diff[(head - post_pad) * step] * - pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio * - bottom_data[(head - post_pad) * step] * accum_ratio; - ++head; - } - // subtract only - while (head < channels + post_pad) { - accum_ratio -= top_diff[(head - size) * step] * - top_data[(head - size) * step] / scale[(head - size) * step]; - bottom_diff[(head - post_pad) * step] += - top_diff[(head - post_pad) * step] * - pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio * - bottom_data[(head - post_pad) * step] * accum_ratio; - ++head; + + const int step = height * width; + const int pre_pad = size - (size + 1) / 2; + const int post_pad = size - pre_pad - 1; + + int index = 0; + real accum = 0; + while (index < channels + post_pad) { + if (index < channels) { + accum += top_diff[index * step] * top_data[index * step] / + scale[index * step]; + } + if (index >= size) { + accum -= top_diff[(index - size) * step] * + top_data[(index - size) * step] / scale[(index - size) * step]; + } + if (index >= post_pad) { + bottom_diff[(index - post_pad) * step] += + top_diff[(index - post_pad) * step] * + pow(scale[(index - post_pad) * step], negative_beta) - cache_ratio * + bottom_data[(index - post_pad) * step] * accum; + } + ++index; } } } @@ -528,14 +492,12 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV, real *inDiff, size_t channels, size_t height, size_t width, size_t sizeX, real alpha, real beta) { - size_t threadsNum = frameCnt * height * width; - size_t blocksX = (threadsNum + 1024 - 1) / 1024; - size_t blocksY = 1; - dim3 threads(1024, 1); - dim3 grid(blocksX, blocksY); - KeCMRNormDiff <<>> - (threadsNum, inV, outV, scale, outDiff, channels, - height, width, sizeX, alpha, beta, inDiff); + size_t imageSize = frameCnt * height * width; + int blockSize = 1024; + int gridSize = (imageSize + 1024 - 1) / 1024; + KeCMRNormDiff <<>> + (imageSize, inV, outV, scale, outDiff, channels, + height, width, sizeX, alpha, beta, inDiff); CHECK_SYNC("hl_CMRNorm_backward"); } diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 7983d9fe64c61648a2939ddc610a0f819e338577..8ade15daac8609b715c8b3072d8095ae62a040ca 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1021,11 +1021,10 @@ void testNormLayer(const string& normType, bool trans, bool useGpu) { testLayerGrad(config, "norm", 100, trans, useGpu); } -#ifndef PADDLE_ONLY_CPU TEST(Layer, NormLayer) { testNormLayer("cmrnorm-projection", /* trans= */ false, /* useGpu= */ true); + testNormLayer("cmrnorm-projection", /* trans= */ false, /* useGpu= */ false); } -#endif void setPoolConfig(TestConfig* config, PoolConfig* pool, diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index c69e074a76399db923a5c64243f1d3690858810d..2cde11dd479dc0d150c5d7ce5c0c5c1cbf40e449 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -2227,52 +2227,43 @@ void CpuMatrix::crossMapNormalFwd(Matrix& input, size_t sizeX, float scale, float pow) { - size_t num = input.getHeight(); + CHECK(isContiguous()); + CHECK(input.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK_EQ(getHeight(), input.getHeight()); + CHECK_EQ(getWidth(), input.getWidth()); + CHECK_EQ(getHeight(), denoms.getHeight()); + CHECK_EQ(getWidth(), denoms.getWidth()); + + size_t numSample = input.getHeight(); + size_t numCols = input.getWidth(); size_t height = imgSizeH; size_t width = imgSizeW; - size_t numCols = input.getWidth(); - CHECK(height * width * channels == input.getWidth()); - CHECK(denoms.getHeight() == input.getHeight() && - denoms.getWidth() == input.getWidth() && input.getHeight() == height_ && - input.getWidth() == width_); - real* imgData = input.getData(); - real* diffData = input.getData(); - real* targetData = getData(); - size_t halfSize = sizeX / 2; - size_t imgPixels = height * width; - - // use integral vector to implement the sum in local window - real* integralData = - (real*)malloc((channels + sizeX + 1) * sizeof(real)); // NOLINT // TODO: - for (size_t i = 0; i <= halfSize; i++) { - integralData[i] = 0; - } - for (size_t i = 0; i < num; i++) { - real* targetPtr = targetData + i * numCols; - real* imgPtr = imgData + i * numCols; - real* diffPtr = diffData + i * numCols; - for (size_t m = 0; m < height; m++) { - for (size_t n = 0; n < width; n++) { - for (size_t c = 0; c < channels; c++) { - integralData[c + halfSize + 1] = - integralData[c + halfSize] + _square(*(diffPtr + c * imgPixels)); - } - for (size_t k = channels + halfSize + 1; k <= channels + sizeX; k++) { - integralData[k] = integralData[channels + halfSize]; + CHECK(height * width * channels == numCols); + + // TODO(hedaoyuan) After commit TensorExpress code, + // Reconstruction this code to remove the temporary memory. + CpuMatrix tmp(channels, height * width); + CpuMatrix tmp2(tmp.getData(), 1, channels * height * width); + denoms.zero(); + const int start = -((int)sizeX - 1) / 2; + const int end = (int)sizeX + start; + for (size_t i = 0; i < numSample; i++) { + input.subMatrix(i, 1)->square2(tmp2); + CpuMatrix subDen( + denoms.subMatrix(i, 1)->getData(), channels, height * width); + for (int c = 0; c < (int)channels; c++) { + for (int s = start; s < end; s++) { + if (c + s >= 0 && c + s < (int)channels) { + subDen.subMatrix(c, 1)->add(*tmp.subMatrix(c + s, 1)); } - for (size_t k = 0; k < channels; k += 1) { - real a = integralData[k + sizeX] - integralData[k]; - a = scale * a + 1; - targetPtr[k * imgPixels] = imgPtr[k * imgPixels] * _pow(a, -pow); - } - diffPtr++; - targetPtr++; - imgPtr++; } } } - free(integralData); - integralData = NULL; + + denoms.add(scale, (real)1); + this->pow2(denoms, -pow); + this->dotMul(input); } void CpuMatrix::crossMapNormalBwd(Matrix& localGrad, @@ -2282,19 +2273,63 @@ void CpuMatrix::crossMapNormalBwd(Matrix& localGrad, size_t channels, size_t imgSizeH, size_t imgSizeW, - size_t size, + size_t sizeX, float scale, float pow) { - LOG(FATAL) << "Not implemented"; - - CHECK(imgSizeH * imgSizeW * channels == preOutV.getWidth()); - CHECK(denoms.getHeight() == preOutV.getHeight() && - denoms.getWidth() == preOutV.getWidth() && - preOutV.getHeight() == height_ && preOutV.getWidth() == width_); - CHECK(denoms.getHeight() == localGrad.getHeight() && - denoms.getWidth() == localGrad.getWidth()); - - // NOLINT // TODO: + CHECK(isContiguous()); + CHECK(localGrad.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK(preOutV.isContiguous()); + CHECK(localOutV.isContiguous()); + CHECK_EQ(getHeight(), localGrad.getHeight()); + CHECK_EQ(getWidth(), localGrad.getWidth()); + CHECK_EQ(getHeight(), denoms.getHeight()); + CHECK_EQ(getWidth(), denoms.getWidth()); + CHECK_EQ(getHeight(), preOutV.getHeight()); + CHECK_EQ(getWidth(), preOutV.getWidth()); + CHECK_EQ(getHeight(), localOutV.getHeight()); + CHECK_EQ(getWidth(), localOutV.getWidth()); + + size_t numSample = getHeight(); + size_t numCols = getWidth(); + size_t height = imgSizeH; + size_t width = imgSizeW; + CHECK(height * width * channels == numCols); + + // TODO(hedaoyuan) After commit TensorExpress code, + // Reconstruction this code to remove the temporary memory. + CpuMatrix tmp(1, height * width); + + const int start = -((int)sizeX) / 2; + const int end = (int)sizeX + start; + const real ratio = -(real)2 * scale * pow; + for (size_t i = 0; i < numSample; i++) { + CpuMatrix inputDiff( + this->subMatrix(i, 1)->getData(), channels, height * width); + CpuMatrix outDiff( + localGrad.subMatrix(i, 1)->getData(), channels, height * width); + CpuMatrix input( + preOutV.subMatrix(i, 1)->getData(), channels, height * width); + CpuMatrix output( + localOutV.subMatrix(i, 1)->getData(), channels, height * width); + CpuMatrix subDen( + denoms.subMatrix(i, 1)->getData(), channels, height * width); + + for (int c = 0; c < (int)channels; c++) { + tmp.pow2(*subDen.subMatrix(c, 1), -pow); + inputDiff.subMatrix(c, 1) + ->addDotMul(tmp, *outDiff.subMatrix(c, 1), (real)1, (real)1); + for (int s = start; s < end; s++) { + if (c + s >= 0 && c + s < (int)channels) { + tmp.dotMul(*outDiff.subMatrix(c + s, 1), *output.subMatrix(c + s, 1)); + tmp.mulScalar(ratio); + tmp.dotDiv(tmp, *subDen.subMatrix(c + s, 1)); + tmp.dotMul(*input.subMatrix(c, 1)); + inputDiff.subMatrix(c, 1)->add(tmp); + } + } + } + } } /** diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 713792d82b3c569d26375780cc19fa0bd6cca391..5233a9af40155328d0a2695edf1805ff031ad0a7 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1261,6 +1261,121 @@ TEST(Matrix, MaxOutFwdBwd) { } } } +void testCrossMapNormalFwd( + int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) { + float scale = 1.5; + float pow = 0.5; + int width = imgSizeH * imgSizeW * channels; + MatrixPtr input = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr denorms = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr target = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr inputGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr denormsGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr targetGpu = GpuMatrix::create(numSamples, width, false, true); + + input->randomizeUniform(); + target->randomizeUniform(); + inputGpu->copyFrom(*input); + targetGpu->copyFrom(*target); + + target->crossMapNormalFwd( + *input, imgSizeH, imgSizeW, *denorms, channels, sizeX, scale, pow); + targetGpu->crossMapNormalFwd( + *inputGpu, imgSizeH, imgSizeW, *denormsGpu, channels, sizeX, scale, pow); + + TensorCheckErr(*target, *targetGpu); + TensorCheckErr(*denorms, *denormsGpu); +} + +TEST(Matrix, crossMapNormalFwd) { + for (auto numSamples : {5, 32}) { + for (auto channels : {1, 5, 32}) { + for (auto imgSizeH : {5, 33, 100}) { + for (auto imgSizeW : {5, 32, 96}) { + for (auto sizeX : {1, 2, 3, 5, 7}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW + << " sizeX=" << sizeX; + testCrossMapNormalFwd( + numSamples, channels, imgSizeH, imgSizeW, sizeX); + } + } + } + } + } +} + +void testCrossMapNormalBwd( + int numSamples, int channels, int imgSizeH, int imgSizeW, int sizeX) { + float scale = 1.5; + float pow = 0.5; + size_t width = imgSizeH * imgSizeW * channels; + MatrixPtr localGrad = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr denoms = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr output = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr preOutV = CpuMatrix::create(numSamples, width, false, false); + MatrixPtr localOutV = CpuMatrix::create(numSamples, width, false, false); + + localGrad->randomizeUniform(); + denoms->randomizeUniform(); + preOutV->randomizeUniform(); + localOutV->randomizeUniform(); + output->randomizeUniform(); + denoms->add(0.01); + + MatrixPtr localGradGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr denomsGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr outputGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr preOutVGpu = GpuMatrix::create(numSamples, width, false, true); + MatrixPtr localOutVGpu = GpuMatrix::create(numSamples, width, false, true); + + localGradGpu->copyFrom(*localGrad); + denomsGpu->copyFrom(*denoms); + preOutVGpu->copyFrom(*preOutV); + localOutVGpu->copyFrom(*localOutV); + outputGpu->copyFrom(*output); + + output->crossMapNormalBwd(*localGrad, + *denoms, + *preOutV, + *localOutV, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + outputGpu->crossMapNormalBwd(*localGradGpu, + *denomsGpu, + *preOutVGpu, + *localOutVGpu, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + + TensorCheckErr(*output, *outputGpu); +} + +TEST(Matrix, crossMapNormalBwd) { + for (auto numSamples : {5, 32}) { + for (auto channels : {1, 5, 32}) { + for (auto imgSizeH : {5, 33, 100}) { + for (auto imgSizeW : {5, 32, 96}) { + for (auto sizeX : {1, 2, 3, 5, 7}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW + << " sizeX=" << sizeX; + testCrossMapNormalBwd( + numSamples, channels, imgSizeH, imgSizeW, sizeX); + } + } + } + } + } +} int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv);