提交 7c5fd231 编写于 作者: L liaogang

Update MultiGradientMachine::getLayerOutput

上级 393d8354
...@@ -283,41 +283,34 @@ void MultiGradientMachine::forwardBackward(const std::vector<Argument>& inArgs, ...@@ -283,41 +283,34 @@ void MultiGradientMachine::forwardBackward(const std::vector<Argument>& inArgs,
} }
MatrixPtr MultiGradientMachine::getLayerOutput(const std::string& layerName) { MatrixPtr MultiGradientMachine::getLayerOutput(const std::string& layerName) {
// neural networks are same in each trainer thread // each thread has the same neuro network
// layer output height = height of layer output * thread nums auto nn = threads_[0]->getGradientMachine();
auto nn = dynamic_cast<NeuralNetwork*>(threads_[0]->getGradientMachine());
auto height = nn->getLayerOutput(layerName)->getHeight() * threads_.size();
auto stream = HPPL_STREAM_DEFAULT;
auto copyLayerOutput = [height, stream](
MatrixPtr& dst, MatrixPtr src, int startRow, bool useGpu) {
size_t width = src->getWidth();
if (!dst) {
dst = src->clone(height, width, useGpu);
} else {
dst->resize(height, width);
}
MatrixPtr tmpMatrix = dst->subMatrix(startRow, src->getHeight()); size_t height = 0;
tmpMatrix->copyFrom(*src, stream); size_t width = nn->getLayerOutput(layerName)->getWidth();
}; for (auto& thread : threads_) {
auto out = thread->getGradientMachine()->getLayerOutput(layerName);
height += out->getHeight();
CHECK_EQ(width, out->getWidth());
}
MatrixPtr mats; MatrixPtr dst;
size_t startRow = 0; Matrix::resizeOrCreate(dst, height, width, false, useGpu_);
// copy one layer output from one trainer thread at each time // copy one layer output from one trainer thread at each time
size_t startRow = 0;
for (auto& thread : threads_) { for (auto& thread : threads_) {
auto nn = dynamic_cast<NeuralNetwork*>(thread->getGradientMachine()); auto src = thread->getGradientMachine()->getLayerOutput(layerName);
auto mat = nn->getLayerOutput(layerName); auto tmpMatrix = dst->subMatrix(startRow, src->getHeight());
copyLayerOutput(mats, mat, startRow, useGpu_); tmpMatrix->copyFrom(*src, HPPL_STREAM_DEFAULT);
startRow += mat->getHeight(); startRow += src->getHeight();
} }
if (useGpu_) { if (useGpu_) {
hl_stream_synchronize(HPPL_STREAM_DEFAULT); hl_stream_synchronize(HPPL_STREAM_DEFAULT);
} }
return mats; return dst;
} }
void MultiGradientMachine::backwardImp(const UpdateCallback& callback) { void MultiGradientMachine::backwardImp(const UpdateCallback& callback) {
......
...@@ -42,7 +42,7 @@ void CosSimLayer::forward(PassType passType) { ...@@ -42,7 +42,7 @@ void CosSimLayer::forward(PassType passType) {
/* malloc memory for the output_ if necessary */ /* malloc memory for the output_ if necessary */
int batchSize = getInputValue(0)->getHeight(); int batchSize = getInputValue(0)->getHeight();
int size = getSize(); int size = getSize();
CHECK_EQ(forward_.size(), 1) << "Only one forward function needed"; CHECK_EQ(forward_.size(), 1UL) << "Only one forward function needed";
{ {
REGISTER_TIMER_INFO("CosFwResetTimer", getName().c_str()); REGISTER_TIMER_INFO("CosFwResetTimer", getName().c_str());
......
...@@ -17,10 +17,10 @@ limitations under the License. */ ...@@ -17,10 +17,10 @@ limitations under the License. */
TEST(RowBuffer, testAutoGrow) { TEST(RowBuffer, testAutoGrow) {
paddle::RowBuffer buf(128); paddle::RowBuffer buf(128);
ASSERT_EQ(128, buf.getWidth()); ASSERT_EQ(128UL, buf.getWidth());
ASSERT_TRUE(buf.isAutoGrowth()); ASSERT_TRUE(buf.isAutoGrowth());
buf.resize(2); buf.resize(2);
ASSERT_EQ(2, buf.getRowCount()); ASSERT_EQ(2UL, buf.getRowCount());
for (size_t i = 0; i < buf.getWidth() * 2; ++i) { for (size_t i = 0; i < buf.getWidth() * 2; ++i) {
buf.data()[i] = i; buf.data()[i] = i;
} }
...@@ -35,7 +35,7 @@ TEST(RowBuffer, testAutoGrow) { ...@@ -35,7 +35,7 @@ TEST(RowBuffer, testAutoGrow) {
data[i] = i; data[i] = i;
} }
ASSERT_EQ(3, buf.getRowCount()); ASSERT_EQ(3UL, buf.getRowCount());
for (size_t i = 0; i < buf.getRowCount() - 1; ++i) { for (size_t i = 0; i < buf.getRowCount() - 1; ++i) {
for (size_t j = 0; j < buf.getWidth(); ++j) { for (size_t j = 0; j < buf.getWidth(); ++j) {
ASSERT_NEAR(i * buf.getWidth() + j, buf.get(i)[j], 1e-5); ASSERT_NEAR(i * buf.getWidth() + j, buf.get(i)[j], 1e-5);
...@@ -51,7 +51,7 @@ TEST(RowBuffer, testWithMemBuf) { ...@@ -51,7 +51,7 @@ TEST(RowBuffer, testWithMemBuf) {
std::make_shared<paddle::CpuMemoryHandle>(128 * 2 * sizeof(real)); std::make_shared<paddle::CpuMemoryHandle>(128 * 2 * sizeof(real));
paddle::RowBuffer buf(mem, 128); paddle::RowBuffer buf(mem, 128);
ASSERT_TRUE(!buf.isAutoGrowth()); ASSERT_TRUE(!buf.isAutoGrowth());
ASSERT_EQ(2, buf.getRowCount()); ASSERT_EQ(2UL, buf.getRowCount());
for (size_t i = 0; i < buf.getWidth() * 2; ++i) { for (size_t i = 0; i < buf.getWidth() * 2; ++i) {
buf.data()[i] = i; buf.data()[i] = i;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册