提交 94cee3d6 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Merge pull request #3163 from NHZlX/fix_conv_1x1

ignore im2col if not necessary in conv 1 * 1
...@@ -109,6 +109,13 @@ protected: ...@@ -109,6 +109,13 @@ protected:
return filter[filter.ndims() - 1]; return filter[filter.ndims() - 1];
} }
// determine whether im2col needs to be performed
inline bool isNeedIm2col(const TensorShape& filter) const {
return !(getFilterHeight(filter) == 1 && getFilterWidth(filter) == 1 &&
strideH() == 1 && strideW() == 1 && paddingH() == 0 &&
paddingW() == 0);
}
std::vector<size_t> strides_; std::vector<size_t> strides_;
std::vector<size_t> paddings_; std::vector<size_t> paddings_;
......
...@@ -66,16 +66,23 @@ public: ...@@ -66,16 +66,23 @@ public:
real* inputData = inputs[0].data<real>(); real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>(); real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>(); real* outputData = outputs[0].data<real>();
bool needIm2col = isNeedIm2col(filter);
TensorShape imShape = TensorShape imShape =
TensorShape({inputChannels / groups_, inputHeight, inputWidth}); TensorShape({inputChannels / groups_, inputHeight, inputWidth});
TensorShape colShape = TensorShape({inputChannels / groups_,
TensorShape colShape;
real* colData = NULL;
if (needIm2col) {
colShape = TensorShape({inputChannels / groups_,
filterHeight, filterHeight,
filterWidth, filterWidth,
outputHeight, outputHeight,
outputWidth}); outputWidth});
resizeBuffer<Device>(colShape.getElements()); resizeBuffer<Device>(colShape.getElements());
real* colData = reinterpret_cast<real*>(memory_->getBuf()); colData = reinterpret_cast<real*>(memory_->getBuf());
}
Im2ColFunctor<kCFO, Device, real> im2col; Im2ColFunctor<kCFO, Device, real> im2col;
GemmFunctor<Device, real> gemm; GemmFunctor<Device, real> gemm;
...@@ -86,6 +93,7 @@ public: ...@@ -86,6 +93,7 @@ public:
for (size_t i = 0; i < batchSize; i++) { for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) { for (size_t g = 0; g < groups_; g++) {
if (needIm2col) {
im2col(inputData + g * inputOffset, im2col(inputData + g * inputOffset,
imShape, imShape,
colData, colData,
...@@ -94,7 +102,9 @@ public: ...@@ -94,7 +102,9 @@ public:
strideW(), strideW(),
paddingH(), paddingH(),
paddingW()); paddingW());
} else {
colData = inputData + g * inputOffset;
}
int M = outputChannels / groups_; int M = outputChannels / groups_;
int N = outputHeight * outputWidth; int N = outputHeight * outputWidth;
int K = inputChannels / groups_ * filterHeight * filterWidth; int K = inputChannels / groups_ * filterHeight * filterWidth;
...@@ -159,19 +169,27 @@ public: ...@@ -159,19 +169,27 @@ public:
real* outputGrad = inputs[0].data<real>(); real* outputGrad = inputs[0].data<real>();
real* filterData = inputs[1].data<real>(); real* filterData = inputs[1].data<real>();
real* inputGrad = outputs[0].data<real>(); real* inputGrad = outputs[0].data<real>();
bool needIm2col = isNeedIm2col(filter);
TensorShape imShape = TensorShape imShape =
TensorShape({inputChannels / groups_, inputHeight, inputWidth}); TensorShape({inputChannels / groups_, inputHeight, inputWidth});
TensorShape colShape = TensorShape({inputChannels / groups_,
TensorShape colShape;
real* colData = NULL;
if (needIm2col) {
colShape = TensorShape({inputChannels / groups_,
filterHeight, filterHeight,
filterWidth, filterWidth,
outputHeight, outputHeight,
outputWidth}); outputWidth});
resizeBuffer<Device>(colShape.getElements()); resizeBuffer<Device>(colShape.getElements());
real* colData = reinterpret_cast<real*>(memory_->getBuf()); colData = reinterpret_cast<real*>(memory_->getBuf());
}
Col2ImFunctor<kCFO, Device, real> col2im; Col2ImFunctor<kCFO, Device, real> col2im;
GemmFunctor<Device, real> gemm; GemmFunctor<Device, real> gemm;
size_t inputOffset = imShape.getElements(); size_t inputOffset = imShape.getElements();
size_t outputOffset = size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth; (outputChannels / groups_) * outputHeight * outputWidth;
...@@ -182,6 +200,11 @@ public: ...@@ -182,6 +200,11 @@ public:
int K = outputChannels / groups_; int K = outputChannels / groups_;
int N = outputHeight * outputWidth; int N = outputHeight * outputWidth;
int M = inputChannels / groups_ * filterHeight * filterWidth; int M = inputChannels / groups_ * filterHeight * filterWidth;
real scale = 0.0f;
if (!needIm2col) {
colData = inputGrad + g * inputOffset;
scale = 1.0f;
}
gemm(CblasTrans, gemm(CblasTrans,
CblasNoTrans, CblasNoTrans,
M, M,
...@@ -192,9 +215,10 @@ public: ...@@ -192,9 +215,10 @@ public:
M, M,
outputGrad + g * outputOffset, outputGrad + g * outputOffset,
N, N,
0.0f, scale,
colData, colData,
N); N);
if (needIm2col) {
col2im(inputGrad + g * inputOffset, col2im(inputGrad + g * inputOffset,
imShape, imShape,
colData, colData,
...@@ -204,6 +228,7 @@ public: ...@@ -204,6 +228,7 @@ public:
paddingH(), paddingH(),
paddingW()); paddingW());
} }
}
inputGrad += inputChannels * inputHeight * inputWidth; inputGrad += inputChannels * inputHeight * inputWidth;
outputGrad += outputChannels * outputHeight * outputWidth; outputGrad += outputChannels * outputHeight * outputWidth;
} }
...@@ -255,16 +280,23 @@ public: ...@@ -255,16 +280,23 @@ public:
real* outputGrad = inputs[0].data<real>(); real* outputGrad = inputs[0].data<real>();
real* inputData = inputs[1].data<real>(); real* inputData = inputs[1].data<real>();
real* filterGrad = outputs[0].data<real>(); real* filterGrad = outputs[0].data<real>();
bool needIm2col = isNeedIm2col(filter);
TensorShape imShape = TensorShape imShape =
TensorShape({inputChannels / groups_, inputHeight, inputWidth}); TensorShape({inputChannels / groups_, inputHeight, inputWidth});
TensorShape colShape = TensorShape({inputChannels / groups_,
TensorShape colShape;
real* colData = NULL;
if (needIm2col) {
colShape = TensorShape({inputChannels / groups_,
filterHeight, filterHeight,
filterWidth, filterWidth,
outputHeight, outputHeight,
outputWidth}); outputWidth});
resizeBuffer<Device>(colShape.getElements()); resizeBuffer<Device>(colShape.getElements());
real* colData = reinterpret_cast<real*>(memory_->getBuf()); colData = reinterpret_cast<real*>(memory_->getBuf());
}
Im2ColFunctor<kCFO, Device, real> im2col; Im2ColFunctor<kCFO, Device, real> im2col;
GemmFunctor<Device, real> gemm; GemmFunctor<Device, real> gemm;
...@@ -274,6 +306,7 @@ public: ...@@ -274,6 +306,7 @@ public:
size_t filterOffset = filter.getElements() / groups_; size_t filterOffset = filter.getElements() / groups_;
for (size_t i = 0; i < batchSize; i++) { for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) { for (size_t g = 0; g < groups_; g++) {
if (needIm2col) {
im2col(inputData + g * inputOffset, im2col(inputData + g * inputOffset,
imShape, imShape,
colData, colData,
...@@ -282,7 +315,9 @@ public: ...@@ -282,7 +315,9 @@ public:
strideW(), strideW(),
paddingH(), paddingH(),
paddingW()); paddingW());
} else {
colData = inputData + g * inputOffset;
}
int M = outputChannels / groups_; int M = outputChannels / groups_;
int K = outputHeight * outputWidth; int K = outputHeight * outputWidth;
int N = inputChannels / groups_ * filterHeight * filterWidth; int N = inputChannels / groups_ * filterHeight * filterWidth;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册