From 3842bc4d7c904b2d0bda4aa48429a20c317f1420 Mon Sep 17 00:00:00 2001 From: liaogang Date: Fri, 17 Feb 2017 13:42:33 +0800 Subject: [PATCH] refine code --- .../gradientmachines/MultiGradientMachine.cpp | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp index 56b1836e4..db13a8868 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp @@ -285,32 +285,34 @@ void MultiGradientMachine::forwardBackward(const std::vector& inArgs, MatrixPtr MultiGradientMachine::getLayerOutput(const std::string& layerName) { // each thread has the same neuro network auto nn = threads_[0]->getGradientMachine(); - size_t height = 0; size_t width = nn->getLayerOutput(layerName)->getWidth(); + std::vector mats; + mats.reserve(threads_.size()); for (auto& thread : threads_) { - auto out = thread->getGradientMachine()->getLayerOutput(layerName); + MatrixPtr out = thread->getGradientMachine()->getLayerOutput(layerName); + mats.push_back(out); height += out->getHeight(); CHECK_EQ(width, out->getWidth()); } - MatrixPtr dst; - Matrix::resizeOrCreate(dst, height, width, false, useGpu_); + MatrixPtr layerOutput; + Matrix::resizeOrCreate(layerOutput, height, width, false, useGpu_); // copy one layer output from one trainer thread at each time size_t startRow = 0; - for (auto& thread : threads_) { - auto src = thread->getGradientMachine()->getLayerOutput(layerName); - auto tmpMatrix = dst->subMatrix(startRow, src->getHeight()); - tmpMatrix->copyFrom(*src, HPPL_STREAM_DEFAULT); - startRow += src->getHeight(); + + for (size_t i = 0; i < threads_.size(); i++) { + auto tmpMatrix = layerOutput->subMatrix(startRow, mats[i]->getHeight()); + tmpMatrix->copyFrom(*mats[i], HPPL_STREAM_DEFAULT); + startRow += mats[i]->getHeight(); } if (useGpu_) { hl_stream_synchronize(HPPL_STREAM_DEFAULT); } - return dst; + return layerOutput; } void MultiGradientMachine::backwardImp(const UpdateCallback& callback) { -- GitLab