提交 316bf75a 编写于 作者: X xutianbing

clean code in function/MulOp.cpp

上级 9ade63e6
...@@ -26,22 +26,16 @@ limitations under the License. */ ...@@ -26,22 +26,16 @@ limitations under the License. */
#endif #endif
namespace { namespace {
inline void vecAddTo(real* a, const real* b, size_t len) {
for (unsigned int i = 0; i < len; ++i) {
a[i] += b[i];
}
}
inline void vecAddTo(real* a, const real* b, real scaleB, size_t len) { inline void vecAddTo(real* a, const real* b, real scaleB, size_t len) {
for (unsigned int i = 0; i < len; ++i) { for (unsigned int i = 0; i < len; ++i) {
a[i] += scaleB * b[i]; a[i] += (1.0 == scaleB) ? b[i] : scaleB * b[i];
} }
} }
inline void colVecAddTo( inline void colVecAddTo(
real* a, real* b, real c, size_t len, size_t aWidth, size_t bWidth) { real* a, real* b, real c, size_t len, size_t aWidth, size_t bWidth) {
for (unsigned int i = 0; i < len; ++i) { for (unsigned int i = 0; i < len; ++i) {
a[i * aWidth] += b[i * bWidth] * c; a[i * aWidth] += (1.0 == c) ? b[i * bWidth] : b[i * bWidth] * c;
} }
} }
} // namespace } // namespace
...@@ -53,15 +47,19 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out, ...@@ -53,15 +47,19 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
const CpuMatrix& b, const CpuMatrix& b,
real scaleAB, real scaleAB,
real scaleT) { real scaleT) {
/// todo(tianbing), clean the code
CHECK(!out.isTransposed()) << "Not supported"; CHECK(!out.isTransposed()) << "Not supported";
CHECK_EQ(out.getValueType(), FLOAT_VALUE); CHECK_EQ(out.getValueType(), FLOAT_VALUE);
CHECK(!a.isTransposed() || !b.isTransposed()) CHECK(!a.isTransposed() || !b.isTransposed())
<< "Not support both a and b are transpose matrices"; << "Not support both a and b are transpose matrices";
if (!a.isTransposed() && b.isTransposed()) {
CHECK(out.getFormat() != SPARSE_CSC) size_t height = out.getHeight();
<< "Not supported CSC format when a is not trans and b is trans"; size_t width = out.getWidth();
} size_t aRow = !a.isTransposed() ? a.getHeight() : a.getWidth();
size_t aCol = !a.isTransposed() ? a.getWidth() : a.getHeight();
size_t bRow = !b.isTransposed() ? b.getHeight() : b.getWidth();
size_t bCol = !b.isTransposed() ? b.getWidth() : b.getHeight();
/// C = A * B, for matrix format
CHECK(aCol == bRow && aRow == height && bCol == width);
if (scaleT == 0) { if (scaleT == 0) {
out.zeroMem(); out.zeroMem();
...@@ -71,14 +69,12 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out, ...@@ -71,14 +69,12 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
real* C = out.getValue(); real* C = out.getValue();
int* rows = out.getRows(); int* rows = out.getRows();
int* cols = out.getCols(); int* cols = out.getCols();
size_t height = out.getHeight();
size_t width = out.getWidth();
if (!a.isTransposed() && !b.isTransposed()) { /// SPARSE_CSC, {a any, b not trans}
CHECK(b.getHeight() == a.getWidth() && a.getHeight() == height &&
b.getWidth() == width);
size_t m = a.getWidth();
if (out.getFormat() == SPARSE_CSC) { if (out.getFormat() == SPARSE_CSC) {
/// b not trans and a any
CHECK(!b.isTransposed());
size_t m = !a.isTransposed() ? a.getWidth() : a.getHeight();
for (size_t i = 0; i < width; i++) { for (size_t i = 0; i < width; i++) {
size_t start = out.getColStartIdx(i); size_t start = out.getColStartIdx(i);
size_t end = out.getColStartIdx(i + 1); size_t end = out.getColStartIdx(i + 1);
...@@ -86,67 +82,21 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out, ...@@ -86,67 +82,21 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
real sum = 0; real sum = 0;
size_t rowIdx = rows[j]; size_t rowIdx = rows[j];
for (size_t k = 0; k < m; k++) { for (size_t k = 0; k < m; k++) {
sum += A[rowIdx * m + k] * B[k * width + i]; sum +=
(!a.isTransposed() ? A[rowIdx * m + k] : A[k * height + rowIdx]) *
B[k * width + i];
} }
C[j] = scaleAB * sum + scaleT * C[j]; C[j] = scaleAB * sum + scaleT * C[j];
} }
} }
} else { /// out.getFormat() == SPARSE_CSR
for (size_t i = 0; i < height; i++) {
size_t start = out.getRowStartIdx(i);
size_t end = out.getRowStartIdx(i + 1);
for (size_t j = start; j < end; j++) {
real sum = 0;
size_t colIdx = cols[j];
for (size_t k = 0; k < a.getWidth(); k++) {
sum += A[i * m + k] * B[k * width + colIdx];
}
C[j] = scaleAB * sum + scaleT * C[j];
}
}
}
return; return;
} }
if (a.isTransposed() && !b.isTransposed()) { /// SPARSE_CSR, {a any, b not trans} or {a not trans, b trans}
CHECK(a.getHeight() == b.getHeight() && b.getWidth() == width &&
a.getWidth() == height);
size_t m = a.getHeight();
if (out.getFormat() == SPARSE_CSC) {
for (size_t i = 0; i < width; i++) {
size_t start = out.getColStartIdx(i);
size_t end = out.getColStartIdx(i + 1);
for (size_t j = start; j < end; j++) {
real sum = 0;
size_t rowIdx = rows[j];
for (size_t k = 0; k < m; k++) {
sum += A[k * height + rowIdx] * B[k * width + i];
}
C[j] = scaleAB * sum + scaleT * C[j];
}
}
} else { /// out.getFormat() == SPARSE_CSR
for (size_t i = 0; i < height; i++) {
int start = out.getRowStartIdx(i);
int end = out.getRowStartIdx(i + 1);
for (int j = start; j < end; j++) {
real sum = 0;
size_t colIdx = cols[j];
for (size_t k = 0; k < a.getHeight(); k++) {
sum += A[k * height + i] * B[k * width + colIdx];
}
C[j] = scaleAB * sum + scaleT * C[j];
}
}
}
return;
}
if (!a.isTransposed() && b.isTransposed()) {
CHECK(b.getWidth() == a.getWidth() && a.getHeight() == height &&
b.getHeight() == width);
size_t m = a.getWidth();
if (out.getFormat() == SPARSE_CSR) { if (out.getFormat() == SPARSE_CSR) {
/// a and b can not both transpose
CHECK(!(a.isTransposed() && b.isTransposed()));
size_t m = a.getWidth();
for (size_t i = 0; i < height; i++) { for (size_t i = 0; i < height; i++) {
size_t start = out.getRowStartIdx(i); size_t start = out.getRowStartIdx(i);
size_t end = out.getRowStartIdx(i + 1); size_t end = out.getRowStartIdx(i + 1);
...@@ -154,12 +104,13 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out, ...@@ -154,12 +104,13 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
real sum = 0; real sum = 0;
size_t colIdx = cols[j]; size_t colIdx = cols[j];
for (size_t k = 0; k < m; k++) { for (size_t k = 0; k < m; k++) {
sum += A[i * m + k] * B[colIdx * m + k]; sum +=
(!a.isTransposed() ? A[i * m + k] : A[k * height + i]) *
(!b.isTransposed() ? B[k * width + colIdx] : B[colIdx * m + k]);
} }
C[j] = scaleAB * sum + scaleT * C[j]; C[j] = scaleAB * sum + scaleT * C[j];
} }
} }
}
return; return;
} }
} }
...@@ -330,11 +281,11 @@ public: ...@@ -330,11 +281,11 @@ public:
CHECK_EQ(outputs[0].shape().ndims(), (size_t)2); CHECK_EQ(outputs[0].shape().ndims(), (size_t)2);
CHECK_EQ(outputs[0].getArgType(), ADD_TO); CHECK_EQ(outputs[0].getArgType(), ADD_TO);
auto out_mat = outputs[0].matrix<Device>(); auto outMat = outputs[0].matrix<Device>();
/// matrix = matrix * matrix /// matrix = matrix * matrix
if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() && if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
!outputs[0].isSparseArg()) { !outputs[0].isSparseArg()) {
MulOp<Device>(out_mat, MulOp<Device>(outMat,
inputs[0].matrix<Device>(), inputs[0].matrix<Device>(),
inputs[1].matrix<Device>(), inputs[1].matrix<Device>(),
alpha_, alpha_,
...@@ -345,7 +296,7 @@ public: ...@@ -345,7 +296,7 @@ public:
/// matrix = matrix * sparse matrix /// matrix = matrix * sparse matrix
if (!inputs[0].isSparseArg() && inputs[1].isSparseArg() && if (!inputs[0].isSparseArg() && inputs[1].isSparseArg() &&
!outputs[0].isSparseArg()) { !outputs[0].isSparseArg()) {
MulOp<Device>(out_mat, MulOp<Device>(outMat,
inputs[0].matrix<Device>(), inputs[0].matrix<Device>(),
inputs[1].sparse().SparseMatrix<Device>(), inputs[1].sparse().SparseMatrix<Device>(),
alpha_, alpha_,
...@@ -356,7 +307,7 @@ public: ...@@ -356,7 +307,7 @@ public:
/// matrix = sparse matrix * matrix /// matrix = sparse matrix * matrix
if (inputs[0].isSparseArg() && !inputs[1].isSparseArg() && if (inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
!outputs[0].isSparseArg()) { !outputs[0].isSparseArg()) {
MulOp<Device>(out_mat, MulOp<Device>(outMat,
inputs[0].sparse().SparseMatrix<Device>(), inputs[0].sparse().SparseMatrix<Device>(),
inputs[1].matrix<Device>(), inputs[1].matrix<Device>(),
alpha_, alpha_,
...@@ -365,10 +316,10 @@ public: ...@@ -365,10 +316,10 @@ public:
} }
/// sparse matrix = matrix * matrix /// sparse matrix = matrix * matrix
auto out_sparse_mat = outputs[0].sparse().SparseMatrix<Device>(); auto outSparseMat = outputs[0].sparse().SparseMatrix<Device>();
if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() && if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
outputs[0].isSparseArg()) { outputs[0].isSparseArg()) {
MulOp<Device>(out_sparse_mat, MulOp<Device>(outSparseMat,
inputs[0].matrix<Device>(), inputs[0].matrix<Device>(),
inputs[1].matrix<Device>(), inputs[1].matrix<Device>(),
alpha_, alpha_,
......
...@@ -15,8 +15,6 @@ limitations under the License. */ ...@@ -15,8 +15,6 @@ limitations under the License. */
#pragma once #pragma once
#include "Function.h" #include "Function.h"
/// todo(tianbing), delete
#include <iostream>
#include "paddle/math/Matrix.h" #include "paddle/math/Matrix.h"
#include "paddle/math/SparseMatrix.h" #include "paddle/math/SparseMatrix.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册