diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 4fc9b2d0893665c3d478fbec55f810c5bc99e236..a5b0d959536fa85d46c9cc0b4783027c471da895 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -187,6 +187,15 @@ MatrixPtr Matrix::subMatrix(size_t startRow, size_t endRow, size_t startCol, trans_, useGpu_); } +void Matrix::setDiag(real value) { + CHECK(data_ != NULL); + CHECK_EQ(height_, width_); + + zeroMem(); + BaseMatrix diag(height_, 1, stride_ + 1, data_, false, useGpu_); + diag.assign(value); +} + GpuMatrix::GpuMatrix(size_t height, size_t width, bool trans) : Matrix(std::make_shared(height * width * sizeof(real)), height, width, trans, true) {} @@ -202,6 +211,7 @@ void GpuMatrix::resetOne() { CHECK(data_ != NULL); one(); } + void GpuMatrix::resize(size_t newHeight, size_t newWidth) { size_t newSize = newHeight * newWidth; if (NULL == memoryHandle_.get() || diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 293d13f4d6d5af0883ea76fb64ca5d9173efd4e0..120957f45d0c93656a3f9e87ed59410513632e25 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -195,6 +195,8 @@ public: virtual void resetOne() { LOG(FATAL) << "Not implemented"; } + void setDiag(real value); + virtual void copyFrom(const Matrix& src) { LOG(FATAL) << "Not implemented"; } virtual void trimFrom(const CpuSparseMatrix& src) { @@ -330,6 +332,7 @@ public: virtual MatrixPtr getInverse() { LOG(FATAL) << "Not implemented"; + return nullptr; } /** @@ -1016,6 +1019,7 @@ public: void zeroMem(); void resetOne(); + void setDiag(real value); void resize(size_t newHeight, size_t newWidth); void resize(size_t newHeight, size_t newWidth, @@ -1280,6 +1284,8 @@ public: void zeroMem(); void resetOne(); + void setDiag(real value); + void resize(size_t newHeight, size_t newWidth); void resize(size_t newHeight, size_t newWidth, size_t newNnz, /* used to allocate space */ diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index b887cccaaa14e6c3761d151f31a859de66cf8fac..91a68006325f65516bc67a29c2ad4f0762b8f96d 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -647,20 +647,23 @@ void testMatrixInverse(int height) { MatrixPtr cpuI = std::make_shared(height, height); MatrixPtr gpuI = std::make_shared(height, height); + /* Make matrix well conditioned: cpu * cpuT + Identity */ cpu->randomizeUniform(); + MatrixPtr cpuT = cpu->getTranspose(); + MatrixPtr outputCheck = std::make_shared(height, height); + outputCheck->mul(cpu, cpuT); + cpu->setDiag(1.0); + cpu->add(*outputCheck); + gpu->copyFrom(*cpu); cpu->inverse(cpuI, false); gpu->inverse(gpuI, false); - MatrixPtr outputCheck = std::make_shared(height, height); outputCheck->copyFrom(*gpuI); MatrixCheckErr(*cpuI, *outputCheck); outputCheck->mul(cpu, cpuI); - cpu->zeroMem(); - for (int i = 0; i < height; i++) { - cpu->getRowBuf(i)[i] = 1.0; - } + cpu->setDiag(1.0); MatrixCheckErr(*cpu, *outputCheck); }