From d99faf310865fe500083f0db53063e53efd2731f Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 6 Jun 2017 12:51:30 +0800 Subject: [PATCH] Add the calculation implementation of GemmConvGradInputFunction. --- paddle/function/ConvOpTest.cpp | 7 +- paddle/function/GemmConvOp.cpp | 142 +++++++++++++++++++++++++++---- paddle/function/GemmConvOp.h | 18 ++++ paddle/function/GemmConvOpGpu.cu | 93 ++++++++++++++++++++ 4 files changed, 242 insertions(+), 18 deletions(-) diff --git a/paddle/function/ConvOpTest.cpp b/paddle/function/ConvOpTest.cpp index e2997df0128..2fa0b365465 100644 --- a/paddle/function/ConvOpTest.cpp +++ b/paddle/function/ConvOpTest.cpp @@ -78,12 +78,10 @@ public: test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, output)); test.run(); } else if (type == BACKWARD_INPUT_TEST) { -#if 0 test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter)); test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input)); test.run(); -#endif } else if (type == BACKWARD_FILTER_TEST) { test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input)); @@ -111,6 +109,11 @@ TEST(Forward, GEMM2) { "GemmConv-CPU", "GemmConv-GPU", FORWARD_TEST); } +TEST(BackwardInput, GEMM) { + ConvolutionTest test( + "GemmConvGradInput-CPU", "GemmConvGradInput-GPU", BACKWARD_INPUT_TEST); +} + TEST(BackwardFilter, GEMM) { ConvolutionTest test( "GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", BACKWARD_FILTER_TEST); diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 414c7a885b6..bb7bc647792 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -44,22 +44,62 @@ public: for (int c = 0; c < channelsCol; ++c) { int wOffset = c % filterWidth; int hOffset = (c / filterWidth) % filterHeight; - int c_im = c / filterHeight / filterWidth; + int c_im = c / filterWidth / filterHeight; for (int h = 0; h < outputHeight; ++h) { for (int w = 0; w < outputWidth; ++w) { - // no c_im*height to Exclude the channel number - int imgRowIdx = h * strideHeight + hOffset; - int imgColIdx = w * strideWidth + wOffset; - if ((imgRowIdx - paddingHeight) < 0 || - (imgRowIdx - paddingHeight) >= inputHeight || - (imgColIdx - paddingWidth) < 0 || - (imgColIdx - paddingWidth) >= inputWidth) { + int imRowIdx = h * strideHeight + hOffset; + int imColIdx = w * strideWidth + wOffset; + if ((imRowIdx - paddingHeight) < 0 || + (imRowIdx - paddingHeight) >= inputHeight || + (imColIdx - paddingWidth) < 0 || + (imColIdx - paddingWidth) >= inputWidth) { colData[(c * outputHeight + h) * outputWidth + w] = T(0); } else { - imgRowIdx += c_im * inputHeight - paddingHeight; - imgColIdx -= paddingWidth; + imRowIdx += c_im * inputHeight - paddingHeight; + imColIdx -= paddingWidth; colData[(c * outputHeight + h) * outputWidth + w] = - imData[imgRowIdx * inputWidth + imgColIdx]; + imData[imRowIdx * inputWidth + imColIdx]; + } + } + } + } + } +}; + +template +class Col2ImFunctor { +public: + void operator()(const T* colData, + int inputChannels, + int inputHeight, + int inputWidth, + int filterHeight, + int filterWidth, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth, + int outputHeight, + int outputWidth, + T* imData) { + int channelsCol = inputChannels * filterHeight * filterWidth; + + for (int c = 0; c < channelsCol; ++c) { + int wOffset = c % filterWidth; + int hOffset = (c / filterWidth) % filterHeight; + int c_im = c / filterWidth / filterHeight; + for (int h = 0; h < outputHeight; ++h) { + for (int w = 0; w < outputWidth; ++w) { + int imRowIdx = h * strideHeight + hOffset; + int imColIdx = w * strideWidth + wOffset; + if ((imRowIdx - paddingHeight) >= 0 && + (imRowIdx - paddingHeight) < inputHeight && + (imColIdx - paddingWidth) >= 0 && + (imColIdx - paddingWidth) < inputWidth) { + imRowIdx += c_im * inputHeight - paddingHeight; + imColIdx -= paddingWidth; + imData[imRowIdx * inputWidth + imColIdx] += + colData[(c * outputHeight + h) * outputWidth + w]; } } } @@ -171,10 +211,74 @@ public: void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numOutputs_, outputs.size()); - const TensorShape& outputGrad = inputs[0].shape(); + // CHECK_EQ(outputs[0].getArgType(), ADD_TO); + const TensorShape& output = inputs[0].shape(); const TensorShape& filter = inputs[1].shape(); - const TensorShape& inputGrad = outputs[0].shape(); - check(inputGrad, filter, outputGrad); + const TensorShape& input = outputs[0].shape(); + check(input, filter, output); + + size_t batchSize = input[0]; + size_t inputChannels = input[1]; + size_t inputHeight = input[2]; + size_t inputWidth = input[3]; + size_t filterHeight = filter[2]; + size_t filterWidth = filter[3]; + size_t outputChannels = output[1]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; + + real* outputGrad = inputs[0].data(); + real* filterData = inputs[1].data(); + real* inputGrad = outputs[0].data(); + + size_t size = inputChannels / groups_ * filterHeight * filterWidth * + outputHeight * outputWidth; + resizeBuffer(size); + real* colData = reinterpret_cast(memory_->getBuf()); + + Col2ImFunctor col2im; + GemmFunctor gemm; + size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth; + size_t outputOffset = + (outputChannels / groups_) * outputHeight * outputWidth; + size_t filterOffset = filter.getElements() / groups_; + + for (size_t i = 0; i < batchSize; i++) { + for (size_t g = 0; g < groups_; g++) { + int K = outputChannels / groups_; + int N = outputHeight * outputWidth; + int M = inputChannels / groups_ * filterHeight * filterWidth; + gemm(CblasTrans, + CblasNoTrans, + M, + N, + K, + 1.0f, + filterData + g * filterOffset, + M, + outputGrad + g * outputOffset, + N, + 0.0f, + colData, + N); + + col2im(colData, + inputChannels / groups_, + inputHeight, + inputWidth, + filterHeight, + filterWidth, + strideH(), + strideW(), + paddingH(), + paddingW(), + outputHeight, + outputWidth, + inputGrad + g * inputOffset); + } + inputGrad += inputChannels * inputHeight * inputWidth; + outputGrad += outputChannels * outputHeight * outputWidth; + } } }; @@ -191,12 +295,18 @@ public: void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numOutputs_, outputs.size()); - CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); const TensorShape& output = inputs[0].shape(); const TensorShape& input = inputs[1].shape(); const TensorShape& filter = outputs[0].shape(); check(input, filter, output); + real beta; + if (outputs[0].getArgType() == ADD_TO) { + beta = 1.0; + } else { + beta = 0.0; + } + size_t batchSize = input[0]; size_t inputChannels = input[1]; size_t inputHeight = input[2]; @@ -251,7 +361,7 @@ public: K, colData, K, - 1.0f, + i == 0 ? beta : 1.0f, filterGrad + g * filterOffset, N); } diff --git a/paddle/function/GemmConvOp.h b/paddle/function/GemmConvOp.h index 652a64afba4..9f11cce597a 100644 --- a/paddle/function/GemmConvOp.h +++ b/paddle/function/GemmConvOp.h @@ -41,4 +41,22 @@ public: T* colData); }; +template +class Col2ImFunctor { +public: + void operator()(const T* colData, + int inputChannels, + int inputHeight, + int inputWidth, + int filterHeight, + int filterWidth, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth, + int outputHeight, + int outputWidth, + T* imData); +}; + } // namespace paddle diff --git a/paddle/function/GemmConvOpGpu.cu b/paddle/function/GemmConvOpGpu.cu index 06b9904261c..2a1795ff0fb 100644 --- a/paddle/function/GemmConvOpGpu.cu +++ b/paddle/function/GemmConvOpGpu.cu @@ -87,7 +87,100 @@ public: } }; +template +__global__ +void col2im(size_t n, const T* data_col, size_t height, + size_t width, size_t channels, + size_t blockH, size_t blockW, + size_t strideH, size_t strideW, + size_t paddingH, size_t paddingW, + size_t height_col, size_t width_col, + T* data_im) { + size_t index = + (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + if (index < n) { + T val = 0; + int w = int(index % width); + int h = int((index / width) % height); + int c = int(index / (width * height)); + if ((w - (int)paddingW) >= 0 && + (w - (int)paddingW) < (width-2 * paddingW) && + (h - (int)paddingH) >= 0 && + (h - paddingH) < (height - 2 * paddingH)) { + // compute the start and end of the output + int w_col_start = + (w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1; + int w_col_end = + min((int)(w / (int)strideW + 1), (int)(width_col)); + int h_col_start = + (h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1; + int h_col_end = min(int(h / strideH + 1), int(height_col)); + for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { + for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { + // the col location: [c * width * height + h_out, w_out] + int c_col = int(c * blockH* blockW) + \ + (h - h_col * (int)strideH) * (int)blockW + + (w - w_col * (int)strideW); + val += data_col[(c_col * height_col + h_col) * width_col + w_col]; + } + } + h -= paddingH; + w -= paddingW; + data_im[c*((width-2*paddingW) * (height-2*paddingH)) + + h*(width-2*paddingW) + w] += val; + } + } +} + +template +class Col2ImFunctor { +public: + void operator()(const T* colData, + int inputChannels, + int inputHeight, + int inputWidth, + int filterHeight, + int filterWidth, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth, + int outputHeight, + int outputWidth, + T* imData) { + size_t numKernels = inputChannels * (inputHeight + 2*paddingHeight) + * (inputWidth + 2*paddingWidth); + + size_t blocks = (numKernels + 1024 -1) / 1024; + size_t blockX = 512; + size_t blockY = (blocks+512-1)/512; + dim3 threads(1024, 1); + dim3 grid(blockX, blockY); + + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + col2im<<< grid, threads, 0, STREAM_DEFAULT >>> + (numKernels, + colData, + inputHeight + 2*paddingHeight, + inputWidth + 2*paddingWidth, + inputChannels, + filterHeight, + filterWidth, + strideHeight, + strideWidth, + paddingHeight, + paddingWidth, + outputHeight, + outputWidth, + imData); + CHECK_SYNC("Col2ImFunctor GPU failed"); + } +}; + template class Im2ColFunctor; template class Im2ColFunctor; +template class Col2ImFunctor; +template class Col2ImFunctor; } // namespace paddle -- GitLab