MulOpGpu.cu 5.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "MulOp.h"
L
liaogang 已提交
16
#include "hl_base.h"
X
Xin Pan 已提交
17 18
#include "paddle/legacy/math/Matrix.h"
#include "paddle/legacy/math/SparseMatrix.h"
19 20

namespace paddle {
21
/// dense matrix (+)= dense matrix * dense matrix
22 23 24 25
template <>
void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
                            const GpuMatrix& a,
                            const GpuMatrix& b,
X
xutianbing 已提交
26
                            real scaleAB,
X
xutianbing 已提交
27 28
                            real scaleT,
                            bool aTrans,
X
xutianbing 已提交
29
                            bool bTrans) {
X
xutianbing 已提交
30
  CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
31
  hl_matrix_mul(const_cast<real*>(a.getData()),
X
xutianbing 已提交
32
                !aTrans ? HPPL_OP_N : HPPL_OP_T,
33
                const_cast<real*>(b.getData()),
X
xutianbing 已提交
34
                !bTrans ? HPPL_OP_N : HPPL_OP_T,
35
                const_cast<real*>(out.getData()),
X
xutianbing 已提交
36 37
                out.getHeight(),
                out.getWidth(),
X
xutianbing 已提交
38
                !aTrans ? a.getWidth() : a.getHeight(),
X
xutianbing 已提交
39 40 41 42 43
                scaleAB,
                scaleT,
                a.getStride(),
                b.getStride(),
                out.getStride());
44 45
}

46
/// dense matrix (+)= sparse matrix * dense matrix
47 48 49 50
template <>
void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
                            const GpuSparseMatrix& a,
                            const GpuMatrix& b,
X
xutianbing 已提交
51
                            real scaleAB,
X
xutianbing 已提交
52 53
                            real scaleT,
                            bool aTrans,
X
xutianbing 已提交
54
                            bool bTrans) {
55 56
  CHECK(out.isContiguous());
  CHECK(b.isContiguous());
X
xutianbing 已提交
57
  CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
58
  hl_matrix_csr_mul_dense(a.sMatrix_.get(),
X
xutianbing 已提交
59
                          aTrans ? HPPL_OP_T : HPPL_OP_N,
60
                          const_cast<real*>(b.getData()),
61
                          HPPL_OP_N,
62
                          const_cast<real*>(out.getData()),
X
xutianbing 已提交
63 64 65 66 67
                          out.getHeight(),
                          out.getWidth(),
                          b.getHeight(),
                          scaleAB,
                          scaleT);
68 69
}

70
/// dense matrix (+)= dense matrix * sparse matrix
71 72 73 74
template <>
void MulOp<DEVICE_TYPE_GPU>(GpuMatrix& out,
                            const GpuMatrix& a,
                            const GpuSparseMatrix& b,
X
xutianbing 已提交
75
                            real scaleAB,
X
xutianbing 已提交
76 77
                            real scaleT,
                            bool aTrans,
X
xutianbing 已提交
78
                            bool bTrans) {
79 80
  CHECK(out.isContiguous());
  CHECK(a.isContiguous());
X
xutianbing 已提交
81
  CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
X
xutianbing 已提交
82

83
  if (b.format_ == SPARSE_CSC) {
84
    hl_matrix_dense_mul_csc(const_cast<real*>(a.getData()),
85
                            HPPL_OP_N,
86
                            b.sMatrix_.get(),
X
xutianbing 已提交
87
                            bTrans ? HPPL_OP_T : HPPL_OP_N,
88
                            const_cast<real*>(out.getData()),
X
xutianbing 已提交
89 90 91 92 93
                            out.getHeight(),
                            out.getWidth(),
                            a.getWidth(),
                            scaleAB,
                            scaleT);
94
  } else {
95
    hl_matrix_dense_mul_csr(const_cast<real*>(a.getData()),
96
                            HPPL_OP_N,
97
                            b.sMatrix_.get(),
X
xutianbing 已提交
98
                            bTrans ? HPPL_OP_T : HPPL_OP_N,
99
                            const_cast<real*>(out.getData()),
X
xutianbing 已提交
100 101 102 103 104
                            out.getHeight(),
                            out.getWidth(),
                            a.getWidth(),
                            scaleAB,
                            scaleT);
105 106 107
  }
}

108
/// sparse matrix (+)= dense matrix * dense matrix
109 110 111 112
template <>
void MulOp<DEVICE_TYPE_GPU>(GpuSparseMatrix& out,
                            const GpuMatrix& a,
                            const GpuMatrix& b,
X
xutianbing 已提交
113
                            real scaleAB,
X
xutianbing 已提交
114 115
                            real scaleT,
                            bool aTrans,
X
xutianbing 已提交
116
                            bool bTrans) {
X
xutianbing 已提交
117
  CHECK(a.useGpu_ && b.useGpu_) << "matrix device type not match";
118
  hl_sparse_matrix_mul(const_cast<real*>(a.getData()),
X
xutianbing 已提交
119
                       aTrans ? HPPL_OP_T : HPPL_OP_N,
120
                       const_cast<real*>(b.getData()),
X
xutianbing 已提交
121
                       bTrans ? HPPL_OP_T : HPPL_OP_N,
122
                       out.sMatrix_.get(),
X
xutianbing 已提交
123 124 125 126 127
                       out.getHeight(),
                       out.getWidth(),
                       !bTrans ? b.getHeight() : b.getWidth(),
                       scaleAB,
                       scaleT);
128 129
}

130
}  // namespace paddle