提交 992ac8f9 编写于 作者: L Liang Zhao

Implement setDiag() with BaseMatrix::assign()

上级 8c40bfd0
......@@ -187,6 +187,15 @@ MatrixPtr Matrix::subMatrix(size_t startRow, size_t endRow, size_t startCol,
trans_, useGpu_);
}
void Matrix::setDiag(real value) {
CHECK(data_ != NULL);
CHECK_EQ(height_, width_);
zeroMem();
BaseMatrix diag(height_, 1, stride_ + 1, data_, false, useGpu_);
diag.assign(value);
}
GpuMatrix::GpuMatrix(size_t height, size_t width, bool trans)
: Matrix(std::make_shared<GpuMemoryHandle>(height * width * sizeof(real)),
height, width, trans, true) {}
......@@ -203,16 +212,6 @@ void GpuMatrix::resetOne() {
one();
}
void GpuMatrix::setDiag(real value) {
CHECK(data_ != NULL);
CHECK_EQ(height_, width_);
zeroMem();
for (size_t i = 0; i < height_; i++) {
hl_memcpy_host2device(&data_[i * stride_ + i], &value, sizeof(real));
}
}
void GpuMatrix::resize(size_t newHeight, size_t newWidth) {
size_t newSize = newHeight * newWidth;
if (NULL == memoryHandle_.get() ||
......@@ -1255,16 +1254,6 @@ void CpuMatrix::resetOne() {
BaseMatrix::one();
}
void CpuMatrix::setDiag(real value) {
CHECK(data_ != NULL);
CHECK_EQ(height_, width_);
zeroMem();
for (size_t i = 0; i < height_; i++) {
data_[i * stride_ + i] = value;
}
}
void CpuMatrix::copyFrom(const Matrix& src) {
CHECK(isContiguous());
if (typeid(src) == typeid(GpuMatrix)) {
......
......@@ -195,7 +195,7 @@ public:
virtual void resetOne() { LOG(FATAL) << "Not implemented"; }
virtual void setDiag(real value) { LOG(FATAL) << "Not implemented"; }
void setDiag(real value);
virtual void copyFrom(const Matrix& src) { LOG(FATAL) << "Not implemented"; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册