From 0fce0fe6983f4f167b873465fc90cff08fc31bd9 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Fri, 15 Dec 2017 15:48:58 +0800 Subject: [PATCH] Reduce memory usage in conv layer and RoI layer for mobile inference. --- paddle/function/GemmConvOp.cpp | 5 +++++ paddle/gserver/layers/ROIPoolLayer.cpp | 27 +++++++++++++++++--------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 8d34eee886a..ffbf366fa9f 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -233,6 +233,11 @@ public: inputGrad += inputChannels * inputHeight * inputWidth; outputGrad += outputChannels * outputHeight * outputWidth; } +#ifdef PADDLE_MOBILE_INFERENCE + if (Device == DEVICE_TYPE_CPU) { + delete memory_; + } +#endif } }; diff --git a/paddle/gserver/layers/ROIPoolLayer.cpp b/paddle/gserver/layers/ROIPoolLayer.cpp index 2c8256b91c9..7d7c30b4d89 100644 --- a/paddle/gserver/layers/ROIPoolLayer.cpp +++ b/paddle/gserver/layers/ROIPoolLayer.cpp @@ -84,12 +84,15 @@ void ROIPoolLayer::forward(PassType passType) { size_t poolChannelOffset = pooledHeight_ * pooledWidth_; real* outputData = outputValue->getData(); - Matrix::resizeOrCreate(maxIdxs_, - numROIs, - channels_ * pooledHeight_ * pooledWidth_, - false, - false); - real* argmaxData = maxIdxs_->getData(); + real* argmaxData = nullptr; + if (passType != PASS_TEST) { + Matrix::resizeOrCreate(maxIdxs_, + numROIs, + channels_ * pooledHeight_ * pooledWidth_, + false, + false); + argmaxData = maxIdxs_->getData(); + } for (size_t n = 0; n < numROIs; ++n) { // the first five elememts of each RoI should be: @@ -128,14 +131,18 @@ void ROIPoolLayer::forward(PassType passType) { bool isEmpty = (hend <= hstart) || (wend <= wstart); size_t poolIndex = ph * pooledWidth_ + pw; outputData[poolIndex] = isEmpty ? 0 : -FLT_MAX; - argmaxData[poolIndex] = -1; + if (argmaxData) { + argmaxData[poolIndex] = -1; + } for (size_t h = hstart; h < hend; ++h) { for (size_t w = wstart; w < wend; ++w) { size_t index = h * width_ + w; if (batchData[index] > outputData[poolIndex]) { outputData[poolIndex] = batchData[index]; - argmaxData[poolIndex] = index; + if (argmaxData) { + argmaxData[poolIndex] = index; + } } } } @@ -143,7 +150,9 @@ void ROIPoolLayer::forward(PassType passType) { } batchData += channelOffset; outputData += poolChannelOffset; - argmaxData += poolChannelOffset; + if (argmaxData) { + argmaxData += poolChannelOffset; + } } bottomROIs += roiOffset; } -- GitLab