MulOp.cpp 11.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "MulOp.h"
16 17
/// todo(tianbing), delete it
#include <iostream>
18 19 20 21 22 23 24 25 26 27 28 29 30
#include "paddle/math/MathFunctions.h"
#include "paddle/math/SIMDFunctions.h"
#include "paddle/utils/ThreadLocal.h"

#ifndef PADDLE_TYPE_DOUBLE
#define GEMM paddle::gemm<float>
#else
#define GEMM paddle::gemm<double>
#endif

namespace {
inline void vecAddTo(real* a, const real* b, real scaleB, size_t len) {
  for (unsigned int i = 0; i < len; ++i) {
X
xutianbing 已提交
31
    a[i] += (1.0 == scaleB) ? b[i] : scaleB * b[i];
32 33 34 35 36 37
  }
}

inline void colVecAddTo(
    real* a, real* b, real c, size_t len, size_t aWidth, size_t bWidth) {
  for (unsigned int i = 0; i < len; ++i) {
X
xutianbing 已提交
38
    a[i * aWidth] += (1.0 == c) ? b[i * bWidth] : b[i * bWidth] * c;
39 40 41
  }
}
}  // namespace
42 43

namespace paddle {
44
/// sparse matrix (+)= dense matrix * dense matrix
45 46 47 48 49
template <>
void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
                            const CpuMatrix& a,
                            const CpuMatrix& b,
                            real scaleAB,
X
xutianbing 已提交
50 51
                            real scaleT,
                            bool aTrans,
X
xutianbing 已提交
52
                            bool bTrans) {
53
  CHECK_EQ(out.getValueType(), FLOAT_VALUE);
X
xutianbing 已提交
54 55 56
  if (scaleT == 0) {
    out.zeroMem();
  }
57 58 59 60 61
  const real* A = a.getData();
  const real* B = b.getData();
  real* C = out.getValue();
  int* rows = out.getRows();
  int* cols = out.getCols();
X
xutianbing 已提交
62 63
  size_t width = out.getWidth();
  size_t height = out.getHeight();
64

X
xutianbing 已提交
65 66 67
  /// SPARSE_CSC, {a any, b not trans}
  if (out.getFormat() == SPARSE_CSC) {
    /// b not trans and a any
X
xutianbing 已提交
68 69
    CHECK(!bTrans);
    size_t m = !aTrans ? a.getWidth() : a.getHeight();
X
xutianbing 已提交
70 71 72 73 74 75 76
    for (size_t i = 0; i < width; i++) {
      size_t start = out.getColStartIdx(i);
      size_t end = out.getColStartIdx(i + 1);
      for (size_t j = start; j < end; j++) {
        real sum = 0;
        size_t rowIdx = rows[j];
        for (size_t k = 0; k < m; k++) {
X
xutianbing 已提交
77 78
          sum += (!aTrans ? A[rowIdx * m + k] : A[k * height + rowIdx]) *
                 B[k * width + i];
79
        }
X
xutianbing 已提交
80
        C[j] = scaleAB * sum + scaleT * C[j];
81 82
      }
    }
X
xutianbing 已提交
83 84 85
    return;
  }

X
xutianbing 已提交
86 87 88
  /// SPARSE_CSR, {a any, b not trans} or {a not trans, b trans}
  if (out.getFormat() == SPARSE_CSR) {
    /// a and b can not both transpose
X
xutianbing 已提交
89
    CHECK(!(aTrans && bTrans));
90
    size_t m = a.getWidth();
X
xutianbing 已提交
91 92 93 94 95 96 97
    for (size_t i = 0; i < height; i++) {
      size_t start = out.getRowStartIdx(i);
      size_t end = out.getRowStartIdx(i + 1);
      for (size_t j = start; j < end; j++) {
        real sum = 0;
        size_t colIdx = cols[j];
        for (size_t k = 0; k < m; k++) {
X
xutianbing 已提交
98 99
          sum += (!aTrans ? A[i * m + k] : A[k * height + i]) *
                 (!bTrans ? B[k * width + colIdx] : B[colIdx * m + k]);
100
        }
X
xutianbing 已提交
101
        C[j] = scaleAB * sum + scaleT * C[j];
102 103
      }
    }
X
xutianbing 已提交
104
    return;
105 106 107
  }
}

108
/// dense matrix (+)= dense matrix * dense matrix
109 110 111 112 113
template <>
void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
                            const CpuMatrix& a,
                            const CpuMatrix& b,
                            real scaleAB,
X
xutianbing 已提交
114 115
                            real scaleT,
                            bool aTrans,
X
xutianbing 已提交
116
                            bool bTrans) {
X
xutianbing 已提交
117 118
  GEMM(aTrans ? CblasTrans : CblasNoTrans,
       bTrans ? CblasTrans : CblasNoTrans,
X
xutianbing 已提交
119 120
       out.getHeight(),
       out.getWidth(),
X
xutianbing 已提交
121
       !aTrans ? a.getWidth() : a.getHeight(),
X
xutianbing 已提交
122 123 124 125 126 127 128 129
       scaleAB,
       a.getData(),
       a.getStride(),
       b.getData(),
       b.getStride(),
       scaleT,
       out.getData(),
       out.getStride());
130 131
}

132
/// dense matrix (+)= sparse matrix * dense matrix
133 134 135 136 137
template <>
void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
                            const CpuSparseMatrix& a,
                            const CpuMatrix& b,
                            real scaleAB,
X
xutianbing 已提交
138 139
                            real scaleT,
                            bool aTrans,
X
xutianbing 已提交
140
                            bool bTrans) {
141 142 143
  if (scaleT == 0) {
    out.zeroMem();
  }
X
xutianbing 已提交
144 145 146 147 148 149
  const real* B = b.getData();
  real* C = out.getData();
  if (out.getWidth() % 32 == 0) {
    CHECK_EQ((size_t)B % 32, 0UL);
    CHECK_EQ((size_t)C % 32, 0UL);
  }
150

X
xutianbing 已提交
151 152 153 154 155 156
  int* cols = a.getCols();
  real* values = a.getValue();
  for (size_t i = 0; i < a.getHeight(); ++i) {
    const int start = a.getRowStartIdx(i);
    const int end = a.getRowStartIdx(i + 1);
    for (int j = start; j < end; ++j) {
X
xutianbing 已提交
157 158 159
      vecAddTo(!aTrans ? out.getRow(i) : out.getRow(cols[j]),
               !aTrans ? const_cast<CpuMatrix&>(b).getRow(cols[j])
                       : const_cast<CpuMatrix&>(b).getRow(i),
X
xutianbing 已提交
160 161
               (a.getValueType() == FLOAT_VALUE) ? values[j] : (real)1.0,
               out.getWidth());
162 163 164 165
    }
  }
}

166
/// dense matrix (+)= dense matrix * sparse matrix
167 168 169 170 171
template <>
void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
                            const CpuMatrix& a,
                            const CpuSparseMatrix& b,
                            real scaleAB,
X
xutianbing 已提交
172 173
                            real scaleT,
                            bool aTrans,
X
xutianbing 已提交
174
                            bool bTrans) {
X
xutianbing 已提交
175 176 177
  if (scaleT == 0) {
    out.zeroMem();
  }
178 179 180 181 182 183
  real* A = const_cast<real*>(a.getData());
  real* B = const_cast<real*>(b.getValue());
  real* C = out.getData();
  int* rows = b.getRows();
  int* cols = b.getCols();

184
  /// SPARSE_CSC format
185
  if (b.getFormat() == SPARSE_CSC) {
X
xutianbing 已提交
186 187 188 189
    for (size_t j = 0; j < b.getWidth(); ++j) {
      int start = b.getColStartIdx(j);
      int end = b.getColStartIdx(j + 1);
      for (int i = start; i < end; ++i) {
X
xutianbing 已提交
190 191
        colVecAddTo(!bTrans ? C + j : C + rows[i],
                    !bTrans ? A + rows[i] : A + j,
X
xutianbing 已提交
192 193 194 195
                    (b.getValueType() == NO_VALUE) ? (real)1.0 : B[i],
                    out.getHeight(),
                    out.getWidth(),
                    a.getWidth());
196 197
      }
    }
X
xutianbing 已提交
198 199 200
    return;
  }

201
  /// SPARSE_CSR format
X
xutianbing 已提交
202 203 204 205 206
  if (b.getFormat() == SPARSE_CSR) {
    for (size_t j = 0; j < b.getHeight(); ++j) {
      int start = b.getRowStartIdx(j);
      int end = b.getRowStartIdx(j + 1);
      for (int i = start; i < end; ++i) {
X
xutianbing 已提交
207 208
        colVecAddTo(!bTrans ? C + cols[i] : C + j,
                    !bTrans ? A + j : A + cols[i],
X
xutianbing 已提交
209 210 211 212
                    (b.getValueType() == NO_VALUE) ? (real)1.0 : B[i],
                    out.getHeight(),
                    out.getWidth(),
                    a.getWidth());
213 214
      }
    }
X
xutianbing 已提交
215
    return;
216 217
  }
}
218 219 220

/**
 * mul operator
X
xutianbing 已提交
221
 * out = scaleT * out + scaleAB * (A * B)
222
 * here, scaleT in {0, 1}, scaleAB == 1,
X
xutianbing 已提交
223 224
 * out = A * B, ASSIGN_TO
 * out += A * B, ADD_TO
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
 *
 *
 * \param outputs[0]      output matrix (out), M * N,
 *                        could be either Sparse or Dense Matrix
 *                        M is num of rows, N is num of columns
 * \param inputs[0]       first input matrix (A),  M * K (if non-trans)
 *                        could be either Sparse or Dense Matrix
 *                        M is num of rows, K is num of columns
 * \param inputs[1]       second input matrix (B), K * N (if non-trans)
 *                        could be either Sparse or Dense Matrix
 *                        K is num of rows, N is num of columns
 *
 * Support eight Mul operators, with both GPU and CPU devices
 * For each device, four Mul operators are supported:
 * 1. dense (out) = dense (A) * dense (B)
 * 2. dense (out) = sparse (A) * dense (B)
 *    sparse matrix only support SPARSE_CSR format
 * 3. dense (out) = dense (A) * sparse (B)
 *    sparse matrix support SPARSE_CSC and SPARSE_CSR formats
 * 4. sparse (out) = dense (A) * dense (B)
 *    sparse matrix support SPARSE_CSC and SPARSE_CSR formats
246 247 248 249 250 251
 *
 */
template <DeviceType Device>
class MulFunc : public FunctionBase {
public:
  void init(const FuncConfig& config) override {
X
xutianbing 已提交
252 253
    aTrans_ = config.get<bool>("aTrans");
    bTrans_ = config.get<bool>("bTrans");
254 255 256
  }

  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
X
xutianbing 已提交
257 258 259
    CHECK(!aTrans_ || !bTrans_)
        << "Not support both a and b are transpose matrices";

260 261
    CHECK_EQ((size_t)2, inputs.size());
    CHECK_EQ((size_t)1, outputs.size());
262 263 264 265
    CHECK(inputs[0].data() && inputs[1].data() && outputs[0].data());
    CHECK_EQ(inputs[0].shape().ndims(), (size_t)2);
    CHECK_EQ(inputs[1].shape().ndims(), (size_t)2);
    CHECK_EQ(outputs[0].shape().ndims(), (size_t)2);
X
xutianbing 已提交
266 267 268 269 270 271 272 273 274 275

    size_t aRow = !aTrans_ ? inputs[0].shape()[0] : inputs[0].shape()[1];
    size_t aCol = !aTrans_ ? inputs[0].shape()[1] : inputs[0].shape()[0];
    size_t bRow = !bTrans_ ? inputs[1].shape()[0] : inputs[1].shape()[1];
    size_t bCol = !bTrans_ ? inputs[1].shape()[1] : inputs[1].shape()[0];
    /// C = A * B, or C += A * B, for matrix format
    CHECK_EQ(aCol, bRow);
    CHECK_EQ(aRow, outputs[0].shape()[0]);
    CHECK_EQ(bCol, outputs[0].shape()[1]);

X
xutianbing 已提交
276 277
    /// only support C = A * B (ASSIGN_TO) or C += A * B (ADD_TO)
    real scaleT = (outputs[0].getArgType() == ADD_TO) ? 1.0 : 0.0;
X
xutianbing 已提交
278 279 280 281 282 283 284

    /// support dense = not both sparse * sparse
    /// or sparse = dense * dense
    CHECK((!outputs[0].isSparseArg() &&
           !(inputs[0].isSparseArg() && inputs[1].isSparseArg())) ||
          (outputs[0].isSparseArg() && !inputs[0].isSparseArg() &&
           !inputs[1].isSparseArg()));
285

X
xutianbing 已提交
286
    auto outMat = outputs[0].matrix<Device>();
287
    /// dense matrix = dense matrix * dense matrix
288 289
    if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
        !outputs[0].isSparseArg()) {
X
xutianbing 已提交
290
      MulOp<Device>(outMat,
291 292
                    inputs[0].matrix<Device>(),
                    inputs[1].matrix<Device>(),
X
xutianbing 已提交
293 294
                    1.0,  // scaleAB
                    scaleT,
X
xutianbing 已提交
295
                    aTrans_,
X
xutianbing 已提交
296
                    bTrans_);
297
      return;
298
    }
299

300
    /// dense matrix = dense matrix * sparse matrix
301 302
    if (!inputs[0].isSparseArg() && inputs[1].isSparseArg() &&
        !outputs[0].isSparseArg()) {
X
xutianbing 已提交
303
      CHECK(!aTrans_) << "Not supported a transpose";
X
xutianbing 已提交
304
      MulOp<Device>(outMat,
305 306
                    inputs[0].matrix<Device>(),
                    inputs[1].sparse().SparseMatrix<Device>(),
X
xutianbing 已提交
307 308
                    1.0,  // scaleAB
                    scaleT,
X
xutianbing 已提交
309
                    aTrans_,
X
xutianbing 已提交
310
                    bTrans_);
311 312 313
      return;
    }

314
    /// dense matrix = sparse matrix * dense matrix
315 316
    if (inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
        !outputs[0].isSparseArg()) {
X
xutianbing 已提交
317
      CHECK(!bTrans_) << "Not supported b transpose";
318 319
      CHECK_EQ(inputs[0].sparse().dataFormat(), T_SPARSE_CSR)
          << "Only supported SPARSE_CSR format for sparse matrix a";
X
xutianbing 已提交
320
      MulOp<Device>(outMat,
321 322
                    inputs[0].sparse().SparseMatrix<Device>(),
                    inputs[1].matrix<Device>(),
X
xutianbing 已提交
323 324
                    1.0,  // scaleAB
                    scaleT,
X
xutianbing 已提交
325
                    aTrans_,
X
xutianbing 已提交
326
                    bTrans_);
327
      return;
328
    }
329

330
    /// sparse matrix = dense matrix * dense matrix
X
xutianbing 已提交
331
    auto outSparseMat = outputs[0].sparse().SparseMatrix<Device>();
332 333
    if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
        outputs[0].isSparseArg()) {
X
xutianbing 已提交
334
      MulOp<Device>(outSparseMat,
335 336
                    inputs[0].matrix<Device>(),
                    inputs[1].matrix<Device>(),
X
xutianbing 已提交
337 338
                    1.0,  // scaleAB
                    scaleT,
X
xutianbing 已提交
339
                    aTrans_,
X
xutianbing 已提交
340
                    bTrans_);
341 342
      return;
    }
343 344 345
  }

private:
X
xutianbing 已提交
346 347
  bool aTrans_;
  bool bTrans_;
348 349
};

350
REGISTER_TYPED_FUNC(MulOp, CPU, MulFunc);
351 352 353 354
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(MulOp, GPU, MulFunc);
#endif
}  // namespace paddle