MulOp.cpp 11.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "MulOp.h"
16 17
/// todo(tianbing), delete it
#include <iostream>
18 19 20 21 22 23 24 25 26 27 28 29 30
#include "paddle/math/MathFunctions.h"
#include "paddle/math/SIMDFunctions.h"
#include "paddle/utils/ThreadLocal.h"

#ifndef PADDLE_TYPE_DOUBLE
#define GEMM paddle::gemm<float>
#else
#define GEMM paddle::gemm<double>
#endif

namespace {
inline void vecAddTo(real* a, const real* b, real scaleB, size_t len) {
  for (unsigned int i = 0; i < len; ++i) {
X
xutianbing 已提交
31
    a[i] += (1.0 == scaleB) ? b[i] : scaleB * b[i];
32 33 34 35 36 37
  }
}

inline void colVecAddTo(
    real* a, real* b, real c, size_t len, size_t aWidth, size_t bWidth) {
  for (unsigned int i = 0; i < len; ++i) {
X
xutianbing 已提交
38
    a[i * aWidth] += (1.0 == c) ? b[i * bWidth] : b[i * bWidth] * c;
39 40 41
  }
}
}  // namespace
42 43

namespace paddle {
44 45 46 47 48
template <>
void MulOp<DEVICE_TYPE_CPU>(CpuSparseMatrix& out,
                            const CpuMatrix& a,
                            const CpuMatrix& b,
                            real scaleAB,
X
xutianbing 已提交
49 50 51 52
                            real scaleT,
                            bool aTrans,
                            bool bTrans,
                            bool cTrans) {
53
  CHECK_EQ(out.getValueType(), FLOAT_VALUE);
X
xutianbing 已提交
54 55 56
  if (scaleT == 0) {
    out.zeroMem();
  }
57 58 59 60 61
  const real* A = a.getData();
  const real* B = b.getData();
  real* C = out.getValue();
  int* rows = out.getRows();
  int* cols = out.getCols();
X
xutianbing 已提交
62 63
  size_t width = out.getWidth();
  size_t height = out.getHeight();
64

X
xutianbing 已提交
65 66 67
  /// SPARSE_CSC, {a any, b not trans}
  if (out.getFormat() == SPARSE_CSC) {
    /// b not trans and a any
X
xutianbing 已提交
68 69
    CHECK(!bTrans);
    size_t m = !aTrans ? a.getWidth() : a.getHeight();
X
xutianbing 已提交
70 71 72 73 74 75 76
    for (size_t i = 0; i < width; i++) {
      size_t start = out.getColStartIdx(i);
      size_t end = out.getColStartIdx(i + 1);
      for (size_t j = start; j < end; j++) {
        real sum = 0;
        size_t rowIdx = rows[j];
        for (size_t k = 0; k < m; k++) {
X
xutianbing 已提交
77 78
          sum += (!aTrans ? A[rowIdx * m + k] : A[k * height + rowIdx]) *
                 B[k * width + i];
79
        }
X
xutianbing 已提交
80
        C[j] = scaleAB * sum + scaleT * C[j];
81 82
      }
    }
X
xutianbing 已提交
83 84 85
    return;
  }

X
xutianbing 已提交
86 87 88
  /// SPARSE_CSR, {a any, b not trans} or {a not trans, b trans}
  if (out.getFormat() == SPARSE_CSR) {
    /// a and b can not both transpose
X
xutianbing 已提交
89
    CHECK(!(aTrans && bTrans));
90
    size_t m = a.getWidth();
X
xutianbing 已提交
91 92 93 94 95 96 97
    for (size_t i = 0; i < height; i++) {
      size_t start = out.getRowStartIdx(i);
      size_t end = out.getRowStartIdx(i + 1);
      for (size_t j = start; j < end; j++) {
        real sum = 0;
        size_t colIdx = cols[j];
        for (size_t k = 0; k < m; k++) {
X
xutianbing 已提交
98 99
          sum += (!aTrans ? A[i * m + k] : A[k * height + i]) *
                 (!bTrans ? B[k * width + colIdx] : B[colIdx * m + k]);
100
        }
X
xutianbing 已提交
101
        C[j] = scaleAB * sum + scaleT * C[j];
102 103
      }
    }
X
xutianbing 已提交
104
    return;
105 106 107 108 109 110 111 112
  }
}

template <>
void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
                            const CpuMatrix& a,
                            const CpuMatrix& b,
                            real scaleAB,
X
xutianbing 已提交
113 114 115 116 117 118
                            real scaleT,
                            bool aTrans,
                            bool bTrans,
                            bool cTrans) {
  GEMM(aTrans ? CblasTrans : CblasNoTrans,
       bTrans ? CblasTrans : CblasNoTrans,
X
xutianbing 已提交
119 120
       out.getHeight(),
       out.getWidth(),
X
xutianbing 已提交
121
       !aTrans ? a.getWidth() : a.getHeight(),
X
xutianbing 已提交
122 123 124 125 126 127 128 129
       scaleAB,
       a.getData(),
       a.getStride(),
       b.getData(),
       b.getStride(),
       scaleT,
       out.getData(),
       out.getStride());
130 131 132 133 134 135 136
}

template <>
void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
                            const CpuSparseMatrix& a,
                            const CpuMatrix& b,
                            real scaleAB,
X
xutianbing 已提交
137 138 139 140 141 142
                            real scaleT,
                            bool aTrans,
                            bool bTrans,
                            bool cTrans) {
  CHECK_EQ(a.getFormat(), SPARSE_CSR)
      << "Not supported SPARSE_CSR format for a";
143 144 145
  if (scaleT == 0) {
    out.zeroMem();
  }
X
xutianbing 已提交
146 147 148 149 150 151
  const real* B = b.getData();
  real* C = out.getData();
  if (out.getWidth() % 32 == 0) {
    CHECK_EQ((size_t)B % 32, 0UL);
    CHECK_EQ((size_t)C % 32, 0UL);
  }
152

X
xutianbing 已提交
153 154 155 156 157 158
  int* cols = a.getCols();
  real* values = a.getValue();
  for (size_t i = 0; i < a.getHeight(); ++i) {
    const int start = a.getRowStartIdx(i);
    const int end = a.getRowStartIdx(i + 1);
    for (int j = start; j < end; ++j) {
X
xutianbing 已提交
159 160 161
      vecAddTo(!aTrans ? out.getRow(i) : out.getRow(cols[j]),
               !aTrans ? const_cast<CpuMatrix&>(b).getRow(cols[j])
                       : const_cast<CpuMatrix&>(b).getRow(i),
X
xutianbing 已提交
162 163
               (a.getValueType() == FLOAT_VALUE) ? values[j] : (real)1.0,
               out.getWidth());
164 165 166 167 168 169 170 171 172
    }
  }
}

template <>
void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
                            const CpuMatrix& a,
                            const CpuSparseMatrix& b,
                            real scaleAB,
X
xutianbing 已提交
173 174 175 176
                            real scaleT,
                            bool aTrans,
                            bool bTrans,
                            bool cTrans) {
X
xutianbing 已提交
177 178 179
  if (scaleT == 0) {
    out.zeroMem();
  }
180 181 182 183 184 185
  real* A = const_cast<real*>(a.getData());
  real* B = const_cast<real*>(b.getValue());
  real* C = out.getData();
  int* rows = b.getRows();
  int* cols = b.getCols();

X
xutianbing 已提交
186
  /// b.getFormat() == SPARSE_CSC
187
  if (b.getFormat() == SPARSE_CSC) {
X
xutianbing 已提交
188 189 190 191
    for (size_t j = 0; j < b.getWidth(); ++j) {
      int start = b.getColStartIdx(j);
      int end = b.getColStartIdx(j + 1);
      for (int i = start; i < end; ++i) {
X
xutianbing 已提交
192 193
        colVecAddTo(!bTrans ? C + j : C + rows[i],
                    !bTrans ? A + rows[i] : A + j,
X
xutianbing 已提交
194 195 196 197
                    (b.getValueType() == NO_VALUE) ? (real)1.0 : B[i],
                    out.getHeight(),
                    out.getWidth(),
                    a.getWidth());
198 199
      }
    }
X
xutianbing 已提交
200 201 202 203 204 205 206 207 208
    return;
  }

  /// b.getFormat() == SPARSE_CSR
  if (b.getFormat() == SPARSE_CSR) {
    for (size_t j = 0; j < b.getHeight(); ++j) {
      int start = b.getRowStartIdx(j);
      int end = b.getRowStartIdx(j + 1);
      for (int i = start; i < end; ++i) {
X
xutianbing 已提交
209 210
        colVecAddTo(!bTrans ? C + cols[i] : C + j,
                    !bTrans ? A + j : A + cols[i],
X
xutianbing 已提交
211 212 213 214
                    (b.getValueType() == NO_VALUE) ? (real)1.0 : B[i],
                    out.getHeight(),
                    out.getWidth(),
                    a.getWidth());
215 216
      }
    }
X
xutianbing 已提交
217
    return;
218 219
  }
}
220 221 222 223 224

/**
 * mul operator
 * out = scaleT * out + scaleAB*(in1 * in2)
 *
225 226 227
 * \param outputs[0]      output matrix, M * N
 * \param inputs[0]       first input (sparse) matrix,  M * K (if non-trans)
 * \param inputs[1]       second input matrix, K * N (if non-trans)
228 229 230 231 232
 */
template <DeviceType Device>
class MulFunc : public FunctionBase {
public:
  void init(const FuncConfig& config) override {
233 234
    alpha_ = config.get<real>("scaleAB");
    beta_ = config.get<real>("scaleT");
X
xutianbing 已提交
235 236 237
    aTrans_ = config.get<bool>("aTrans");
    bTrans_ = config.get<bool>("bTrans");
    cTrans_ = config.get<bool>("cTrans");
238 239 240
  }

  void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
X
xutianbing 已提交
241 242 243 244
    CHECK(!cTrans_) << "output matrix should not be transposed";
    CHECK(!aTrans_ || !bTrans_)
        << "Not support both a and b are transpose matrices";

245 246
    CHECK_EQ((size_t)2, inputs.size());
    CHECK_EQ((size_t)1, outputs.size());
247 248 249 250
    CHECK(inputs[0].data() && inputs[1].data() && outputs[0].data());
    CHECK_EQ(inputs[0].shape().ndims(), (size_t)2);
    CHECK_EQ(inputs[1].shape().ndims(), (size_t)2);
    CHECK_EQ(outputs[0].shape().ndims(), (size_t)2);
X
xutianbing 已提交
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271

    size_t aRow = !aTrans_ ? inputs[0].shape()[0] : inputs[0].shape()[1];
    size_t aCol = !aTrans_ ? inputs[0].shape()[1] : inputs[0].shape()[0];
    size_t bRow = !bTrans_ ? inputs[1].shape()[0] : inputs[1].shape()[1];
    size_t bCol = !bTrans_ ? inputs[1].shape()[1] : inputs[1].shape()[0];
    /// C = A * B, or C += A * B, for matrix format
    CHECK_EQ(aCol, bRow);
    CHECK_EQ(aRow, outputs[0].shape()[0]);
    CHECK_EQ(bCol, outputs[0].shape()[1]);

    /// only support C = A * B or C += A * B
    CHECK_EQ(alpha_, static_cast<real>(1.0));
    CHECK((beta_ == 0 && outputs[0].getArgType() == ASSIGN_TO) ||
          (beta_ == 1 && outputs[0].getArgType() == ADD_TO));

    /// support dense = not both sparse * sparse
    /// or sparse = dense * dense
    CHECK((!outputs[0].isSparseArg() &&
           !(inputs[0].isSparseArg() && inputs[1].isSparseArg())) ||
          (outputs[0].isSparseArg() && !inputs[0].isSparseArg() &&
           !inputs[1].isSparseArg()));
272

X
xutianbing 已提交
273
    auto outMat = outputs[0].matrix<Device>();
274 275 276
    /// matrix = matrix * matrix
    if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
        !outputs[0].isSparseArg()) {
X
xutianbing 已提交
277
      MulOp<Device>(outMat,
278 279 280
                    inputs[0].matrix<Device>(),
                    inputs[1].matrix<Device>(),
                    alpha_,
X
xutianbing 已提交
281 282 283 284
                    beta_,
                    aTrans_,
                    bTrans_,
                    cTrans_);
285
      return;
286
    }
287

288 289 290
    /// matrix = matrix * sparse matrix
    if (!inputs[0].isSparseArg() && inputs[1].isSparseArg() &&
        !outputs[0].isSparseArg()) {
X
xutianbing 已提交
291
      CHECK(!aTrans_) << "Not supported a transpose";
X
xutianbing 已提交
292
      MulOp<Device>(outMat,
293 294 295
                    inputs[0].matrix<Device>(),
                    inputs[1].sparse().SparseMatrix<Device>(),
                    alpha_,
X
xutianbing 已提交
296 297 298 299
                    beta_,
                    aTrans_,
                    bTrans_,
                    cTrans_);
300 301 302
      return;
    }

303 304 305
    /// matrix = sparse matrix * matrix
    if (inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
        !outputs[0].isSparseArg()) {
X
xutianbing 已提交
306
      CHECK(!bTrans_) << "Not supported b transpose";
X
xutianbing 已提交
307
      MulOp<Device>(outMat,
308 309 310
                    inputs[0].sparse().SparseMatrix<Device>(),
                    inputs[1].matrix<Device>(),
                    alpha_,
X
xutianbing 已提交
311 312 313 314
                    beta_,
                    aTrans_,
                    bTrans_,
                    cTrans_);
315
      return;
316
    }
317 318

    /// sparse matrix = matrix * matrix
X
xutianbing 已提交
319
    auto outSparseMat = outputs[0].sparse().SparseMatrix<Device>();
320 321
    if (!inputs[0].isSparseArg() && !inputs[1].isSparseArg() &&
        outputs[0].isSparseArg()) {
X
xutianbing 已提交
322
      MulOp<Device>(outSparseMat,
323 324 325
                    inputs[0].matrix<Device>(),
                    inputs[1].matrix<Device>(),
                    alpha_,
X
xutianbing 已提交
326 327 328 329
                    beta_,
                    aTrans_,
                    bTrans_,
                    cTrans_);
330 331
      return;
    }
332 333 334
  }

private:
335 336
  real alpha_;
  real beta_;
X
xutianbing 已提交
337 338 339
  bool aTrans_;
  bool bTrans_;
  bool cTrans_;
340 341
};

342
REGISTER_TYPED_FUNC(MulOp, CPU, MulFunc);
343 344 345 346
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(MulOp, GPU, MulFunc);
#endif
}  // namespace paddle