MulOp.cpp 12.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "MulOp.h"
16 17
/// todo(tianbing), delete it
#include <iostream>
18 19 20 21 22 23 24 25 26 27 28 29 30
#include "paddle/math/MathFunctions.h"
#include "paddle/math/SIMDFunctions.h"
#include "paddle/utils/ThreadLocal.h"

#ifndef PADDLE_TYPE_DOUBLE
#define GEMM paddle::gemm<float>
#else
#define GEMM paddle::gemm<double>
#endif

namespace {
inline void vecAddTo(real* a, const real* b, real scaleB, size_t len) {
  for (unsigned int i = 0; i < len; ++i) {
X
xutianbing 已提交
31
    a[i] += (1.0 == scaleB) ? b[i] : scaleB * b[i];
32 33 34 35 36 37
  }
}

inline void colVecAddTo(
    real* a, real* b, real c, size_t len, size_t aWidth, size_t bWidth) {
  for (unsigned int i = 0; i < len; ++i) {
X
xutianbing 已提交
38
    a[i * aWidth] += (1.0 == c) ? b[i * bWidth] : b[i * bWidth] * c;
39 40 41
  }
}
}  // namespace
42 43

namespace paddle {
44
/// sparse matrix (+)= dense matrix * dense matrix
45 46 47 48 49
template <>
void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
                            const CpuMatrix& a,
                            const CpuMatrix& b,
                            real scaleAB,
X
xutianbing 已提交
50 51 52 53
                            real scaleT,
                            bool aTrans,
                            bool bTrans,
                            bool cTrans) {
54
  CHECK_EQ(out.getValueType(), FLOAT_VALUE);
X
xutianbing 已提交
55 56 57
  if (scaleT == 0) {
    out.zeroMem();
  }
58 59 60 61 62
  const real* A = a.getData();
  const real* B = b.getData();
  real* C = out.getValue();
  int* rows = out.getRows();
  int* cols = out.getCols();
X
xutianbing 已提交
63 64
  size_t width = out.getWidth();
  size_t height = out.getHeight();
65

X
xutianbing 已提交
66 67 68
  /// SPARSE_CSC, {a any, b not trans}
  if (out.getFormat() == SPARSE_CSC) {
    /// b not trans and a any
X
xutianbing 已提交
69 70
    CHECK(!bTrans);
    size_t m = !aTrans ? a.getWidth() : a.getHeight();
X
xutianbing 已提交
71 72 73 74 75 76 77
    for (size_t i = 0; i < width; i++) {
      size_t start = out.getColStartIdx(i);
      size_t end = out.getColStartIdx(i + 1);
      for (size_t j = start; j < end; j++) {
        real sum = 0;
        size_t rowIdx = rows[j];
        for (size_t k = 0; k < m; k++) {
X
xutianbing 已提交
78 79
          sum += (!aTrans ? A[rowIdx * m + k] : A[k * height + rowIdx]) *
                 B[k * width + i];
80
        }
X
xutianbing 已提交
81
        C[j] = scaleAB * sum + scaleT * C[j];
82 83
      }
    }
X
xutianbing 已提交
84 85 86
    return;
  }

X
xutianbing 已提交
87 88 89
  /// SPARSE_CSR, {a any, b not trans} or {a not trans, b trans}
  if (out.getFormat() == SPARSE_CSR) {
    /// a and b can not both transpose
X
xutianbing 已提交
90
    CHECK(!(aTrans && bTrans));
91
    size_t m = a.getWidth();
X
xutianbing 已提交
92 93 94 95 96 97 98
    for (size_t i = 0; i < height; i++) {
      size_t start = out.getRowStartIdx(i);
      size_t end = out.getRowStartIdx(i + 1);
      for (size_t j = start; j < end; j++) {
        real sum = 0;
        size_t colIdx = cols[j];
        for (size_t k = 0; k < m; k++) {
X
xutianbing 已提交
99 100
          sum += (!aTrans ? A[i * m + k] : A[k * height + i]) *
                 (!bTrans ? B[k * width + colIdx] : B[colIdx * m + k]);
101
        }
X
xutianbing 已提交
102
        C[j] = scaleAB * sum + scaleT * C[j];
103 104
      }
    }
X
xutianbing 已提交
105
    return;
106 107 108
  }
}

109
/// dense matrix (+)= dense matrix * dense matrix
110 111 112 113 114
template <>
void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
                            const CpuMatrix& a,
                            const CpuMatrix& b,
                            real scaleAB,
X
xutianbing 已提交
115 116 117 118 119 120
                            real scaleT,
                            bool aTrans,
                            bool bTrans,
                            bool cTrans) {
  GEMM(aTrans ? CblasTrans : CblasNoTrans,
       bTrans ? CblasTrans : CblasNoTrans,
X
xutianbing 已提交
121 122
       out.getHeight(),
       out.getWidth(),
X
xutianbing 已提交
123
       !aTrans ? a.getWidth() : a.getHeight(),
X
xutianbing 已提交
124 125 126 127 128 129 130 131
       scaleAB,
       a.getData(),
       a.getStride(),
       b.getData(),
       b.getStride(),
       scaleT,
       out.getData(),
       out.getStride());
132 133
}

134
/// dense matrix (+)= sparse matrix * dense matrix
135 136 137 138 139
template <>
void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
                            const CpuSparseMatrix& a,
                            const CpuMatrix& b,
                            real scaleAB,
X
xutianbing 已提交
140 141 142 143
                            real scaleT,
                            bool aTrans,
                            bool bTrans,
                            bool cTrans) {
144 145 146
  if (scaleT == 0) {
    out.zeroMem();
  }
X
xutianbing 已提交
147 148 149 150 151 152
  const real* B = b.getData();
  real* C = out.getData();
  if (out.getWidth() % 32 == 0) {
    CHECK_EQ((size_t)B % 32, 0UL);
    CHECK_EQ((size_t)C % 32, 0UL);
  }
153

X
xutianbing 已提交
154 155 156 157 158 159
  int* cols = a.getCols();
  real* values = a.getValue();
  for (size_t i = 0; i < a.getHeight(); ++i) {
    const int start = a.getRowStartIdx(i);
    const int end = a.getRowStartIdx(i + 1);
    for (int j = start; j < end; ++j) {
X
xutianbing 已提交
160 161 162
      vecAddTo(!aTrans ? out.getRow(i) : out.getRow(cols[j]),
               !aTrans ? const_cast<CpuMatrix&>(b).getRow(cols[j])
                       : const_cast<CpuMatrix&>(b).getRow(i),
X
xutianbing 已提交
163 164
               (a.getValueType() == FLOAT_VALUE) ? values[j] : (real)1.0,
               out.getWidth());
165 166 167 168
    }
  }
}

169
/// dense matrix (+)= dense matrix * sparse matrix
170 171 172 173 174
template <>
void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
                            const CpuMatrix& a,
                            const CpuSparseMatrix& b,
                            real scaleAB,
X
xutianbing 已提交
175 176 177 178
                            real scaleT,
                            bool aTrans,
                            bool bTrans,
                            bool cTrans) {
X
xutianbing 已提交
179 180 181
  if (scaleT == 0) {
    out.zeroMem();
  }
182 183 184 185 186 187
  real* A = const_cast<real*>(a.getData());
  real* B = const_cast<real*>(b.getValue());
  real* C = out.getData();
  int* rows = b.getRows();
  int* cols = b.getCols();

188
  /// SPARSE_CSC format
189
  if (b.getFormat() == SPARSE_CSC) {
X
xutianbing 已提交
190 191 192 193
    for (size_t j = 0; j < b.getWidth(); ++j) {
      int start = b.getColStartIdx(j);
      int end = b.getColStartIdx(j + 1);
      for (int i = start; i < end; ++i) {
X
xutianbing 已提交
194 195
        colVecAddTo(!bTrans ? C + j : C + rows[i],
                    !bTrans ? A + rows[i] : A + j,
X
xutianbing 已提交
196 197 198 199
                    (b.getValueType() == NO_VALUE) ? (real)1.0 : B[i],
                    out.getHeight(),
                    out.getWidth(),
                    a.getWidth());
200 201
      }
    }
X
xutianbing 已提交
202 203 204
    return;
  }

205
  /// SPARSE_CSR format
X
xutianbing 已提交
206 207 208 209 210
  if (b.getFormat() == SPARSE_CSR) {
    for (size_t j = 0; j < b.getHeight(); ++j) {
      int start = b.getRowStartIdx(j);
      int end = b.getRowStartIdx(j + 1);
      for (int i = start; i < end; ++i) {
X
xutianbing 已提交
211 212
        colVecAddTo(!bTrans ? C + cols[i] : C + j,
                    !bTrans ? A + j : A + cols[i],
X
xutianbing 已提交
213 214 215 216
                    (b.getValueType() == NO_VALUE) ? (real)1.0 : B[i],
                    out.getHeight(),
                    out.getWidth(),
                    a.getWidth());
217 218
      }
    }
X
xutianbing 已提交
219
    return;
220 221
  }
}
222 223 224

/**
 * mul operator
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
 * out = scaleT * out + scaleAB * (in1 * in2)
 * here, scaleT in {0, 1}, scaleAB == 1,
 * out = in1 (A) * in2 (B), ASSIGN_TO
 * out += in1 (A) * in2 (B), ADD_TO
 *
 *
 * \param outputs[0]      output matrix (out), M * N,
 *                        could be either Sparse or Dense Matrix
 *                        M is num of rows, N is num of columns
 * \param inputs[0]       first input matrix (A),  M * K (if non-trans)
 *                        could be either Sparse or Dense Matrix
 *                        M is num of rows, K is num of columns
 * \param inputs[1]       second input matrix (B), K * N (if non-trans)
 *                        could be either Sparse or Dense Matrix
 *                        K is num of rows, N is num of columns
 *
 * Support eight Mul operators, with both GPU and CPU devices
 * For each device, four Mul operators are supported:
 * 1. dense (out) = dense (A) * dense (B)
 * 2. dense (out) = sparse (A) * dense (B)
 *    sparse matrix only support SPARSE_CSR format
 * 3. dense (out) = dense (A) * sparse (B)
 *    sparse matrix support SPARSE_CSC and SPARSE_CSR formats
 * 4. sparse (out) = dense (A) * dense (B)
 *    sparse matrix support SPARSE_CSC and SPARSE_CSR formats
250 251 252 253 254 255
 *
 */
template <DeviceType Device>
class MulFunc : public FunctionBase {
public:
  void init(const FuncConfig& config) override {
256 257
    alpha_ = config.get<real>("scaleAB");
    beta_ = config.get<real>("scaleT");
X
xutianbing 已提交
258 259 260
    aTrans_ = config.get<bool>("aTrans");
    bTrans_ = config.get<bool>("bTrans");
    cTrans_ = config.get<bool>("cTrans");
261 262 263
  }

  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
X
xutianbing 已提交
264 265 266 267
    CHECK(!cTrans_) << "output matrix should not be transposed";
    CHECK(!aTrans_ || !bTrans_)
        << "Not support both a and b are transpose matrices";

268 269
    CHECK_EQ((size_t)2, inputs.size());
    CHECK_EQ((size_t)1, outputs.size());
270 271 272 273
    CHECK(inputs[0].data() && inputs[1].data() && outputs[0].data());
    CHECK_EQ(inputs[0].shape().ndims(), (size_t)2);
    CHECK_EQ(inputs[1].shape().ndims(), (size_t)2);
    CHECK_EQ(outputs[0].shape().ndims(), (size_t)2);
X
xutianbing 已提交
274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294

    size_t aRow = !aTrans_ ? inputs[0].shape()[0] : inputs[0].shape()[1];
    size_t aCol = !aTrans_ ? inputs[0].shape()[1] : inputs[0].shape()[0];
    size_t bRow = !bTrans_ ? inputs[1].shape()[0] : inputs[1].shape()[1];
    size_t bCol = !bTrans_ ? inputs[1].shape()[1] : inputs[1].shape()[0];
    /// C = A * B, or C += A * B, for matrix format
    CHECK_EQ(aCol, bRow);
    CHECK_EQ(aRow, outputs[0].shape()[0]);
    CHECK_EQ(bCol, outputs[0].shape()[1]);

    /// only support C = A * B or C += A * B
    CHECK_EQ(alpha_, static_cast<real>(1.0));
    CHECK((beta_ == 0 && outputs[0].getArgType() == ASSIGN_TO) ||
          (beta_ == 1 && outputs[0].getArgType() == ADD_TO));

    /// support dense = not both sparse * sparse
    /// or sparse = dense * dense
    CHECK((!outputs[0].isSparseArg() &&
           !(inputs[0].isSparseArg() && inputs[1].isSparseArg())) ||
          (outputs[0].isSparseArg() && !inputs[0].isSparseArg() &&
           !inputs[1].isSparseArg()));
295

X
xutianbing 已提交
296
    auto outMat = outputs[0].matrix<Device>();
297
    /// dense matrix = dense matrix * dense matrix
298 299
    if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
        !outputs[0].isSparseArg()) {
X
xutianbing 已提交
300
      MulOp<Device>(outMat,
301 302 303
                    inputs[0].matrix<Device>(),
                    inputs[1].matrix<Device>(),
                    alpha_,
X
xutianbing 已提交
304 305 306 307
                    beta_,
                    aTrans_,
                    bTrans_,
                    cTrans_);
308
      return;
309
    }
310

311
    /// dense matrix = dense matrix * sparse matrix
312 313
    if (!inputs[0].isSparseArg() && inputs[1].isSparseArg() &&
        !outputs[0].isSparseArg()) {
X
xutianbing 已提交
314
      CHECK(!aTrans_) << "Not supported a transpose";
X
xutianbing 已提交
315
      MulOp<Device>(outMat,
316 317 318
                    inputs[0].matrix<Device>(),
                    inputs[1].sparse().SparseMatrix<Device>(),
                    alpha_,
X
xutianbing 已提交
319 320 321 322
                    beta_,
                    aTrans_,
                    bTrans_,
                    cTrans_);
323 324 325
      return;
    }

326
    /// dense matrix = sparse matrix * dense matrix
327 328
    if (inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
        !outputs[0].isSparseArg()) {
X
xutianbing 已提交
329
      CHECK(!bTrans_) << "Not supported b transpose";
330 331
      CHECK_EQ(inputs[0].sparse().dataFormat(), T_SPARSE_CSR)
          << "Only supported SPARSE_CSR format for sparse matrix a";
X
xutianbing 已提交
332
      MulOp<Device>(outMat,
333 334 335
                    inputs[0].sparse().SparseMatrix<Device>(),
                    inputs[1].matrix<Device>(),
                    alpha_,
X
xutianbing 已提交
336 337 338 339
                    beta_,
                    aTrans_,
                    bTrans_,
                    cTrans_);
340
      return;
341
    }
342

343
    /// sparse matrix = dense matrix * dense matrix
X
xutianbing 已提交
344
    auto outSparseMat = outputs[0].sparse().SparseMatrix<Device>();
345 346
    if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
        outputs[0].isSparseArg()) {
X
xutianbing 已提交
347
      MulOp<Device>(outSparseMat,
348 349 350
                    inputs[0].matrix<Device>(),
                    inputs[1].matrix<Device>(),
                    alpha_,
X
xutianbing 已提交
351 352 353 354
                    beta_,
                    aTrans_,
                    bTrans_,
                    cTrans_);
355 356
      return;
    }
357 358 359
  }

private:
360 361
  real alpha_;
  real beta_;
X
xutianbing 已提交
362 363 364
  bool aTrans_;
  bool bTrans_;
  bool cTrans_;
365 366
};

367
REGISTER_TYPED_FUNC(MulOp, CPU, MulFunc);
368 369 370 371
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(MulOp, GPU, MulFunc);
#endif
}  // namespace paddle