提交 0fce0fe6 编写于 作者: D dangqingqing

Reduce memory usage in conv layer and RoI layer for mobile inference.

上级 d5cab4f0
...@@ -233,6 +233,11 @@ public: ...@@ -233,6 +233,11 @@ public:
inputGrad += inputChannels * inputHeight * inputWidth; inputGrad += inputChannels * inputHeight * inputWidth;
outputGrad += outputChannels * outputHeight * outputWidth; outputGrad += outputChannels * outputHeight * outputWidth;
} }
#ifdef PADDLE_MOBILE_INFERENCE
if (Device == DEVICE_TYPE_CPU) {
delete memory_;
}
#endif
} }
}; };
......
...@@ -84,12 +84,15 @@ void ROIPoolLayer::forward(PassType passType) { ...@@ -84,12 +84,15 @@ void ROIPoolLayer::forward(PassType passType) {
size_t poolChannelOffset = pooledHeight_ * pooledWidth_; size_t poolChannelOffset = pooledHeight_ * pooledWidth_;
real* outputData = outputValue->getData(); real* outputData = outputValue->getData();
Matrix::resizeOrCreate(maxIdxs_, real* argmaxData = nullptr;
numROIs, if (passType != PASS_TEST) {
channels_ * pooledHeight_ * pooledWidth_, Matrix::resizeOrCreate(maxIdxs_,
false, numROIs,
false); channels_ * pooledHeight_ * pooledWidth_,
real* argmaxData = maxIdxs_->getData(); false,
false);
argmaxData = maxIdxs_->getData();
}
for (size_t n = 0; n < numROIs; ++n) { for (size_t n = 0; n < numROIs; ++n) {
// the first five elememts of each RoI should be: // the first five elememts of each RoI should be:
...@@ -128,14 +131,18 @@ void ROIPoolLayer::forward(PassType passType) { ...@@ -128,14 +131,18 @@ void ROIPoolLayer::forward(PassType passType) {
bool isEmpty = (hend <= hstart) || (wend <= wstart); bool isEmpty = (hend <= hstart) || (wend <= wstart);
size_t poolIndex = ph * pooledWidth_ + pw; size_t poolIndex = ph * pooledWidth_ + pw;
outputData[poolIndex] = isEmpty ? 0 : -FLT_MAX; outputData[poolIndex] = isEmpty ? 0 : -FLT_MAX;
argmaxData[poolIndex] = -1; if (argmaxData) {
argmaxData[poolIndex] = -1;
}
for (size_t h = hstart; h < hend; ++h) { for (size_t h = hstart; h < hend; ++h) {
for (size_t w = wstart; w < wend; ++w) { for (size_t w = wstart; w < wend; ++w) {
size_t index = h * width_ + w; size_t index = h * width_ + w;
if (batchData[index] > outputData[poolIndex]) { if (batchData[index] > outputData[poolIndex]) {
outputData[poolIndex] = batchData[index]; outputData[poolIndex] = batchData[index];
argmaxData[poolIndex] = index; if (argmaxData) {
argmaxData[poolIndex] = index;
}
} }
} }
} }
...@@ -143,7 +150,9 @@ void ROIPoolLayer::forward(PassType passType) { ...@@ -143,7 +150,9 @@ void ROIPoolLayer::forward(PassType passType) {
} }
batchData += channelOffset; batchData += channelOffset;
outputData += poolChannelOffset; outputData += poolChannelOffset;
argmaxData += poolChannelOffset; if (argmaxData) {
argmaxData += poolChannelOffset;
}
} }
bottomROIs += roiOffset; bottomROIs += roiOffset;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册