diff --git a/paddle/gserver/gradientmachines/GradientMachine.h b/paddle/gserver/gradientmachines/GradientMachine.h index 0829968d87c5dc7eeb2d1b70c758ff305d89496f..201b65bc45181b4f9a59cd742f91d36b68eadc7c 100644 --- a/paddle/gserver/gradientmachines/GradientMachine.h +++ b/paddle/gserver/gradientmachines/GradientMachine.h @@ -134,6 +134,8 @@ public: backward(callback); } + virtual MatrixPtr getLayerOutput(const std::string& layerName) = 0; + // see comment in Layer.h for the function with the same name virtual void resetState() {} diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp index 80f223824d8dccfb0e9386f4c076b28f9332a958..a571b3d72f1d520b03c9cbc0a8469e44f1ff037c 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp @@ -282,6 +282,44 @@ void MultiGradientMachine::forwardBackward(const std::vector& inArgs, backwardImp(callback); } +MatrixPtr MultiGradientMachine::getLayerOutput(const std::string& layerName) { + // neural networks are same in each trainer thread + // layer output height = height of layer output * thread nums + auto nn = dynamic_cast(threads_[0]->getGradientMachine()); + auto height = nn->getLayerOutput(layerName)->getHeight() * threads_.size(); + auto stream = HPPL_STREAM_DEFAULT; + + auto copyLayerOutput = [height, stream]( + MatrixPtr& dst, MatrixPtr src, int startRow, bool useGpu) { + size_t width = src->getWidth(); + if (!dst) { + dst = src->clone(height, width, useGpu); + } else { + dst->resize(height, width); + } + + MatrixPtr tmpMatrix = dst->subMatrix(startRow, src->getHeight()); + tmpMatrix->copyFrom(*src, stream); + }; + + MatrixPtr mats; + size_t startRow = 0; + + // copy one layer output from one trainer thread at each time + for (auto& thread : threads_) { + auto nn = dynamic_cast(thread->getGradientMachine()); + auto mat = nn->getLayerOutput(layerName); + copyLayerOutput(mats, mat, startRow, useGpu_); + startRow += mat->getHeight(); + } + + if (useGpu_) { + hl_stream_synchronize(HPPL_STREAM_DEFAULT); + } + + return mats; +} + void MultiGradientMachine::backwardImp(const UpdateCallback& callback) { for (size_t i = 0; i < parameters_.size(); i++) { if (!parameters_[i]->useGpu() || parameters_[i]->isStatic()) continue; diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.h b/paddle/gserver/gradientmachines/MultiGradientMachine.h index 9be15ef4bcf34f26b7eceb9047252e537f20a4a8..988d5098179806fc75aa2fae5dcc4330d7963257 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.h +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.h @@ -189,6 +189,8 @@ public: PassType passType, const UpdateCallback& callback); + virtual MatrixPtr getLayerOutput(const std::string& layerName); + virtual void onPassEnd(); virtual void finish(); diff --git a/paddle/gserver/gradientmachines/NeuralNetwork.cpp b/paddle/gserver/gradientmachines/NeuralNetwork.cpp index 22051e07ee0026bc3c44a8767e265a56b415b8e4..1f9ace4f67fdcab6af522277fa83bb0e6044360d 100644 --- a/paddle/gserver/gradientmachines/NeuralNetwork.cpp +++ b/paddle/gserver/gradientmachines/NeuralNetwork.cpp @@ -298,6 +298,7 @@ MatrixPtr NeuralNetwork::getLayerOutput(const std::string& layerName) { CHECK(it != layerMap_.end()) << "Cannot find layer: " << layerName; return it->second->getOutputValue(); } + void NeuralNetwork::onPassEnd() { for (auto& layer : layers_) { layer->onPassEnd(); diff --git a/paddle/gserver/gradientmachines/NeuralNetwork.h b/paddle/gserver/gradientmachines/NeuralNetwork.h index 25af4abcf81700e200feea806fa3daed19df1275..bf9ed09327f2f13585a2b37993f3139fe6cb862b 100644 --- a/paddle/gserver/gradientmachines/NeuralNetwork.h +++ b/paddle/gserver/gradientmachines/NeuralNetwork.h @@ -87,7 +87,8 @@ public: virtual void backward(const UpdateCallback& callback = nullptr); - MatrixPtr getLayerOutput(const std::string& layerName); + virtual MatrixPtr getLayerOutput(const std::string& layerName); + const LayerPtr& getLayer(const std::string& layerName) const { auto it = layerMap_.find(layerName); CHECK(it != layerMap_.end()) << "Unknown layer " << layerName;