提交 11588b36 编写于 作者: X xzl

support inputchannels != outputchannels of depthwiseconv

上级 02e04b44
...@@ -30,6 +30,7 @@ public: ...@@ -30,6 +30,7 @@ public:
int inputChannels, int inputChannels,
int inputHeight, int inputHeight,
int inputWidth, int inputWidth,
int filterMultiplier,
int filterHeight, int filterHeight,
int filterWidth, int filterWidth,
int strideH, int strideH,
...@@ -53,6 +54,7 @@ public: ...@@ -53,6 +54,7 @@ public:
int inputChannels, int inputChannels,
int inputHeight, int inputHeight,
int inputWidth, int inputWidth,
int filterMultiplier,
int filterHeight, int filterHeight,
int filterWidth, int filterWidth,
int strideH, int strideH,
...@@ -75,6 +77,7 @@ public: ...@@ -75,6 +77,7 @@ public:
int inputChannels, int inputChannels,
int inputHeight, int inputHeight,
int inputWidth, int inputWidth,
int filterMultiplier,
int filterHeight, int filterHeight,
int filterWidth, int filterWidth,
int strideH, int strideH,
...@@ -122,6 +125,7 @@ public: ...@@ -122,6 +125,7 @@ public:
size_t outputChannels = output[1]; size_t outputChannels = output[1];
size_t outputHeight = output[2]; size_t outputHeight = output[2];
size_t outputWidth = output[3]; size_t outputWidth = output[3];
size_t filterMultiplier = outputChannels / groups_;
real* inputData = inputs[0].data<real>(); real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>(); real* filterData = inputs[1].data<real>();
...@@ -137,6 +141,7 @@ public: ...@@ -137,6 +141,7 @@ public:
inputChannels, inputChannels,
inputHeight, inputHeight,
inputWidth, inputWidth,
filterMultiplier,
filterHeight, filterHeight,
filterWidth, filterWidth,
strideH(), strideH(),
...@@ -183,6 +188,7 @@ public: ...@@ -183,6 +188,7 @@ public:
size_t outputChannels = output[1]; size_t outputChannels = output[1];
size_t outputHeight = output[2]; size_t outputHeight = output[2];
size_t outputWidth = output[3]; size_t outputWidth = output[3];
size_t filterMultiplier = outputChannels / groups_;
real* outputGrad = inputs[0].data<real>(); real* outputGrad = inputs[0].data<real>();
real* filterData = inputs[1].data<real>(); real* filterData = inputs[1].data<real>();
...@@ -198,6 +204,7 @@ public: ...@@ -198,6 +204,7 @@ public:
inputChannels, inputChannels,
inputHeight, inputHeight,
inputWidth, inputWidth,
filterMultiplier,
filterHeight, filterHeight,
filterWidth, filterWidth,
strideH(), strideH(),
...@@ -243,13 +250,14 @@ public: ...@@ -243,13 +250,14 @@ public:
size_t outputChannels = output[1]; size_t outputChannels = output[1];
size_t outputHeight = output[2]; size_t outputHeight = output[2];
size_t outputWidth = output[3]; size_t outputWidth = output[3];
size_t filterMultiplier = outputChannels / groups_;
real* outputGrad = inputs[0].data<real>(); real* outputGrad = inputs[0].data<real>();
real* inputData = inputs[1].data<real>(); real* inputData = inputs[1].data<real>();
real* filterGrad = outputs[0].data<real>(); real* filterGrad = outputs[0].data<real>();
int size = int size = outputChannels * filterHeight * filterWidth * outputHeight *
inputChannels * filterHeight * filterWidth * outputHeight * outputWidth; outputWidth;
resizeBuffer<Device>(size); resizeBuffer<Device>(size);
real* colData = reinterpret_cast<real*>(memory_->getBuf()); real* colData = reinterpret_cast<real*>(memory_->getBuf());
...@@ -264,6 +272,7 @@ public: ...@@ -264,6 +272,7 @@ public:
inputChannels, inputChannels,
inputHeight, inputHeight,
inputWidth, inputWidth,
filterMultiplier,
filterHeight, filterHeight,
filterWidth, filterWidth,
strideH(), strideH(),
......
...@@ -32,6 +32,7 @@ namespace paddle { ...@@ -32,6 +32,7 @@ namespace paddle {
* \param[in] inputChannels channels of inputData. * \param[in] inputChannels channels of inputData.
* \param[in] inputHeight height of inputData. * \param[in] inputHeight height of inputData.
* \param[in] inputWidth width of inputData.. * \param[in] inputWidth width of inputData..
* \param[in] filterMultiplier equals to outputChannels/groups_.
* \param[in] filterHeight height of filter. * \param[in] filterHeight height of filter.
* \param[in] filterWidth widht of filter. * \param[in] filterWidth widht of filter.
* \param[in] strideH stride size in height direction. * \param[in] strideH stride size in height direction.
...@@ -53,6 +54,7 @@ public: ...@@ -53,6 +54,7 @@ public:
int inputChannels, int inputChannels,
int inputHeight, int inputHeight,
int inputWidth, int inputWidth,
int filterMultiplier,
int filterHeight, int filterHeight,
int filterWidth, int filterWidth,
int strideH, int strideH,
...@@ -74,7 +76,8 @@ public: ...@@ -74,7 +76,8 @@ public:
* \param[in] outputWidth width of outputData. * \param[in] outputWidth width of outputData.
* \param[in] inputChannels channels of input data. * \param[in] inputChannels channels of input data.
* \param[in] inputHeight height of inputData. * \param[in] inputHeight height of inputData.
* \param[in] inputWidth width of inputData.. * \param[in] inputWidth width of inputData.
* \param[in] filterMultiplier equals to outputChannels/groups_.
* \param[in] filterHeight height of filter. * \param[in] filterHeight height of filter.
* \param[in] filterWidth widht of filter. * \param[in] filterWidth widht of filter.
* \param[in] strideH stride size in height direction. * \param[in] strideH stride size in height direction.
...@@ -96,6 +99,7 @@ public: ...@@ -96,6 +99,7 @@ public:
int inputChannels, int inputChannels,
int inputHeight, int inputHeight,
int inputWidth, int inputWidth,
int filterMultiplier,
int filterHeight, int filterHeight,
int filterWidth, int filterWidth,
int strideH, int strideH,
...@@ -116,7 +120,8 @@ public: ...@@ -116,7 +120,8 @@ public:
* \param[in] outputWidth width of outputData. * \param[in] outputWidth width of outputData.
* \param[in] inputChannels channels of input data. * \param[in] inputChannels channels of input data.
* \param[in] inputHeight height of inputData. * \param[in] inputHeight height of inputData.
* \param[in] inputWidth width of inputData.. * \param[in] inputWidth width of inputData.
* \param[in] filterMultiplier equals to outputChannels/groups_.
* \param[in] filterHeight height of filter. * \param[in] filterHeight height of filter.
* \param[in] filterWidth widht of filter. * \param[in] filterWidth widht of filter.
* \param[in] strideH stride size in height direction. * \param[in] strideH stride size in height direction.
...@@ -140,6 +145,7 @@ public: ...@@ -140,6 +145,7 @@ public:
int inputChannels, int inputChannels,
int inputHeight, int inputHeight,
int inputWidth, int inputWidth,
int filterMultiplier,
int filterHeight, int filterHeight,
int filterWidth, int filterWidth,
int strideH, int strideH,
......
...@@ -25,7 +25,7 @@ void ConvolutionDepthwiseForward(const int nthreads, ...@@ -25,7 +25,7 @@ void ConvolutionDepthwiseForward(const int nthreads,
const T* const inputData, const T* const filterData, const T* const inputData, const T* const filterData,
const int batchSize, const int outputChannels, const int outputHeight, const int batchSize, const int outputChannels, const int outputHeight,
const int outputWidth,const int inputChannels, const int inputHeight, const int inputWidth, const int outputWidth,const int inputChannels, const int inputHeight, const int inputWidth,
const int filterHeight, const int filterWidth, const int strideH, const int filterMultiplier, const int filterHeight, const int filterWidth, const int strideH,
const int strideW, const int paddingH, const int paddingW, const int strideW, const int paddingH, const int paddingW,
T* const outputData) { T* const outputData) {
...@@ -33,23 +33,25 @@ void ConvolutionDepthwiseForward(const int nthreads, ...@@ -33,23 +33,25 @@ void ConvolutionDepthwiseForward(const int nthreads,
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if(index < nthreads) { if(index < nthreads) {
const int n = index / outputChannels / outputHeight / outputWidth; const int batch = index / outputChannels / outputHeight / outputWidth;
const int c = (index / outputHeight / outputWidth) % outputChannels; const int c_out = (index / outputHeight / outputWidth) % outputChannels;
const int h = (index / outputWidth) % outputHeight; const int h_out = (index / outputWidth) % outputHeight;
const int w = index % outputWidth; const int w_out = index % outputWidth;
const T* weight = filterData + c * filterHeight * filterWidth;
const int c_in = c_out / filterMultiplier;
const T* weight = filterData + c_out * filterHeight * filterWidth;
T value = 0; T value = 0;
const int h_in_start = -paddingH + h * strideH; const int h_in_start = -paddingH + h_out * strideH;
const int w_in_start = -paddingW + w * strideW; const int w_in_start = -paddingW + w_out * strideW;
const int h_in_end = -paddingH + h * strideH + filterHeight - 1; const int h_in_end = -paddingH + h_out * strideH + filterHeight - 1;
const int w_in_end = -paddingW + w * strideW + filterWidth - 1; const int w_in_end = -paddingW + w_out * strideW + filterWidth - 1;
if ((h_in_start >= 0) && (h_in_end < inputHeight) if ((h_in_start >= 0) && (h_in_end < inputHeight)
&&(w_in_start >= 0) && (w_in_end < inputWidth)) { &&(w_in_start >= 0) && (w_in_end < inputWidth)) {
for (int kh = 0; kh < filterHeight; ++kh) { for (int kh = 0; kh < filterHeight; ++kh) {
for (int kw = 0; kw < filterWidth; ++kw) { for (int kw = 0; kw < filterWidth; ++kw) {
const int h_in = -paddingH + h * strideH + kh; const int h_in = -paddingH + h_out * strideH + kh;
const int w_in = -paddingW + w * strideW + kw; const int w_in = -paddingW + w_out * strideW + kw;
const int offset = ((n * inputChannels + c) * inputHeight + h_in) const int offset = ((batch * inputChannels + c_in) * inputHeight + h_in)
* inputWidth + w_in; * inputWidth + w_in;
value += (*weight) * inputData[offset]; value += (*weight) * inputData[offset];
++weight; ++weight;
...@@ -58,11 +60,11 @@ void ConvolutionDepthwiseForward(const int nthreads, ...@@ -58,11 +60,11 @@ void ConvolutionDepthwiseForward(const int nthreads,
}else{ }else{
for (int kh = 0; kh < filterHeight; ++kh) { for (int kh = 0; kh < filterHeight; ++kh) {
for (int kw = 0; kw < filterWidth; ++kw) { for (int kw = 0; kw < filterWidth; ++kw) {
const int h_in = -paddingH + h * strideH + kh; const int h_in = -paddingH + h_out * strideH + kh;
const int w_in = -paddingW + w * strideW + kw; const int w_in = -paddingW + w_out * strideW + kw;
if ((h_in >= 0) && (h_in < inputHeight) if ((h_in >= 0) && (h_in < inputHeight)
&& (w_in >= 0) && (w_in < inputWidth)) { && (w_in >= 0) && (w_in < inputWidth)) {
const int offset = ((n * outputChannels + c) * inputHeight + h_in) const int offset = ((batch * inputChannels + c_in) * inputHeight + h_in)
* inputWidth + w_in; * inputWidth + w_in;
value += (*weight) * inputData[offset]; value += (*weight) * inputData[offset];
} }
...@@ -81,38 +83,42 @@ void ConvolutionDepthwiseInputBackward(const int nthreads, ...@@ -81,38 +83,42 @@ void ConvolutionDepthwiseInputBackward(const int nthreads,
const T* const top_diff, const T* const weight_data, const T* const top_diff, const T* const weight_data,
const int num, const int outputChannels, const int outputHeight, const int num, const int outputChannels, const int outputHeight,
const int outputWidth,const int inputChannels, const int inputHeight, const int inputWidth, const int outputWidth,const int inputChannels, const int inputHeight, const int inputWidth,
const int filterHeight, const int filterWidth, const int strideH, const int filterMultiplier, const int filterHeight, const int filterWidth, const int strideH,
const int strideW, const int paddingH, const int paddingW, const int strideW, const int paddingH, const int paddingW,
T* const bottom_diff) { T* const bottom_diff) {
int index = int index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if(index < nthreads) { if(index < nthreads) {
const int n = index / inputChannels / inputHeight / inputWidth; const int batch = index / inputChannels / inputHeight / inputWidth;
const int c = (index / inputHeight / inputWidth) % inputChannels; const int c_in = (index / inputHeight / inputWidth) % inputChannels;
const int h = (index / inputWidth) % inputHeight; const int h_in = (index / inputWidth) % inputHeight;
const int w = index % inputWidth; const int w_in = index % inputWidth;
const T* weight = weight_data + c * filterHeight * filterWidth; const int c_out_start = c_in * filterMultiplier;
T value = 0; T value = 0;
for (int kh = 0; kh < filterHeight; ++kh) { for(int c_out = c_out_start; c_out < c_out_start + filterMultiplier; c_out ++){
for (int kw = 0; kw < filterWidth; ++kw) { //weight bixu c_out
const int h_out_s = h + paddingH - kh; const T* weight = weight_data + c_out * filterHeight * filterWidth;
const int w_out_s = w + paddingW - kw; for (int kh = 0; kh < filterHeight; ++kh) {
if (((h_out_s % strideH) == 0) && ((w_out_s % strideW) == 0)) { for (int kw = 0; kw < filterWidth; ++kw) {
const int h_out = h_out_s / strideH; const int h_out_s = h_in + paddingH - kh;
const int w_out = w_out_s / strideW; const int w_out_s = w_in + paddingW - kw;
// TODO(zhaolong) : the 'if' affect the effectiveness, it needs to optimize if (((h_out_s % strideH) == 0) && ((w_out_s % strideW) == 0)) {
if ((h_out >= 0) && (h_out < outputHeight) const int h_out = h_out_s / strideH;
&& (w_out >= 0) && (w_out < outputWidth)) { const int w_out = w_out_s / strideW;
const int offset = ((n * outputChannels + c) * outputHeight + h_out) // TODO(zhaolong) : the 'if' affect the effectiveness, it needs to optimize
* outputWidth + w_out; if ((h_out >= 0) && (h_out < outputHeight)
value += (*weight) * top_diff[offset]; && (w_out >= 0) && (w_out < outputWidth)) {
} const int offset = ((batch * outputChannels + c_out) * outputHeight + h_out)
* outputWidth + w_out;
value += (*weight) * top_diff[offset];
}
}
++weight;
}
} }
++weight;
}
} }
bottom_diff[index] += value; bottom_diff[index] += value;
} }
} }
// CUDA kernel to compute the depthwise convolution backprop w.r.t filter. // CUDA kernel to compute the depthwise convolution backprop w.r.t filter.
...@@ -122,26 +128,27 @@ void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads, ...@@ -122,26 +128,27 @@ void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads,
const T* const top_diff, const T* const inputData, const T* const top_diff, const T* const inputData,
const int num, const int outputChannels, const int outputHeight, const int num, const int outputChannels, const int outputHeight,
const int outputWidth, const int inputChannels, const int inputHeight, const int inputWidth, const int outputWidth, const int inputChannels, const int inputHeight, const int inputWidth,
const int filterHeight, const int filterWidth, const int strideH, const int filterMultiplier, const int filterHeight, const int filterWidth, const int strideH,
const int strideW, const int paddingH, const int paddingW, const int strideW, const int paddingH, const int paddingW,
T* const buffer_data) { T* const buffer_data) {
int index = int index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < nthreads) { if (index < nthreads) {
const int h = (index / outputWidth) % outputHeight; const int h_out = (index / outputWidth) % outputHeight;
const int w = index % outputWidth; const int w_out = index % outputWidth;
const int kh = (index / filterWidth / outputHeight / outputWidth) const int kh = (index / filterWidth / outputHeight / outputWidth)
% filterHeight; % filterHeight;
const int kw = (index / outputHeight / outputWidth) % filterWidth; const int kw = (index / outputHeight / outputWidth) % filterWidth;
const int h_in = -paddingH + h * strideH + kh; const int h_in = -paddingH + h_out * strideH + kh;
const int w_in = -paddingW + w * strideW + kw; const int w_in = -paddingW + w_out * strideW + kw;
if ((h_in >= 0) && (h_in < inputHeight) if ((h_in >= 0) && (h_in < inputHeight)
&& (w_in >= 0) && (w_in < inputWidth)) { && (w_in >= 0) && (w_in < inputWidth)) {
const int c = index / filterHeight / filterWidth / outputHeight / outputWidth; const int c_out = index / filterHeight / filterWidth / outputHeight / outputWidth;
const int n = num_i; const int c_in = c_out / filterMultiplier;
const int top_offset = ((n * outputChannels + c) * outputHeight + h) const int batch = num_i;
* outputWidth + w; const int top_offset = ((batch * outputChannels + c_out) * outputHeight + h_out)
const int bottom_offset = ((n * inputChannels + c) * inputHeight + h_in) * outputWidth + w_out;
const int bottom_offset = ((batch * inputChannels + c_in) * inputHeight + h_in)
* inputWidth + w_in; * inputWidth + w_in;
buffer_data[index] = top_diff[top_offset] * inputData[bottom_offset]; buffer_data[index] = top_diff[top_offset] * inputData[bottom_offset];
} else { } else {
...@@ -162,6 +169,7 @@ public: ...@@ -162,6 +169,7 @@ public:
int inputChannels, int inputChannels,
int inputHeight, int inputHeight,
int inputWidth, int inputWidth,
int filterMultiplier,
int filterHeight, int filterHeight,
int filterWidth, int filterWidth,
int strideH, int strideH,
...@@ -190,6 +198,7 @@ public: ...@@ -190,6 +198,7 @@ public:
inputChannels, inputChannels,
inputHeight, inputHeight,
inputWidth, inputWidth,
filterMultiplier,
filterHeight, filterHeight,
filterWidth, filterWidth,
strideH, strideH,
...@@ -212,6 +221,7 @@ public: ...@@ -212,6 +221,7 @@ public:
int inputChannels, int inputChannels,
int inputHeight, int inputHeight,
int inputWidth, int inputWidth,
int filterMultiplier,
int filterHeight, int filterHeight,
int filterWidth, int filterWidth,
int strideH, int strideH,
...@@ -242,6 +252,7 @@ public: ...@@ -242,6 +252,7 @@ public:
inputChannels, inputChannels,
inputHeight, inputHeight,
inputWidth, inputWidth,
filterMultiplier,
filterHeight, filterHeight,
filterWidth, filterWidth,
strideH, strideH,
...@@ -264,6 +275,7 @@ public: ...@@ -264,6 +275,7 @@ public:
int inputChannels, int inputChannels,
int inputHeight, int inputHeight,
int inputWidth, int inputWidth,
int filterMultiplier,
int filterHeight, int filterHeight,
int filterWidth, int filterWidth,
int strideH, int strideH,
...@@ -273,14 +285,14 @@ public: ...@@ -273,14 +285,14 @@ public:
T* colData, T* colData,
T* filterGrad){ T* filterGrad){
int colDataSize = inputChannels * filterHeight * filterWidth * outputHeight * outputWidth; int colDataSize = outputChannels * filterHeight * filterWidth * outputHeight * outputWidth;
size_t blocks = (colDataSize + 1024 -1) / 1024; size_t blocks = (colDataSize + 1024 -1) / 1024;
size_t blockX = 512; size_t blockX = 512;
size_t blockY = (blocks+512-1)/512; size_t blockY = (blocks+512-1)/512;
dim3 threads(1024, 1); dim3 threads(1024, 1);
dim3 grid(blockX, blockY); dim3 grid(blockX, blockY);
BaseMatrix filterGradMatrix(inputChannels * filterHeight * filterWidth, 1, filterGrad, false, true); BaseMatrix filterGradMatrix(outputChannels * filterHeight * filterWidth, 1, filterGrad, false, true);
for(int i = 0; i < batchSize; i++) { for(int i = 0; i < batchSize; i++) {
ConvolutionDepthwiseFilterBackward<T> ConvolutionDepthwiseFilterBackward<T>
...@@ -296,6 +308,7 @@ public: ...@@ -296,6 +308,7 @@ public:
inputChannels, inputChannels,
inputHeight, inputHeight,
inputWidth, inputWidth,
filterMultiplier,
filterHeight, filterHeight,
filterWidth, filterWidth,
strideH, strideH,
...@@ -304,8 +317,8 @@ public: ...@@ -304,8 +317,8 @@ public:
paddingW, paddingW,
colData colData
); );
int M = colDataSize / outputHeight / outputWidth;
int K = outputHeight * outputWidth; int K = outputHeight * outputWidth;
int M = colDataSize / K;
BaseMatrix colMatrix(M, K, colData, false, true); BaseMatrix colMatrix(M, K, colData, false, true);
filterGradMatrix.sumRows(colMatrix, (T)1.0, (T)1.0); filterGradMatrix.sumRows(colMatrix, (T)1.0, (T)1.0);
......
...@@ -355,7 +355,7 @@ void testDepthwiseConvLayer(const string& type, bool useGpu) { ...@@ -355,7 +355,7 @@ void testDepthwiseConvLayer(const string& type, bool useGpu) {
config.layerConfig.set_partial_sum(1); config.layerConfig.set_partial_sum(1);
config.layerConfig.set_shared_biases(true); config.layerConfig.set_shared_biases(true);
config.inputDefs.push_back({INPUT_DATA, "layer_0", 2048, 96}); config.inputDefs.push_back({INPUT_DATA, "layer_0", 2048, 192 / 2});
LayerInputConfig* input = config.layerConfig.add_inputs(); LayerInputConfig* input = config.layerConfig.add_inputs();
ConvConfig* conv = input->mutable_conv_conf(); ConvConfig* conv = input->mutable_conv_conf();
conv->set_filter_size(2); conv->set_filter_size(2);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册