提交 6fed6f20 编写于 作者: W wangmeng28

Add support of sparse_binary_vector as input for fm layer

上级 5392a503
......@@ -96,15 +96,20 @@ void FactorizationMachineLayer::backward(const UpdateCallback& callback) {
/* Calculate the gradients of the latentVectors_ matrix */
if (latentVectors_->getWGrad()) {
MatrixPtr tmpInput = inputV->clone(0, 0, useGpu_);
if (dynamic_cast<CpuSparseMatrix*>(inputV.get())) {
Matrix::resizeOrCreateSparseMatrix(tmpInput_,
inputV->getHeight(),
inputV->getWidth(),
inputV->getElementCnt());
CpuSparseMatrix* sparseInputV =
dynamic_cast<CpuSparseMatrix*>(inputV.get());
CpuSparseMatrix* sparseInputSquare =
dynamic_cast<CpuSparseMatrix*>(inputSquare_.get());
CpuSparseMatrix* sparseTmpInput =
dynamic_cast<CpuSparseMatrix*>(tmpInput.get());
dynamic_cast<CpuSparseMatrix*>(tmpInput_.get());
sparseTmpInput->copyFrom(*sparseInputV);
sparseTmpInput->rowScale(0, *sparseInputV, *oGrad);
latentVectors_->getWGrad()->mul(
*sparseTmpInput->getTranspose(), *inputMulFactor_, 1, 1);
......@@ -115,12 +120,15 @@ void FactorizationMachineLayer::backward(const UpdateCallback& callback) {
negOnes_->add(-1);
tmpSum_->mul(*negOnes_, *sparseTmpInput, 1, 0);
} else {
tmpInput->rowScale(0, *inputV, *oGrad);
Matrix::resizeOrCreate(
tmpInput_, inputV->getHeight(), inputV->getWidth(), false, useGpu_);
tmpInput_->rowScale(0, *inputV, *oGrad);
latentVectors_->getWGrad()->mul(
*tmpInput->getTranspose(), *inputMulFactor_, 1, 1);
tmpInput->rowScale(0, *inputSquare_, *oGrad);
*tmpInput_->getTranspose(), *inputMulFactor_, 1, 1);
tmpInput_->rowScale(0, *inputSquare_, *oGrad);
tmpSum_->sumCols(*tmpInput, -1, 0);
tmpSum_->sumCols(*tmpInput_, -1, 0);
}
latentVectors_->getWGrad()->addRowScale(
......
......@@ -61,6 +61,7 @@ private:
// Store temporary calculation result
MatrixPtr tmpOut_;
MatrixPtr tmpSum_;
MatrixPtr tmpInput_;
// Negative identity matrix
MatrixPtr negOnes_;
......
......@@ -266,13 +266,25 @@ void CpuSparseMatrix::rowScale(size_t cCol, CpuSparseMatrix& b, Matrix& c) {
CHECK_EQ(width_, b.getWidth());
real* A = getValue();
real* B = b.getValue();
for (size_t i = 0; i < height_; i++) {
size_t start = getRowStartIdx(i);
size_t end = getRowStartIdx(i + 1);
CHECK_EQ(start, b.getRowStartIdx(i));
CHECK_EQ(end, b.getRowStartIdx(i + 1));
for (size_t j = start; j < end; j++) {
A[j] = B[j] * c.getElement(i, cCol);
if (b.getValueType() == FLOAT_VALUE) {
for (size_t i = 0; i < height_; i++) {
size_t start = getRowStartIdx(i);
size_t end = getRowStartIdx(i + 1);
CHECK_EQ(start, b.getRowStartIdx(i));
CHECK_EQ(end, b.getRowStartIdx(i + 1));
for (size_t j = start; j < end; j++) {
A[j] = B[j] * c.getElement(i, cCol);
}
}
} else if (b.getValueType() == NO_VALUE) {
for (size_t i = 0; i < height_; i++) {
size_t start = getRowStartIdx(i);
size_t end = getRowStartIdx(i + 1);
CHECK_EQ(start, b.getRowStartIdx(i));
CHECK_EQ(end, b.getRowStartIdx(i + 1));
for (size_t j = start; j < end; j++) {
A[j] = c.getElement(i, cCol);
}
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册