提交 5ccf84ab 编写于 作者: E emailweixu 提交者: GitHub

Merge pull request #383 from lzhao4ever/fix_matrix_inverse

Fix matrix inverse  unittest to be more robust
...@@ -187,6 +187,15 @@ MatrixPtr Matrix::subMatrix(size_t startRow, size_t endRow, size_t startCol, ...@@ -187,6 +187,15 @@ MatrixPtr Matrix::subMatrix(size_t startRow, size_t endRow, size_t startCol,
trans_, useGpu_); trans_, useGpu_);
} }
void Matrix::setDiag(real value) {
CHECK(data_ != NULL);
CHECK_EQ(height_, width_);
zeroMem();
BaseMatrix diag(height_, 1, stride_ + 1, data_, false, useGpu_);
diag.assign(value);
}
GpuMatrix::GpuMatrix(size_t height, size_t width, bool trans) GpuMatrix::GpuMatrix(size_t height, size_t width, bool trans)
: Matrix(std::make_shared<GpuMemoryHandle>(height * width * sizeof(real)), : Matrix(std::make_shared<GpuMemoryHandle>(height * width * sizeof(real)),
height, width, trans, true) {} height, width, trans, true) {}
...@@ -202,6 +211,7 @@ void GpuMatrix::resetOne() { ...@@ -202,6 +211,7 @@ void GpuMatrix::resetOne() {
CHECK(data_ != NULL); CHECK(data_ != NULL);
one(); one();
} }
void GpuMatrix::resize(size_t newHeight, size_t newWidth) { void GpuMatrix::resize(size_t newHeight, size_t newWidth) {
size_t newSize = newHeight * newWidth; size_t newSize = newHeight * newWidth;
if (NULL == memoryHandle_.get() || if (NULL == memoryHandle_.get() ||
......
...@@ -195,6 +195,8 @@ public: ...@@ -195,6 +195,8 @@ public:
virtual void resetOne() { LOG(FATAL) << "Not implemented"; } virtual void resetOne() { LOG(FATAL) << "Not implemented"; }
void setDiag(real value);
virtual void copyFrom(const Matrix& src) { LOG(FATAL) << "Not implemented"; } virtual void copyFrom(const Matrix& src) { LOG(FATAL) << "Not implemented"; }
virtual void trimFrom(const CpuSparseMatrix& src) { virtual void trimFrom(const CpuSparseMatrix& src) {
...@@ -330,6 +332,7 @@ public: ...@@ -330,6 +332,7 @@ public:
virtual MatrixPtr getInverse() { virtual MatrixPtr getInverse() {
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
return nullptr;
} }
/** /**
...@@ -1016,6 +1019,7 @@ public: ...@@ -1016,6 +1019,7 @@ public:
void zeroMem(); void zeroMem();
void resetOne(); void resetOne();
void setDiag(real value);
void resize(size_t newHeight, size_t newWidth); void resize(size_t newHeight, size_t newWidth);
void resize(size_t newHeight, size_t newWidth, void resize(size_t newHeight, size_t newWidth,
...@@ -1280,6 +1284,8 @@ public: ...@@ -1280,6 +1284,8 @@ public:
void zeroMem(); void zeroMem();
void resetOne(); void resetOne();
void setDiag(real value);
void resize(size_t newHeight, size_t newWidth); void resize(size_t newHeight, size_t newWidth);
void resize(size_t newHeight, size_t newWidth, void resize(size_t newHeight, size_t newWidth,
size_t newNnz, /* used to allocate space */ size_t newNnz, /* used to allocate space */
......
...@@ -647,20 +647,23 @@ void testMatrixInverse(int height) { ...@@ -647,20 +647,23 @@ void testMatrixInverse(int height) {
MatrixPtr cpuI = std::make_shared<CpuMatrix>(height, height); MatrixPtr cpuI = std::make_shared<CpuMatrix>(height, height);
MatrixPtr gpuI = std::make_shared<GpuMatrix>(height, height); MatrixPtr gpuI = std::make_shared<GpuMatrix>(height, height);
/* Make matrix well conditioned: cpu * cpuT + Identity */
cpu->randomizeUniform(); cpu->randomizeUniform();
MatrixPtr cpuT = cpu->getTranspose();
MatrixPtr outputCheck = std::make_shared<CpuMatrix>(height, height);
outputCheck->mul(cpu, cpuT);
cpu->setDiag(1.0);
cpu->add(*outputCheck);
gpu->copyFrom(*cpu); gpu->copyFrom(*cpu);
cpu->inverse(cpuI, false); cpu->inverse(cpuI, false);
gpu->inverse(gpuI, false); gpu->inverse(gpuI, false);
MatrixPtr outputCheck = std::make_shared<CpuMatrix>(height, height);
outputCheck->copyFrom(*gpuI); outputCheck->copyFrom(*gpuI);
MatrixCheckErr(*cpuI, *outputCheck); MatrixCheckErr(*cpuI, *outputCheck);
outputCheck->mul(cpu, cpuI); outputCheck->mul(cpu, cpuI);
cpu->zeroMem(); cpu->setDiag(1.0);
for (int i = 0; i < height; i++) {
cpu->getRowBuf(i)[i] = 1.0;
}
MatrixCheckErr(*cpu, *outputCheck); MatrixCheckErr(*cpu, *outputCheck);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册