提交 992ac8f9 编写于 作者: L Liang Zhao

Implement setDiag() with BaseMatrix::assign()

上级 8c40bfd0
...@@ -187,6 +187,15 @@ MatrixPtr Matrix::subMatrix(size_t startRow, size_t endRow, size_t startCol, ...@@ -187,6 +187,15 @@ MatrixPtr Matrix::subMatrix(size_t startRow, size_t endRow, size_t startCol,
trans_, useGpu_); trans_, useGpu_);
} }
void Matrix::setDiag(real value) {
CHECK(data_ != NULL);
CHECK_EQ(height_, width_);
zeroMem();
BaseMatrix diag(height_, 1, stride_ + 1, data_, false, useGpu_);
diag.assign(value);
}
GpuMatrix::GpuMatrix(size_t height, size_t width, bool trans) GpuMatrix::GpuMatrix(size_t height, size_t width, bool trans)
: Matrix(std::make_shared<GpuMemoryHandle>(height * width * sizeof(real)), : Matrix(std::make_shared<GpuMemoryHandle>(height * width * sizeof(real)),
height, width, trans, true) {} height, width, trans, true) {}
...@@ -203,16 +212,6 @@ void GpuMatrix::resetOne() { ...@@ -203,16 +212,6 @@ void GpuMatrix::resetOne() {
one(); one();
} }
void GpuMatrix::setDiag(real value) {
CHECK(data_ != NULL);
CHECK_EQ(height_, width_);
zeroMem();
for (size_t i = 0; i < height_; i++) {
hl_memcpy_host2device(&data_[i * stride_ + i], &value, sizeof(real));
}
}
void GpuMatrix::resize(size_t newHeight, size_t newWidth) { void GpuMatrix::resize(size_t newHeight, size_t newWidth) {
size_t newSize = newHeight * newWidth; size_t newSize = newHeight * newWidth;
if (NULL == memoryHandle_.get() || if (NULL == memoryHandle_.get() ||
...@@ -1255,16 +1254,6 @@ void CpuMatrix::resetOne() { ...@@ -1255,16 +1254,6 @@ void CpuMatrix::resetOne() {
BaseMatrix::one(); BaseMatrix::one();
} }
void CpuMatrix::setDiag(real value) {
CHECK(data_ != NULL);
CHECK_EQ(height_, width_);
zeroMem();
for (size_t i = 0; i < height_; i++) {
data_[i * stride_ + i] = value;
}
}
void CpuMatrix::copyFrom(const Matrix& src) { void CpuMatrix::copyFrom(const Matrix& src) {
CHECK(isContiguous()); CHECK(isContiguous());
if (typeid(src) == typeid(GpuMatrix)) { if (typeid(src) == typeid(GpuMatrix)) {
......
...@@ -195,7 +195,7 @@ public: ...@@ -195,7 +195,7 @@ public:
virtual void resetOne() { LOG(FATAL) << "Not implemented"; } virtual void resetOne() { LOG(FATAL) << "Not implemented"; }
virtual void setDiag(real value) { LOG(FATAL) << "Not implemented"; } void setDiag(real value);
virtual void copyFrom(const Matrix& src) { LOG(FATAL) << "Not implemented"; } virtual void copyFrom(const Matrix& src) { LOG(FATAL) << "Not implemented"; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册