提交 6fed6f20 编写于 作者: W wangmeng28

Add support of sparse_binary_vector as input for fm layer

上级 5392a503
...@@ -96,15 +96,20 @@ void FactorizationMachineLayer::backward(const UpdateCallback& callback) { ...@@ -96,15 +96,20 @@ void FactorizationMachineLayer::backward(const UpdateCallback& callback) {
/* Calculate the gradients of the latentVectors_ matrix */ /* Calculate the gradients of the latentVectors_ matrix */
if (latentVectors_->getWGrad()) { if (latentVectors_->getWGrad()) {
MatrixPtr tmpInput = inputV->clone(0, 0, useGpu_);
if (dynamic_cast<CpuSparseMatrix*>(inputV.get())) { if (dynamic_cast<CpuSparseMatrix*>(inputV.get())) {
Matrix::resizeOrCreateSparseMatrix(tmpInput_,
inputV->getHeight(),
inputV->getWidth(),
inputV->getElementCnt());
CpuSparseMatrix* sparseInputV = CpuSparseMatrix* sparseInputV =
dynamic_cast<CpuSparseMatrix*>(inputV.get()); dynamic_cast<CpuSparseMatrix*>(inputV.get());
CpuSparseMatrix* sparseInputSquare = CpuSparseMatrix* sparseInputSquare =
dynamic_cast<CpuSparseMatrix*>(inputSquare_.get()); dynamic_cast<CpuSparseMatrix*>(inputSquare_.get());
CpuSparseMatrix* sparseTmpInput = CpuSparseMatrix* sparseTmpInput =
dynamic_cast<CpuSparseMatrix*>(tmpInput.get()); dynamic_cast<CpuSparseMatrix*>(tmpInput_.get());
sparseTmpInput->copyFrom(*sparseInputV); sparseTmpInput->copyFrom(*sparseInputV);
sparseTmpInput->rowScale(0, *sparseInputV, *oGrad); sparseTmpInput->rowScale(0, *sparseInputV, *oGrad);
latentVectors_->getWGrad()->mul( latentVectors_->getWGrad()->mul(
*sparseTmpInput->getTranspose(), *inputMulFactor_, 1, 1); *sparseTmpInput->getTranspose(), *inputMulFactor_, 1, 1);
...@@ -115,12 +120,15 @@ void FactorizationMachineLayer::backward(const UpdateCallback& callback) { ...@@ -115,12 +120,15 @@ void FactorizationMachineLayer::backward(const UpdateCallback& callback) {
negOnes_->add(-1); negOnes_->add(-1);
tmpSum_->mul(*negOnes_, *sparseTmpInput, 1, 0); tmpSum_->mul(*negOnes_, *sparseTmpInput, 1, 0);
} else { } else {
tmpInput->rowScale(0, *inputV, *oGrad); Matrix::resizeOrCreate(
tmpInput_, inputV->getHeight(), inputV->getWidth(), false, useGpu_);
tmpInput_->rowScale(0, *inputV, *oGrad);
latentVectors_->getWGrad()->mul( latentVectors_->getWGrad()->mul(
*tmpInput->getTranspose(), *inputMulFactor_, 1, 1); *tmpInput_->getTranspose(), *inputMulFactor_, 1, 1);
tmpInput->rowScale(0, *inputSquare_, *oGrad); tmpInput_->rowScale(0, *inputSquare_, *oGrad);
tmpSum_->sumCols(*tmpInput, -1, 0); tmpSum_->sumCols(*tmpInput_, -1, 0);
} }
latentVectors_->getWGrad()->addRowScale( latentVectors_->getWGrad()->addRowScale(
......
...@@ -61,6 +61,7 @@ private: ...@@ -61,6 +61,7 @@ private:
// Store temporary calculation result // Store temporary calculation result
MatrixPtr tmpOut_; MatrixPtr tmpOut_;
MatrixPtr tmpSum_; MatrixPtr tmpSum_;
MatrixPtr tmpInput_;
// Negative identity matrix // Negative identity matrix
MatrixPtr negOnes_; MatrixPtr negOnes_;
......
...@@ -266,13 +266,25 @@ void CpuSparseMatrix::rowScale(size_t cCol, CpuSparseMatrix& b, Matrix& c) { ...@@ -266,13 +266,25 @@ void CpuSparseMatrix::rowScale(size_t cCol, CpuSparseMatrix& b, Matrix& c) {
CHECK_EQ(width_, b.getWidth()); CHECK_EQ(width_, b.getWidth());
real* A = getValue(); real* A = getValue();
real* B = b.getValue(); real* B = b.getValue();
for (size_t i = 0; i < height_; i++) { if (b.getValueType() == FLOAT_VALUE) {
size_t start = getRowStartIdx(i); for (size_t i = 0; i < height_; i++) {
size_t end = getRowStartIdx(i + 1); size_t start = getRowStartIdx(i);
CHECK_EQ(start, b.getRowStartIdx(i)); size_t end = getRowStartIdx(i + 1);
CHECK_EQ(end, b.getRowStartIdx(i + 1)); CHECK_EQ(start, b.getRowStartIdx(i));
for (size_t j = start; j < end; j++) { CHECK_EQ(end, b.getRowStartIdx(i + 1));
A[j] = B[j] * c.getElement(i, cCol); for (size_t j = start; j < end; j++) {
A[j] = B[j] * c.getElement(i, cCol);
}
}
} else if (b.getValueType() == NO_VALUE) {
for (size_t i = 0; i < height_; i++) {
size_t start = getRowStartIdx(i);
size_t end = getRowStartIdx(i + 1);
CHECK_EQ(start, b.getRowStartIdx(i));
CHECK_EQ(end, b.getRowStartIdx(i + 1));
for (size_t j = start; j < end; j++) {
A[j] = c.getElement(i, cCol);
}
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册