diff --git a/paddle/function/DepthwiseConvOp.cpp b/paddle/function/DepthwiseConvOp.cpp index d1430239bc17a2ae2593f6b2b4e227d81df4f22a..9180c19b118251b8312a32355fcd1917133a9d10 100644 --- a/paddle/function/DepthwiseConvOp.cpp +++ b/paddle/function/DepthwiseConvOp.cpp @@ -99,8 +99,7 @@ public: ConvFunctionBase::init(config); } - virtual void check(const BufferArgs& inputs, - const BufferArgs& outputs) override { + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { const TensorShape& input = inputs[0].shape(); const TensorShape& filter = inputs[1].shape(); const TensorShape& output = outputs[0].shape(); @@ -162,8 +161,7 @@ public: ConvFunctionBase::init(config); } - virtual void check(const BufferArgs& inputs, - const BufferArgs& outputs) override { + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { const TensorShape& output = inputs[0].shape(); const TensorShape& filter = inputs[1].shape(); const TensorShape& input = outputs[0].shape(); @@ -225,8 +223,7 @@ public: ConvFunctionBase::init(config); } - virtual void check(const BufferArgs& inputs, - const BufferArgs& outputs) override { + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { const TensorShape& output = inputs[0].shape(); const TensorShape& input = inputs[1].shape(); const TensorShape& filter = outputs[0].shape(); diff --git a/paddle/function/DepthwiseConvOpGpu.cu b/paddle/function/DepthwiseConvOpGpu.cu index 51aed9ffcf079496fd6dc616ef47e6c0178bd8c3..bb7b97df5a6d353d5ea27fe6418af756c7bdb39a 100644 --- a/paddle/function/DepthwiseConvOpGpu.cu +++ b/paddle/function/DepthwiseConvOpGpu.cu @@ -20,58 +20,58 @@ namespace paddle { // CUDA kernel to compute the depthwise convolution forward pass template -__global__ +__global__ void ConvolutionDepthwiseForward(const int nthreads, const T* const inputData, const T* const filterData, const int batchSize, const int outputChannels, const int outputHeight, - const int outputWidth,const int inputChannels, const int inputHeight, const int inputWidth, - const int filterMultiplier, const int filterHeight, const int filterWidth, const int strideH, - const int strideW, const int paddingH, const int paddingW, - T* const outputData) { + const int outputWidth, const int inputChannels, const int inputHeight, + const int inputWidth, const int filterMultiplier, const int filterHeight, + const int filterWidth, const int strideH, const int strideW, + const int paddingH, const int paddingW, T* const outputData) { int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; - - if(index < nthreads) { + + if (index < nthreads) { const int batch = index / outputChannels / outputHeight / outputWidth; const int c_out = (index / outputHeight / outputWidth) % outputChannels; const int h_out = (index / outputWidth) % outputHeight; const int w_out = index % outputWidth; - const int c_in = c_out / filterMultiplier; + const int c_in = c_out / filterMultiplier; const T* weight = filterData + c_out * filterHeight * filterWidth; T value = 0; const int h_in_start = -paddingH + h_out * strideH; const int w_in_start = -paddingW + w_out * strideW; const int h_in_end = -paddingH + h_out * strideH + filterHeight - 1; const int w_in_end = -paddingW + w_out * strideW + filterWidth - 1; - if ((h_in_start >= 0) && (h_in_end < inputHeight) - &&(w_in_start >= 0) && (w_in_end < inputWidth)) { + if ((h_in_start >= 0) && (h_in_end < inputHeight) + && (w_in_start >= 0) && (w_in_end < inputWidth)) { for (int kh = 0; kh < filterHeight; ++kh) { for (int kw = 0; kw < filterWidth; ++kw) { const int h_in = -paddingH + h_out * strideH + kh; const int w_in = -paddingW + w_out * strideW + kw; - const int offset = ((batch * inputChannels + c_in) * inputHeight + h_in) - * inputWidth + w_in; + const int offset = ((batch * inputChannels + c_in) + * inputHeight + h_in) * inputWidth + w_in; value += (*weight) * inputData[offset]; ++weight; - } - } - }else{ + } + } + } else { for (int kh = 0; kh < filterHeight; ++kh) { for (int kw = 0; kw < filterWidth; ++kw) { const int h_in = -paddingH + h_out * strideH + kh; const int w_in = -paddingW + w_out * strideW + kw; if ((h_in >= 0) && (h_in < inputHeight) && (w_in >= 0) && (w_in < inputWidth)) { - const int offset = ((batch * inputChannels + c_in) * inputHeight + h_in) - * inputWidth + w_in; + const int offset = ((batch * inputChannels + c_in) + * inputHeight + h_in) * inputWidth + w_in; value += (*weight) * inputData[offset]; } ++weight; } } - } + } outputData[index] = value; } } @@ -82,21 +82,21 @@ __global__ void ConvolutionDepthwiseInputBackward(const int nthreads, const T* const top_diff, const T* const weight_data, const int num, const int outputChannels, const int outputHeight, - const int outputWidth,const int inputChannels, const int inputHeight, const int inputWidth, - const int filterMultiplier, const int filterHeight, const int filterWidth, const int strideH, - const int strideW, const int paddingH, const int paddingW, - T* const bottom_diff) { + const int outputWidth, const int inputChannels, const int inputHeight, + const int inputWidth, const int filterMultiplier, const int filterHeight, + const int filterWidth, const int strideH, const int strideW, + const int paddingH, const int paddingW, T* const bottom_diff) { int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; - if(index < nthreads) { + if (index < nthreads) { const int batch = index / inputChannels / inputHeight / inputWidth; const int c_in = (index / inputHeight / inputWidth) % inputChannels; const int h_in = (index / inputWidth) % inputHeight; const int w_in = index % inputWidth; - const int c_out_start = c_in * filterMultiplier; + const int c_out_start = c_in * filterMultiplier; T value = 0; - for(int c_out = c_out_start; c_out < c_out_start + filterMultiplier; c_out ++){ - //weight bixu c_out + for (int c_out = c_out_start; + c_out < c_out_start + filterMultiplier; c_out ++) { const T* weight = weight_data + c_out * filterHeight * filterWidth; for (int kh = 0; kh < filterHeight; ++kh) { for (int kw = 0; kw < filterWidth; ++kw) { @@ -105,11 +105,12 @@ void ConvolutionDepthwiseInputBackward(const int nthreads, if (((h_out_s % strideH) == 0) && ((w_out_s % strideW) == 0)) { const int h_out = h_out_s / strideH; const int w_out = w_out_s / strideW; - // TODO(zhaolong) : the 'if' affect the effectiveness, it needs to optimize + // TODO(zhaolong) : the 'if' affect the effectiveness, + // it needs to optimize if ((h_out >= 0) && (h_out < outputHeight) && (w_out >= 0) && (w_out < outputWidth)) { - const int offset = ((batch * outputChannels + c_out) * outputHeight + h_out) - * outputWidth + w_out; + const int offset = ((batch * outputChannels + c_out) + * outputHeight + h_out) * outputWidth + w_out; value += (*weight) * top_diff[offset]; } } @@ -127,10 +128,10 @@ __global__ void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads, const T* const top_diff, const T* const inputData, const int num, const int outputChannels, const int outputHeight, - const int outputWidth, const int inputChannels, const int inputHeight, const int inputWidth, - const int filterMultiplier, const int filterHeight, const int filterWidth, const int strideH, - const int strideW, const int paddingH, const int paddingW, - T* const buffer_data) { + const int outputWidth, const int inputChannels, const int inputHeight, + const int inputWidth, const int filterMultiplier, const int filterHeight, + const int filterWidth, const int strideH, const int strideW, + const int paddingH, const int paddingW, T* const buffer_data) { int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; if (index < nthreads) { @@ -143,13 +144,14 @@ void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads, const int w_in = -paddingW + w_out * strideW + kw; if ((h_in >= 0) && (h_in < inputHeight) && (w_in >= 0) && (w_in < inputWidth)) { - const int c_out = index / filterHeight / filterWidth / outputHeight / outputWidth; - const int c_in = c_out / filterMultiplier; + const int c_out = index / + (filterHeight * filterWidth * outputHeight * outputWidth); + const int c_in = c_out / filterMultiplier; const int batch = num_i; - const int top_offset = ((batch * outputChannels + c_out) * outputHeight + h_out) - * outputWidth + w_out; - const int bottom_offset = ((batch * inputChannels + c_in) * inputHeight + h_in) - * inputWidth + w_in; + const int top_offset = ((batch * outputChannels + c_out) * + outputHeight + h_out) * outputWidth + w_out; + const int bottom_offset = ((batch * inputChannels + c_in) + * inputHeight + h_in) * inputWidth + w_in; buffer_data[index] = top_diff[top_offset] * inputData[bottom_offset]; } else { buffer_data[index] = 0; @@ -160,13 +162,13 @@ void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads, template class DepthwiseConvFunctor{ public: - void operator()(const T* inputData, + void operator()(const T* inputData, const T* filterData, int batchSize, int outputChannels, int outputHeight, int outputWidth, - int inputChannels, + int inputChannels, int inputHeight, int inputWidth, int filterMultiplier, @@ -177,7 +179,6 @@ public: int paddingH, int paddingW, T* outputData){ - int outputSize = batchSize * outputChannels * outputHeight * outputWidth; size_t blocks = (outputSize + 1024 -1) / 1024; @@ -188,14 +189,14 @@ public: ConvolutionDepthwiseForward <<< grid, threads, 0, STREAM_DEFAULT >>>( - outputSize, - inputData, + outputSize, + inputData, filterData, batchSize, outputChannels, outputHeight, outputWidth, - inputChannels, + inputChannels, inputHeight, inputWidth, filterMultiplier, @@ -229,7 +230,6 @@ public: int paddingH, int paddingW, T* inputGrad){ - int inputSize = batchSize * inputChannels * inputHeight * inputWidth; size_t blocks = (inputSize + 1024 -1) / 1024; @@ -249,7 +249,7 @@ public: outputChannels, outputHeight, outputWidth, - inputChannels, + inputChannels, inputHeight, inputWidth, filterMultiplier, @@ -284,17 +284,18 @@ public: int paddingW, T* colData, T* filterGrad){ - - int colDataSize = outputChannels * filterHeight * filterWidth * outputHeight * outputWidth; + int colDataSize = outputChannels * filterHeight * filterWidth + * outputHeight * outputWidth; size_t blocks = (colDataSize + 1024 -1) / 1024; size_t blockX = 512; size_t blockY = (blocks+512-1)/512; dim3 threads(1024, 1); dim3 grid(blockX, blockY); - BaseMatrix filterGradMatrix(outputChannels * filterHeight * filterWidth, 1, filterGrad, false, true); + BaseMatrix filterGradMatrix(outputChannels * filterHeight * filterWidth, + 1, filterGrad, false, true); - for(int i = 0; i < batchSize; i++) { + for (int i = 0; i < batchSize; i++) { ConvolutionDepthwiseFilterBackward <<< grid, threads, 0, STREAM_DEFAULT >>>( i, @@ -305,24 +306,23 @@ public: outputChannels, outputHeight, outputWidth, - inputChannels, + inputChannels, inputHeight, inputWidth, - filterMultiplier, + filterMultiplier, filterHeight, filterWidth, strideH, strideW, paddingH, paddingW, - colData - ); + colData); int K = outputHeight * outputWidth; int M = colDataSize / K; BaseMatrix colMatrix(M, K, colData, false, true); - filterGradMatrix.sumRows(colMatrix, (T)1.0, (T)1.0); - } + filterGradMatrix.sumRows(colMatrix, (T)1.0, (T)1.0); + } } }; @@ -330,7 +330,7 @@ public: template class DepthwiseConvGradInputFunctor; template class DepthwiseConvFunctor; template class DepthwiseConvGradFilterFunctor; -#else +#else template class DepthwiseConvGradInputFunctor; template class DepthwiseConvFunctor; template class DepthwiseConvGradFilterFunctor;