From 992ac8f9a1c54080bc273f7748510e2b85c7f8cc Mon Sep 17 00:00:00 2001 From: Liang Zhao Date: Tue, 8 Nov 2016 10:36:22 -0800 Subject: [PATCH] Implement setDiag() with BaseMatrix::assign() --- paddle/math/Matrix.cpp | 29 +++++++++-------------------- paddle/math/Matrix.h | 2 +- 2 files changed, 10 insertions(+), 21 deletions(-) diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index f0f5ebe3b..a5b0d9595 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -187,6 +187,15 @@ MatrixPtr Matrix::subMatrix(size_t startRow, size_t endRow, size_t startCol, trans_, useGpu_); } +void Matrix::setDiag(real value) { + CHECK(data_ != NULL); + CHECK_EQ(height_, width_); + + zeroMem(); + BaseMatrix diag(height_, 1, stride_ + 1, data_, false, useGpu_); + diag.assign(value); +} + GpuMatrix::GpuMatrix(size_t height, size_t width, bool trans) : Matrix(std::make_shared(height * width * sizeof(real)), height, width, trans, true) {} @@ -203,16 +212,6 @@ void GpuMatrix::resetOne() { one(); } -void GpuMatrix::setDiag(real value) { - CHECK(data_ != NULL); - CHECK_EQ(height_, width_); - - zeroMem(); - for (size_t i = 0; i < height_; i++) { - hl_memcpy_host2device(&data_[i * stride_ + i], &value, sizeof(real)); - } -} - void GpuMatrix::resize(size_t newHeight, size_t newWidth) { size_t newSize = newHeight * newWidth; if (NULL == memoryHandle_.get() || @@ -1255,16 +1254,6 @@ void CpuMatrix::resetOne() { BaseMatrix::one(); } -void CpuMatrix::setDiag(real value) { - CHECK(data_ != NULL); - CHECK_EQ(height_, width_); - - zeroMem(); - for (size_t i = 0; i < height_; i++) { - data_[i * stride_ + i] = value; - } -} - void CpuMatrix::copyFrom(const Matrix& src) { CHECK(isContiguous()); if (typeid(src) == typeid(GpuMatrix)) { diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 9e15055c0..120957f45 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -195,7 +195,7 @@ public: virtual void resetOne() { LOG(FATAL) << "Not implemented"; } - virtual void setDiag(real value) { LOG(FATAL) << "Not implemented"; } + void setDiag(real value); virtual void copyFrom(const Matrix& src) { LOG(FATAL) << "Not implemented"; } -- GitLab