提交 5229df52 编写于 作者: X xzl

ignore im2col if not necessary in conv 1 * 1

上级 f70e8077
......@@ -109,6 +109,13 @@ protected:
return filter[filter.ndims() - 1];
}
// determine whether im2col needs to be performed
inline bool isSkipIm2col(const TensorShape& filter) const {
return (getFilterHeight(filter) == 1 && getFilterWidth(filter) == 1 &&
strideH() == 1 && strideW() == 1 && paddingH() == 0 &&
paddingW() == 0);
}
std::vector<size_t> strides_;
std::vector<size_t> paddings_;
......
......@@ -66,16 +66,23 @@ public:
real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>();
bool skipIm2col = isSkipIm2col(filter);
TensorShape imShape =
TensorShape({inputChannels / groups_, inputHeight, inputWidth});
TensorShape colShape = TensorShape({inputChannels / groups_,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
resizeBuffer<Device>(colShape.getElements());
real* colData = reinterpret_cast<real*>(memory_->getBuf());
TensorShape colShape;
real *colBuffer, *colData = NULL;
if (!skipIm2col) {
colShape = TensorShape({inputChannels / groups_,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
resizeBuffer<Device>(colShape.getElements());
colData = reinterpret_cast<real*>(memory_->getBuf());
}
Im2ColFunctor<kCFO, Device, real> im2col;
GemmFunctor<Device, real> gemm;
......@@ -86,15 +93,18 @@ public:
for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) {
im2col(inputData + g * inputOffset,
imShape,
colData,
colShape,
strideH(),
strideW(),
paddingH(),
paddingW());
colBuffer = inputData + g * inputOffset;
if (!skipIm2col) {
im2col(inputData + g * inputOffset,
imShape,
colData,
colShape,
strideH(),
strideW(),
paddingH(),
paddingW());
colBuffer = colData;
}
int M = outputChannels / groups_;
int N = outputHeight * outputWidth;
int K = inputChannels / groups_ * filterHeight * filterWidth;
......@@ -106,7 +116,7 @@ public:
1.0f,
filterData + g * filterOffset,
K,
colData,
colBuffer,
N,
beta,
outputData + g * outputOffset,
......@@ -159,19 +169,27 @@ public:
real* outputGrad = inputs[0].data<real>();
real* filterData = inputs[1].data<real>();
real* inputGrad = outputs[0].data<real>();
bool skipIm2col = isSkipIm2col(filter);
TensorShape imShape =
TensorShape({inputChannels / groups_, inputHeight, inputWidth});
TensorShape colShape = TensorShape({inputChannels / groups_,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
resizeBuffer<Device>(colShape.getElements());
real* colData = reinterpret_cast<real*>(memory_->getBuf());
TensorShape colShape;
real *colBuffer, *colData = NULL;
if (!skipIm2col) {
colShape = TensorShape({inputChannels / groups_,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
resizeBuffer<Device>(colShape.getElements());
colData = reinterpret_cast<real*>(memory_->getBuf());
}
Col2ImFunctor<kCFO, Device, real> col2im;
GemmFunctor<Device, real> gemm;
size_t inputOffset = imShape.getElements();
size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth;
......@@ -182,6 +200,12 @@ public:
int K = outputChannels / groups_;
int N = outputHeight * outputWidth;
int M = inputChannels / groups_ * filterHeight * filterWidth;
colBuffer = colData;
real scale = 0.0f;
if (skipIm2col) {
colBuffer = inputGrad + g * inputOffset;
scale = 1.0f;
}
gemm(CblasTrans,
CblasNoTrans,
M,
......@@ -192,17 +216,19 @@ public:
M,
outputGrad + g * outputOffset,
N,
0.0f,
colData,
scale,
colBuffer,
N);
col2im(inputGrad + g * inputOffset,
imShape,
colData,
colShape,
strideH(),
strideW(),
paddingH(),
paddingW());
if (!skipIm2col) {
col2im(inputGrad + g * inputOffset,
imShape,
colBuffer,
colShape,
strideH(),
strideW(),
paddingH(),
paddingW());
}
}
inputGrad += inputChannels * inputHeight * inputWidth;
outputGrad += outputChannels * outputHeight * outputWidth;
......@@ -255,16 +281,23 @@ public:
real* outputGrad = inputs[0].data<real>();
real* inputData = inputs[1].data<real>();
real* filterGrad = outputs[0].data<real>();
bool skipIm2col = isSkipIm2col(filter);
TensorShape imShape =
TensorShape({inputChannels / groups_, inputHeight, inputWidth});
TensorShape colShape = TensorShape({inputChannels / groups_,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
resizeBuffer<Device>(colShape.getElements());
real* colData = reinterpret_cast<real*>(memory_->getBuf());
TensorShape colShape;
real *colBuffer, *colData = NULL;
if (!skipIm2col) {
colShape = TensorShape({inputChannels / groups_,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
resizeBuffer<Device>(colShape.getElements());
colData = reinterpret_cast<real*>(memory_->getBuf());
}
Im2ColFunctor<kCFO, Device, real> im2col;
GemmFunctor<Device, real> gemm;
......@@ -274,15 +307,18 @@ public:
size_t filterOffset = filter.getElements() / groups_;
for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) {
im2col(inputData + g * inputOffset,
imShape,
colData,
colShape,
strideH(),
strideW(),
paddingH(),
paddingW());
colBuffer = inputData + g * inputOffset;
if (!skipIm2col) {
im2col(inputData + g * inputOffset,
imShape,
colData,
colShape,
strideH(),
strideW(),
paddingH(),
paddingW());
colBuffer = colData;
}
int M = outputChannels / groups_;
int K = outputHeight * outputWidth;
int N = inputChannels / groups_ * filterHeight * filterWidth;
......@@ -294,7 +330,7 @@ public:
1.0f,
outputGrad + g * outputOffset,
K,
colData,
colBuffer,
K,
i == 0 ? beta : 1.0f,
filterGrad + g * filterOffset,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册