提交 d775895e 编写于 作者: H hedaoyuan

Add Im2ColMobileFunctor.

上级 dbf1d75f
...@@ -206,8 +206,7 @@ public: ...@@ -206,8 +206,7 @@ public:
colData = reinterpret_cast<real*>(memory_->getBuf()); colData = reinterpret_cast<real*>(memory_->getBuf());
} }
Im2ColFunctor<kCFO, Device, real> im2col; Im2ColMobileFunctor<real> im2col;
GemmFunctor<Device, real> gemm;
size_t inputOffset = imShape.getElements(); size_t inputOffset = imShape.getElements();
size_t outputOffset = size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth; (outputChannels / groups_) * outputHeight * outputWidth;
...@@ -241,19 +240,20 @@ public: ...@@ -241,19 +240,20 @@ public:
// gemm // gemm
int M = outputChannels / groups_; int M = outputChannels / groups_;
gemm(CblasNoTrans, BlasGemm<Device, real>::compute(
CblasNoTrans, false,
M, false,
N, M,
K, N,
1.0f, K,
filterData + g * filterOffset + colHeightStart, 1.0f,
kStride, filterData + g * filterOffset + colHeightStart,
colData, kStride,
N, colData,
beta_, N,
outputData + g * outputOffset + colWidthStart, beta_,
nStride); outputData + g * outputOffset + colWidthStart,
nStride);
} }
beta_ = 1.0; beta_ = 1.0;
} }
...@@ -261,19 +261,19 @@ public: ...@@ -261,19 +261,19 @@ public:
int M = outputChannels / groups_; int M = outputChannels / groups_;
int N = outputHeight * outputWidth; int N = outputHeight * outputWidth;
int K = inputChannels / groups_ * filterHeight * filterWidth; int K = inputChannels / groups_ * filterHeight * filterWidth;
gemm(CblasNoTrans, BlasGemm<Device, real>::compute(false,
CblasNoTrans, false,
M, M,
N, N,
K, K,
1.0f, 1.0f,
filterData + g * filterOffset, filterData + g * filterOffset,
K, K,
inputData + g * inputOffset, inputData + g * inputOffset,
N, N,
beta, beta,
outputData + g * outputOffset, outputData + g * outputOffset,
N); N);
} }
} }
inputData += inputChannels * inputHeight * inputWidth; inputData += inputChannels * inputHeight * inputWidth;
......
...@@ -98,4 +98,52 @@ public: ...@@ -98,4 +98,52 @@ public:
int dilationWidth = 1); int dilationWidth = 1);
}; };
template <class T>
class Im2ColMobileFunctor {
public:
void operator()(const T* imData,
const TensorShape& imShape,
T* colData,
const TensorShape& colShape,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int colHeightStart,
int colHeightSize,
int colWidthStart,
int colWidthSize) {
int inputHeight = imShape[1];
int inputWidth = imShape[2];
int filterHeight = colShape[1];
int filterWidth = colShape[2];
int outputWidth = colShape[4];
for (int colh = 0; colh < colHeightSize; colh++) {
int wOffset = (colHeightStart + colh) % filterWidth;
int hOffset = ((colHeightStart + colh) / filterWidth) % filterHeight;
int c_im = (colHeightStart + colh) / filterWidth / filterHeight;
for (int colw = 0; colw < colWidthSize; colw++) {
int h = (colWidthStart + colw) / outputWidth;
int w = (colWidthStart + colw) % outputWidth;
int imRowIdx = h * strideHeight + hOffset;
int imColIdx = w * strideWidth + wOffset;
if ((imRowIdx - paddingHeight) < 0 ||
(imRowIdx - paddingHeight) >= inputHeight ||
(imColIdx - paddingWidth) < 0 ||
(imColIdx - paddingWidth) >= inputWidth) {
colData[colh * colWidthSize + colw] = T(0);
} else {
imRowIdx += c_im * inputHeight - paddingHeight;
imColIdx -= paddingWidth;
colData[colh * colWidthSize + colw] =
imData[imRowIdx * inputWidth + imColIdx];
}
}
}
}
};
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册