提交 fa10677a 编写于 作者: X xzl

modify skipIm2col to need2col, delete useless variable colBuffer

上级 5229df52
...@@ -110,10 +110,10 @@ protected: ...@@ -110,10 +110,10 @@ protected:
} }
// determine whether im2col needs to be performed // determine whether im2col needs to be performed
inline bool isSkipIm2col(const TensorShape& filter) const { inline bool isNeedIm2col(const TensorShape& filter) const {
return (getFilterHeight(filter) == 1 && getFilterWidth(filter) == 1 && return !(getFilterHeight(filter) == 1 && getFilterWidth(filter) == 1 &&
strideH() == 1 && strideW() == 1 && paddingH() == 0 && strideH() == 1 && strideW() == 1 && paddingH() == 0 &&
paddingW() == 0); paddingW() == 0);
} }
std::vector<size_t> strides_; std::vector<size_t> strides_;
......
...@@ -66,15 +66,15 @@ public: ...@@ -66,15 +66,15 @@ public:
real* inputData = inputs[0].data<real>(); real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>(); real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>(); real* outputData = outputs[0].data<real>();
bool skipIm2col = isSkipIm2col(filter); bool needIm2col = isNeedIm2col(filter);
TensorShape imShape = TensorShape imShape =
TensorShape({inputChannels / groups_, inputHeight, inputWidth}); TensorShape({inputChannels / groups_, inputHeight, inputWidth});
TensorShape colShape; TensorShape colShape;
real *colBuffer, *colData = NULL; real* colData = NULL;
if (!skipIm2col) { if (needIm2col) {
colShape = TensorShape({inputChannels / groups_, colShape = TensorShape({inputChannels / groups_,
filterHeight, filterHeight,
filterWidth, filterWidth,
...@@ -93,8 +93,7 @@ public: ...@@ -93,8 +93,7 @@ public:
for (size_t i = 0; i < batchSize; i++) { for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) { for (size_t g = 0; g < groups_; g++) {
colBuffer = inputData + g * inputOffset; if (needIm2col) {
if (!skipIm2col) {
im2col(inputData + g * inputOffset, im2col(inputData + g * inputOffset,
imShape, imShape,
colData, colData,
...@@ -103,7 +102,8 @@ public: ...@@ -103,7 +102,8 @@ public:
strideW(), strideW(),
paddingH(), paddingH(),
paddingW()); paddingW());
colBuffer = colData; } else {
colData = inputData + g * inputOffset;
} }
int M = outputChannels / groups_; int M = outputChannels / groups_;
int N = outputHeight * outputWidth; int N = outputHeight * outputWidth;
...@@ -116,7 +116,7 @@ public: ...@@ -116,7 +116,7 @@ public:
1.0f, 1.0f,
filterData + g * filterOffset, filterData + g * filterOffset,
K, K,
colBuffer, colData,
N, N,
beta, beta,
outputData + g * outputOffset, outputData + g * outputOffset,
...@@ -169,15 +169,15 @@ public: ...@@ -169,15 +169,15 @@ public:
real* outputGrad = inputs[0].data<real>(); real* outputGrad = inputs[0].data<real>();
real* filterData = inputs[1].data<real>(); real* filterData = inputs[1].data<real>();
real* inputGrad = outputs[0].data<real>(); real* inputGrad = outputs[0].data<real>();
bool skipIm2col = isSkipIm2col(filter); bool needIm2col = isNeedIm2col(filter);
TensorShape imShape = TensorShape imShape =
TensorShape({inputChannels / groups_, inputHeight, inputWidth}); TensorShape({inputChannels / groups_, inputHeight, inputWidth});
TensorShape colShape; TensorShape colShape;
real *colBuffer, *colData = NULL; real* colData = NULL;
if (!skipIm2col) { if (needIm2col) {
colShape = TensorShape({inputChannels / groups_, colShape = TensorShape({inputChannels / groups_,
filterHeight, filterHeight,
filterWidth, filterWidth,
...@@ -200,10 +200,9 @@ public: ...@@ -200,10 +200,9 @@ public:
int K = outputChannels / groups_; int K = outputChannels / groups_;
int N = outputHeight * outputWidth; int N = outputHeight * outputWidth;
int M = inputChannels / groups_ * filterHeight * filterWidth; int M = inputChannels / groups_ * filterHeight * filterWidth;
colBuffer = colData;
real scale = 0.0f; real scale = 0.0f;
if (skipIm2col) { if (!needIm2col) {
colBuffer = inputGrad + g * inputOffset; colData = inputGrad + g * inputOffset;
scale = 1.0f; scale = 1.0f;
} }
gemm(CblasTrans, gemm(CblasTrans,
...@@ -217,12 +216,12 @@ public: ...@@ -217,12 +216,12 @@ public:
outputGrad + g * outputOffset, outputGrad + g * outputOffset,
N, N,
scale, scale,
colBuffer, colData,
N); N);
if (!skipIm2col) { if (needIm2col) {
col2im(inputGrad + g * inputOffset, col2im(inputGrad + g * inputOffset,
imShape, imShape,
colBuffer, colData,
colShape, colShape,
strideH(), strideH(),
strideW(), strideW(),
...@@ -281,15 +280,15 @@ public: ...@@ -281,15 +280,15 @@ public:
real* outputGrad = inputs[0].data<real>(); real* outputGrad = inputs[0].data<real>();
real* inputData = inputs[1].data<real>(); real* inputData = inputs[1].data<real>();
real* filterGrad = outputs[0].data<real>(); real* filterGrad = outputs[0].data<real>();
bool skipIm2col = isSkipIm2col(filter); bool needIm2col = isNeedIm2col(filter);
TensorShape imShape = TensorShape imShape =
TensorShape({inputChannels / groups_, inputHeight, inputWidth}); TensorShape({inputChannels / groups_, inputHeight, inputWidth});
TensorShape colShape; TensorShape colShape;
real *colBuffer, *colData = NULL; real* colData = NULL;
if (!skipIm2col) { if (needIm2col) {
colShape = TensorShape({inputChannels / groups_, colShape = TensorShape({inputChannels / groups_,
filterHeight, filterHeight,
filterWidth, filterWidth,
...@@ -307,8 +306,7 @@ public: ...@@ -307,8 +306,7 @@ public:
size_t filterOffset = filter.getElements() / groups_; size_t filterOffset = filter.getElements() / groups_;
for (size_t i = 0; i < batchSize; i++) { for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) { for (size_t g = 0; g < groups_; g++) {
colBuffer = inputData + g * inputOffset; if (needIm2col) {
if (!skipIm2col) {
im2col(inputData + g * inputOffset, im2col(inputData + g * inputOffset,
imShape, imShape,
colData, colData,
...@@ -317,7 +315,8 @@ public: ...@@ -317,7 +315,8 @@ public:
strideW(), strideW(),
paddingH(), paddingH(),
paddingW()); paddingW());
colBuffer = colData; } else {
colData = inputData + g * inputOffset;
} }
int M = outputChannels / groups_; int M = outputChannels / groups_;
int K = outputHeight * outputWidth; int K = outputHeight * outputWidth;
...@@ -330,7 +329,7 @@ public: ...@@ -330,7 +329,7 @@ public:
1.0f, 1.0f,
outputGrad + g * outputOffset, outputGrad + g * outputOffset,
K, K,
colBuffer, colData,
K, K,
i == 0 ? beta : 1.0f, i == 0 ? beta : 1.0f,
filterGrad + g * filterOffset, filterGrad + g * filterOffset,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册