未验证 提交 292c1951 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #7441 from hedaoyuan/inference

Some optimization of CNN model computation.
...@@ -178,19 +178,22 @@ public: ...@@ -178,19 +178,22 @@ public:
real* inputData = inputs[0].data<real>(); real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>(); real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>(); real* outputData = outputs[0].data<real>();
real* colData = NULL;
bool needIm2col = isNeedIm2col(filter); bool needIm2col = isNeedIm2col(filter);
TensorShape imShape = TensorShape imShape =
TensorShape({inputChannels / groups_, inputHeight, inputWidth}); TensorShape({inputChannels / groups_, inputHeight, inputWidth});
TensorShape colShape; TensorShape colShape;
real* colData = NULL;
size_t colHeight = inputChannels / groups_ * filterHeight * filterWidth; // Max col matrix width 4096, Max col matrix size 4M.
size_t colWidth = outputHeight * outputWidth; size_t outputHeightSteps =
// Max col matrix height 256, Max col matrix width 1024 std::min(std::max(4096 / outputWidth, (size_t)1), outputHeight);
size_t stepColHeight = std::min(colHeight, static_cast<size_t>(256)); size_t maxColWidth = outputHeightSteps * outputWidth;
size_t stepColWidth = std::min(colWidth, static_cast<size_t>(2048)); size_t channelSteps =
std::min(std::max((1048576 / maxColWidth) / filterHeight * filterWidth,
(size_t)1),
inputChannels / groups_);
size_t maxColHeight = channelSteps * filterHeight * filterWidth;
if (needIm2col) { if (needIm2col) {
colShape = TensorShape({inputChannels / groups_, colShape = TensorShape({inputChannels / groups_,
...@@ -199,7 +202,7 @@ public: ...@@ -199,7 +202,7 @@ public:
outputHeight, outputHeight,
outputWidth}); outputWidth});
resizeBuffer<Device>(stepColHeight * stepColWidth * sizeof(real)); resizeBuffer<Device>(maxColHeight * maxColWidth * sizeof(real));
colData = reinterpret_cast<real*>(memory_->getBuf()); colData = reinterpret_cast<real*>(memory_->getBuf());
} }
...@@ -209,20 +212,24 @@ public: ...@@ -209,20 +212,24 @@ public:
(outputChannels / groups_) * outputHeight * outputWidth; (outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_; size_t filterOffset = filter.getElements() / groups_;
int nStride = colWidth; int nStride = outputHeight * outputWidth;
int kStride = colHeight; int kStride = inputChannels / groups_ * filterHeight * filterWidth;
for (size_t i = 0; i < batchSize; i++) { for (size_t i = 0; i < batchSize; i++) {
filterData = inputs[1].data<real>();
for (size_t g = 0; g < groups_; g++) { for (size_t g = 0; g < groups_; g++) {
if (needIm2col) { if (needIm2col) {
real beta_ = beta; real beta_ = beta;
for (size_t colHeightStart = 0; colHeightStart < colHeight; for (size_t ic = 0; ic < inputChannels / groups_;
colHeightStart += stepColHeight) { ic += channelSteps) {
for (size_t colWidthStart = 0; colWidthStart < colWidth; int channels = std::min(inputChannels / groups_ - ic, channelSteps);
colWidthStart += stepColWidth) { for (size_t oh = 0; oh < outputHeight; oh += outputHeightSteps) {
int N = std::min(colWidth - colWidthStart, stepColWidth); int height = std::min(outputHeight - oh, outputHeightSteps);
int K = std::min(colHeight - colHeightStart, stepColHeight);
int M = outputChannels / groups_;
int N = height * outputWidth;
int K = channels * filterHeight * filterWidth;
// im2col // im2col
im2col(inputData + g * inputOffset, im2col(inputData,
imShape, imShape,
colData, colData,
colShape, colShape,
...@@ -232,13 +239,12 @@ public: ...@@ -232,13 +239,12 @@ public:
paddingW(), paddingW(),
dilationH(), dilationH(),
dilationW(), dilationW(),
colHeightStart, channels,
K, oh,
colWidthStart, height,
N); N);
// gemm // gemm
int M = outputChannels / groups_;
BlasGemm<Device, real>::compute( BlasGemm<Device, real>::compute(
false, false,
false, false,
...@@ -246,12 +252,12 @@ public: ...@@ -246,12 +252,12 @@ public:
N, N,
K, K,
1.0f, 1.0f,
filterData + g * filterOffset + colHeightStart, filterData + ic * filterHeight * filterWidth,
kStride, kStride,
colData, colData,
N, N,
beta_, beta_,
outputData + g * outputOffset + colWidthStart, outputData + oh * outputWidth,
nStride); nStride);
} }
beta_ = 1.0; beta_ = 1.0;
...@@ -266,17 +272,18 @@ public: ...@@ -266,17 +272,18 @@ public:
N, N,
K, K,
1.0f, 1.0f,
filterData + g * filterOffset, filterData,
K, K,
inputData + g * inputOffset, inputData,
N, N,
beta, beta,
outputData + g * outputOffset, outputData,
N); N);
} }
inputData += inputOffset;
outputData += outputOffset;
filterData += filterOffset;
} }
inputData += inputChannels * inputHeight * inputWidth;
outputData += outputChannels * outputHeight * outputWidth;
} }
memory_.reset(); memory_.reset();
......
...@@ -111,39 +111,42 @@ public: ...@@ -111,39 +111,42 @@ public:
int paddingWidth, int paddingWidth,
int dilationHeight, int dilationHeight,
int dilationWidth, int dilationWidth,
int colHeightStart, int inputChannels,
int colHeightSize, int colOffset,
int colWidthStart, int colOutputHeight,
int colWidthSize) { int colWidth) {
int inputHeight = imShape[1]; int inputHeight = imShape[1];
int inputWidth = imShape[2]; int inputWidth = imShape[2];
int filterHeight = colShape[1]; int filterHeight = colShape[1];
int filterWidth = colShape[2]; int filterWidth = colShape[2];
int outputWidth = colShape[4]; int outputWidth = colShape[4];
for (int colh = 0; colh < colHeightSize; colh++) { for (int ic = 0; ic < inputChannels; ic++) {
int wOffset = (colHeightStart + colh) % filterWidth; for (int oh = 0; oh < colOutputHeight; oh++) {
int hOffset = ((colHeightStart + colh) / filterWidth) % filterHeight; T* dstData = colData + oh * outputWidth;
int c_im = (colHeightStart + colh) / filterWidth / filterHeight; for (int fh = 0; fh < filterHeight; fh++) {
for (int fw = 0; fw < filterWidth; fw++) {
for (int colw = 0; colw < colWidthSize; colw++) { int imRowIdx = (oh + colOffset) * strideHeight +
int h = (colWidthStart + colw) / outputWidth; fh * dilationHeight - paddingHeight;
int w = (colWidthStart + colw) % outputWidth; if (imRowIdx < 0 || imRowIdx >= inputHeight) {
memset(dstData, 0, outputWidth * sizeof(T));
int imRowIdx = h * strideHeight + hOffset * dilationHeight;
int imColIdx = w * strideWidth + wOffset * dilationWidth;
if ((imRowIdx - paddingHeight) < 0 ||
(imRowIdx - paddingHeight) >= inputHeight ||
(imColIdx - paddingWidth) < 0 ||
(imColIdx - paddingWidth) >= inputWidth) {
colData[colh * colWidthSize + colw] = static_cast<T>(0);
} else { } else {
imRowIdx += c_im * inputHeight - paddingHeight; for (int ow = 0; ow < outputWidth; ow++) {
imColIdx -= paddingWidth; int imColIdx =
colData[colh * colWidthSize + colw] = ow * strideWidth + fw * dilationWidth - paddingWidth;
imData[imRowIdx * inputWidth + imColIdx]; if (imColIdx < 0 || imColIdx >= inputWidth) {
dstData[ow] = T(0);
} else {
dstData[ow] = imData[imRowIdx * inputWidth + imColIdx];
}
}
}
dstData += colWidth;
}
} }
} }
colData += filterHeight * filterWidth * colWidth;
imData += inputHeight * inputWidth;
} }
} }
}; };
......
...@@ -202,10 +202,10 @@ void TestIm2ColMobileFunctor() { ...@@ -202,10 +202,10 @@ void TestIm2ColMobileFunctor() {
padding, padding,
dilation, dilation,
dilation, dilation,
channels,
0, 0,
height, outputHeight,
0, outputHeight * outputWidth);
width);
autotest::TensorCheckEqual(*output1, *output2); autotest::TensorCheckEqual(*output1, *output2);
} }
......
...@@ -2015,13 +2015,6 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, ...@@ -2015,13 +2015,6 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
CHECK_EQ(channels * outLength, maskMatP->getWidth()); CHECK_EQ(channels * outLength, maskMatP->getWidth());
} }
/* initialize the data_ */
for (size_t i = 0; i < height_; i++) {
for (size_t j = 0; j < width_; j++) {
outData[i * outStride + j] = -(real)FLT_MAX;
}
}
/* pool max one by one */ /* pool max one by one */
for (size_t n = 0; n < num; ++n) { // frame by frame for (size_t n = 0; n < num; ++n) { // frame by frame
if (!isContiguous()) { if (!isContiguous()) {
...@@ -2030,19 +2023,24 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, ...@@ -2030,19 +2023,24 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
for (size_t c = 0; c < channels; ++c) { // channel by channel for (size_t c = 0; c < channels; ++c) { // channel by channel
for (size_t ph = 0; ph < outputH; ++ph) { for (size_t ph = 0; ph < outputH; ++ph) {
int hstart = ph * strideH - paddingH; int hstart = ph * strideH - paddingH;
int hend = std::min(hstart + sizeY, imgSizeH); int hend = hstart + sizeY;
hstart = std::max(hstart, 0); hstart = hstart < 0 ? 0 : hstart;
hend = hend < (int)imgSizeH ? hend : (int)imgSizeH;
for (size_t pw = 0; pw < outputW; ++pw) { for (size_t pw = 0; pw < outputW; ++pw) {
int wstart = pw * strideW - paddingW; int wstart = pw * strideW - paddingW;
int wend = std::min(wstart + sizeX, imgSizeW); int wend = wstart + sizeX;
wstart = std::max(wstart, 0); wstart = wstart < 0 ? 0 : wstart;
wend = wend < (int)imgSizeW ? wend : (int)imgSizeW;
if (maskData == NULL) { if (maskData == NULL) {
real tmp = -(real)FLT_MAX;
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
outData[ph * outputW + pw] = std::max( tmp = tmp < inputData[h * imgSizeW + w]
outData[ph * outputW + pw], inputData[h * imgSizeW + w]); ? inputData[h * imgSizeW + w]
: tmp;
} }
} }
outData[ph * outputW + pw] = tmp;
} else { } else {
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册