提交 316bf75a 编写于 作者: X xutianbing

clean code in function/MulOp.cpp

上级 9ade63e6
......@@ -26,22 +26,16 @@ limitations under the License. */
#endif
namespace {
inline void vecAddTo(real* a, const real* b, size_t len) {
for (unsigned int i = 0; i < len; ++i) {
a[i] += b[i];
}
}
inline void vecAddTo(real* a, const real* b, real scaleB, size_t len) {
for (unsigned int i = 0; i < len; ++i) {
a[i] += scaleB * b[i];
a[i] += (1.0 == scaleB) ? b[i] : scaleB * b[i];
}
}
inline void colVecAddTo(
real* a, real* b, real c, size_t len, size_t aWidth, size_t bWidth) {
for (unsigned int i = 0; i < len; ++i) {
a[i * aWidth] += b[i * bWidth] * c;
a[i * aWidth] += (1.0 == c) ? b[i * bWidth] : b[i * bWidth] * c;
}
}
} // namespace
......@@ -53,15 +47,19 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
const CpuMatrix& b,
real scaleAB,
real scaleT) {
/// todo(tianbing), clean the code
CHECK(!out.isTransposed()) << "Not supported";
CHECK_EQ(out.getValueType(), FLOAT_VALUE);
CHECK(!a.isTransposed() || !b.isTransposed())
<< "Not support both a and b are transpose matrices";
if (!a.isTransposed() && b.isTransposed()) {
CHECK(out.getFormat() != SPARSE_CSC)
<< "Not supported CSC format when a is not trans and b is trans";
}
size_t height = out.getHeight();
size_t width = out.getWidth();
size_t aRow = !a.isTransposed() ? a.getHeight() : a.getWidth();
size_t aCol = !a.isTransposed() ? a.getWidth() : a.getHeight();
size_t bRow = !b.isTransposed() ? b.getHeight() : b.getWidth();
size_t bCol = !b.isTransposed() ? b.getWidth() : b.getHeight();
/// C = A * B, for matrix format
CHECK(aCol == bRow && aRow == height && bCol == width);
if (scaleT == 0) {
out.zeroMem();
......@@ -71,93 +69,46 @@ void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
real* C = out.getValue();
int* rows = out.getRows();
int* cols = out.getCols();
size_t height = out.getHeight();
size_t width = out.getWidth();
if (!a.isTransposed() && !b.isTransposed()) {
CHECK(b.getHeight() == a.getWidth() && a.getHeight() == height &&
b.getWidth() == width);
size_t m = a.getWidth();
if (out.getFormat() == SPARSE_CSC) {
for (size_t i = 0; i < width; i++) {
size_t start = out.getColStartIdx(i);
size_t end = out.getColStartIdx(i + 1);
for (size_t j = start; j < end; j++) {
real sum = 0;
size_t rowIdx = rows[j];
for (size_t k = 0; k < m; k++) {
sum += A[rowIdx * m + k] * B[k * width + i];
}
C[j] = scaleAB * sum + scaleT * C[j];
}
}
} else { /// out.getFormat() == SPARSE_CSR
for (size_t i = 0; i < height; i++) {
size_t start = out.getRowStartIdx(i);
size_t end = out.getRowStartIdx(i + 1);
for (size_t j = start; j < end; j++) {
real sum = 0;
size_t colIdx = cols[j];
for (size_t k = 0; k < a.getWidth(); k++) {
sum += A[i * m + k] * B[k * width + colIdx];
}
C[j] = scaleAB * sum + scaleT * C[j];
}
}
}
return;
}
if (a.isTransposed() && !b.isTransposed()) {
CHECK(a.getHeight() == b.getHeight() && b.getWidth() == width &&
a.getWidth() == height);
size_t m = a.getHeight();
if (out.getFormat() == SPARSE_CSC) {
for (size_t i = 0; i < width; i++) {
size_t start = out.getColStartIdx(i);
size_t end = out.getColStartIdx(i + 1);
for (size_t j = start; j < end; j++) {
real sum = 0;
size_t rowIdx = rows[j];
for (size_t k = 0; k < m; k++) {
sum += A[k * height + rowIdx] * B[k * width + i];
}
C[j] = scaleAB * sum + scaleT * C[j];
}
}
} else { /// out.getFormat() == SPARSE_CSR
for (size_t i = 0; i < height; i++) {
int start = out.getRowStartIdx(i);
int end = out.getRowStartIdx(i + 1);
for (int j = start; j < end; j++) {
real sum = 0;
size_t colIdx = cols[j];
for (size_t k = 0; k < a.getHeight(); k++) {
sum += A[k * height + i] * B[k * width + colIdx];
}
C[j] = scaleAB * sum + scaleT * C[j];
/// SPARSE_CSC, {a any, b not trans}
if (out.getFormat() == SPARSE_CSC) {
/// b not trans and a any
CHECK(!b.isTransposed());
size_t m = !a.isTransposed() ? a.getWidth() : a.getHeight();
for (size_t i = 0; i < width; i++) {
size_t start = out.getColStartIdx(i);
size_t end = out.getColStartIdx(i + 1);
for (size_t j = start; j < end; j++) {
real sum = 0;
size_t rowIdx = rows[j];
for (size_t k = 0; k < m; k++) {
sum +=
(!a.isTransposed() ? A[rowIdx * m + k] : A[k * height + rowIdx]) *
B[k * width + i];
}
C[j] = scaleAB * sum + scaleT * C[j];
}
}
return;
}
if (!a.isTransposed() && b.isTransposed()) {
CHECK(b.getWidth() == a.getWidth() && a.getHeight() == height &&
b.getHeight() == width);
/// SPARSE_CSR, {a any, b not trans} or {a not trans, b trans}
if (out.getFormat() == SPARSE_CSR) {
/// a and b can not both transpose
CHECK(!(a.isTransposed() && b.isTransposed()));
size_t m = a.getWidth();
if (out.getFormat() == SPARSE_CSR) {
for (size_t i = 0; i < height; i++) {
size_t start = out.getRowStartIdx(i);
size_t end = out.getRowStartIdx(i + 1);
for (size_t j = start; j < end; j++) {
real sum = 0;
size_t colIdx = cols[j];
for (size_t k = 0; k < m; k++) {
sum += A[i * m + k] * B[colIdx * m + k];
}
C[j] = scaleAB * sum + scaleT * C[j];
for (size_t i = 0; i < height; i++) {
size_t start = out.getRowStartIdx(i);
size_t end = out.getRowStartIdx(i + 1);
for (size_t j = start; j < end; j++) {
real sum = 0;
size_t colIdx = cols[j];
for (size_t k = 0; k < m; k++) {
sum +=
(!a.isTransposed() ? A[i * m + k] : A[k * height + i]) *
(!b.isTransposed() ? B[k * width + colIdx] : B[colIdx * m + k]);
}
C[j] = scaleAB * sum + scaleT * C[j];
}
}
return;
......@@ -330,11 +281,11 @@ public:
CHECK_EQ(outputs[0].shape().ndims(), (size_t)2);
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
auto out_mat = outputs[0].matrix<Device>();
auto outMat = outputs[0].matrix<Device>();
/// matrix = matrix * matrix
if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
!outputs[0].isSparseArg()) {
MulOp<Device>(out_mat,
MulOp<Device>(outMat,
inputs[0].matrix<Device>(),
inputs[1].matrix<Device>(),
alpha_,
......@@ -345,7 +296,7 @@ public:
/// matrix = matrix * sparse matrix
if (!inputs[0].isSparseArg() && inputs[1].isSparseArg() &&
!outputs[0].isSparseArg()) {
MulOp<Device>(out_mat,
MulOp<Device>(outMat,
inputs[0].matrix<Device>(),
inputs[1].sparse().SparseMatrix<Device>(),
alpha_,
......@@ -356,7 +307,7 @@ public:
/// matrix = sparse matrix * matrix
if (inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
!outputs[0].isSparseArg()) {
MulOp<Device>(out_mat,
MulOp<Device>(outMat,
inputs[0].sparse().SparseMatrix<Device>(),
inputs[1].matrix<Device>(),
alpha_,
......@@ -365,10 +316,10 @@ public:
}
/// sparse matrix = matrix * matrix
auto out_sparse_mat = outputs[0].sparse().SparseMatrix<Device>();
auto outSparseMat = outputs[0].sparse().SparseMatrix<Device>();
if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
outputs[0].isSparseArg()) {
MulOp<Device>(out_sparse_mat,
MulOp<Device>(outSparseMat,
inputs[0].matrix<Device>(),
inputs[1].matrix<Device>(),
alpha_,
......
......@@ -15,8 +15,6 @@ limitations under the License. */
#pragma once
#include "Function.h"
/// todo(tianbing), delete
#include <iostream>
#include "paddle/math/Matrix.h"
#include "paddle/math/SparseMatrix.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册