From f8ef8c174c442f14662a94e59fcda6587498c8a5 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 13 Jun 2017 21:07:20 +0800 Subject: [PATCH] Add the GPU version implementation of ImageExpandGrad function. --- paddle/function/Im2ColOpGpu.cu | 107 +++++++++++++++++---- paddle/function/ImageExpandOp.cpp | 1 + paddle/gserver/layers/BlockExpandLayer.cpp | 33 +++---- 3 files changed, 103 insertions(+), 38 deletions(-) diff --git a/paddle/function/Im2ColOpGpu.cu b/paddle/function/Im2ColOpGpu.cu index 1dac2585db7..bddd8ffc7c0 100644 --- a/paddle/function/Im2ColOpGpu.cu +++ b/paddle/function/Im2ColOpGpu.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "Im2Col.h" +#include "hl_device_functions.cuh" namespace paddle { @@ -25,30 +26,29 @@ void im2colOCF(const T* imData, T* colData, int strideHeight, int strideWidth, int paddingHeight, int paddingWidth, int outputHeight, int outputWidth) { - int idx = threadIdx.x; - int idy = threadIdx.y; int swId = blockIdx.x; int shId = blockIdx.y; - for (int channelId = threadIdx.z; channelId < inputChannels; channelId += blockDim.z) { - int widthOffset = idx + swId * strideWidth - paddingWidth; - int heightOffset = idy + shId * strideHeight - paddingHeight; - int imOffset = widthOffset + heightOffset * inputWidth - + channelId * inputHeight * inputWidth; - - int colOffset = idx + idy * filterWidth - + channelId * filterHeight * filterWidth - + (shId * outputWidth + swId) - * (inputChannels * filterHeight * filterWidth); - - if (idx < filterWidth && idy < filterHeight) { - if (heightOffset >= inputHeight || heightOffset < 0 || - widthOffset >= inputWidth || widthOffset < 0) { - colData[colOffset] = T(0); - } else { - colData[colOffset] = imData[imOffset]; + for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) { + for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) { + int widthOffset = idx + swId * strideWidth - paddingWidth; + int heightOffset = idy + shId * strideHeight - paddingHeight; + int imOffset = widthOffset + heightOffset * inputWidth + + channelId * inputHeight * inputWidth; + + int colOffset = idx + idy * filterWidth + + channelId * filterHeight * filterWidth + + (shId * outputWidth + swId) + * (inputChannels * filterHeight * filterWidth); + + if (heightOffset >= inputHeight || heightOffset < 0 || + widthOffset >= inputWidth || widthOffset < 0) { + colData[colOffset] = T(0); + } else { + colData[colOffset] = imData[imOffset]; + } } } } @@ -105,6 +105,41 @@ public: } }; +template +__global__ +void col2imOCF(T* imData, const T* colData, + int inputChannels, + int inputHeight, int inputWidth, + int filterHeight, int filterWidth, + int strideHeight, int strideWidth, + int paddingHeight, int paddingWidth, + int outputHeight, int outputWidth) { + int swId = blockIdx.x; + int shId = blockIdx.y; + for (int channelId = threadIdx.z; + channelId < inputChannels; + channelId += blockDim.z) { + for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) { + for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) { + int widthOffset = idx + swId * strideWidth - paddingWidth; + int heightOffset = idy + shId * strideHeight - paddingHeight; + int imOffset = widthOffset + heightOffset * inputWidth + + channelId * inputHeight * inputWidth; + + int colOffset = idx + idy * filterWidth + + channelId * filterHeight * filterWidth + + (shId * outputWidth + swId) + * (inputChannels * filterHeight * filterWidth); + + if (heightOffset >= 0 && heightOffset < inputHeight && + widthOffset >= 0 && widthOffset < inputWidth) { + paddle::paddleAtomicAdd(imData + imOffset, colData[colOffset]); + } + } + } + } +} + /* * imShape = [inputChannels, inputHeight, inputWidth] * colShape = @@ -121,10 +156,44 @@ public: int strideWidth, int paddingHeight, int paddingWidth) { + int inputChannels = imShape[0]; + int inputHeight = imShape[1]; + int inputWidth = imShape[2]; + int filterHeight = colShape[3]; + int filterWidth = colShape[4]; + int outputHeight = colShape[0]; + int outputWidth = colShape[1]; + + int blockDimX = 0; + int blockDimY = 0; + if (filterHeight <= 4 && filterWidth <= 4) { + blockDimX = 4; + blockDimY = 4; + } else if (filterHeight <= 8 && filterWidth <= 8) { + blockDimX = 8; + blockDimY = 8; + } else if (filterHeight <= 16 && filterWidth <= 16) { + blockDimX = 16; + blockDimY = 16; + } else { + blockDimX = 32; + blockDimY = 32; + } + + int blockDimZ = 1024 / blockDimX / blockDimY; + dim3 threads(blockDimX, blockDimY, std::min(blockDimZ, inputChannels)); + dim3 grid(outputWidth, outputHeight); + col2imOCF<<< grid, threads, 0, STREAM_DEFAULT >>> + (imData, colData, inputChannels, inputHeight, inputWidth, + filterHeight, filterWidth, strideHeight, strideWidth, + paddingHeight, paddingWidth, outputHeight, outputWidth); + CHECK_SYNC("Col2ImFunctor GPU failed"); } }; template class Im2ColFunctor; template class Im2ColFunctor; +template class Col2ImFunctor; +template class Col2ImFunctor; } // namespace paddle diff --git a/paddle/function/ImageExpandOp.cpp b/paddle/function/ImageExpandOp.cpp index fe4c8fefcf5..f227f6d0e10 100644 --- a/paddle/function/ImageExpandOp.cpp +++ b/paddle/function/ImageExpandOp.cpp @@ -293,6 +293,7 @@ REGISTER_TYPED_FUNC(ImageExpand, CPU, ImageExpandForward); REGISTER_TYPED_FUNC(ImageExpandGrad, CPU, ImageExpandBackward); #ifndef PADDLE_ONLY_CPU REGISTER_TYPED_FUNC(ImageExpand, GPU, ImageExpandForward); +REGISTER_TYPED_FUNC(ImageExpandGrad, GPU, ImageExpandBackward); #endif } // namespace paddle diff --git a/paddle/gserver/layers/BlockExpandLayer.cpp b/paddle/gserver/layers/BlockExpandLayer.cpp index 1889b347c2d..a5e644a4ae3 100644 --- a/paddle/gserver/layers/BlockExpandLayer.cpp +++ b/paddle/gserver/layers/BlockExpandLayer.cpp @@ -46,14 +46,12 @@ bool BlockExpandLayer::init(const LayerMap& layerMap, .set("strides", strides) .set("paddings", paddings) .set("blocks", blocks)); - if (!useGpu_) { - createFunction(backward_, - "ImageExpandGrad", - FuncConfig() - .set("strides", strides) - .set("paddings", paddings) - .set("blocks", blocks)); - } + createFunction(backward_, + "ImageExpandGrad", + FuncConfig() + .set("strides", strides) + .set("paddings", paddings) + .set("blocks", blocks)); return true; } @@ -110,14 +108,16 @@ void BlockExpandLayer::forward(PassType passType) { } void BlockExpandLayer::backward(const UpdateCallback& callback) { - size_t blockNum = outputH_ * outputW_; - size_t blockSize = blockH_ * blockW_ * channels_; /* Calculate the input layers error */ - MatrixPtr preGrad = inputLayers_[0]->getOutputGrad(); - if (!preGrad) { - return; + if (getInputGrad(0)) { + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getOutputGrad(), outputShape_); + outputs.addArg(*getInputGrad(0), inputShape_, ADD_TO); + backward_[0]->calc(inputs, outputs); } +#if 0 if (useGpu_) { MatrixPtr grad = getOutputGrad(); MatrixPtr gradTrans = Matrix::create(blockSize, blockNum, false, useGpu_); @@ -155,13 +155,8 @@ void BlockExpandLayer::backward(const UpdateCallback& callback) { 1.0, 1.0); } - } else { - BufferArgs inputs; - BufferArgs outputs; - inputs.addArg(*getOutputGrad(), outputShape_); - outputs.addArg(*getInputGrad(0), inputShape_, ADD_TO); - backward_[0]->calc(inputs, outputs); } +#endif } } // namespace paddle -- GitLab