MulOpGpu.cu 7.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "hl_base.h"
#include "MulOp.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/SparseMatrix.h"

namespace paddle {
/**
X
xutianbing 已提交
22
 * out = scaleT * out + scaleAB * (a * b)
23 24 25 26 27 28
 * out : output matrix, M * N
 */
template <>
void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
                            const GpuMatrix& a,
                            const GpuMatrix& b,
X
xutianbing 已提交
29 30 31
                            real scaleAB,
                            real scaleT) {
  CHECK(!out.isTransposed()) << "Transpose not supported for out matrix";
32
  if (!a.isTransposed() && !b.isTransposed()) {
X
xutianbing 已提交
33 34 35 36
      /// a : M * K, b: K * N
      CHECK(out.getWidth() == b.getWidth() &&
              out.getHeight() == a.getHeight() &&
              a.getWidth() == b.getHeight());
37
  } else if (a.isTransposed() && !b.isTransposed()) {
X
xutianbing 已提交
38 39 40 41
      /// a : K * M, b : K * N
      CHECK(out.getWidth() == b.getWidth() &&
              out.getHeight() == a.getWidth() &&
              a.getHeight() == b.getHeight());
42
  } else if (!a.isTransposed() && b.isTransposed()) {
X
xutianbing 已提交
43 44 45 46
      /// a: M * K, b : N * K
      CHECK(out.getWidth() == b.getHeight() &&
              out.getHeight() == a.getHeight() &&
              a.getWidth() == b.getWidth());
47
  } else {
X
xutianbing 已提交
48
    LOG(FATAL) << "Not support for both a and b are Transposed Matrices";
49 50
  }

X
xutianbing 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
  real* aData = const_cast<real*>(a.getData());
  real* bData = const_cast<real*>(b.getData());
  real* outData = const_cast<real*>(out.getData());
  hl_matrix_mul(aData,
                !a.isTransposed() ? HPPL_OP_N : HPPL_OP_T,
                bData,
                !b.isTransposed() ? HPPL_OP_N : HPPL_OP_T,
                outData,
                out.getHeight(),
                out.getWidth(),
                !a.isTransposed() ? a.getWidth() : a.getHeight(),
                scaleAB,
                scaleT,
                a.getStride(),
                b.getStride(),
                out.getStride());
67 68 69
}

/**
X
xutianbing 已提交
70
 * out = scaleT * out + scaleAB * (a * b)
71
 * out : M * N
72 73 74 75 76
 */
template <>
void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
                            const GpuSparseMatrix& a,
                            const GpuMatrix& b,
X
xutianbing 已提交
77 78
                            real scaleAB,
                            real scaleT) {
79 80
  CHECK(out.isContiguous());
  CHECK(b.isContiguous());
X
xutianbing 已提交
81 82 83
  CHECK(b.useGpu_) << "Matrix type are not equal";
  CHECK(!out.isTransposed() && !b.isTransposed()) << "not supported";
  if (!a.isTransposed()) {
84
    /// a: M * K,  b: K * N
X
xutianbing 已提交
85 86
    CHECK(out.getWidth() == b.getWidth() && out.getHeight() == a.getHeight()
        && a.getWidth() == b.getHeight()) << "Matrix dimensions are not equal";
87
  } else {
88
    /// a: K * M, transpose,  b: K * N
X
xutianbing 已提交
89 90
    CHECK(out.getWidth() == b.getWidth() && out.getHeight() == a.getWidth()
        && a.getHeight() == b.getHeight()) << "Matrix dimensions are not equal";
91
  }
92

X
xutianbing 已提交
93 94 95 96 97 98 99
  hl_trans_op_t aTrans = a.isTransposed() ? HPPL_OP_T : HPPL_OP_N;
  hl_sparse_matrix_s aData = a.sMatrix_.get();
  real* bData = const_cast<real*>(b.getData());
  real* outData = const_cast<real*>(out.getData());
  hl_matrix_csr_mul_dense(aData,
                          aTrans,
                          bData,
100
                          HPPL_OP_N,
X
xutianbing 已提交
101 102 103 104 105 106
                          outData,
                          out.getHeight(),
                          out.getWidth(),
                          b.getHeight(),
                          scaleAB,
                          scaleT);
107 108
}

109
/**
X
xutianbing 已提交
110
 * out = scaleT * out + scaleAB * (a * b)
111 112 113 114 115 116
 * out : M * N
 */
template <>
void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
                            const GpuMatrix& a,
                            const GpuSparseMatrix& b,
X
xutianbing 已提交
117 118
                            real scaleAB,
                            real scaleT) {
119 120
  CHECK(out.isContiguous());
  CHECK(a.isContiguous());
X
xutianbing 已提交
121 122 123 124 125 126 127
  CHECK(a.useGpu_) << "Matrix type are not equal";
  if (!b.isTransposed()) {
      /// a : M * K, b : K * N
      CHECK(out.getWidth() == b.getWidth() &&
              out.getHeight() == a.getHeight() &&
              a.getWidth() == b.getHeight())
          << "Matrix dimensions are not equal";
128
  } else {
X
xutianbing 已提交
129 130 131 132 133
      /// a : M * K, b : N * K, transpose
      CHECK(out.getWidth() == b.getHeight() &&
              out.getHeight() == a.getHeight() &&
              a.getWidth() == b.getWidth())
          << "Matrix dimensions are not equal";
134
  }
X
xutianbing 已提交
135 136 137 138 139 140

  hl_trans_op_t bTrans = b.isTransposed() ? HPPL_OP_T : HPPL_OP_N;
  hl_sparse_matrix_s bData = b.sMatrix_.get();
  real* aData = const_cast<real*>(a.getData());
  real* outData = const_cast<real*>(out.getData());

141
  if (b.format_ == SPARSE_CSC) {
X
xutianbing 已提交
142
    hl_matrix_dense_mul_csc(aData,
143
                            HPPL_OP_N,
X
xutianbing 已提交
144 145 146 147 148 149 150 151
                            bData,
                            bTrans,
                            outData,
                            out.getHeight(),
                            out.getWidth(),
                            a.getWidth(),
                            scaleAB,
                            scaleT);
152
  } else {
X
xutianbing 已提交
153
    hl_matrix_dense_mul_csr(aData,
154
                            HPPL_OP_N,
X
xutianbing 已提交
155 156 157 158 159 160 161 162
                            bData,
                            bTrans,
                            outData,
                            out.getHeight(),
                            out.getWidth(),
                            a.getWidth(),
                            scaleAB,
                            scaleT);
163 164 165
  }
}

166 167 168 169
template <>
void MulOp<DEVICE_TYPE_GPU>(GpuSparseMatrix& out,
                            const GpuMatrix& a,
                            const GpuMatrix& b,
X
xutianbing 已提交
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
                            real scaleAB,
                            real scaleT) {
  CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
  CHECK(!out.isTransposed()) << "Transpose is not supported for out matrix";

  if (!a.isTransposed() && !b.isTransposed()) {
    CHECK(out.getHeight() == a.getHeight() &&
         out.getWidth() == b.getWidth() &&
         a.getWidth() == b.getHeight());
  } else if (a.isTransposed() && !b.isTransposed()) {
    CHECK(out.getHeight() == a.getWidth() &&
          out.getWidth() == b.getWidth() &&
          a.getHeight() == b.getHeight());
  } else if (!a.isTransposed() && b.isTransposed()) {
    CHECK(out.getHeight() == a.getHeight() &&
          out.getWidth() == b.getHeight() &&
          a.getWidth() == b.getWidth());
187
  } else {
X
xutianbing 已提交
188
    LOG(FATAL) << "Not support for both a and b are Transposed Matrices";
189
  }
X
xutianbing 已提交
190 191 192 193 194 195 196 197 198 199

  hl_trans_op_t aTrans = a.isTransposed() ? HPPL_OP_T : HPPL_OP_N;
  hl_trans_op_t bTrans = b.isTransposed() ? HPPL_OP_T : HPPL_OP_N;
  int dimK = !b.isTransposed() ? b.getHeight() : b.getWidth();
  real* aData = const_cast<real*>(a.getData());
  real* bData = const_cast<real*>(b.getData());
  hl_sparse_matrix_s outData = out.sMatrix_.get();

  hl_sparse_matrix_mul(aData, aTrans, bData, bTrans, outData,
      out.getHeight(), out.getWidth(), dimK, scaleAB, scaleT);
200 201
}

202
}  // namespace paddle