提交 8c40bfd0 编写于 作者: L Liang Zhao

Make matrix well-conditioned when unittest inverse

上级 0c7ac3d9
...@@ -202,6 +202,17 @@ void GpuMatrix::resetOne() { ...@@ -202,6 +202,17 @@ void GpuMatrix::resetOne() {
CHECK(data_ != NULL); CHECK(data_ != NULL);
one(); one();
} }
void GpuMatrix::setDiag(real value) {
CHECK(data_ != NULL);
CHECK_EQ(height_, width_);
zeroMem();
for (size_t i = 0; i < height_; i++) {
hl_memcpy_host2device(&data_[i * stride_ + i], &value, sizeof(real));
}
}
void GpuMatrix::resize(size_t newHeight, size_t newWidth) { void GpuMatrix::resize(size_t newHeight, size_t newWidth) {
size_t newSize = newHeight * newWidth; size_t newSize = newHeight * newWidth;
if (NULL == memoryHandle_.get() || if (NULL == memoryHandle_.get() ||
...@@ -1244,6 +1255,16 @@ void CpuMatrix::resetOne() { ...@@ -1244,6 +1255,16 @@ void CpuMatrix::resetOne() {
BaseMatrix::one(); BaseMatrix::one();
} }
void CpuMatrix::setDiag(real value) {
CHECK(data_ != NULL);
CHECK_EQ(height_, width_);
zeroMem();
for (size_t i = 0; i < height_; i++) {
data_[i * stride_ + i] = value;
}
}
void CpuMatrix::copyFrom(const Matrix& src) { void CpuMatrix::copyFrom(const Matrix& src) {
CHECK(isContiguous()); CHECK(isContiguous());
if (typeid(src) == typeid(GpuMatrix)) { if (typeid(src) == typeid(GpuMatrix)) {
......
...@@ -195,6 +195,8 @@ public: ...@@ -195,6 +195,8 @@ public:
virtual void resetOne() { LOG(FATAL) << "Not implemented"; } virtual void resetOne() { LOG(FATAL) << "Not implemented"; }
virtual void setDiag(real value) { LOG(FATAL) << "Not implemented"; }
virtual void copyFrom(const Matrix& src) { LOG(FATAL) << "Not implemented"; } virtual void copyFrom(const Matrix& src) { LOG(FATAL) << "Not implemented"; }
virtual void trimFrom(const CpuSparseMatrix& src) { virtual void trimFrom(const CpuSparseMatrix& src) {
...@@ -330,6 +332,7 @@ public: ...@@ -330,6 +332,7 @@ public:
virtual MatrixPtr getInverse() { virtual MatrixPtr getInverse() {
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
return nullptr;
} }
/** /**
...@@ -1016,6 +1019,7 @@ public: ...@@ -1016,6 +1019,7 @@ public:
void zeroMem(); void zeroMem();
void resetOne(); void resetOne();
void setDiag(real value);
void resize(size_t newHeight, size_t newWidth); void resize(size_t newHeight, size_t newWidth);
void resize(size_t newHeight, size_t newWidth, void resize(size_t newHeight, size_t newWidth,
...@@ -1280,6 +1284,8 @@ public: ...@@ -1280,6 +1284,8 @@ public:
void zeroMem(); void zeroMem();
void resetOne(); void resetOne();
void setDiag(real value);
void resize(size_t newHeight, size_t newWidth); void resize(size_t newHeight, size_t newWidth);
void resize(size_t newHeight, size_t newWidth, void resize(size_t newHeight, size_t newWidth,
size_t newNnz, /* used to allocate space */ size_t newNnz, /* used to allocate space */
......
...@@ -647,20 +647,23 @@ void testMatrixInverse(int height) { ...@@ -647,20 +647,23 @@ void testMatrixInverse(int height) {
MatrixPtr cpuI = std::make_shared<CpuMatrix>(height, height); MatrixPtr cpuI = std::make_shared<CpuMatrix>(height, height);
MatrixPtr gpuI = std::make_shared<GpuMatrix>(height, height); MatrixPtr gpuI = std::make_shared<GpuMatrix>(height, height);
/* Make matrix well conditioned: cpu * cpuT + Identity */
cpu->randomizeUniform(); cpu->randomizeUniform();
MatrixPtr cpuT = cpu->getTranspose();
MatrixPtr outputCheck = std::make_shared<CpuMatrix>(height, height);
outputCheck->mul(cpu, cpuT);
cpu->setDiag(1.0);
cpu->add(*outputCheck);
gpu->copyFrom(*cpu); gpu->copyFrom(*cpu);
cpu->inverse(cpuI, false); cpu->inverse(cpuI, false);
gpu->inverse(gpuI, false); gpu->inverse(gpuI, false);
MatrixPtr outputCheck = std::make_shared<CpuMatrix>(height, height);
outputCheck->copyFrom(*gpuI); outputCheck->copyFrom(*gpuI);
MatrixCheckErr(*cpuI, *outputCheck); MatrixCheckErr(*cpuI, *outputCheck);
outputCheck->mul(cpu, cpuI); outputCheck->mul(cpu, cpuI);
cpu->zeroMem(); cpu->setDiag(1.0);
for (int i = 0; i < height; i++) {
cpu->getRowBuf(i)[i] = 1.0;
}
MatrixCheckErr(*cpu, *outputCheck); MatrixCheckErr(*cpu, *outputCheck);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册