diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index f0f5ebe3bd05ca63c7d55f7576ba192111832956..a5b0d959536fa85d46c9cc0b4783027c471da895 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -187,6 +187,15 @@ MatrixPtr Matrix::subMatrix(size_t startRow, size_t endRow, size_t startCol, trans_, useGpu_); } +void Matrix::setDiag(real value) { + CHECK(data_ != NULL); + CHECK_EQ(height_, width_); + + zeroMem(); + BaseMatrix diag(height_, 1, stride_ + 1, data_, false, useGpu_); + diag.assign(value); +} + GpuMatrix::GpuMatrix(size_t height, size_t width, bool trans) : Matrix(std::make_shared(height * width * sizeof(real)), height, width, trans, true) {} @@ -203,16 +212,6 @@ void GpuMatrix::resetOne() { one(); } -void GpuMatrix::setDiag(real value) { - CHECK(data_ != NULL); - CHECK_EQ(height_, width_); - - zeroMem(); - for (size_t i = 0; i < height_; i++) { - hl_memcpy_host2device(&data_[i * stride_ + i], &value, sizeof(real)); - } -} - void GpuMatrix::resize(size_t newHeight, size_t newWidth) { size_t newSize = newHeight * newWidth; if (NULL == memoryHandle_.get() || @@ -1255,16 +1254,6 @@ void CpuMatrix::resetOne() { BaseMatrix::one(); } -void CpuMatrix::setDiag(real value) { - CHECK(data_ != NULL); - CHECK_EQ(height_, width_); - - zeroMem(); - for (size_t i = 0; i < height_; i++) { - data_[i * stride_ + i] = value; - } -} - void CpuMatrix::copyFrom(const Matrix& src) { CHECK(isContiguous()); if (typeid(src) == typeid(GpuMatrix)) { diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 9e15055c056a1e36f4553657a9345e2e36995cf4..120957f45d0c93656a3f9e87ed59410513632e25 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -195,7 +195,7 @@ public: virtual void resetOne() { LOG(FATAL) << "Not implemented"; } - virtual void setDiag(real value) { LOG(FATAL) << "Not implemented"; } + void setDiag(real value); virtual void copyFrom(const Matrix& src) { LOG(FATAL) << "Not implemented"; }