MulOpGpu.cu 5.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "hl_base.h"
#include "MulOp.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/SparseMatrix.h"

namespace paddle {
/**
X
xutianbing 已提交
22
 * out = scaleT * out + scaleAB * (a * b)
23 24 25 26 27 28
 * out : output matrix, M * N
 */
template <>
void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
                            const GpuMatrix& a,
                            const GpuMatrix& b,
X
xutianbing 已提交
29
                            real scaleAB,
X
xutianbing 已提交
30 31 32 33 34
                            real scaleT,
                            bool aTrans,
                            bool bTrans,
                            bool cTrans) {
  CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
X
xutianbing 已提交
35 36 37 38
  real* aData = const_cast<real*>(a.getData());
  real* bData = const_cast<real*>(b.getData());
  real* outData = const_cast<real*>(out.getData());
  hl_matrix_mul(aData,
X
xutianbing 已提交
39
                !aTrans ? HPPL_OP_N : HPPL_OP_T,
X
xutianbing 已提交
40
                bData,
X
xutianbing 已提交
41
                !bTrans ? HPPL_OP_N : HPPL_OP_T,
X
xutianbing 已提交
42 43 44
                outData,
                out.getHeight(),
                out.getWidth(),
X
xutianbing 已提交
45
                !aTrans ? a.getWidth() : a.getHeight(),
X
xutianbing 已提交
46 47 48 49 50
                scaleAB,
                scaleT,
                a.getStride(),
                b.getStride(),
                out.getStride());
51 52 53
}

/**
X
xutianbing 已提交
54
 * out = scaleT * out + scaleAB * (a * b)
55
 * out : M * N
56 57 58 59 60
 */
template <>
void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
                            const GpuSparseMatrix& a,
                            const GpuMatrix& b,
X
xutianbing 已提交
61
                            real scaleAB,
X
xutianbing 已提交
62 63 64 65
                            real scaleT,
                            bool aTrans,
                            bool bTrans,
                            bool cTrans) {
66 67
  CHECK(out.isContiguous());
  CHECK(b.isContiguous());
X
xutianbing 已提交
68
  CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
69

X
xutianbing 已提交
70 71 72 73
  hl_sparse_matrix_s aData = a.sMatrix_.get();
  real* bData = const_cast<real*>(b.getData());
  real* outData = const_cast<real*>(out.getData());
  hl_matrix_csr_mul_dense(aData,
X
xutianbing 已提交
74
                          aTrans ? HPPL_OP_T : HPPL_OP_N,
X
xutianbing 已提交
75
                          bData,
76
                          HPPL_OP_N,
X
xutianbing 已提交
77 78 79 80 81 82
                          outData,
                          out.getHeight(),
                          out.getWidth(),
                          b.getHeight(),
                          scaleAB,
                          scaleT);
83 84
}

85
/**
X
xutianbing 已提交
86
 * out = scaleT * out + scaleAB * (a * b)
87 88 89 90 91 92
 * out : M * N
 */
template <>
void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
                            const GpuMatrix& a,
                            const GpuSparseMatrix& b,
X
xutianbing 已提交
93
                            real scaleAB,
X
xutianbing 已提交
94 95 96 97
                            real scaleT,
                            bool aTrans,
                            bool bTrans,
                            bool cTrans) {
98 99
  CHECK(out.isContiguous());
  CHECK(a.isContiguous());
X
xutianbing 已提交
100
  CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
X
xutianbing 已提交
101 102 103 104 105

  hl_sparse_matrix_s bData = b.sMatrix_.get();
  real* aData = const_cast<real*>(a.getData());
  real* outData = const_cast<real*>(out.getData());

106
  if (b.format_ == SPARSE_CSC) {
X
xutianbing 已提交
107
    hl_matrix_dense_mul_csc(aData,
108
                            HPPL_OP_N,
X
xutianbing 已提交
109
                            bData,
X
xutianbing 已提交
110
                            bTrans ? HPPL_OP_T : HPPL_OP_N,
X
xutianbing 已提交
111 112 113 114 115 116
                            outData,
                            out.getHeight(),
                            out.getWidth(),
                            a.getWidth(),
                            scaleAB,
                            scaleT);
117
  } else {
X
xutianbing 已提交
118
    hl_matrix_dense_mul_csr(aData,
119
                            HPPL_OP_N,
X
xutianbing 已提交
120
                            bData,
X
xutianbing 已提交
121
                            bTrans ? HPPL_OP_T : HPPL_OP_N,
X
xutianbing 已提交
122 123 124 125 126 127
                            outData,
                            out.getHeight(),
                            out.getWidth(),
                            a.getWidth(),
                            scaleAB,
                            scaleT);
128 129 130
  }
}

131 132 133 134
template <>
void MulOp<DEVICE_TYPE_GPU>(GpuSparseMatrix& out,
                            const GpuMatrix& a,
                            const GpuMatrix& b,
X
xutianbing 已提交
135
                            real scaleAB,
X
xutianbing 已提交
136 137 138 139
                            real scaleT,
                            bool aTrans,
                            bool bTrans,
                            bool cTrans) {
X
xutianbing 已提交
140 141 142 143 144 145
  CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";

  real* aData = const_cast<real*>(a.getData());
  real* bData = const_cast<real*>(b.getData());
  hl_sparse_matrix_s outData = out.sMatrix_.get();

X
xutianbing 已提交
146 147 148 149 150 151 152 153 154 155
  hl_sparse_matrix_mul(aData,
                       aTrans ? HPPL_OP_T : HPPL_OP_N,
                       bData,
                       bTrans ? HPPL_OP_T : HPPL_OP_N,
                       outData,
                       out.getHeight(),
                       out.getWidth(),
                       !bTrans ? b.getHeight() : b.getWidth(),
                       scaleAB,
                       scaleT);
156 157
}

158
}  // namespace paddle