提交 171eaff2 编写于 作者: X xutianbing

clean the code a little bit.

上级 4751cc8f
......@@ -38,13 +38,6 @@ inline void vecAddTo(real* a, const real* b, real scaleB, size_t len) {
}
}
inline void colVecAddTo(
real* a, const real* b, size_t len, size_t aWidth, size_t bWidth) {
for (unsigned int i = 0; i < len; ++i) {
a[i * aWidth] += b[i * bWidth];
}
}
inline void colVecAddTo(
real* a, real* b, real c, size_t len, size_t aWidth, size_t bWidth) {
for (unsigned int i = 0; i < len; ++i) {
......@@ -336,140 +329,59 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
const CpuSparseMatrix& b,
real scaleAB,
real scaleT) {
/// todo(tianbing), clean the code
CHECK(!out.trans_) << "Not supported";
CHECK(!a.isTransposed()) << "Not supported";
CHECK(scaleT == 0 || scaleT == 1);
CHECK_EQ(scaleAB, static_cast<real>(1.0));
if (!b.isTransposed()) { /// b is not Transpose
CHECK(b.getHeight() == a.getWidth() && a.getHeight() == out.getHeight() &&
b.getWidth() == out.getWidth());
} else {
CHECK(b.getHeight() == out.getWidth() && a.getHeight() == out.getHeight() &&
b.getWidth() == a.getWidth());
}
if (scaleT == 0) {
out.zeroMem();
}
real* A = const_cast<real*>(a.getData());
real* B = const_cast<real*>(b.getValue());
real* C = out.getData();
int* rows = b.getRows();
int* cols = b.getCols();
if (scaleT == 0) {
out.zeroMem();
}
/// todo(tianbing), clean the code
/// b.getFormat() == SPARSE_CSC
if (b.getFormat() == SPARSE_CSC) {
if (!b.isTransposed()) {
size_t m = a.getWidth();
CHECK_EQ(b.getHeight(), m);
CHECK_EQ(a.getHeight(), out.height_);
CHECK_EQ(b.getWidth(), out.width_);
if (b.getValueType() == NO_VALUE) {
for (size_t j = 0; j < b.getWidth(); ++j) {
int start = b.getColStartIdx(j);
int end = b.getColStartIdx(j + 1);
for (int i = start; i < end; ++i) {
colVecAddTo(
C + j, A + rows[i], out.height_, out.width_, a.getWidth());
}
}
} else if (b.getValueType() == FLOAT_VALUE) {
for (size_t j = 0; j < b.getWidth(); ++j) {
int start = b.getColStartIdx(j);
int end = b.getColStartIdx(j + 1);
for (int i = start; i < end; ++i) {
colVecAddTo(C + j,
A + rows[i],
B[i],
out.height_,
out.width_,
a.getWidth());
}
}
}
} else /*if (b.isTransposed())*/ {
size_t m = a.getWidth();
CHECK_EQ(b.getHeight(), out.width_);
CHECK_EQ(a.getHeight(), out.height_);
CHECK_EQ(b.getWidth(), m);
if (b.getValueType() == NO_VALUE) {
for (size_t i = 0; i < b.getWidth(); ++i) {
int start = b.getColStartIdx(i);
int end = b.getColStartIdx(i + 1);
for (int j = start; j < end; ++j) {
colVecAddTo(
C + rows[j], A + i, out.height_, out.width_, a.getWidth());
}
}
} else if (b.getValueType() == FLOAT_VALUE) {
for (size_t i = 0; i < b.getWidth(); ++i) {
int start = b.getColStartIdx(i);
int end = b.getColStartIdx(i + 1);
for (int j = start; j < end; ++j) {
colVecAddTo(C + rows[j],
A + i,
B[j],
out.height_,
out.width_,
a.getWidth());
}
}
for (size_t j = 0; j < b.getWidth(); ++j) {
int start = b.getColStartIdx(j);
int end = b.getColStartIdx(j + 1);
for (int i = start; i < end; ++i) {
colVecAddTo(!b.isTransposed() ? C + j : C + rows[i],
!b.isTransposed() ? A + rows[i] : A + j,
(b.getValueType() == NO_VALUE) ? (real)1.0 : B[i],
out.getHeight(),
out.getWidth(),
a.getWidth());
}
}
} else {
if (!b.isTransposed()) {
size_t m = a.getWidth();
CHECK_EQ(b.getHeight(), m);
CHECK_EQ(a.getHeight(), out.height_);
CHECK_EQ(b.getWidth(), out.width_);
if (b.getValueType() == NO_VALUE) {
for (size_t j = 0; j < b.getHeight(); ++j) {
int start = b.getRowStartIdx(j);
int end = b.getRowStartIdx(j + 1);
for (int i = start; i < end; ++i) {
colVecAddTo(
C + cols[i], A + j, out.height_, out.width_, a.getWidth());
}
}
} else if (b.getValueType() == FLOAT_VALUE) {
for (size_t j = 0; j < b.getHeight(); ++j) {
int start = b.getRowStartIdx(j);
int end = b.getRowStartIdx(j + 1);
for (int i = start; i < end; ++i) {
colVecAddTo(C + cols[i],
A + j,
B[i],
out.height_,
out.width_,
a.getWidth());
}
}
}
} else /*if (b.isTransposed())*/ {
size_t m = a.getWidth();
CHECK_EQ(b.getHeight(), out.width_);
CHECK_EQ(a.getHeight(), out.height_);
CHECK_EQ(b.getWidth(), m);
if (b.getValueType() == NO_VALUE) {
for (size_t i = 0; i < b.getHeight(); ++i) {
int start = b.getRowStartIdx(i);
int end = b.getRowStartIdx(i + 1);
for (int j = start; j < end; ++j) {
colVecAddTo(
C + i, A + cols[j], out.height_, out.width_, a.getWidth());
}
}
} else if (b.getValueType() == FLOAT_VALUE) {
for (size_t i = 0; i < b.getHeight(); ++i) {
int start = b.getRowStartIdx(i);
int end = b.getRowStartIdx(i + 1);
for (int j = start; j < end; ++j) {
colVecAddTo(C + i,
A + cols[j],
B[j],
out.height_,
out.width_,
a.getWidth());
}
}
return;
}
/// b.getFormat() == SPARSE_CSR
if (b.getFormat() == SPARSE_CSR) {
for (size_t j = 0; j < b.getHeight(); ++j) {
int start = b.getRowStartIdx(j);
int end = b.getRowStartIdx(j + 1);
for (int i = start; i < end; ++i) {
colVecAddTo(!b.isTransposed() ? C + cols[i] : C + j,
!b.isTransposed() ? A + j : A + cols[i],
(b.getValueType() == NO_VALUE) ? (real)1.0 : B[i],
out.getHeight(),
out.getWidth(),
a.getWidth());
}
}
return;
}
}
......
......@@ -19,154 +19,147 @@ limitations under the License. */
namespace paddle {
/**
* out = scale_t * out + scale_ab * (a * b)
* out = scaleT * out + scaleAB * (a * b)
* out : output matrix, M * N
*/
template <>
void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
const GpuMatrix& a,
const GpuMatrix& b,
real scale_ab,
real scale_t) {
CHECK(!out.isTransposed()) << "Not supported";
real scaleAB,
real scaleT) {
CHECK(!out.isTransposed()) << "Transpose not supported for out matrix";
if (!a.isTransposed() && !b.isTransposed()) {
/// a : M * K, b: K * N
CHECK_EQ(out.width_, b.width_);
CHECK_EQ(out.height_, a.height_);
CHECK_EQ(a.width_, b.height_);
/// a : M * K, b: K * N
CHECK(out.getWidth() == b.getWidth() &&
out.getHeight() == a.getHeight() &&
a.getWidth() == b.getHeight());
} else if (a.isTransposed() && !b.isTransposed()) {
/// a : K * M, b : K * N
CHECK_EQ(out.width_, b.width_);
CHECK_EQ(out.height_, a.width_);
CHECK_EQ(a.height_, b.height_);
/// a : K * M, b : K * N
CHECK(out.getWidth() == b.getWidth() &&
out.getHeight() == a.getWidth() &&
a.getHeight() == b.getHeight());
} else if (!a.isTransposed() && b.isTransposed()) {
/// a: M * K, b : N * K
CHECK_EQ(out.width_, b.height_);
CHECK_EQ(out.height_, a.height_);
CHECK_EQ(a.width_, b.width_);
/// a: M * K, b : N * K
CHECK(out.getWidth() == b.getHeight() &&
out.getHeight() == a.getHeight() &&
a.getWidth() == b.getWidth());
} else {
LOG(FATAL) << "Is not supported";
LOG(FATAL) << "Not support for both a and b are Transposed Matrices";
}
real* a_data = a.data_;
real* b_data = b.data_;
real* out_data = out.data_;
int dim_m = out.getHeight();
int dim_n = out.getWidth();
int dim_k = !a.isTransposed() ? a.width_ : a.height_;
int lda = a.getStride();
int ldb = b.getStride();
int ldc = out.getStride();
hl_trans_op_t trans_a = !a.isTransposed() ? HPPL_OP_N : HPPL_OP_T;
hl_trans_op_t trans_b = !b.isTransposed() ? HPPL_OP_N : HPPL_OP_T;
hl_matrix_mul(a_data,
trans_a,
b_data,
trans_b,
out_data,
dim_m,
dim_n,
dim_k,
scale_ab,
scale_t,
lda,
ldb,
ldc);
real* aData = const_cast<real*>(a.getData());
real* bData = const_cast<real*>(b.getData());
real* outData = const_cast<real*>(out.getData());
hl_matrix_mul(aData,
!a.isTransposed() ? HPPL_OP_N : HPPL_OP_T,
bData,
!b.isTransposed() ? HPPL_OP_N : HPPL_OP_T,
outData,
out.getHeight(),
out.getWidth(),
!a.isTransposed() ? a.getWidth() : a.getHeight(),
scaleAB,
scaleT,
a.getStride(),
b.getStride(),
out.getStride());
}
/**
* out = scale_t * out + scale_ab * (a * b)
* out = scaleT * out + scaleAB * (a * b)
* out : M * N
*/
template <>
void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
const GpuSparseMatrix& a,
const GpuMatrix& b,
real scale_ab,
real scale_t) {
real scaleAB,
real scaleT) {
CHECK(out.isContiguous());
CHECK(b.isContiguous());
CHECK(b.useGpu_ == true) << "Matrix type are not equal";
CHECK(!out.trans_ && !b.trans_) << "not supported";
if (!a.trans_) {
CHECK(b.useGpu_) << "Matrix type are not equal";
CHECK(!out.isTransposed() && !b.isTransposed()) << "not supported";
if (!a.isTransposed()) {
/// a: M * K, b: K * N
CHECK(out.width_ == b.width_ && out.height_ == a.height_
&& a.width_ == b.height_) << "Matrix dimensions are not equal";
CHECK(out.getWidth() == b.getWidth() && out.getHeight() == a.getHeight()
&& a.getWidth() == b.getHeight()) << "Matrix dimensions are not equal";
} else {
/// a: K * M, transpose, b: K * N
CHECK(out.width_ == b.width_ && out.height_ == a.width_
&& a.height_ == b.height_) << "Matrix dimensions are not equal";
CHECK(out.getWidth() == b.getWidth() && out.getHeight() == a.getWidth()
&& a.getHeight() == b.getHeight()) << "Matrix dimensions are not equal";
}
hl_trans_op_t a_trans = a.trans_ ? HPPL_OP_T : HPPL_OP_N;
hl_sparse_matrix_s a_data = a.sMatrix_.get();
real* b_data = b.data_;
real* out_data = out.data_;
hl_matrix_csr_mul_dense(a_data,
a_trans,
b_data,
hl_trans_op_t aTrans = a.isTransposed() ? HPPL_OP_T : HPPL_OP_N;
hl_sparse_matrix_s aData = a.sMatrix_.get();
real* bData = const_cast<real*>(b.getData());
real* outData = const_cast<real*>(out.getData());
hl_matrix_csr_mul_dense(aData,
aTrans,
bData,
HPPL_OP_N,
out_data,
out.height_,
out.width_,
b.height_,
scale_ab,
scale_t);
outData,
out.getHeight(),
out.getWidth(),
b.getHeight(),
scaleAB,
scaleT);
}
/**
* out = scale_t * out + scale_ab * (a * b)
* out = scaleT * out + scaleAB * (a * b)
* out : M * N
*/
template <>
void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
const GpuMatrix& a,
const GpuSparseMatrix& b,
real scale_ab,
real scale_t) {
real scaleAB,
real scaleT) {
CHECK(out.isContiguous());
CHECK(a.isContiguous());
CHECK(a.useGpu_ == true) << "Matrix type are not equal";
hl_sparse_matrix_s b_data = b.sMatrix_.get();
real* a_data = a.data_;
real* out_data = out.data_;
hl_trans_op_t trans_b = b.trans_ ? HPPL_OP_T : HPPL_OP_N;
if (!b.trans_) {
/// a : M * K, b : K * N
CHECK(out.width_ == b.width_ &&
out.height_ == a.height_ && a.width_ == b.height_)
<< "Matrix dimensions are not equal";
CHECK(a.useGpu_) << "Matrix type are not equal";
if (!b.isTransposed()) {
/// a : M * K, b : K * N
CHECK(out.getWidth() == b.getWidth() &&
out.getHeight() == a.getHeight() &&
a.getWidth() == b.getHeight())
<< "Matrix dimensions are not equal";
} else {
/// a : M * K, b : N * K, transpose
CHECK(out.width_ == b.height_ &&
out.height_ == a.height_ && a.width_ == b.width_)
<< "Matrix dimensions are not equal";
/// a : M * K, b : N * K, transpose
CHECK(out.getWidth() == b.getHeight() &&
out.getHeight() == a.getHeight() &&
a.getWidth() == b.getWidth())
<< "Matrix dimensions are not equal";
}
hl_trans_op_t bTrans = b.isTransposed() ? HPPL_OP_T : HPPL_OP_N;
hl_sparse_matrix_s bData = b.sMatrix_.get();
real* aData = const_cast<real*>(a.getData());
real* outData = const_cast<real*>(out.getData());
if (b.format_ == SPARSE_CSC) {
hl_matrix_dense_mul_csc(a_data,
hl_matrix_dense_mul_csc(aData,
HPPL_OP_N,
b_data,
trans_b,
out_data,
out.height_,
out.width_,
a.width_,
scale_ab,
scale_t);
bData,
bTrans,
outData,
out.getHeight(),
out.getWidth(),
a.getWidth(),
scaleAB,
scaleT);
} else {
hl_matrix_dense_mul_csr(a_data,
hl_matrix_dense_mul_csr(aData,
HPPL_OP_N,
b_data,
trans_b,
out_data,
out.height_,
out.width_,
a.width_,
scale_ab,
scale_t);
bData,
bTrans,
outData,
out.getHeight(),
out.getWidth(),
a.getWidth(),
scaleAB,
scaleT);
}
}
......@@ -174,38 +167,36 @@ template <>
void MulOp<DEVICE_TYPE_GPU>(GpuSparseMatrix& out,
const GpuMatrix& a,
const GpuMatrix& b,
real scale_ab,
real scale_t) {
/// todo(tianbing), clean the code
CHECK(a.useGpu_ && b.useGpu_) << "type not match";
CHECK(!out.trans_) << "trans not supported";
real* a_data = const_cast<real*>(a.getData());
real* b_data = const_cast<real*>(b.getData());
hl_sparse_matrix_s out_data = out.sMatrix_.get();
hl_trans_op_t a_trans = a.trans_ ? HPPL_OP_T : HPPL_OP_N;
hl_trans_op_t b_trans = b.trans_ ? HPPL_OP_T : HPPL_OP_N;
if (!a.trans_ && !b.trans_) {
CHECK(out.height_ == a.getHeight());
CHECK(out.width_ == b.getWidth());
CHECK(a.getWidth() == b.getHeight());
} else if (a.trans_ && !b.trans_) {
CHECK(out.height_ == a.getWidth());
CHECK(out.width_ == b.getWidth());
CHECK(a.getHeight() == b.getHeight());
} else if (!a.trans_ && b.trans_) {
CHECK(out.height_ == a.getHeight());
CHECK(out.width_ == b.getHeight());
CHECK(a.getWidth() == b.getWidth());
real scaleAB,
real scaleT) {
CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
CHECK(!out.isTransposed()) << "Transpose is not supported for out matrix";
if (!a.isTransposed() && !b.isTransposed()) {
CHECK(out.getHeight() == a.getHeight() &&
out.getWidth() == b.getWidth() &&
a.getWidth() == b.getHeight());
} else if (a.isTransposed() && !b.isTransposed()) {
CHECK(out.getHeight() == a.getWidth() &&
out.getWidth() == b.getWidth() &&
a.getHeight() == b.getHeight());
} else if (!a.isTransposed() && b.isTransposed()) {
CHECK(out.getHeight() == a.getHeight() &&
out.getWidth() == b.getHeight() &&
a.getWidth() == b.getWidth());
} else {
LOG(INFO) << "Not support";
LOG(FATAL) << "Not support for both a and b are Transposed Matrices";
}
int dim_m = out.height_;
int dim_n = out.width_;
int dim_k = !b.trans_ ? b.getHeight() : b.getWidth();
hl_sparse_matrix_mul(
a_data, a_trans, b_data, b_trans, out_data,
dim_m, dim_n, dim_k, scale_ab, scale_t);
hl_trans_op_t aTrans = a.isTransposed() ? HPPL_OP_T : HPPL_OP_N;
hl_trans_op_t bTrans = b.isTransposed() ? HPPL_OP_T : HPPL_OP_N;
int dimK = !b.isTransposed() ? b.getHeight() : b.getWidth();
real* aData = const_cast<real*>(a.getData());
real* bData = const_cast<real*>(b.getData());
hl_sparse_matrix_s outData = out.sMatrix_.get();
hl_sparse_matrix_mul(aData, aTrans, bData, bTrans, outData,
out.getHeight(), out.getWidth(), dimK, scaleAB, scaleT);
}
} // namespace paddle
......@@ -76,12 +76,12 @@ void testDDDMatrix(bool transa, bool transb, int dimM, int dimN, int dimK) {
TEST(Matrix, DDDMul) {
LOG(INFO) << "test for dense = dense * dense matrix";
for (auto transa : {false, true}) {
for (auto transb : {false, true}) {
for (auto dimM : {1, 10, 100}) {
for (auto dimN : {1, 10}) {
for (auto dimK : {8}) {
if (true == transa && true == transb) {
for (const auto transa : {false, true}) {
for (const auto transb : {false, true}) {
for (const auto dimM : {1, 10, 100}) {
for (const auto dimN : {1, 10}) {
for (const auto dimK : {8}) {
if (transa && transb) {
continue;
}
VLOG(3) << setiosflags(std::ios::left) << std::setfill(' ')
......@@ -89,7 +89,6 @@ TEST(Matrix, DDDMul) {
<< " dimM=" << std::setw(5) << dimM
<< " dimN=" << std::setw(5) << dimN
<< " dimK=" << std::setw(5) << dimK;
testDDDMatrix(transa, transb, dimM, dimN, dimK);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册