diff --git a/paddle/math/cross_map_normal_op.cpp b/paddle/math/cross_map_normal_op.cpp index 3eb51b5998fc4495a1d71b6aa677ec912f4cfa15..be242926aff161a36f61c1101a80272ccbace168 100644 --- a/paddle/math/cross_map_normal_op.cpp +++ b/paddle/math/cross_map_normal_op.cpp @@ -17,15 +17,16 @@ limitations under the License. */ namespace paddle { // NCHW -void CrossMapNormal::operator()(CpuMatrix& outputs, - CpuMatrix& denoms, - CpuMatrix& inputs, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { +template <> +void CrossMapNormal::operator()(CpuMatrix& outputs, + CpuMatrix& denoms, + CpuMatrix& inputs, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { CHECK(outputs.isContiguous()); CHECK(inputs.isContiguous()); CHECK(denoms.isContiguous()); @@ -58,17 +59,18 @@ void CrossMapNormal::operator()(CpuMatrix& outputs, outputs = inputs * denoms.pow(-pow); } -void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, - CpuMatrix& inputsValue, - CpuMatrix& outputsGrad, - CpuMatrix& outputsValue, - CpuMatrix& denoms, - size_t channels, - size_t imgSizeH, - size_t imgSizeW, - size_t sizeX, - real scale, - real pow) { +template <> +void CrossMapNormalGrad::operator()(CpuMatrix& inputsGrad, + CpuMatrix& inputsValue, + CpuMatrix& outputsGrad, + CpuMatrix& outputsValue, + CpuMatrix& denoms, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { CHECK(inputsGrad.isContiguous()); CHECK(outputsGrad.isContiguous()); CHECK(denoms.isContiguous()); diff --git a/paddle/math/cross_map_normal_op.h b/paddle/math/cross_map_normal_op.h index 2f996072528a0a48a14d787f59c56dd1a8d9308a..c2bb95f6b11fba35c396c9f5a04b4603991a4dfb 100644 --- a/paddle/math/cross_map_normal_op.h +++ b/paddle/math/cross_map_normal_op.h @@ -18,10 +18,30 @@ limitations under the License. */ namespace paddle { +enum DeviceType { + DEVICE_TYPE_UNSPECIFIED = 0, + DEVICE_TYPE_CPU = 1, + DEVICE_TYPE_GPU = 2, +}; + +template +struct MatrixT; + +template <> +struct MatrixT { + using type = CpuMatrix; +}; + +template <> +struct MatrixT { + using type = GpuMatrix; +}; + +template struct CrossMapNormal { - void operator()(CpuMatrix& outputs, - CpuMatrix& denoms, - CpuMatrix& inputs, + void operator()(typename MatrixT::type& outputs, + typename MatrixT::type& denoms, + typename MatrixT::type& inputs, size_t channels, size_t imgSizeH, size_t imgSizeW, @@ -30,12 +50,13 @@ struct CrossMapNormal { real pow); }; +template struct CrossMapNormalGrad { - void operator()(CpuMatrix& inputsGrad, - CpuMatrix& inputsValue, - CpuMatrix& outputsGrad, - CpuMatrix& outputsValue, - CpuMatrix& denoms, + void operator()(typename MatrixT::type& inputsGrad, + typename MatrixT::type& inputsValue, + typename MatrixT::type& outputsGrad, + typename MatrixT::type& outputsValue, + typename MatrixT::type& denoms, size_t channels, size_t imgSizeH, size_t imgSizeW, diff --git a/paddle/math/cross_map_normal_op_gpu.cu b/paddle/math/cross_map_normal_op_gpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..0a154d97ac02f698f32168c4cea65062d689e9b1 --- /dev/null +++ b/paddle/math/cross_map_normal_op_gpu.cu @@ -0,0 +1,194 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "hl_base.h" +#include "cross_map_normal_op.h" + +namespace paddle { + +__global__ void KeCMRNormFillScale(size_t imageSize, const real* in, + real* scale, size_t channels, + size_t height, size_t width, size_t size, + real alpha) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < imageSize) { + const int w = idx % width; + const int h = (idx / width) % height; + const int n = idx / width / height; + const int offset = (n * channels * height + h) * width + w; + + in += offset; + scale += offset; + const int step = height * width; + const int pre_pad = (size - 1) / 2; + const int post_pad = size - pre_pad - 1; + + real accum = 0; + int index = 0; + while (index < channels + post_pad) { + if (index < channels) { + accum += in[index * step] * in[index * step]; + } + if (index >= size) { + accum -= in[(index - size) * step] * in[(index - size) * step]; + } + if (index >= post_pad) { + scale[(index - post_pad) * step] = 1. + accum * alpha; + } + ++index; + } + } +} + +__global__ void KeCMRNormOutput(size_t inputSize, const real* in, + const real* scale, real negative_beta, + real* out) { + const int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < inputSize) { + out[index] = in[index] * pow(scale[index], negative_beta); + } +} + +template <> +void CrossMapNormal::operator()(GpuMatrix& outputs, + GpuMatrix& denoms, + GpuMatrix& inputs, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { + CHECK(outputs.isContiguous()); + CHECK(inputs.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK_EQ(outputs.getHeight(), inputs.getHeight()); + CHECK_EQ(outputs.getWidth(), inputs.getWidth()); + CHECK_EQ(outputs.getHeight(), denoms.getHeight()); + CHECK_EQ(outputs.getWidth(), denoms.getWidth()); + + size_t numSample = inputs.getHeight(); + size_t numCols = inputs.getWidth(); + CHECK(imgSizeH * imgSizeW * channels == numCols); + + real* inputsData = inputs.getData(); + real* denomsData = denoms.getData(); + real* outputsData = outputs.getData(); + + size_t imageSize = numSample * imgSizeH * imgSizeW; + int blockSize = 1024; + int gridSize = (imageSize + 1024 - 1) / 1024; + KeCMRNormFillScale<<>> + (imageSize, inputsData, denomsData, + channels, imgSizeH, imgSizeW, sizeX, scale); + + size_t inputSize = numSample * imgSizeH * imgSizeW *channels; + blockSize = 1024; + gridSize = (inputSize + 1024 - 1) / 1024; + KeCMRNormOutput<<>> + (inputSize, inputsData, denomsData, -pow, outputsData); + + CHECK_SYNC("CrossMapNormalFwd"); +} + +__global__ void KeCMRNormDiff(size_t imageSize, const real* bottom_data, + const real* top_data, const real* scale, + const real* top_diff, size_t channels, + size_t height, size_t width, size_t size, + real negative_beta, real cache_ratio, + real* bottom_diff ) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < imageSize) { + const int w = idx % width; + const int h = (idx / width) % height; + const int n = idx / width / height; + const int offset = (n * channels * height + h) * width + w; + bottom_data += offset; + top_data += offset; + scale += offset; + top_diff += offset; + bottom_diff += offset; + + const int step = height * width; + const int pre_pad = size - (size + 1) / 2; + const int post_pad = size - pre_pad - 1; + + int index = 0; + real accum = 0; + while (index < channels + post_pad) { + if (index < channels) { + accum += top_diff[index * step] * top_data[index * step] / + scale[index * step]; + } + if (index >= size) { + accum -= top_diff[(index - size) * step] * + top_data[(index - size) * step] / scale[(index - size) * step]; + } + if (index >= post_pad) { + bottom_diff[(index - post_pad) * step] += + top_diff[(index - post_pad) * step] * + pow(scale[(index - post_pad) * step], negative_beta) - cache_ratio * + bottom_data[(index - post_pad) * step] * accum; + } + ++index; + } + } +} + +template <> +void CrossMapNormalGrad::operator()(GpuMatrix& inputsGrad, + GpuMatrix& inputsValue, + GpuMatrix& outputsGrad, + GpuMatrix& outputsValue, + GpuMatrix& denoms, + size_t channels, + size_t imgSizeH, + size_t imgSizeW, + size_t sizeX, + real scale, + real pow) { + CHECK(inputsGrad.isContiguous()); + CHECK(outputsGrad.isContiguous()); + CHECK(denoms.isContiguous()); + CHECK(inputsValue.isContiguous()); + CHECK(outputsValue.isContiguous()); + CHECK_EQ(inputsGrad.getHeight(), outputsGrad.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), outputsGrad.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), denoms.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), denoms.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), inputsValue.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), inputsValue.getWidth()); + CHECK_EQ(inputsGrad.getHeight(), outputsValue.getHeight()); + CHECK_EQ(inputsGrad.getWidth(), outputsValue.getWidth()); + + size_t numSample = inputsGrad.getHeight(); + size_t numCols = inputsGrad.getWidth(); + CHECK(imgSizeH * imgSizeW * channels == numCols); + + size_t imageSize = numSample * imgSizeH * imgSizeW; + real* inputsGradData = inputsGrad.getData(); + real* inputsData = inputsValue.getData(); + real* denomsData = denoms.getData(); + real* outputsGradData = outputsGrad.getData(); + real* outputsData = outputsValue.getData(); + + int blockSize = 1024; + int gridSize = (imageSize + 1024 - 1) / 1024; + KeCMRNormDiff <<>> + (imageSize, inputsData, outputsData, denomsData, outputsGradData, channels, + imgSizeH, imgSizeW, sizeX, -pow, 2.0f * pow * scale, inputsGradData); + CHECK_SYNC("KeCMRNormDiff"); +} + +} // namespace paddle diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 9bb1fdbdab83a69037211f2677b4bfe3395a04a5..8d7a4fb94d0a1df1849c78d6c192504db7fc121c 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1280,11 +1280,25 @@ void testCrossMapNormalFwd( inputsGpu.copyFrom(inputs); outputsGpu.copyFrom(outputs); - CrossMapNormal cross; - cross( + CrossMapNormal cpuCross; + cpuCross( outputs, denoms, inputs, channels, imgSizeH, imgSizeW, sizeX, scale, pow); + + CrossMapNormal gpuCross; + gpuCross(outputsGpu, + denomsGpu, + inputsGpu, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + +#if 0 outputsGpu.crossMapNormalFwd( inputsGpu, imgSizeH, imgSizeW, denomsGpu, channels, sizeX, scale, pow); +#endif TensorCheckErr(outputs, outputsGpu); TensorCheckErr(denoms, denomsGpu); @@ -1339,29 +1353,31 @@ void testCrossMapNormalBwd( outputsValueGpu.copyFrom(outputsValue); inputsGradGpu.copyFrom(inputsGrad); - CrossMapNormalGrad cross; - cross(inputsGrad, - inputsValue, - outputsGrad, - outputsValue, - denoms, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); - - inputsGradGpu.crossMapNormalBwd(outputsGradGpu, - denomsGpu, - inputsValueGpu, - outputsValueGpu, - channels, - imgSizeH, - imgSizeW, - sizeX, - scale, - pow); + CrossMapNormalGrad cpuCross; + cpuCross(inputsGrad, + inputsValue, + outputsGrad, + outputsValue, + denoms, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); + + CrossMapNormalGrad gpuCross; + gpuCross(inputsGradGpu, + inputsValueGpu, + outputsGradGpu, + outputsValueGpu, + denomsGpu, + channels, + imgSizeH, + imgSizeW, + sizeX, + scale, + pow); TensorCheckErr(inputsGrad, inputsGradGpu); }