From 74a699a72ef9046a7f302e339c8e20a8152ae9d8 Mon Sep 17 00:00:00 2001 From: wangmeng28 Date: Mon, 20 Nov 2017 22:14:24 +0800 Subject: [PATCH] change clone to resizeOrCreate in fm layer --- .../gserver/layers/FactorizationMachineLayer.cpp | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/paddle/gserver/layers/FactorizationMachineLayer.cpp b/paddle/gserver/layers/FactorizationMachineLayer.cpp index b665fb6dfc4..be26b9ba88c 100644 --- a/paddle/gserver/layers/FactorizationMachineLayer.cpp +++ b/paddle/gserver/layers/FactorizationMachineLayer.cpp @@ -58,16 +58,22 @@ void FactorizationMachineLayer::forward(PassType passType) { inputMulFactor_, batchSize, factorSize_, false, useGpu_); Matrix::resizeOrCreate(tmpOut_, batchSize, factorSize_, false, useGpu_); - REGISTER_TIMER_INFO("InputMulFactorTimer", getName().c_str()); + REGISTER_TIMER_INFO("FmInputMulFactorTimer", getName().c_str()); inputMulFactor_->mul(*inputV, *latentVectors_->getW()); inputMulFactor_->square2(*tmpOut_); outV->sumRows(*tmpOut_, 0.5, 0); - inputSquare_ = inputV->clone(0, 0, useGpu_); - if (dynamic_cast(inputSquare_.get())) { + if (dynamic_cast(inputV.get())) { + Matrix::resizeOrCreateSparseMatrix(inputSquare_, + inputV->getHeight(), + inputV->getWidth(), + inputV->getElementCnt(), + inputV->getValueType()); inputSquare_->copyFrom(*inputV); (dynamic_cast(inputSquare_.get()))->square2(); } else { + Matrix::resizeOrCreate( + inputSquare_, inputV->getHeight(), inputV->getWidth(), false, useGpu_); inputV->square2(*inputSquare_); } latentVectors_->getW()->square2(*latentVectorsSquare_); @@ -75,7 +81,7 @@ void FactorizationMachineLayer::forward(PassType passType) { outV->sumRows(*tmpOut_, -0.5, 1.0); /* activation */ { - REGISTER_TIMER_INFO("FmAtvTimer", getName().c_str()); + REGISTER_TIMER_INFO("FmFwAtvTimer", getName().c_str()); forwardActivation(); } } -- GitLab