MulOp.cpp 11.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "MulOp.h"
16
#include "GemmFunctor.h"
17 18 19 20 21 22
#include "paddle/math/SIMDFunctions.h"
#include "paddle/utils/ThreadLocal.h"

namespace {
inline void vecAddTo(real* a, const real* b, real scaleB, size_t len) {
  for (unsigned int i = 0; i < len; ++i) {
X
xutianbing 已提交
23
    a[i] += (1.0 == scaleB) ? b[i] : scaleB * b[i];
24 25 26 27 28 29
  }
}

inline void colVecAddTo(
    real* a, real* b, real c, size_t len, size_t aWidth, size_t bWidth) {
  for (unsigned int i = 0; i < len; ++i) {
X
xutianbing 已提交
30
    a[i * aWidth] += (1.0 == c) ? b[i * bWidth] : b[i * bWidth] * c;
31 32 33
  }
}
}  // namespace
34 35

namespace paddle {
36
/// sparse matrix (+)= dense matrix * dense matrix
37 38 39 40 41
template <>
void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
                            const CpuMatrix& a,
                            const CpuMatrix& b,
                            real scaleAB,
X
xutianbing 已提交
42 43
                            real scaleT,
                            bool aTrans,
X
xutianbing 已提交
44
                            bool bTrans) {
45
  CHECK_EQ(out.getValueType(), FLOAT_VALUE);
X
xutianbing 已提交
46 47 48
  if (scaleT == 0) {
    out.zeroMem();
  }
49 50 51 52 53
  const real* A = a.getData();
  const real* B = b.getData();
  real* C = out.getValue();
  int* rows = out.getRows();
  int* cols = out.getCols();
X
xutianbing 已提交
54 55
  size_t width = out.getWidth();
  size_t height = out.getHeight();
56

X
xutianbing 已提交
57 58 59
  /// SPARSE_CSC, {a any, b not trans}
  if (out.getFormat() == SPARSE_CSC) {
    /// b not trans and a any
X
xutianbing 已提交
60 61
    CHECK(!bTrans);
    size_t m = !aTrans ? a.getWidth() : a.getHeight();
X
xutianbing 已提交
62 63 64 65 66 67 68
    for (size_t i = 0; i < width; i++) {
      size_t start = out.getColStartIdx(i);
      size_t end = out.getColStartIdx(i + 1);
      for (size_t j = start; j < end; j++) {
        real sum = 0;
        size_t rowIdx = rows[j];
        for (size_t k = 0; k < m; k++) {
X
xutianbing 已提交
69 70
          sum += (!aTrans ? A[rowIdx * m + k] : A[k * height + rowIdx]) *
                 B[k * width + i];
71
        }
X
xutianbing 已提交
72
        C[j] = scaleAB * sum + scaleT * C[j];
73 74
      }
    }
X
xutianbing 已提交
75 76 77
    return;
  }

X
xutianbing 已提交
78 79 80
  /// SPARSE_CSR, {a any, b not trans} or {a not trans, b trans}
  if (out.getFormat() == SPARSE_CSR) {
    /// a and b can not both transpose
X
xutianbing 已提交
81
    CHECK(!(aTrans && bTrans));
82
    size_t m = a.getWidth();
X
xutianbing 已提交
83 84 85 86 87 88 89
    for (size_t i = 0; i < height; i++) {
      size_t start = out.getRowStartIdx(i);
      size_t end = out.getRowStartIdx(i + 1);
      for (size_t j = start; j < end; j++) {
        real sum = 0;
        size_t colIdx = cols[j];
        for (size_t k = 0; k < m; k++) {
X
xutianbing 已提交
90 91
          sum += (!aTrans ? A[i * m + k] : A[k * height + i]) *
                 (!bTrans ? B[k * width + colIdx] : B[colIdx * m + k]);
92
        }
X
xutianbing 已提交
93
        C[j] = scaleAB * sum + scaleT * C[j];
94 95
      }
    }
X
xutianbing 已提交
96
    return;
97 98 99
  }
}

100
/// dense matrix (+)= dense matrix * dense matrix
101 102 103 104 105
template <>
void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
                            const CpuMatrix& a,
                            const CpuMatrix& b,
                            real scaleAB,
X
xutianbing 已提交
106 107
                            real scaleT,
                            bool aTrans,
X
xutianbing 已提交
108
                            bool bTrans) {
109 110 111 112 113 114 115 116 117 118 119 120 121 122
  BlasGemm<DEVICE_TYPE_CPU, real>::compute(
      aTrans,
      bTrans,
      out.getHeight(),
      out.getWidth(),
      !aTrans ? a.getWidth() : a.getHeight(),
      scaleAB,
      a.getData(),
      a.getStride(),
      b.getData(),
      b.getStride(),
      scaleT,
      out.getData(),
      out.getStride());
123 124
}

125
/// dense matrix (+)= sparse matrix * dense matrix
126 127 128 129 130
template <>
void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
                            const CpuSparseMatrix& a,
                            const CpuMatrix& b,
                            real scaleAB,
X
xutianbing 已提交
131 132
                            real scaleT,
                            bool aTrans,
X
xutianbing 已提交
133
                            bool bTrans) {
134 135 136
  if (scaleT == 0) {
    out.zeroMem();
  }
X
xutianbing 已提交
137 138 139 140 141 142
  const real* B = b.getData();
  real* C = out.getData();
  if (out.getWidth() % 32 == 0) {
    CHECK_EQ((size_t)B % 32, 0UL);
    CHECK_EQ((size_t)C % 32, 0UL);
  }
143

X
xutianbing 已提交
144 145 146 147 148 149
  int* cols = a.getCols();
  real* values = a.getValue();
  for (size_t i = 0; i < a.getHeight(); ++i) {
    const int start = a.getRowStartIdx(i);
    const int end = a.getRowStartIdx(i + 1);
    for (int j = start; j < end; ++j) {
X
xutianbing 已提交
150 151 152
      vecAddTo(!aTrans ? out.getRow(i) : out.getRow(cols[j]),
               !aTrans ? const_cast<CpuMatrix&>(b).getRow(cols[j])
                       : const_cast<CpuMatrix&>(b).getRow(i),
X
xutianbing 已提交
153 154
               (a.getValueType() == FLOAT_VALUE) ? values[j] : (real)1.0,
               out.getWidth());
155 156 157 158
    }
  }
}

159
/// dense matrix (+)= dense matrix * sparse matrix
160 161 162 163 164
template <>
void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
                            const CpuMatrix& a,
                            const CpuSparseMatrix& b,
                            real scaleAB,
X
xutianbing 已提交
165 166
                            real scaleT,
                            bool aTrans,
X
xutianbing 已提交
167
                            bool bTrans) {
X
xutianbing 已提交
168 169 170
  if (scaleT == 0) {
    out.zeroMem();
  }
171 172 173 174 175 176
  real* A = const_cast<real*>(a.getData());
  real* B = const_cast<real*>(b.getValue());
  real* C = out.getData();
  int* rows = b.getRows();
  int* cols = b.getCols();

177
  /// SPARSE_CSC format
178
  if (b.getFormat() == SPARSE_CSC) {
X
xutianbing 已提交
179 180 181 182
    for (size_t j = 0; j < b.getWidth(); ++j) {
      int start = b.getColStartIdx(j);
      int end = b.getColStartIdx(j + 1);
      for (int i = start; i < end; ++i) {
X
xutianbing 已提交
183 184
        colVecAddTo(!bTrans ? C + j : C + rows[i],
                    !bTrans ? A + rows[i] : A + j,
X
xutianbing 已提交
185 186 187 188
                    (b.getValueType() == NO_VALUE) ? (real)1.0 : B[i],
                    out.getHeight(),
                    out.getWidth(),
                    a.getWidth());
189 190
      }
    }
X
xutianbing 已提交
191 192 193
    return;
  }

194
  /// SPARSE_CSR format
X
xutianbing 已提交
195 196 197 198 199
  if (b.getFormat() == SPARSE_CSR) {
    for (size_t j = 0; j < b.getHeight(); ++j) {
      int start = b.getRowStartIdx(j);
      int end = b.getRowStartIdx(j + 1);
      for (int i = start; i < end; ++i) {
X
xutianbing 已提交
200 201
        colVecAddTo(!bTrans ? C + cols[i] : C + j,
                    !bTrans ? A + j : A + cols[i],
X
xutianbing 已提交
202 203 204 205
                    (b.getValueType() == NO_VALUE) ? (real)1.0 : B[i],
                    out.getHeight(),
                    out.getWidth(),
                    a.getWidth());
206 207
      }
    }
X
xutianbing 已提交
208
    return;
209 210
  }
}
211 212 213

/**
 * mul operator
X
xutianbing 已提交
214
 * out = scaleT * out + scaleAB * (A * B)
215
 * here, scaleT in {0, 1}, scaleAB == 1,
X
xutianbing 已提交
216 217
 * out = A * B, ASSIGN_TO
 * out += A * B, ADD_TO
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
 *
 *
 * \param outputs[0]      output matrix (out), M * N,
 *                        could be either Sparse or Dense Matrix
 *                        M is num of rows, N is num of columns
 * \param inputs[0]       first input matrix (A),  M * K (if non-trans)
 *                        could be either Sparse or Dense Matrix
 *                        M is num of rows, K is num of columns
 * \param inputs[1]       second input matrix (B), K * N (if non-trans)
 *                        could be either Sparse or Dense Matrix
 *                        K is num of rows, N is num of columns
 *
 * Support eight Mul operators, with both GPU and CPU devices
 * For each device, four Mul operators are supported:
 * 1. dense (out) = dense (A) * dense (B)
 * 2. dense (out) = sparse (A) * dense (B)
 *    sparse matrix only support SPARSE_CSR format
 * 3. dense (out) = dense (A) * sparse (B)
 *    sparse matrix support SPARSE_CSC and SPARSE_CSR formats
 * 4. sparse (out) = dense (A) * dense (B)
 *    sparse matrix support SPARSE_CSC and SPARSE_CSR formats
239 240 241 242 243 244
 *
 */
template <DeviceType Device>
class MulFunc : public FunctionBase {
public:
  void init(const FuncConfig& config) override {
X
xutianbing 已提交
245 246
    aTrans_ = config.get<bool>("aTrans");
    bTrans_ = config.get<bool>("bTrans");
247 248 249
  }

  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
X
xutianbing 已提交
250 251 252
    CHECK(!aTrans_ || !bTrans_)
        << "Not support both a and b are transpose matrices";

253 254
    CHECK_EQ((size_t)2, inputs.size());
    CHECK_EQ((size_t)1, outputs.size());
255 256 257 258
    CHECK(inputs[0].data() && inputs[1].data() && outputs[0].data());
    CHECK_EQ(inputs[0].shape().ndims(), (size_t)2);
    CHECK_EQ(inputs[1].shape().ndims(), (size_t)2);
    CHECK_EQ(outputs[0].shape().ndims(), (size_t)2);
X
xutianbing 已提交
259 260 261 262 263 264 265 266 267 268

    size_t aRow = !aTrans_ ? inputs[0].shape()[0] : inputs[0].shape()[1];
    size_t aCol = !aTrans_ ? inputs[0].shape()[1] : inputs[0].shape()[0];
    size_t bRow = !bTrans_ ? inputs[1].shape()[0] : inputs[1].shape()[1];
    size_t bCol = !bTrans_ ? inputs[1].shape()[1] : inputs[1].shape()[0];
    /// C = A * B, or C += A * B, for matrix format
    CHECK_EQ(aCol, bRow);
    CHECK_EQ(aRow, outputs[0].shape()[0]);
    CHECK_EQ(bCol, outputs[0].shape()[1]);

X
xutianbing 已提交
269 270
    /// only support C = A * B (ASSIGN_TO) or C += A * B (ADD_TO)
    real scaleT = (outputs[0].getArgType() == ADD_TO) ? 1.0 : 0.0;
X
xutianbing 已提交
271 272 273 274 275 276 277

    /// support dense = not both sparse * sparse
    /// or sparse = dense * dense
    CHECK((!outputs[0].isSparseArg() &&
           !(inputs[0].isSparseArg() && inputs[1].isSparseArg())) ||
          (outputs[0].isSparseArg() && !inputs[0].isSparseArg() &&
           !inputs[1].isSparseArg()));
278

X
xutianbing 已提交
279
    auto outMat = outputs[0].matrix<Device>();
280
    /// dense matrix = dense matrix * dense matrix
281 282
    if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
        !outputs[0].isSparseArg()) {
X
xutianbing 已提交
283
      MulOp<Device>(outMat,
284 285
                    inputs[0].matrix<Device>(),
                    inputs[1].matrix<Device>(),
X
xutianbing 已提交
286 287
                    1.0,  // scaleAB
                    scaleT,
X
xutianbing 已提交
288
                    aTrans_,
X
xutianbing 已提交
289
                    bTrans_);
290
      return;
291
    }
292

293
    /// dense matrix = dense matrix * sparse matrix
294 295
    if (!inputs[0].isSparseArg() && inputs[1].isSparseArg() &&
        !outputs[0].isSparseArg()) {
X
xutianbing 已提交
296
      CHECK(!aTrans_) << "Not supported a transpose";
X
xutianbing 已提交
297
      MulOp<Device>(outMat,
298 299
                    inputs[0].matrix<Device>(),
                    inputs[1].sparse().SparseMatrix<Device>(),
X
xutianbing 已提交
300 301
                    1.0,  // scaleAB
                    scaleT,
X
xutianbing 已提交
302
                    aTrans_,
X
xutianbing 已提交
303
                    bTrans_);
304 305 306
      return;
    }

307
    /// dense matrix = sparse matrix * dense matrix
308 309
    if (inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
        !outputs[0].isSparseArg()) {
X
xutianbing 已提交
310
      CHECK(!bTrans_) << "Not supported b transpose";
311 312
      CHECK_EQ(inputs[0].sparse().dataFormat(), T_SPARSE_CSR)
          << "Only supported SPARSE_CSR format for sparse matrix a";
X
xutianbing 已提交
313
      MulOp<Device>(outMat,
314 315
                    inputs[0].sparse().SparseMatrix<Device>(),
                    inputs[1].matrix<Device>(),
X
xutianbing 已提交
316 317
                    1.0,  // scaleAB
                    scaleT,
X
xutianbing 已提交
318
                    aTrans_,
X
xutianbing 已提交
319
                    bTrans_);
320
      return;
321
    }
322

323
    /// sparse matrix = dense matrix * dense matrix
X
xutianbing 已提交
324
    auto outSparseMat = outputs[0].sparse().SparseMatrix<Device>();
325 326
    if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
        outputs[0].isSparseArg()) {
X
xutianbing 已提交
327
      MulOp<Device>(outSparseMat,
328 329
                    inputs[0].matrix<Device>(),
                    inputs[1].matrix<Device>(),
X
xutianbing 已提交
330 331
                    1.0,  // scaleAB
                    scaleT,
X
xutianbing 已提交
332
                    aTrans_,
X
xutianbing 已提交
333
                    bTrans_);
334 335
      return;
    }
336 337 338
  }

private:
X
xutianbing 已提交
339 340
  bool aTrans_;
  bool bTrans_;
341 342
};

343
REGISTER_TYPED_FUNC(MulOp, CPU, MulFunc);
344 345 346 347
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(MulOp, GPU, MulFunc);
#endif
}  // namespace paddle