diff --git a/paddle/function/MulOp.cpp b/paddle/function/MulOp.cpp index b911ccd13b607a6ea97d9844bfbc58166adc3fd6..bd3bc5c087d669eca117234d3a14b83d2be488ba 100644 --- a/paddle/function/MulOp.cpp +++ b/paddle/function/MulOp.cpp @@ -26,22 +26,16 @@ limitations under the License. */ #endif namespace { -inline void vecAddTo(real* a, const real* b, size_t len) { - for (unsigned int i = 0; i < len; ++i) { - a[i] += b[i]; - } -} - inline void vecAddTo(real* a, const real* b, real scaleB, size_t len) { for (unsigned int i = 0; i < len; ++i) { - a[i] += scaleB * b[i]; + a[i] += (1.0 == scaleB) ? b[i] : scaleB * b[i]; } } inline void colVecAddTo( real* a, real* b, real c, size_t len, size_t aWidth, size_t bWidth) { for (unsigned int i = 0; i < len; ++i) { - a[i * aWidth] += b[i * bWidth] * c; + a[i * aWidth] += (1.0 == c) ? b[i * bWidth] : b[i * bWidth] * c; } } } // namespace @@ -53,15 +47,19 @@ void MulOp(CpuSparseMatrix& out, const CpuMatrix& b, real scaleAB, real scaleT) { - /// todo(tianbing), clean the code CHECK(!out.isTransposed()) << "Not supported"; CHECK_EQ(out.getValueType(), FLOAT_VALUE); CHECK(!a.isTransposed() || !b.isTransposed()) << "Not support both a and b are transpose matrices"; - if (!a.isTransposed() && b.isTransposed()) { - CHECK(out.getFormat() != SPARSE_CSC) - << "Not supported CSC format when a is not trans and b is trans"; - } + + size_t height = out.getHeight(); + size_t width = out.getWidth(); + size_t aRow = !a.isTransposed() ? a.getHeight() : a.getWidth(); + size_t aCol = !a.isTransposed() ? a.getWidth() : a.getHeight(); + size_t bRow = !b.isTransposed() ? b.getHeight() : b.getWidth(); + size_t bCol = !b.isTransposed() ? b.getWidth() : b.getHeight(); + /// C = A * B, for matrix format + CHECK(aCol == bRow && aRow == height && bCol == width); if (scaleT == 0) { out.zeroMem(); @@ -71,93 +69,46 @@ void MulOp(CpuSparseMatrix& out, real* C = out.getValue(); int* rows = out.getRows(); int* cols = out.getCols(); - size_t height = out.getHeight(); - size_t width = out.getWidth(); - if (!a.isTransposed() && !b.isTransposed()) { - CHECK(b.getHeight() == a.getWidth() && a.getHeight() == height && - b.getWidth() == width); - size_t m = a.getWidth(); - if (out.getFormat() == SPARSE_CSC) { - for (size_t i = 0; i < width; i++) { - size_t start = out.getColStartIdx(i); - size_t end = out.getColStartIdx(i + 1); - for (size_t j = start; j < end; j++) { - real sum = 0; - size_t rowIdx = rows[j]; - for (size_t k = 0; k < m; k++) { - sum += A[rowIdx * m + k] * B[k * width + i]; - } - C[j] = scaleAB * sum + scaleT * C[j]; - } - } - } else { /// out.getFormat() == SPARSE_CSR - for (size_t i = 0; i < height; i++) { - size_t start = out.getRowStartIdx(i); - size_t end = out.getRowStartIdx(i + 1); - for (size_t j = start; j < end; j++) { - real sum = 0; - size_t colIdx = cols[j]; - for (size_t k = 0; k < a.getWidth(); k++) { - sum += A[i * m + k] * B[k * width + colIdx]; - } - C[j] = scaleAB * sum + scaleT * C[j]; - } - } - } - return; - } - - if (a.isTransposed() && !b.isTransposed()) { - CHECK(a.getHeight() == b.getHeight() && b.getWidth() == width && - a.getWidth() == height); - size_t m = a.getHeight(); - if (out.getFormat() == SPARSE_CSC) { - for (size_t i = 0; i < width; i++) { - size_t start = out.getColStartIdx(i); - size_t end = out.getColStartIdx(i + 1); - for (size_t j = start; j < end; j++) { - real sum = 0; - size_t rowIdx = rows[j]; - for (size_t k = 0; k < m; k++) { - sum += A[k * height + rowIdx] * B[k * width + i]; - } - C[j] = scaleAB * sum + scaleT * C[j]; - } - } - } else { /// out.getFormat() == SPARSE_CSR - for (size_t i = 0; i < height; i++) { - int start = out.getRowStartIdx(i); - int end = out.getRowStartIdx(i + 1); - for (int j = start; j < end; j++) { - real sum = 0; - size_t colIdx = cols[j]; - for (size_t k = 0; k < a.getHeight(); k++) { - sum += A[k * height + i] * B[k * width + colIdx]; - } - C[j] = scaleAB * sum + scaleT * C[j]; + /// SPARSE_CSC, {a any, b not trans} + if (out.getFormat() == SPARSE_CSC) { + /// b not trans and a any + CHECK(!b.isTransposed()); + size_t m = !a.isTransposed() ? a.getWidth() : a.getHeight(); + for (size_t i = 0; i < width; i++) { + size_t start = out.getColStartIdx(i); + size_t end = out.getColStartIdx(i + 1); + for (size_t j = start; j < end; j++) { + real sum = 0; + size_t rowIdx = rows[j]; + for (size_t k = 0; k < m; k++) { + sum += + (!a.isTransposed() ? A[rowIdx * m + k] : A[k * height + rowIdx]) * + B[k * width + i]; } + C[j] = scaleAB * sum + scaleT * C[j]; } } return; } - if (!a.isTransposed() && b.isTransposed()) { - CHECK(b.getWidth() == a.getWidth() && a.getHeight() == height && - b.getHeight() == width); + /// SPARSE_CSR, {a any, b not trans} or {a not trans, b trans} + if (out.getFormat() == SPARSE_CSR) { + /// a and b can not both transpose + CHECK(!(a.isTransposed() && b.isTransposed())); size_t m = a.getWidth(); - if (out.getFormat() == SPARSE_CSR) { - for (size_t i = 0; i < height; i++) { - size_t start = out.getRowStartIdx(i); - size_t end = out.getRowStartIdx(i + 1); - for (size_t j = start; j < end; j++) { - real sum = 0; - size_t colIdx = cols[j]; - for (size_t k = 0; k < m; k++) { - sum += A[i * m + k] * B[colIdx * m + k]; - } - C[j] = scaleAB * sum + scaleT * C[j]; + for (size_t i = 0; i < height; i++) { + size_t start = out.getRowStartIdx(i); + size_t end = out.getRowStartIdx(i + 1); + for (size_t j = start; j < end; j++) { + real sum = 0; + size_t colIdx = cols[j]; + for (size_t k = 0; k < m; k++) { + sum += + (!a.isTransposed() ? A[i * m + k] : A[k * height + i]) * + (!b.isTransposed() ? B[k * width + colIdx] : B[colIdx * m + k]); } + C[j] = scaleAB * sum + scaleT * C[j]; } } return; @@ -330,11 +281,11 @@ public: CHECK_EQ(outputs[0].shape().ndims(), (size_t)2); CHECK_EQ(outputs[0].getArgType(), ADD_TO); - auto out_mat = outputs[0].matrix(); + auto outMat = outputs[0].matrix(); /// matrix = matrix * matrix if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() && !outputs[0].isSparseArg()) { - MulOp(out_mat, + MulOp(outMat, inputs[0].matrix(), inputs[1].matrix(), alpha_, @@ -345,7 +296,7 @@ public: /// matrix = matrix * sparse matrix if (!inputs[0].isSparseArg() && inputs[1].isSparseArg() && !outputs[0].isSparseArg()) { - MulOp(out_mat, + MulOp(outMat, inputs[0].matrix(), inputs[1].sparse().SparseMatrix(), alpha_, @@ -356,7 +307,7 @@ public: /// matrix = sparse matrix * matrix if (inputs[0].isSparseArg() && !inputs[1].isSparseArg() && !outputs[0].isSparseArg()) { - MulOp(out_mat, + MulOp(outMat, inputs[0].sparse().SparseMatrix(), inputs[1].matrix(), alpha_, @@ -365,10 +316,10 @@ public: } /// sparse matrix = matrix * matrix - auto out_sparse_mat = outputs[0].sparse().SparseMatrix(); + auto outSparseMat = outputs[0].sparse().SparseMatrix(); if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() && outputs[0].isSparseArg()) { - MulOp(out_sparse_mat, + MulOp(outSparseMat, inputs[0].matrix(), inputs[1].matrix(), alpha_, diff --git a/paddle/function/MulOp.h b/paddle/function/MulOp.h index 23bfd0fa932178f17a4ddeb599fc94eedb7cf4be..b7b1f56af10375b95b7e1350c406fa68d422c1e1 100644 --- a/paddle/function/MulOp.h +++ b/paddle/function/MulOp.h @@ -15,8 +15,6 @@ limitations under the License. */ #pragma once #include "Function.h" -/// todo(tianbing), delete -#include #include "paddle/math/Matrix.h" #include "paddle/math/SparseMatrix.h"