提交 171eaff2 编写于 作者: X xutianbing

clean the code a little bit.

上级 4751cc8f
...@@ -38,13 +38,6 @@ inline void vecAddTo(real* a, const real* b, real scaleB, size_t len) { ...@@ -38,13 +38,6 @@ inline void vecAddTo(real* a, const real* b, real scaleB, size_t len) {
} }
} }
inline void colVecAddTo(
real* a, const real* b, size_t len, size_t aWidth, size_t bWidth) {
for (unsigned int i = 0; i < len; ++i) {
a[i * aWidth] += b[i * bWidth];
}
}
inline void colVecAddTo( inline void colVecAddTo(
real* a, real* b, real c, size_t len, size_t aWidth, size_t bWidth) { real* a, real* b, real c, size_t len, size_t aWidth, size_t bWidth) {
for (unsigned int i = 0; i < len; ++i) { for (unsigned int i = 0; i < len; ++i) {
...@@ -336,140 +329,59 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, ...@@ -336,140 +329,59 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
const CpuSparseMatrix& b, const CpuSparseMatrix& b,
real scaleAB, real scaleAB,
real scaleT) { real scaleT) {
/// todo(tianbing), clean the code
CHECK(!out.trans_) << "Not supported"; CHECK(!out.trans_) << "Not supported";
CHECK(!a.isTransposed()) << "Not supported"; CHECK(!a.isTransposed()) << "Not supported";
CHECK(scaleT == 0 || scaleT == 1); CHECK(scaleT == 0 || scaleT == 1);
CHECK_EQ(scaleAB, static_cast<real>(1.0)); CHECK_EQ(scaleAB, static_cast<real>(1.0));
if (!b.isTransposed()) { /// b is not Transpose
CHECK(b.getHeight() == a.getWidth() && a.getHeight() == out.getHeight() &&
b.getWidth() == out.getWidth());
} else {
CHECK(b.getHeight() == out.getWidth() && a.getHeight() == out.getHeight() &&
b.getWidth() == a.getWidth());
}
if (scaleT == 0) {
out.zeroMem();
}
real* A = const_cast<real*>(a.getData()); real* A = const_cast<real*>(a.getData());
real* B = const_cast<real*>(b.getValue()); real* B = const_cast<real*>(b.getValue());
real* C = out.getData(); real* C = out.getData();
int* rows = b.getRows(); int* rows = b.getRows();
int* cols = b.getCols(); int* cols = b.getCols();
if (scaleT == 0) { /// b.getFormat() == SPARSE_CSC
out.zeroMem();
}
/// todo(tianbing), clean the code
if (b.getFormat() == SPARSE_CSC) { if (b.getFormat() == SPARSE_CSC) {
if (!b.isTransposed()) { for (size_t j = 0; j < b.getWidth(); ++j) {
size_t m = a.getWidth(); int start = b.getColStartIdx(j);
CHECK_EQ(b.getHeight(), m); int end = b.getColStartIdx(j + 1);
CHECK_EQ(a.getHeight(), out.height_); for (int i = start; i < end; ++i) {
CHECK_EQ(b.getWidth(), out.width_); colVecAddTo(!b.isTransposed() ? C + j : C + rows[i],
!b.isTransposed() ? A + rows[i] : A + j,
if (b.getValueType() == NO_VALUE) { (b.getValueType() == NO_VALUE) ? (real)1.0 : B[i],
for (size_t j = 0; j < b.getWidth(); ++j) { out.getHeight(),
int start = b.getColStartIdx(j); out.getWidth(),
int end = b.getColStartIdx(j + 1); a.getWidth());
for (int i = start; i < end; ++i) {
colVecAddTo(
C + j, A + rows[i], out.height_, out.width_, a.getWidth());
}
}
} else if (b.getValueType() == FLOAT_VALUE) {
for (size_t j = 0; j < b.getWidth(); ++j) {
int start = b.getColStartIdx(j);
int end = b.getColStartIdx(j + 1);
for (int i = start; i < end; ++i) {
colVecAddTo(C + j,
A + rows[i],
B[i],
out.height_,
out.width_,
a.getWidth());
}
}
}
} else /*if (b.isTransposed())*/ {
size_t m = a.getWidth();
CHECK_EQ(b.getHeight(), out.width_);
CHECK_EQ(a.getHeight(), out.height_);
CHECK_EQ(b.getWidth(), m);
if (b.getValueType() == NO_VALUE) {
for (size_t i = 0; i < b.getWidth(); ++i) {
int start = b.getColStartIdx(i);
int end = b.getColStartIdx(i + 1);
for (int j = start; j < end; ++j) {
colVecAddTo(
C + rows[j], A + i, out.height_, out.width_, a.getWidth());
}
}
} else if (b.getValueType() == FLOAT_VALUE) {
for (size_t i = 0; i < b.getWidth(); ++i) {
int start = b.getColStartIdx(i);
int end = b.getColStartIdx(i + 1);
for (int j = start; j < end; ++j) {
colVecAddTo(C + rows[j],
A + i,
B[j],
out.height_,
out.width_,
a.getWidth());
}
}
} }
} }
} else { return;
if (!b.isTransposed()) { }
size_t m = a.getWidth();
CHECK_EQ(b.getHeight(), m); /// b.getFormat() == SPARSE_CSR
CHECK_EQ(a.getHeight(), out.height_); if (b.getFormat() == SPARSE_CSR) {
CHECK_EQ(b.getWidth(), out.width_); for (size_t j = 0; j < b.getHeight(); ++j) {
int start = b.getRowStartIdx(j);
if (b.getValueType() == NO_VALUE) { int end = b.getRowStartIdx(j + 1);
for (size_t j = 0; j < b.getHeight(); ++j) { for (int i = start; i < end; ++i) {
int start = b.getRowStartIdx(j); colVecAddTo(!b.isTransposed() ? C + cols[i] : C + j,
int end = b.getRowStartIdx(j + 1); !b.isTransposed() ? A + j : A + cols[i],
for (int i = start; i < end; ++i) { (b.getValueType() == NO_VALUE) ? (real)1.0 : B[i],
colVecAddTo( out.getHeight(),
C + cols[i], A + j, out.height_, out.width_, a.getWidth()); out.getWidth(),
} a.getWidth());
}
} else if (b.getValueType() == FLOAT_VALUE) {
for (size_t j = 0; j < b.getHeight(); ++j) {
int start = b.getRowStartIdx(j);
int end = b.getRowStartIdx(j + 1);
for (int i = start; i < end; ++i) {
colVecAddTo(C + cols[i],
A + j,
B[i],
out.height_,
out.width_,
a.getWidth());
}
}
}
} else /*if (b.isTransposed())*/ {
size_t m = a.getWidth();
CHECK_EQ(b.getHeight(), out.width_);
CHECK_EQ(a.getHeight(), out.height_);
CHECK_EQ(b.getWidth(), m);
if (b.getValueType() == NO_VALUE) {
for (size_t i = 0; i < b.getHeight(); ++i) {
int start = b.getRowStartIdx(i);
int end = b.getRowStartIdx(i + 1);
for (int j = start; j < end; ++j) {
colVecAddTo(
C + i, A + cols[j], out.height_, out.width_, a.getWidth());
}
}
} else if (b.getValueType() == FLOAT_VALUE) {
for (size_t i = 0; i < b.getHeight(); ++i) {
int start = b.getRowStartIdx(i);
int end = b.getRowStartIdx(i + 1);
for (int j = start; j < end; ++j) {
colVecAddTo(C + i,
A + cols[j],
B[j],
out.height_,
out.width_,
a.getWidth());
}
}
} }
} }
return;
} }
} }
......
...@@ -19,154 +19,147 @@ limitations under the License. */ ...@@ -19,154 +19,147 @@ limitations under the License. */
namespace paddle { namespace paddle {
/** /**
* out = scale_t * out + scale_ab * (a * b) * out = scaleT * out + scaleAB * (a * b)
* out : output matrix, M * N * out : output matrix, M * N
*/ */
template <> template <>
void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out, void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
const GpuMatrix& a, const GpuMatrix& a,
const GpuMatrix& b, const GpuMatrix& b,
real scale_ab, real scaleAB,
real scale_t) { real scaleT) {
CHECK(!out.isTransposed()) << "Not supported"; CHECK(!out.isTransposed()) << "Transpose not supported for out matrix";
if (!a.isTransposed() && !b.isTransposed()) { if (!a.isTransposed() && !b.isTransposed()) {
/// a : M * K, b: K * N /// a : M * K, b: K * N
CHECK_EQ(out.width_, b.width_); CHECK(out.getWidth() == b.getWidth() &&
CHECK_EQ(out.height_, a.height_); out.getHeight() == a.getHeight() &&
CHECK_EQ(a.width_, b.height_); a.getWidth() == b.getHeight());
} else if (a.isTransposed() && !b.isTransposed()) { } else if (a.isTransposed() && !b.isTransposed()) {
/// a : K * M, b : K * N /// a : K * M, b : K * N
CHECK_EQ(out.width_, b.width_); CHECK(out.getWidth() == b.getWidth() &&
CHECK_EQ(out.height_, a.width_); out.getHeight() == a.getWidth() &&
CHECK_EQ(a.height_, b.height_); a.getHeight() == b.getHeight());
} else if (!a.isTransposed() && b.isTransposed()) { } else if (!a.isTransposed() && b.isTransposed()) {
/// a: M * K, b : N * K /// a: M * K, b : N * K
CHECK_EQ(out.width_, b.height_); CHECK(out.getWidth() == b.getHeight() &&
CHECK_EQ(out.height_, a.height_); out.getHeight() == a.getHeight() &&
CHECK_EQ(a.width_, b.width_); a.getWidth() == b.getWidth());
} else { } else {
LOG(FATAL) << "Is not supported"; LOG(FATAL) << "Not support for both a and b are Transposed Matrices";
} }
real* a_data = a.data_; real* aData = const_cast<real*>(a.getData());
real* b_data = b.data_; real* bData = const_cast<real*>(b.getData());
real* out_data = out.data_; real* outData = const_cast<real*>(out.getData());
int dim_m = out.getHeight(); hl_matrix_mul(aData,
int dim_n = out.getWidth(); !a.isTransposed() ? HPPL_OP_N : HPPL_OP_T,
int dim_k = !a.isTransposed() ? a.width_ : a.height_; bData,
int lda = a.getStride(); !b.isTransposed() ? HPPL_OP_N : HPPL_OP_T,
int ldb = b.getStride(); outData,
int ldc = out.getStride(); out.getHeight(),
hl_trans_op_t trans_a = !a.isTransposed() ? HPPL_OP_N : HPPL_OP_T; out.getWidth(),
hl_trans_op_t trans_b = !b.isTransposed() ? HPPL_OP_N : HPPL_OP_T; !a.isTransposed() ? a.getWidth() : a.getHeight(),
scaleAB,
hl_matrix_mul(a_data, scaleT,
trans_a, a.getStride(),
b_data, b.getStride(),
trans_b, out.getStride());
out_data,
dim_m,
dim_n,
dim_k,
scale_ab,
scale_t,
lda,
ldb,
ldc);
} }
/** /**
* out = scale_t * out + scale_ab * (a * b) * out = scaleT * out + scaleAB * (a * b)
* out : M * N * out : M * N
*/ */
template <> template <>
void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out, void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
const GpuSparseMatrix& a, const GpuSparseMatrix& a,
const GpuMatrix& b, const GpuMatrix& b,
real scale_ab, real scaleAB,
real scale_t) { real scaleT) {
CHECK(out.isContiguous()); CHECK(out.isContiguous());
CHECK(b.isContiguous()); CHECK(b.isContiguous());
CHECK(b.useGpu_ == true) << "Matrix type are not equal"; CHECK(b.useGpu_) << "Matrix type are not equal";
CHECK(!out.trans_ && !b.trans_) << "not supported"; CHECK(!out.isTransposed() && !b.isTransposed()) << "not supported";
if (!a.trans_) { if (!a.isTransposed()) {
/// a: M * K, b: K * N /// a: M * K, b: K * N
CHECK(out.width_ == b.width_ && out.height_ == a.height_ CHECK(out.getWidth() == b.getWidth() && out.getHeight() == a.getHeight()
&& a.width_ == b.height_) << "Matrix dimensions are not equal"; && a.getWidth() == b.getHeight()) << "Matrix dimensions are not equal";
} else { } else {
/// a: K * M, transpose, b: K * N /// a: K * M, transpose, b: K * N
CHECK(out.width_ == b.width_ && out.height_ == a.width_ CHECK(out.getWidth() == b.getWidth() && out.getHeight() == a.getWidth()
&& a.height_ == b.height_) << "Matrix dimensions are not equal"; && a.getHeight() == b.getHeight()) << "Matrix dimensions are not equal";
} }
hl_trans_op_t a_trans = a.trans_ ? HPPL_OP_T : HPPL_OP_N; hl_trans_op_t aTrans = a.isTransposed() ? HPPL_OP_T : HPPL_OP_N;
hl_sparse_matrix_s a_data = a.sMatrix_.get(); hl_sparse_matrix_s aData = a.sMatrix_.get();
real* b_data = b.data_; real* bData = const_cast<real*>(b.getData());
real* out_data = out.data_; real* outData = const_cast<real*>(out.getData());
hl_matrix_csr_mul_dense(a_data, hl_matrix_csr_mul_dense(aData,
a_trans, aTrans,
b_data, bData,
HPPL_OP_N, HPPL_OP_N,
out_data, outData,
out.height_, out.getHeight(),
out.width_, out.getWidth(),
b.height_, b.getHeight(),
scale_ab, scaleAB,
scale_t); scaleT);
} }
/** /**
* out = scale_t * out + scale_ab * (a * b) * out = scaleT * out + scaleAB * (a * b)
* out : M * N * out : M * N
*/ */
template <> template <>
void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out, void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
const GpuMatrix& a, const GpuMatrix& a,
const GpuSparseMatrix& b, const GpuSparseMatrix& b,
real scale_ab, real scaleAB,
real scale_t) { real scaleT) {
CHECK(out.isContiguous()); CHECK(out.isContiguous());
CHECK(a.isContiguous()); CHECK(a.isContiguous());
CHECK(a.useGpu_ == true) << "Matrix type are not equal"; CHECK(a.useGpu_) << "Matrix type are not equal";
if (!b.isTransposed()) {
hl_sparse_matrix_s b_data = b.sMatrix_.get(); /// a : M * K, b : K * N
real* a_data = a.data_; CHECK(out.getWidth() == b.getWidth() &&
real* out_data = out.data_; out.getHeight() == a.getHeight() &&
hl_trans_op_t trans_b = b.trans_ ? HPPL_OP_T : HPPL_OP_N; a.getWidth() == b.getHeight())
if (!b.trans_) { << "Matrix dimensions are not equal";
/// a : M * K, b : K * N
CHECK(out.width_ == b.width_ &&
out.height_ == a.height_ && a.width_ == b.height_)
<< "Matrix dimensions are not equal";
} else { } else {
/// a : M * K, b : N * K, transpose /// a : M * K, b : N * K, transpose
CHECK(out.width_ == b.height_ && CHECK(out.getWidth() == b.getHeight() &&
out.height_ == a.height_ && a.width_ == b.width_) out.getHeight() == a.getHeight() &&
<< "Matrix dimensions are not equal"; a.getWidth() == b.getWidth())
<< "Matrix dimensions are not equal";
} }
hl_trans_op_t bTrans = b.isTransposed() ? HPPL_OP_T : HPPL_OP_N;
hl_sparse_matrix_s bData = b.sMatrix_.get();
real* aData = const_cast<real*>(a.getData());
real* outData = const_cast<real*>(out.getData());
if (b.format_ == SPARSE_CSC) { if (b.format_ == SPARSE_CSC) {
hl_matrix_dense_mul_csc(a_data, hl_matrix_dense_mul_csc(aData,
HPPL_OP_N, HPPL_OP_N,
b_data, bData,
trans_b, bTrans,
out_data, outData,
out.height_, out.getHeight(),
out.width_, out.getWidth(),
a.width_, a.getWidth(),
scale_ab, scaleAB,
scale_t); scaleT);
} else { } else {
hl_matrix_dense_mul_csr(a_data, hl_matrix_dense_mul_csr(aData,
HPPL_OP_N, HPPL_OP_N,
b_data, bData,
trans_b, bTrans,
out_data, outData,
out.height_, out.getHeight(),
out.width_, out.getWidth(),
a.width_, a.getWidth(),
scale_ab, scaleAB,
scale_t); scaleT);
} }
} }
...@@ -174,38 +167,36 @@ template <> ...@@ -174,38 +167,36 @@ template <>
void MulOp<DEVICE_TYPE_GPU>(GpuSparseMatrix& out, void MulOp<DEVICE_TYPE_GPU>(GpuSparseMatrix& out,
const GpuMatrix& a, const GpuMatrix& a,
const GpuMatrix& b, const GpuMatrix& b,
real scale_ab, real scaleAB,
real scale_t) { real scaleT) {
/// todo(tianbing), clean the code CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
CHECK(a.useGpu_ && b.useGpu_) << "type not match"; CHECK(!out.isTransposed()) << "Transpose is not supported for out matrix";
CHECK(!out.trans_) << "trans not supported";
real* a_data = const_cast<real*>(a.getData()); if (!a.isTransposed() && !b.isTransposed()) {
real* b_data = const_cast<real*>(b.getData()); CHECK(out.getHeight() == a.getHeight() &&
hl_sparse_matrix_s out_data = out.sMatrix_.get(); out.getWidth() == b.getWidth() &&
hl_trans_op_t a_trans = a.trans_ ? HPPL_OP_T : HPPL_OP_N; a.getWidth() == b.getHeight());
hl_trans_op_t b_trans = b.trans_ ? HPPL_OP_T : HPPL_OP_N; } else if (a.isTransposed() && !b.isTransposed()) {
CHECK(out.getHeight() == a.getWidth() &&
if (!a.trans_ && !b.trans_) { out.getWidth() == b.getWidth() &&
CHECK(out.height_ == a.getHeight()); a.getHeight() == b.getHeight());
CHECK(out.width_ == b.getWidth()); } else if (!a.isTransposed() && b.isTransposed()) {
CHECK(a.getWidth() == b.getHeight()); CHECK(out.getHeight() == a.getHeight() &&
} else if (a.trans_ && !b.trans_) { out.getWidth() == b.getHeight() &&
CHECK(out.height_ == a.getWidth()); a.getWidth() == b.getWidth());
CHECK(out.width_ == b.getWidth());
CHECK(a.getHeight() == b.getHeight());
} else if (!a.trans_ && b.trans_) {
CHECK(out.height_ == a.getHeight());
CHECK(out.width_ == b.getHeight());
CHECK(a.getWidth() == b.getWidth());
} else { } else {
LOG(INFO) << "Not support"; LOG(FATAL) << "Not support for both a and b are Transposed Matrices";
} }
int dim_m = out.height_;
int dim_n = out.width_; hl_trans_op_t aTrans = a.isTransposed() ? HPPL_OP_T : HPPL_OP_N;
int dim_k = !b.trans_ ? b.getHeight() : b.getWidth(); hl_trans_op_t bTrans = b.isTransposed() ? HPPL_OP_T : HPPL_OP_N;
hl_sparse_matrix_mul( int dimK = !b.isTransposed() ? b.getHeight() : b.getWidth();
a_data, a_trans, b_data, b_trans, out_data, real* aData = const_cast<real*>(a.getData());
dim_m, dim_n, dim_k, scale_ab, scale_t); real* bData = const_cast<real*>(b.getData());
hl_sparse_matrix_s outData = out.sMatrix_.get();
hl_sparse_matrix_mul(aData, aTrans, bData, bTrans, outData,
out.getHeight(), out.getWidth(), dimK, scaleAB, scaleT);
} }
} // namespace paddle } // namespace paddle
...@@ -76,12 +76,12 @@ void testDDDMatrix(bool transa, bool transb, int dimM, int dimN, int dimK) { ...@@ -76,12 +76,12 @@ void testDDDMatrix(bool transa, bool transb, int dimM, int dimN, int dimK) {
TEST(Matrix, DDDMul) { TEST(Matrix, DDDMul) {
LOG(INFO) << "test for dense = dense * dense matrix"; LOG(INFO) << "test for dense = dense * dense matrix";
for (auto transa : {false, true}) { for (const auto transa : {false, true}) {
for (auto transb : {false, true}) { for (const auto transb : {false, true}) {
for (auto dimM : {1, 10, 100}) { for (const auto dimM : {1, 10, 100}) {
for (auto dimN : {1, 10}) { for (const auto dimN : {1, 10}) {
for (auto dimK : {8}) { for (const auto dimK : {8}) {
if (true == transa && true == transb) { if (transa && transb) {
continue; continue;
} }
VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ') VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ')
...@@ -89,7 +89,6 @@ TEST(Matrix, DDDMul) { ...@@ -89,7 +89,6 @@ TEST(Matrix, DDDMul) {
<< " dimM=" << std::setw(5) << dimM << " dimM=" << std::setw(5) << dimM
<< " dimN=" << std::setw(5) << dimN << " dimN=" << std::setw(5) << dimN
<< " dimK=" << std::setw(5) << dimK; << " dimK=" << std::setw(5) << dimK;
testDDDMatrix(transa, transb, dimM, dimN, dimK); testDDDMatrix(transa, transb, dimM, dimN, dimK);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册