提交 3842bc4d 编写于 作者: L liaogang

refine code

上级 258e5ec5
...@@ -285,32 +285,34 @@ void MultiGradientMachine::forwardBackward(const std::vector<Argument>& inArgs, ...@@ -285,32 +285,34 @@ void MultiGradientMachine::forwardBackward(const std::vector<Argument>& inArgs,
MatrixPtr MultiGradientMachine::getLayerOutput(const std::string& layerName) { MatrixPtr MultiGradientMachine::getLayerOutput(const std::string& layerName) {
// each thread has the same neuro network // each thread has the same neuro network
auto nn = threads_[0]->getGradientMachine(); auto nn = threads_[0]->getGradientMachine();
size_t height = 0; size_t height = 0;
size_t width = nn->getLayerOutput(layerName)->getWidth(); size_t width = nn->getLayerOutput(layerName)->getWidth();
std::vector<MatrixPtr> mats;
mats.reserve(threads_.size());
for (auto& thread : threads_) { for (auto& thread : threads_) {
auto out = thread->getGradientMachine()->getLayerOutput(layerName); MatrixPtr out = thread->getGradientMachine()->getLayerOutput(layerName);
mats.push_back(out);
height += out->getHeight(); height += out->getHeight();
CHECK_EQ(width, out->getWidth()); CHECK_EQ(width, out->getWidth());
} }
MatrixPtr dst; MatrixPtr layerOutput;
Matrix::resizeOrCreate(dst, height, width, false, useGpu_); Matrix::resizeOrCreate(layerOutput, height, width, false, useGpu_);
// copy one layer output from one trainer thread at each time // copy one layer output from one trainer thread at each time
size_t startRow = 0; size_t startRow = 0;
for (auto& thread : threads_) {
auto src = thread->getGradientMachine()->getLayerOutput(layerName); for (size_t i = 0; i < threads_.size(); i++) {
auto tmpMatrix = dst->subMatrix(startRow, src->getHeight()); auto tmpMatrix = layerOutput->subMatrix(startRow, mats[i]->getHeight());
tmpMatrix->copyFrom(*src, HPPL_STREAM_DEFAULT); tmpMatrix->copyFrom(*mats[i], HPPL_STREAM_DEFAULT);
startRow += src->getHeight(); startRow += mats[i]->getHeight();
} }
if (useGpu_) { if (useGpu_) {
hl_stream_synchronize(HPPL_STREAM_DEFAULT); hl_stream_synchronize(HPPL_STREAM_DEFAULT);
} }
return dst; return layerOutput;
} }
void MultiGradientMachine::backwardImp(const UpdateCallback& callback) { void MultiGradientMachine::backwardImp(const UpdateCallback& callback) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册