hl_cuda_sparse.cu 39.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "hl_cuda.h"
L
liaogang 已提交
16 17 18
#include "hl_cuda_sparse.cuh"
#include "hl_matrix_apply.cuh"
#include "hl_matrix_ops.cuh"
Z
zhangjinchao01 已提交
19 20
#include "hl_sparse.h"
#include "hl_sparse.ph"
X
Xin Pan 已提交
21
#include "paddle/legacy/utils/Logging.h"
Z
zhangjinchao01 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35

DEFINE_MATRIX_UNARY_PARAMETER_OP(mul_scalar, ONE_PARAMETER, a = a * p);
DEFINE_MATRIX_UNARY_OP(Zero, a = 0);

void hl_matrix_csr2dense(hl_sparse_matrix_s A_d,
                         real *C_d,
                         int dimM,
                         int dimN) {
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(C_d);
  CHECK(dimM > 0 && dimN > 0 && A_d->rows == dimM && A_d->cols == dimN);
  CHECK(A_d->format == HL_SPARSE_CSR) << "matrix format error!";

  if (A_d->nnz == 0) {
L
liaogang 已提交
36
    hl_gpu_apply_unary_op(unary::Zero<real>(), C_d, dimM, dimN, dimN);
Z
zhangjinchao01 已提交
37 38 39 40 41
    return;
  }

  /* nnz != 0 */
  hl_csr_matrix A_d2 = (hl_csr_matrix)(A_d->matrix);
L
liaogang 已提交
42 43 44
  CHECK((A_d2->csr_val || A_d->type == HL_NO_VALUE) && A_d2->csr_row &&
        A_d2->csr_col)
      << "parameter transa error!";
Z
zhangjinchao01 已提交
45 46 47 48 49 50 51

  int blocksX = (dimN + CU_CSR2DENSE_THREAD_X - 1) / CU_CSR2DENSE_THREAD_X;
  int blocksY = (dimM + CU_CSR2DENSE_THREAD_X - 1) / CU_CSR2DENSE_THREAD_X;
  dim3 threads(CU_CSR2DENSE_THREAD_X, CU_CSR2DENSE_THREAD_X);
  dim3 grid(blocksX, blocksY);

  if (A_d->type == HL_NO_VALUE) {
L
liaogang 已提交
52 53
    KeSMatrixCsr2Dense<0><<<grid, threads, 0, STREAM_DEFAULT>>>(
        A_d2->csr_val, A_d2->csr_row, A_d2->csr_col, C_d, dimM, dimN);
Z
zhangjinchao01 已提交
54
  } else if (A_d->type == HL_FLOAT_VALUE) {
L
liaogang 已提交
55 56
    KeSMatrixCsr2Dense<1><<<grid, threads, 0, STREAM_DEFAULT>>>(
        A_d2->csr_val, A_d2->csr_row, A_d2->csr_col, C_d, dimM, dimN);
Z
zhangjinchao01 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
  } else {
  }
  CHECK_SYNC("hl_matrix_csr2dense failed");
}

void hl_matrix_csc2dense(hl_sparse_matrix_s A_d,
                         real *C_d,
                         int dimM,
                         int dimN) {
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(C_d);
  CHECK(dimM > 0 && dimN > 0 && A_d->rows == dimM && A_d->cols == dimN);
  CHECK(A_d->format == HL_SPARSE_CSC) << "matrix format error!";

  if (A_d->nnz == 0) {
L
liaogang 已提交
72
    hl_gpu_apply_unary_op(unary::Zero<real>(), C_d, dimM, dimN, dimN);
Z
zhangjinchao01 已提交
73 74 75 76 77
    return;
  }

  /* nnz != 0 */
  hl_csc_matrix A_d2 = (hl_csc_matrix)(A_d->matrix);
L
liaogang 已提交
78 79 80
  CHECK((A_d2->csc_val || A_d->type == HL_NO_VALUE) && A_d2->csc_row &&
        A_d2->csc_col)
      << "parameter transa error!";
Z
zhangjinchao01 已提交
81 82 83 84 85 86 87

  int blocksX = (dimN + CU_CSR2DENSE_THREAD_X - 1) / CU_CSR2DENSE_THREAD_X;
  int blocksY = (dimM + CU_CSR2DENSE_THREAD_X - 1) / CU_CSR2DENSE_THREAD_X;
  dim3 threads(CU_CSR2DENSE_THREAD_X, CU_CSR2DENSE_THREAD_X);
  dim3 grid(blocksX, blocksY);

  if (A_d->type == HL_NO_VALUE) {
L
liaogang 已提交
88 89
    KeSMatrixCsc2Dense<0><<<grid, threads, 0, STREAM_DEFAULT>>>(
        A_d2->csc_val, A_d2->csc_row, A_d2->csc_col, C_d, dimM, dimN);
Z
zhangjinchao01 已提交
90
  } else if (A_d->type == HL_FLOAT_VALUE) {
L
liaogang 已提交
91 92
    KeSMatrixCsc2Dense<1><<<grid, threads, 0, STREAM_DEFAULT>>>(
        A_d2->csc_val, A_d2->csc_row, A_d2->csc_col, C_d, dimM, dimN);
Z
zhangjinchao01 已提交
93 94 95 96 97 98 99
  } else {
  }
  CHECK_SYNC("hl_matrix_csc2dense failed");
}

void hl_malloc_sparse_matrix(hl_sparse_matrix_s *A_d,
                             hl_matrix_format_t format,
L
liaogang 已提交
100
                             hl_matrix_value_t value_type,
Z
zhangjinchao01 已提交
101 102 103 104 105
                             int dimM,
                             int dimN,
                             int nnz) {
  CHECK_NOTNULL(A_d);
  CHECK(format == HL_SPARSE_CSR || format == HL_SPARSE_CSC)
L
liaogang 已提交
106
      << "sparse matrix format error!";
Z
zhangjinchao01 已提交
107
  CHECK(value_type == HL_FLOAT_VALUE || value_type == HL_NO_VALUE)
L
liaogang 已提交
108
      << "sparse matrix value type error!";
Z
zhangjinchao01 已提交
109 110 111 112 113 114
  /* avoid malloc 0 bytes */
  int nnz_s = (nnz == 0 ? 1 : nnz);

  if (format == HL_SPARSE_CSR) {
    CHECK(dimM > 0 && nnz >= 0) << "sparse matrix size error!";

L
liaogang 已提交
115 116
    char *tmp =
        (char *)malloc(sizeof(_hl_sparse_matrix_s) + sizeof(_hl_csr_matrix));
Z
zhangjinchao01 已提交
117 118
    CHECK_NOTNULL(tmp);

L
liaogang 已提交
119
    hl_csr_matrix csr = (hl_csr_matrix)(tmp + sizeof(_hl_sparse_matrix_s));
Z
zhangjinchao01 已提交
120 121 122 123 124
    csr->sparsity = -1.0;

    if (value_type == HL_NO_VALUE) {
      csr->csr_val = NULL;
      csr->nnz_s = nnz_s;
L
liaogang 已提交
125 126 127
      csr->row_s = dimM + 1;
      csr->csr_row = (int *)hl_malloc_device((dimM + 1) * sizeof(int));
      csr->csr_col = (int *)hl_malloc_device((nnz_s) * sizeof(int));
Z
zhangjinchao01 已提交
128 129 130 131 132

      *A_d = (hl_sparse_matrix_s)tmp;
      (*A_d)->matrix = (hl_matrix_s)csr;
    } else if (value_type == HL_FLOAT_VALUE) {
      csr->nnz_s = nnz_s;
L
liaogang 已提交
133 134 135 136
      csr->row_s = dimM + 1;
      csr->csr_val = (real *)hl_malloc_device((nnz_s) * sizeof(real));
      csr->csr_row = (int *)hl_malloc_device((dimM + 1) * sizeof(int));
      csr->csr_col = (int *)hl_malloc_device((nnz_s) * sizeof(int));
Z
zhangjinchao01 已提交
137 138 139 140 141 142 143

      *A_d = (hl_sparse_matrix_s)tmp;
      (*A_d)->matrix = (hl_matrix_s)csr;
    }
  } else if (format == HL_SPARSE_CSC) {
    CHECK(dimM > 0 && nnz >= 0) << "sparse matrix size error!";

L
liaogang 已提交
144 145
    char *tmp =
        (char *)malloc(sizeof(_hl_sparse_matrix_s) + sizeof(_hl_csc_matrix));
Z
zhangjinchao01 已提交
146 147
    CHECK_NOTNULL(tmp);

L
liaogang 已提交
148
    hl_csc_matrix csc = (hl_csc_matrix)(tmp + sizeof(_hl_sparse_matrix_s));
Z
zhangjinchao01 已提交
149 150 151 152 153
    csc->sparsity = -1.0f;

    if (value_type == HL_NO_VALUE) {
      csc->csc_val = NULL;
      csc->nnz_s = nnz_s;
L
liaogang 已提交
154 155 156
      csc->col_s = dimN + 1;
      csc->csc_row = (int *)hl_malloc_device((nnz_s) * sizeof(int));
      csc->csc_col = (int *)hl_malloc_device((dimN + 1) * sizeof(int));
Z
zhangjinchao01 已提交
157 158 159 160 161

      *A_d = (hl_sparse_matrix_s)tmp;
      (*A_d)->matrix = (hl_matrix_s)csc;
    } else if (value_type == HL_FLOAT_VALUE) {
      csc->nnz_s = nnz_s;
L
liaogang 已提交
162 163 164 165
      csc->col_s = dimN + 1;
      csc->csc_val = (real *)hl_malloc_device((nnz_s) * sizeof(real));
      csc->csc_row = (int *)hl_malloc_device((nnz_s) * sizeof(int));
      csc->csc_col = (int *)hl_malloc_device((dimN + 1) * sizeof(int));
Z
zhangjinchao01 已提交
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181

      *A_d = (hl_sparse_matrix_s)tmp;
      (*A_d)->matrix = (hl_matrix_s)csc;
    }
  }

  (*A_d)->format = format;
  (*A_d)->type = value_type;
  (*A_d)->rows = dimM;
  (*A_d)->cols = dimN;
  (*A_d)->nnz = nnz;
}

void hl_free_sparse_matrix(hl_sparse_matrix_s A_d) {
  CHECK_NOTNULL(A_d);
  CHECK(A_d->format == HL_SPARSE_CSR || A_d->format == HL_SPARSE_CSC)
L
liaogang 已提交
182
      << "sparse matrix format error!";
Z
zhangjinchao01 已提交
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230

  if (A_d->matrix == NULL) {
    free(A_d);
    return;
  }

  if (A_d->format == HL_SPARSE_CSR) {
    hl_csr_matrix csr = (hl_csr_matrix)A_d->matrix;
    if (csr->csr_val != NULL) {
      hl_free_mem_device(csr->csr_val);
      csr->csr_val = NULL;
    }

    if (csr->csr_row != NULL) {
      hl_free_mem_device(csr->csr_row);
      csr->csr_row = NULL;
    }

    if (csr->csr_col != NULL) {
      hl_free_mem_device(csr->csr_col);
      csr->csr_col = NULL;
    }

    A_d->matrix = NULL;
    free(A_d);
  } else if (A_d->format == HL_SPARSE_CSC) {
    hl_csc_matrix csc = (hl_csc_matrix)A_d->matrix;
    if (csc->csc_val != NULL) {
      hl_free_mem_device(csc->csc_val);
      csc->csc_val = NULL;
    }

    if (csc->csc_row != NULL) {
      hl_free_mem_device(csc->csc_row);
      csc->csc_row = NULL;
    }

    if (csc->csc_col != NULL) {
      hl_free_mem_device(csc->csc_col);
      csc->csc_col = NULL;
    }

    A_d->matrix = NULL;
    free(A_d);
  }
}

void hl_construct_sparse_matrix(hl_sparse_matrix_s *A_d,
L
liaogang 已提交
231
                                void *dest_d,
Z
zhangjinchao01 已提交
232 233
                                size_t size,
                                hl_matrix_format_t format,
L
liaogang 已提交
234
                                hl_matrix_value_t value_type,
Z
zhangjinchao01 已提交
235 236 237 238 239
                                int dimM,
                                int dimN,
                                int nnz) {
  CHECK_NOTNULL(A_d);
  CHECK(format == HL_SPARSE_CSR || format == HL_SPARSE_CSC)
L
liaogang 已提交
240
      << "sparse matrix format error!";
Z
zhangjinchao01 已提交
241 242 243 244

  if (format == HL_SPARSE_CSR) {
    CHECK(dimM > 0 && nnz >= 0) << "sparse matrix size error!";

L
liaogang 已提交
245
    size_t size_ = (dimM + 1) * sizeof(int) + nnz * sizeof(int);
Z
zhangjinchao01 已提交
246
    if (value_type != HL_NO_VALUE) {
L
liaogang 已提交
247
      size_ += nnz * sizeof(real);
Z
zhangjinchao01 已提交
248 249
    }
    CHECK_LE(size_, size) << "dest_d size(" << size
L
liaogang 已提交
250 251
                          << ") too small, should bigger than(" << size_
                          << ")!";
Z
zhangjinchao01 已提交
252

L
liaogang 已提交
253 254
    char *tmp =
        (char *)malloc(sizeof(_hl_sparse_matrix_s) + sizeof(_hl_csr_matrix));
Z
zhangjinchao01 已提交
255 256
    CHECK_NOTNULL(tmp);

L
liaogang 已提交
257
    hl_csr_matrix csr = (hl_csr_matrix)(tmp + sizeof(_hl_sparse_matrix_s));
Z
zhangjinchao01 已提交
258 259 260

    if (value_type == HL_NO_VALUE) {
      csr->csr_val = NULL;
L
liaogang 已提交
261 262
      csr->csr_row = (int *)dest_d;
      csr->csr_col = (int *)((char *)dest_d + (dimM + 1) * sizeof(int));
Z
zhangjinchao01 已提交
263
    } else {
L
liaogang 已提交
264 265 266 267
      csr->csr_val = (real *)dest_d;
      csr->csr_row = (int *)((char *)dest_d + nnz * sizeof(real));
      csr->csr_col = (int *)((char *)dest_d + nnz * sizeof(real) +
                             (dimM + 1) * sizeof(int));
Z
zhangjinchao01 已提交
268 269
    }
    csr->nnz_s = nnz;
L
liaogang 已提交
270
    csr->row_s = dimM + 1;
Z
zhangjinchao01 已提交
271 272 273 274 275 276
    csr->sparsity = -1.0;
    *A_d = (hl_sparse_matrix_s)tmp;
    (*A_d)->matrix = (hl_matrix_s)csr;
  } else if (format == HL_SPARSE_CSC) {
    CHECK(dimM > 0 && nnz >= 0) << "sparse matrix size error!";

L
liaogang 已提交
277
    size_t size_ = (dimN + 1) * sizeof(int) + nnz * sizeof(int);
Z
zhangjinchao01 已提交
278
    if (value_type != HL_NO_VALUE) {
L
liaogang 已提交
279
      size_ += nnz * sizeof(real);
Z
zhangjinchao01 已提交
280 281
    }
    CHECK_LE(size_, size) << "dest_d size(" << size
L
liaogang 已提交
282 283
                          << ") too small, should bigger than(" << size_
                          << ")!";
Z
zhangjinchao01 已提交
284

L
liaogang 已提交
285 286
    char *tmp =
        (char *)malloc(sizeof(_hl_sparse_matrix_s) + sizeof(_hl_csc_matrix));
Z
zhangjinchao01 已提交
287 288
    CHECK_NOTNULL(tmp);

L
liaogang 已提交
289
    hl_csc_matrix csc = (hl_csc_matrix)(tmp + sizeof(_hl_sparse_matrix_s));
Z
zhangjinchao01 已提交
290 291
    if (value_type == HL_NO_VALUE) {
      csc->csc_val = NULL;
L
liaogang 已提交
292 293
      csc->csc_col = (int *)dest_d;
      csc->csc_row = (int *)((char *)dest_d + (dimN + 1) * sizeof(int));
Z
zhangjinchao01 已提交
294
    } else {
L
liaogang 已提交
295 296 297 298
      csc->csc_val = (real *)dest_d;
      csc->csc_col = (int *)((char *)dest_d + nnz * sizeof(real));
      csc->csc_row = (int *)((char *)dest_d + nnz * sizeof(real) +
                             (dimN + 1) * sizeof(int));
Z
zhangjinchao01 已提交
299 300
    }
    csc->nnz_s = nnz;
L
liaogang 已提交
301
    csc->col_s = dimN + 1;
Z
zhangjinchao01 已提交
302 303 304 305 306 307 308 309 310 311 312 313 314
    csc->sparsity = -1.0f;
    *A_d = (hl_sparse_matrix_s)tmp;
    (*A_d)->matrix = (hl_matrix_s)csc;
  }

  (*A_d)->format = format;
  (*A_d)->type = value_type;
  (*A_d)->rows = dimM;
  (*A_d)->cols = dimN;
  (*A_d)->nnz = nnz;
}

void hl_construct_sparse_matrix(hl_sparse_matrix_s *A_d,
L
liaogang 已提交
315 316 317
                                real *value_d,
                                int *rows_d,
                                int *cols_d,
Z
zhangjinchao01 已提交
318
                                hl_matrix_format_t format,
L
liaogang 已提交
319
                                hl_matrix_value_t value_type,
Z
zhangjinchao01 已提交
320 321 322 323 324 325 326
                                int dimM,
                                int dimN,
                                int nnz) {
  CHECK_NOTNULL(A_d);
  CHECK(dimM > 0 && nnz >= 0) << "sparse matrix size error!";

  CHECK(format == HL_SPARSE_CSR || format == HL_SPARSE_CSC)
L
liaogang 已提交
327
      << "sparse matrix format error!";
Z
zhangjinchao01 已提交
328 329

  if (format == HL_SPARSE_CSR) {
L
liaogang 已提交
330 331
    char *tmp =
        (char *)malloc(sizeof(_hl_sparse_matrix_s) + sizeof(_hl_csr_matrix));
Z
zhangjinchao01 已提交
332 333 334 335 336 337 338 339 340 341 342 343
    CHECK_NOTNULL(tmp);

    hl_csr_matrix csr = (hl_csr_matrix)(tmp + sizeof(_hl_sparse_matrix_s));
    csr->csr_row = rows_d;
    csr->csr_col = cols_d;
    csr->csr_val = value_d;
    csr->nnz_s = nnz;
    csr->row_s = dimM + 1;
    csr->sparsity = -1.0;
    *A_d = (hl_sparse_matrix_s)tmp;
    (*A_d)->matrix = (hl_matrix_s)csr;
  } else if (format == HL_SPARSE_CSC) {
L
liaogang 已提交
344 345
    char *tmp =
        (char *)malloc(sizeof(_hl_sparse_matrix_s) + sizeof(_hl_csc_matrix));
Z
zhangjinchao01 已提交
346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377
    CHECK_NOTNULL(tmp);

    hl_csc_matrix csc = (hl_csc_matrix)(tmp + sizeof(_hl_sparse_matrix_s));
    csc->csc_row = rows_d;
    csc->csc_col = cols_d;
    csc->csc_val = value_d;
    csc->nnz_s = nnz;
    csc->col_s = dimN + 1;
    csc->sparsity = -1.0f;
    *A_d = (hl_sparse_matrix_s)tmp;
    (*A_d)->matrix = (hl_matrix_s)csc;
  }

  (*A_d)->format = format;
  (*A_d)->type = value_type;
  (*A_d)->rows = dimM;
  (*A_d)->cols = dimN;
  (*A_d)->nnz = nnz;
}

void hl_destruct_sparse_matrix(hl_sparse_matrix_s A_d) {
  CHECK_NOTNULL(A_d);
  free(A_d);
}

void hl_memcpy_csr_matrix(hl_sparse_matrix_s csr_matrix,
                          real *csr_val,
                          int *csr_row,
                          int *csr_col,
                          hl_stream_t stream) {
  CHECK_NOTNULL(csr_matrix);
  CHECK_EQ(csr_matrix->format, HL_SPARSE_CSR)
L
liaogang 已提交
378
      << "csr_matrix is not csr format!";
Z
zhangjinchao01 已提交
379 380 381
  CHECK_NOTNULL(csr_matrix->matrix);

  hl_csr_matrix csr = (hl_csr_matrix)(csr_matrix->matrix);
L
liaogang 已提交
382 383 384
  CHECK_LE(csr_matrix->nnz, csr->nnz_s) << "copy size " << csr_matrix->nnz
                                        << " is big than alloc size "
                                        << csr->nnz_s;
Z
zhangjinchao01 已提交
385

L
liaogang 已提交
386 387 388
  CHECK_LE((csr_matrix->rows + 1), csr->row_s)
      << "copy size " << (csr_matrix->rows + 1) << " is big than alloc size "
      << csr->row_s;
Z
zhangjinchao01 已提交
389

L
liaogang 已提交
390 391
  CHECK(csr_matrix->type == HL_FLOAT_VALUE || csr_matrix->type == HL_NO_VALUE)
      << "sparse matrix value type error!";
Z
zhangjinchao01 已提交
392 393 394 395 396

  if (csr_matrix->type == HL_NO_VALUE) {
    if (csr_row == NULL && csr_col == NULL) {
      return;
    } else if (csr_row != NULL && csr_col != NULL) {
L
liaogang 已提交
397 398
      hl_memcpy_async(
          csr->csr_row, csr_row, (csr_matrix->rows + 1) * sizeof(int), stream);
Z
zhangjinchao01 已提交
399

L
liaogang 已提交
400 401
      hl_memcpy_async(
          csr->csr_col, csr_col, (csr_matrix->nnz) * sizeof(int), stream);
Z
zhangjinchao01 已提交
402 403 404 405 406 407 408
    } else {
      LOG(FATAL) << "parameter csr_row or csr_col is null pointer!";
    }
  } else if (csr_matrix->type == HL_FLOAT_VALUE) {
    if (csr_val == NULL && csr_row == NULL && csr_col == NULL) {
      return;
    } else if (csr_val != NULL && csr_row == NULL && csr_col == NULL) {
L
liaogang 已提交
409 410
      hl_memcpy_async(
          csr->csr_val, csr_val, (csr_matrix->nnz) * sizeof(real), stream);
Z
zhangjinchao01 已提交
411
    } else if (csr_val != NULL && csr_row != NULL && csr_col != NULL) {
L
liaogang 已提交
412 413 414 415 416 417
      hl_memcpy_async(
          csr->csr_val, csr_val, (csr_matrix->nnz) * sizeof(real), stream);
      hl_memcpy_async(
          csr->csr_row, csr_row, (csr_matrix->rows + 1) * sizeof(int), stream);
      hl_memcpy_async(
          csr->csr_col, csr_col, (csr_matrix->nnz) * sizeof(int), stream);
Z
zhangjinchao01 已提交
418 419 420 421 422
    } else {
      LOG(FATAL) << "parameter csr_row or csr_col is null pointer!";
    }
  }

L
liaogang 已提交
423
  csr->sparsity = ((float)csr_matrix->nnz) / ((float)csr_matrix->rows) /
Z
zhangjinchao01 已提交
424 425 426 427 428 429 430 431 432 433
                  ((float)csr_matrix->cols);
}

void hl_memcpy_csc_matrix(hl_sparse_matrix_s csc_matrix,
                          real *csc_val,
                          int *csc_row,
                          int *csc_col,
                          hl_stream_t stream) {
  CHECK_NOTNULL(csc_matrix);
  CHECK_EQ(csc_matrix->format, HL_SPARSE_CSC)
L
liaogang 已提交
434
      << "csc_matrix is not csc format error!";
Z
zhangjinchao01 已提交
435 436

  hl_csc_matrix csc = (hl_csc_matrix)(csc_matrix->matrix);
L
liaogang 已提交
437 438 439
  CHECK_LE(csc_matrix->nnz, csc->nnz_s) << "copy size " << csc_matrix->nnz
                                        << " is big than alloc size "
                                        << csc->nnz_s;
Z
zhangjinchao01 已提交
440

L
liaogang 已提交
441 442 443
  CHECK_LE((csc_matrix->cols + 1), csc->col_s)
      << "copy size " << (csc_matrix->cols + 1) << " is big than alloc size "
      << csc->col_s;
Z
zhangjinchao01 已提交
444

L
liaogang 已提交
445 446
  CHECK(csc_matrix->type == HL_FLOAT_VALUE || csc_matrix->type == HL_NO_VALUE)
      << "sparse matrix value type error!";
Z
zhangjinchao01 已提交
447 448 449 450 451

  if (csc_matrix->type == HL_NO_VALUE) {
    if (csc_row == NULL && csc_col == NULL) {
      return;
    } else if (csc_row != NULL && csc_col != NULL) {
L
liaogang 已提交
452 453 454 455
      hl_memcpy_async(
          csc->csc_row, csc_row, (csc_matrix->nnz) * sizeof(int), stream);
      hl_memcpy_async(
          csc->csc_col, csc_col, (csc_matrix->cols + 1) * sizeof(int), stream);
Z
zhangjinchao01 已提交
456 457 458 459 460 461 462
    } else {
      LOG(FATAL) << "parameter csc_row or csc_col is null pointer!";
    }
  } else if (csc_matrix->type == HL_FLOAT_VALUE) {
    if (csc_val == NULL && csc_row == NULL && csc_col == NULL) {
      return;
    } else if (csc_val != NULL && csc_row == NULL && csc_col == NULL) {
L
liaogang 已提交
463 464
      hl_memcpy_async(
          csc->csc_val, csc_val, (csc_matrix->nnz) * sizeof(real), stream);
Z
zhangjinchao01 已提交
465
    } else if (csc_val != NULL && csc_row != NULL && csc_col != NULL) {
L
liaogang 已提交
466 467 468 469 470 471
      hl_memcpy_async(
          csc->csc_val, csc_val, (csc_matrix->nnz) * sizeof(real), stream);
      hl_memcpy_async(
          csc->csc_row, csc_row, (csc_matrix->nnz) * sizeof(int), stream);
      hl_memcpy_async(
          csc->csc_col, csc_col, (csc_matrix->cols + 1) * sizeof(int), stream);
Z
zhangjinchao01 已提交
472 473 474 475 476
    } else {
      LOG(FATAL) << "parameter csc_row or csc_col is null pointer!";
    }
  }

L
liaogang 已提交
477
  csc->sparsity = ((float)csc_matrix->nnz) / ((float)csc_matrix->rows) /
Z
zhangjinchao01 已提交
478 479 480 481 482 483 484
                  ((float)csc_matrix->cols);
}

void hl_memcpy_sparse_matrix(hl_sparse_matrix_s dst,
                             hl_sparse_matrix_s src,
                             hl_stream_t stream) {
  CHECK(dst && src && dst->matrix && src->matrix)
L
liaogang 已提交
485 486
      << "parameter dst or src is null pointer!";
  CHECK_EQ(dst->format, src->format) << "sparse matrix format does not match!";
Z
zhangjinchao01 已提交
487
  CHECK(dst->type != HL_FLOAT_VALUE || src->type != HL_NO_VALUE)
L
liaogang 已提交
488
      << "src sparse matrix is no value, dst sparse matrix has value!";
Z
zhangjinchao01 已提交
489 490 491 492

  if (dst->format == HL_SPARSE_CSR) {
    dst->rows = src->rows;
    dst->cols = src->cols;
L
liaogang 已提交
493
    dst->nnz = src->nnz;
Z
zhangjinchao01 已提交
494
    hl_csr_matrix csr = (hl_csr_matrix)src->matrix;
L
liaogang 已提交
495
    hl_memcpy_csr_matrix(dst, csr->csr_val, csr->csr_row, csr->csr_col, stream);
Z
zhangjinchao01 已提交
496 497 498
  } else if (dst->format == HL_SPARSE_CSC) {
    dst->rows = src->rows;
    dst->cols = src->cols;
L
liaogang 已提交
499
    dst->nnz = src->nnz;
Z
zhangjinchao01 已提交
500
    hl_csc_matrix csc = (hl_csc_matrix)src->matrix;
L
liaogang 已提交
501
    hl_memcpy_csc_matrix(dst, csc->csc_val, csc->csc_row, csc->csc_col, stream);
Z
zhangjinchao01 已提交
502 503 504 505 506
  } else {
    LOG(FATAL) << "sparse matrix format error!";
  }
}

507 508 509 510 511 512 513
/**
 * Calculate beta * C, if beta is zero, C does not have to be a valid input.
 */
static void _beta_mul_c(real *c, int dimM, int dimN, real beta) {
  if (beta == 0.0) {
    hl_gpu_apply_unary_op(unary::Zero<real>(), c, dimM, dimN, dimN);
  } else {
L
liaogang 已提交
514 515
    if (beta != 1.0) {
      hl_gpu_apply_unary_op(unary::mul_scalar<real>(beta), c, dimM, dimN, dimN);
516 517 518 519 520 521
    }
  }

  return;
}

L
liaogang 已提交
522 523 524 525
void hl_matrix_csr_mul_dense(hl_sparse_matrix_s A_d,
                             hl_trans_op_t transa,
                             real *B_d,
                             hl_trans_op_t transb,
Z
zhangjinchao01 已提交
526
                             real *C_d,
L
liaogang 已提交
527 528 529 530 531
                             int dimM,
                             int dimN,
                             int dimK,
                             real alpha,
                             real beta) {
Z
zhangjinchao01 已提交
532 533 534 535 536 537 538 539 540
  CHECK_EQ(transb, HPPL_OP_N);
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(B_d);
  CHECK_NOTNULL(C_d);
  CHECK(dimM > 0 && dimN > 0 && dimK > 0);
  CHECK_EQ(A_d->format, HL_SPARSE_CSR) << "matrix format error!";

  if ((HPPL_OP_N == transa && (A_d->rows != dimM || A_d->cols != dimK)) ||
      (HPPL_OP_T == transa && (A_d->rows != dimK || A_d->cols != dimM))) {
L
liaogang 已提交
541
    LOG(FATAL) << "parameter error!";
Z
zhangjinchao01 已提交
542 543 544
  }

  if (A_d->nnz == 0) {
545 546
    _beta_mul_c(C_d, dimM, dimN, beta);
    return;
Z
zhangjinchao01 已提交
547 548 549 550 551
  }

  /* nnz != 0 */
  hl_csr_matrix A_d2 = (hl_csr_matrix)(A_d->matrix);
  if ((A_d2->csr_val == NULL && A_d->type != HL_NO_VALUE) ||
L
liaogang 已提交
552
      A_d2->csr_row == NULL || A_d2->csr_col == NULL) {
Z
zhangjinchao01 已提交
553 554 555 556 557 558 559 560 561 562 563 564
    LOG(FATAL) << "parameter error!";
  }

  if (HPPL_OP_N == transa) {
    int blocksX = (dimN + CU_CSRMM_BLOCK_N - 1) / CU_CSRMM_BLOCK_N;
    int blocksY = (dimM + CU_CSRMM_THREAD_Y - 1) / CU_CSRMM_THREAD_Y;
    dim3 threads(CU_CSRMM_THREAD_X, CU_CSRMM_THREAD_Y);
    dim3 grid(blocksX, blocksY);

    /* sparsity pattern */
    // A_d->sparsity;
    if (A_d->type == HL_NO_VALUE) {
L
liaogang 已提交
565 566 567 568 569 570 571 572 573 574 575
      KeSMatrixCsrMulDense<0><<<grid, threads, 0, STREAM_DEFAULT>>>(
          C_d,
          A_d2->csr_val,
          A_d2->csr_col,
          A_d2->csr_row,
          B_d,
          dimM,
          dimN,
          dimK,
          alpha,
          beta);
Z
zhangjinchao01 已提交
576
    } else {
L
liaogang 已提交
577 578 579 580 581 582 583 584 585 586 587
      KeSMatrixCsrMulDense<1><<<grid, threads, 0, STREAM_DEFAULT>>>(
          C_d,
          A_d2->csr_val,
          A_d2->csr_col,
          A_d2->csr_row,
          B_d,
          dimM,
          dimN,
          dimK,
          alpha,
          beta);
Z
zhangjinchao01 已提交
588 589
    }
  } else if (HPPL_OP_T == transa) {
590
    _beta_mul_c(C_d, dimM, dimN, beta);
Z
zhangjinchao01 已提交
591

L
liaogang 已提交
592 593 594 595
    int blocksX =
        (dimN + CU_CSC_MUL_DENSE_BLOCK_N - 1) / CU_CSC_MUL_DENSE_BLOCK_N;
    int blocksY =
        (dimK + CU_CSC_MUL_DENSE_BLOCK_K - 1) / CU_CSC_MUL_DENSE_BLOCK_K;
Z
zhangjinchao01 已提交
596 597 598
    dim3 threads(CU_CSC_MUL_DENSE_THREAD_X, CU_CSC_MUL_DENSE_THREAD_Y);
    dim3 grid(blocksX, blocksY);
    if (A_d->type == HL_NO_VALUE) {
L
liaogang 已提交
599 600 601 602 603 604 605 606 607 608 609
      KeSMatrixCscMulDense<0><<<grid, threads, 0, STREAM_DEFAULT>>>(
          C_d,
          A_d2->csr_val,
          A_d2->csr_col,
          A_d2->csr_row,
          B_d,
          dimM,
          dimN,
          dimK,
          alpha,
          beta);
Z
zhangjinchao01 已提交
610
    } else {
L
liaogang 已提交
611 612 613 614 615 616 617 618 619 620 621
      KeSMatrixCscMulDense<1><<<grid, threads, 0, STREAM_DEFAULT>>>(
          C_d,
          A_d2->csr_val,
          A_d2->csr_col,
          A_d2->csr_row,
          B_d,
          dimM,
          dimN,
          dimK,
          alpha,
          beta);
Z
zhangjinchao01 已提交
622 623 624 625 626 627 628 629
    }
  } else {
    LOG(FATAL) << "parameter transa error!";
  }

  CHECK_SYNC("hl_matrix_csr_mul_dense failed");
}

L
liaogang 已提交
630 631 632 633
void hl_matrix_dense_mul_csc(real *A_d,
                             hl_trans_op_t transa,
                             hl_sparse_matrix_s B_d,
                             hl_trans_op_t transb,
Z
zhangjinchao01 已提交
634
                             real *C_d,
L
liaogang 已提交
635 636 637 638 639
                             int dimM,
                             int dimN,
                             int dimK,
                             real alpha,
                             real beta) {
Z
zhangjinchao01 已提交
640 641 642 643 644 645 646 647 648 649 650
  CHECK_EQ(transa, HPPL_OP_N);
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(B_d);
  CHECK_NOTNULL(C_d);

  if (dimM <= 0 || dimN <= 0 || dimK <= 0 ||
      ((transb == HPPL_OP_N) && (B_d->rows != dimK || B_d->cols != dimN)) ||
      ((transb == HPPL_OP_T) && (B_d->rows != dimN || B_d->cols != dimK))) {
    LOG(FATAL) << "parameter dims error!";
  }

L
liaogang 已提交
651
  CHECK_EQ(B_d->format, HL_SPARSE_CSC) << "matrix format error!";
Z
zhangjinchao01 已提交
652 653

  if (B_d->nnz == 0) {
654 655
    _beta_mul_c(C_d, dimM, dimN, beta);
    return;
Z
zhangjinchao01 已提交
656 657 658 659 660
  }

  /* nnz != 0 */
  hl_csc_matrix B_d2 = (hl_csc_matrix)(B_d->matrix);
  if ((B_d2->csc_val == NULL && B_d->type != HL_NO_VALUE) ||
L
liaogang 已提交
661
      B_d2->csc_row == NULL || B_d2->csc_col == NULL) {
Z
zhangjinchao01 已提交
662 663 664 665 666 667 668 669 670 671
    LOG(FATAL) << "parameter B is null!";
  }

  if (transb == HPPL_OP_N) {
    int blocksX = (dimM + CU_CSCMM_BLOCK_M_BEST - 1) / CU_CSCMM_BLOCK_M_BEST;
    int blocksY = (dimN + CU_CSCMM_BLOCK_N_BEST - 1) / CU_CSCMM_BLOCK_N_BEST;
    dim3 threads(CU_CSCMM_THREAD_X_BEST, CU_CSCMM_THREAD_Y_BEST);
    dim3 grid(blocksX, blocksY);

    if (B_d->type == HL_NO_VALUE) {
L
liaogang 已提交
672 673 674 675 676 677 678 679 680 681 682
      KeSMatrixDenseMulCsc<0><<<grid, threads, 0, STREAM_DEFAULT>>>(
          C_d,
          A_d,
          B_d2->csc_val,
          B_d2->csc_row,
          B_d2->csc_col,
          dimM,
          dimN,
          dimK,
          alpha,
          beta);
Z
zhangjinchao01 已提交
683
    } else {
L
liaogang 已提交
684 685 686 687 688 689 690 691 692 693 694
      KeSMatrixDenseMulCsc<1><<<grid, threads, 0, STREAM_DEFAULT>>>(
          C_d,
          A_d,
          B_d2->csc_val,
          B_d2->csc_row,
          B_d2->csc_col,
          dimM,
          dimN,
          dimK,
          alpha,
          beta);
Z
zhangjinchao01 已提交
695 696
    }
  } else if (transb == HPPL_OP_T) {
697
    _beta_mul_c(C_d, dimM, dimN, beta);
L
liaogang 已提交
698 699
    int blocksX = 1 + (dimK - 1) / CU_DM_CSR_THREAD_X;
    int blocksY = 1 + (dimM - 1) / CU_DM_CSR_BLOCK_M;
Z
zhangjinchao01 已提交
700 701 702
    dim3 threads(CU_DM_CSR_THREAD_X, CU_DM_CSR_THREAD_Y);
    dim3 grid(blocksX, blocksY);
    if (B_d->type == HL_NO_VALUE) {
L
liaogang 已提交
703 704 705 706 707 708 709 710 711 712 713
      KeSMatrixDenseMulCsr<0><<<grid, threads, 0, STREAM_DEFAULT>>>(
          C_d,
          A_d,
          B_d2->csc_val,
          B_d2->csc_col,
          B_d2->csc_row,
          dimM,
          dimN,
          dimK,
          alpha,
          beta);
Z
zhangjinchao01 已提交
714
    } else {
L
liaogang 已提交
715 716 717 718 719 720 721 722 723 724 725
      KeSMatrixDenseMulCsr<1><<<grid, threads, 0, STREAM_DEFAULT>>>(
          C_d,
          A_d,
          B_d2->csc_val,
          B_d2->csc_col,
          B_d2->csc_row,
          dimM,
          dimN,
          dimK,
          alpha,
          beta);
Z
zhangjinchao01 已提交
726 727 728 729 730 731 732 733
    }
  } else {
    LOG(FATAL) << "parameter transb error!";
  }

  CHECK_SYNC("hl_matrix_dense_mul_csc failed");
}

L
liaogang 已提交
734 735 736 737
void hl_matrix_dense_mul_csr(real *A_d,
                             hl_trans_op_t transa,
                             hl_sparse_matrix_s B_d,
                             hl_trans_op_t transb,
Z
zhangjinchao01 已提交
738
                             real *C_d,
L
liaogang 已提交
739 740 741 742 743
                             int dimM,
                             int dimN,
                             int dimK,
                             real alpha,
                             real beta) {
Z
zhangjinchao01 已提交
744 745 746 747 748
  CHECK_EQ(transa, HPPL_OP_N);
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(B_d);
  CHECK_NOTNULL(C_d);

L
liaogang 已提交
749 750 751
  if (dimM <= 0 || dimN <= 0 || dimK <= 0 ||
      (transb == HPPL_OP_N && (B_d->rows != dimK || B_d->cols != dimN)) ||
      (transb == HPPL_OP_T && (B_d->rows != dimN || B_d->cols != dimK))) {
Z
zhangjinchao01 已提交
752 753 754
    LOG(FATAL) << "parameter dims error!";
  }

L
liaogang 已提交
755
  CHECK_EQ(B_d->format, HL_SPARSE_CSR) << "matrix format error!";
Z
zhangjinchao01 已提交
756 757

  if (B_d->nnz == 0) {
758 759
    _beta_mul_c(C_d, dimM, dimN, beta);
    return;
Z
zhangjinchao01 已提交
760 761 762 763 764
  }

  /* nnz != 0 */
  hl_csr_matrix B_d2 = (hl_csr_matrix)(B_d->matrix);
  if ((B_d2->csr_val == NULL && B_d->type != HL_NO_VALUE) ||
L
liaogang 已提交
765
      B_d2->csr_row == NULL || B_d2->csr_col == NULL) {
Z
zhangjinchao01 已提交
766 767 768 769
    LOG(FATAL) << "parameter transa error!";
  }

  if (transb == HPPL_OP_N) {
770
    _beta_mul_c(C_d, dimM, dimN, beta);
L
liaogang 已提交
771 772
    int blocksX = 1 + (dimK - 1) / CU_DM_CSR_THREAD_X;
    int blocksY = 1 + (dimM - 1) / CU_DM_CSR_BLOCK_M;
Z
zhangjinchao01 已提交
773 774 775
    dim3 threads(CU_DM_CSR_THREAD_X, CU_DM_CSR_THREAD_Y);
    dim3 grid(blocksX, blocksY);
    if (B_d->type == HL_NO_VALUE) {
L
liaogang 已提交
776 777 778 779 780 781 782 783 784 785 786
      KeSMatrixDenseMulCsr<0><<<grid, threads, 0, STREAM_DEFAULT>>>(
          C_d,
          A_d,
          B_d2->csr_val,
          B_d2->csr_row,
          B_d2->csr_col,
          dimM,
          dimN,
          dimK,
          alpha,
          beta);
Z
zhangjinchao01 已提交
787
    } else {
L
liaogang 已提交
788 789 790 791 792 793 794 795 796 797 798
      KeSMatrixDenseMulCsr<1><<<grid, threads, 0, STREAM_DEFAULT>>>(
          C_d,
          A_d,
          B_d2->csr_val,
          B_d2->csr_row,
          B_d2->csr_col,
          dimM,
          dimN,
          dimK,
          alpha,
          beta);
Z
zhangjinchao01 已提交
799 800 801 802 803 804 805
    }
  } else if (transb == HPPL_OP_T) {
    int blocksX = (dimM + CU_CSCMM_BLOCK_M_BEST - 1) / CU_CSCMM_BLOCK_M_BEST;
    int blocksY = (dimN + CU_CSCMM_BLOCK_N_BEST - 1) / CU_CSCMM_BLOCK_N_BEST;
    dim3 threads(CU_CSCMM_THREAD_X_BEST, CU_CSCMM_THREAD_Y_BEST);
    dim3 grid(blocksX, blocksY);
    if (B_d->type == HL_NO_VALUE) {
L
liaogang 已提交
806 807 808 809 810 811 812 813 814 815 816
      KeSMatrixDenseMulCsc<0><<<grid, threads, 0, STREAM_DEFAULT>>>(
          C_d,
          A_d,
          B_d2->csr_val,
          B_d2->csr_col,
          B_d2->csr_row,
          dimM,
          dimN,
          dimK,
          alpha,
          beta);
Z
zhangjinchao01 已提交
817
    } else {
L
liaogang 已提交
818 819 820 821 822 823 824 825 826 827 828
      KeSMatrixDenseMulCsc<1><<<grid, threads, 0, STREAM_DEFAULT>>>(
          C_d,
          A_d,
          B_d2->csr_val,
          B_d2->csr_col,
          B_d2->csr_row,
          dimM,
          dimN,
          dimK,
          alpha,
          beta);
Z
zhangjinchao01 已提交
829 830 831 832 833 834 835 836
    }
  } else {
    LOG(FATAL) << "parameter transb error!";
  }

  CHECK_SYNC("hl_matrix_dense_mul_csr failed");
}

L
liaogang 已提交
837 838 839 840
void hl_matrix_csc_mul_dense(hl_sparse_matrix_s A_d,
                             hl_trans_op_t transa,
                             real *B_d,
                             hl_trans_op_t transb,
Z
zhangjinchao01 已提交
841
                             real *C_d,
L
liaogang 已提交
842 843 844 845 846
                             int dimM,
                             int dimN,
                             int dimK,
                             real alpha,
                             real beta) {
Z
zhangjinchao01 已提交
847 848 849 850 851 852 853 854 855 856 857 858 859
  CHECK_EQ(transb, HPPL_OP_N);
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(B_d);
  CHECK_NOTNULL(C_d);
  CHECK(dimM > 0 && dimN > 0 && dimK > 0) << "parameter error!";
  CHECK_EQ(A_d->format, HL_SPARSE_CSC) << "matrix format error!";

  if ((HPPL_OP_N == transa && (A_d->rows != dimM || A_d->cols != dimK)) ||
      (HPPL_OP_T == transa && (A_d->rows != dimK || A_d->cols != dimM))) {
    LOG(FATAL) << "parameter error!";
  }

  if (A_d->nnz == 0) {
860 861
    _beta_mul_c(C_d, dimM, dimN, beta);
    return;
Z
zhangjinchao01 已提交
862 863 864 865 866
  }

  /* nnz != 0 */
  hl_csc_matrix A_d2 = (hl_csc_matrix)(A_d->matrix);
  if ((A_d2->csc_val == NULL && A_d->type != HL_NO_VALUE) ||
L
liaogang 已提交
867
      A_d2->csc_row == NULL || A_d2->csc_col == NULL) {
Z
zhangjinchao01 已提交
868 869 870 871
    LOG(FATAL) << "parameter error!";
  }

  if (HPPL_OP_N == transa) {
872
    _beta_mul_c(C_d, dimM, dimN, beta);
Z
zhangjinchao01 已提交
873

L
liaogang 已提交
874 875 876 877
    int blocksX =
        (dimN + CU_CSC_MUL_DENSE_BLOCK_N - 1) / CU_CSC_MUL_DENSE_BLOCK_N;
    int blocksY =
        (dimK + CU_CSC_MUL_DENSE_BLOCK_K - 1) / CU_CSC_MUL_DENSE_BLOCK_K;
Z
zhangjinchao01 已提交
878 879 880
    dim3 threads(CU_CSC_MUL_DENSE_THREAD_X, CU_CSC_MUL_DENSE_THREAD_Y);
    dim3 grid(blocksX, blocksY);
    if (A_d->type == HL_NO_VALUE) {
L
liaogang 已提交
881 882 883 884 885 886 887 888 889 890 891
      KeSMatrixCscMulDense<0><<<grid, threads, 0, STREAM_DEFAULT>>>(
          C_d,
          A_d2->csc_val,
          A_d2->csc_row,
          A_d2->csc_col,
          B_d,
          dimM,
          dimN,
          dimK,
          alpha,
          beta);
Z
zhangjinchao01 已提交
892
    } else {
L
liaogang 已提交
893 894 895 896 897 898 899 900 901 902 903
      KeSMatrixCscMulDense<1><<<grid, threads, 0, STREAM_DEFAULT>>>(
          C_d,
          A_d2->csc_val,
          A_d2->csc_row,
          A_d2->csc_col,
          B_d,
          dimM,
          dimN,
          dimK,
          alpha,
          beta);
Z
zhangjinchao01 已提交
904 905 906 907 908 909 910 911 912 913
    }
  } else if (HPPL_OP_T == transa) {
    int blocksX = (dimN + CU_CSRMM_BLOCK_N - 1) / CU_CSRMM_BLOCK_N;
    int blocksY = (dimM + CU_CSRMM_THREAD_Y - 1) / CU_CSRMM_THREAD_Y;
    dim3 threads(CU_CSRMM_THREAD_X, CU_CSRMM_THREAD_Y);
    dim3 grid(blocksX, blocksY);

    /* sparsity pattern */
    // A_d->sparsity;
    if (A_d->type == HL_NO_VALUE) {
L
liaogang 已提交
914 915 916 917 918 919 920 921 922 923 924
      KeSMatrixCsrMulDense<0><<<grid, threads, 0, STREAM_DEFAULT>>>(
          C_d,
          A_d2->csc_val,
          A_d2->csc_row,
          A_d2->csc_col,
          B_d,
          dimM,
          dimN,
          dimK,
          alpha,
          beta);
Z
zhangjinchao01 已提交
925
    } else {
L
liaogang 已提交
926 927 928 929 930 931 932 933 934 935 936
      KeSMatrixCsrMulDense<1><<<grid, threads, 0, STREAM_DEFAULT>>>(
          C_d,
          A_d2->csc_val,
          A_d2->csc_row,
          A_d2->csc_col,
          B_d,
          dimM,
          dimN,
          dimK,
          alpha,
          beta);
Z
zhangjinchao01 已提交
937 938 939 940 941 942 943 944
    }
  } else {
    LOG(FATAL) << "parameter transa error!";
  }

  CHECK_SYNC("hl_matrix_csc_mul_dense failed");
}

L
liaogang 已提交
945 946 947 948 949 950 951 952 953 954
void hl_sparse_matrix_mul(real *A_d,
                          hl_trans_op_t transa,
                          real *B_d,
                          hl_trans_op_t transb,
                          hl_sparse_matrix_s C_d,
                          int dimM,
                          int dimN,
                          int dimK,
                          real alpha,
                          real beta) {
Z
zhangjinchao01 已提交
955 956 957 958 959 960 961 962 963 964
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(B_d);
  CHECK_NOTNULL(C_d);
  CHECK(dimM > 0 && dimN > 0 && dimK > 0) << "parameter error!";
  CHECK_NE(C_d->type, HL_NO_VALUE) << "C value type error!";

  if (C_d->nnz == 0) return;

  if (C_d->format == HL_SPARSE_CSC) {
    hl_csc_matrix C_d2 = (hl_csc_matrix)(C_d->matrix);
L
liaogang 已提交
965
    if (C_d2->csc_val == NULL || C_d2->csc_row == NULL ||
Z
zhangjinchao01 已提交
966 967 968 969 970
        C_d2->csc_col == NULL) {
      LOG(FATAL) << "parameter error!";
    }

    if (beta != 1.0) {
L
liaogang 已提交
971 972
      hl_gpu_apply_unary_op(
          unary::mul_scalar<real>(beta), C_d2->csc_val, 1, C_d->nnz, C_d->nnz);
Z
zhangjinchao01 已提交
973 974 975 976 977 978 979 980
    }

    int blocksX = dimN;
    int blocksY = 1;
    dim3 threads(CU_CSCMM_DMD2CSC_THREAD_X, 1);
    dim3 grid(blocksX, blocksY);
    bool transA = transa == HPPL_OP_T ? 1 : 0;
    bool transB = transb == HPPL_OP_T ? 1 : 0;
L
liaogang 已提交
981 982 983 984 985 986 987 988 989 990 991 992 993
    KeSMatrixDenseMulDense2CSC<<<grid, threads, 0, STREAM_DEFAULT>>>(
        C_d2->csc_val,
        C_d2->csc_row,
        C_d2->csc_col,
        A_d,
        B_d,
        transA,
        transB,
        dimM,
        dimN,
        dimK,
        alpha,
        beta);
Z
zhangjinchao01 已提交
994 995 996 997
    CHECK_SYNC("hl_sparse_matrix_mul failed");
  } else {
    hl_csr_matrix C_d2 = (hl_csr_matrix)(C_d->matrix);
    if ((C_d2->csr_val == NULL && C_d->type != HL_NO_VALUE) ||
L
liaogang 已提交
998
        C_d2->csr_row == NULL || C_d2->csr_col == NULL) {
Z
zhangjinchao01 已提交
999 1000 1001 1002
      LOG(FATAL) << "parameter error!";
    }

    if (beta != 1.0) {
L
liaogang 已提交
1003 1004
      hl_gpu_apply_unary_op(
          unary::mul_scalar<real>(beta), C_d2->csr_val, 1, C_d->nnz, C_d->nnz);
Z
zhangjinchao01 已提交
1005 1006 1007 1008 1009 1010 1011 1012 1013 1014
    }

    bool transA = transa == HPPL_OP_T ? 1 : 0;
    bool transB = transb == HPPL_OP_T ? 1 : 0;
    if (!transB) {
      int blocksX = dimM;
      int blocksY = 1;
      dim3 threads(CU_CSCMM_DMD2CSR_THREAD_X, 1);
      dim3 grid(blocksX, blocksY);

L
liaogang 已提交
1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028
      KeSMatrixDenseMulDense2CSR<<<grid, threads, 0, STREAM_DEFAULT>>>(
          C_d2->csr_val,
          C_d2->csr_row,
          C_d2->csr_col,
          A_d,
          B_d,
          transA,
          transB,
          dimM,
          dimN,
          dimK,
          alpha,
          beta);
      CHECK_SYNC("hl_sparse_matrix_mul failed");
Z
zhangjinchao01 已提交
1029 1030 1031 1032
    } else {
      CHECK(!transA) << "Not supported A is trans and B is not trans!";

      dim3 block(CU_BLOCK_SIZE, 1);
1033
      int avgNnzPerRow = C_d->nnz / dimM;
Z
zhangjinchao01 已提交
1034 1035 1036
      avgNnzPerRow = avgNnzPerRow > 0 ? avgNnzPerRow : 1;
      int gridx = DIVUP(avgNnzPerRow, CU_BLOCK_SIZE);
      dim3 grid(gridx, dimM);
L
liaogang 已提交
1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051
      KeSMatrixDenseMulDenseTrans2CSR<<<grid, block, 0, STREAM_DEFAULT>>>(
          C_d2->csr_val,
          C_d2->csr_row,
          C_d2->csr_col,
          A_d,
          B_d,
          transA,
          transB,
          dimM,
          dimN,
          dimK,
          alpha,
          beta);
      CHECK_SYNC("hl_sparse_matrix_mul failed");
    }
Z
zhangjinchao01 已提交
1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067
  }
}

void hl_memcpy_from_csc_matrix(real *csc_val,
                               size_t val_size,
                               int *csc_row,
                               size_t row_size,
                               int *csc_col,
                               size_t col_size,
                               hl_sparse_matrix_s csc_matrix,
                               hl_stream_t stream) {
  CHECK_NOTNULL(csc_matrix);
  CHECK_NOTNULL(csc_row);
  CHECK_NOTNULL(csc_col);

  CHECK_EQ(csc_matrix->format, HL_SPARSE_CSC)
L
liaogang 已提交
1068
      << "csc_matrix is not csc format error!";
Z
zhangjinchao01 已提交
1069 1070 1071 1072 1073 1074 1075

  if (csc_matrix->nnz > row_size ||
      csc_matrix->cols + 1 > static_cast<int>(col_size)) {
    LOG(FATAL) << "size not match!";
  }

  hl_csc_matrix csc = (hl_csc_matrix)(csc_matrix->matrix);
L
liaogang 已提交
1076 1077
  hl_memcpy_async((void *)csc_row,
                  (void *)csc->csc_row,
Z
zhangjinchao01 已提交
1078 1079
                  (csc_matrix->nnz) * sizeof(int),
                  stream);
L
liaogang 已提交
1080 1081
  hl_memcpy_async((void *)csc_col,
                  (void *)csc->csc_col,
Z
zhangjinchao01 已提交
1082 1083 1084 1085 1086
                  (csc_matrix->cols + 1) * sizeof(int),
                  stream);
  if (csc_matrix->type == HL_FLOAT_VALUE) {
    if (csc_val != NULL) {
      CHECK_LE(csc_matrix->nnz, val_size) << "size not match!";
L
liaogang 已提交
1087 1088 1089
      hl_memcpy_async((void *)csc_val,
                      (void *)csc->csc_val,
                      (csc_matrix->nnz) * sizeof(real),
Z
zhangjinchao01 已提交
1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108
                      stream);
    } else {
      LOG(FATAL) << "parameter csr_val is null pointer!";
    }
  }
}

void hl_memcpy_from_csr_matrix(real *csr_val,
                               size_t val_size,
                               int *csr_row,
                               size_t row_size,
                               int *csr_col,
                               size_t col_size,
                               hl_sparse_matrix_s csr_matrix,
                               hl_stream_t stream) {
  CHECK_NOTNULL(csr_matrix);
  CHECK_NOTNULL(csr_row);
  CHECK_NOTNULL(csr_col);
  CHECK_EQ(csr_matrix->format, HL_SPARSE_CSR)
L
liaogang 已提交
1109
      << "csr_matrix is not csr format error!";
Z
zhangjinchao01 已提交
1110 1111 1112 1113 1114 1115 1116

  if (csr_matrix->nnz > col_size ||
      csr_matrix->rows + 1 > static_cast<int>(row_size)) {
    LOG(FATAL) << "size not match!";
  }

  hl_csr_matrix csr = (hl_csr_matrix)(csr_matrix->matrix);
L
liaogang 已提交
1117 1118 1119
  hl_memcpy_async((void *)csr_row,
                  (void *)csr->csr_row,
                  (csr_matrix->rows + 1) * sizeof(int),
Z
zhangjinchao01 已提交
1120
                  stream);
L
liaogang 已提交
1121 1122 1123
  hl_memcpy_async((void *)csr_col,
                  (void *)csr->csr_col,
                  (csr_matrix->nnz) * sizeof(int),
Z
zhangjinchao01 已提交
1124 1125 1126 1127
                  stream);
  if (csr_matrix->type == HL_FLOAT_VALUE) {
    if (csr_val != NULL) {
      CHECK_LE(csr_matrix->nnz, val_size) << "size not match!";
L
liaogang 已提交
1128 1129 1130
      hl_memcpy_async((void *)csr_val,
                      (void *)csr->csr_val,
                      (csr_matrix->nnz) * sizeof(real),
Z
zhangjinchao01 已提交
1131 1132 1133 1134 1135 1136 1137
                      stream);
    } else {
      LOG(FATAL) << "parameter csr_val is null pointer!";
    }
  }
}

L
liaogang 已提交
1138 1139
void hl_sparse_matrix_column_sum(
    real *A_d, hl_sparse_matrix_s B_d, int dimM, int dimN, real scale) {
Z
zhangjinchao01 已提交
1140 1141 1142 1143 1144 1145 1146
  if (B_d->format == HL_SPARSE_CSR) {
    hl_matrix_csr_column_sum(A_d, B_d, dimM, dimN, scale);
  } else {
    LOG(FATAL) << "Not support CSC format error!";
  }
}

L
liaogang 已提交
1147 1148
void hl_matrix_csr_column_sum(
    real *A_d, hl_sparse_matrix_s B_d, int dimM, int dimN, real scale) {
Z
zhangjinchao01 已提交
1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(B_d);

  if (dimM <= 0 || dimN <= 0 || (B_d->rows != dimM || B_d->cols != dimN)) {
    LOG(FATAL) << "parameter dims error!";
  }

  hl_csr_matrix B_d2 = (hl_csr_matrix)(B_d->matrix);
  if ((B_d2->csr_val == NULL && B_d->type != HL_NO_VALUE) ||
      B_d2->csr_row == NULL || B_d2->csr_col == NULL) {
    LOG(FATAL) << "parameter B is null!";
  }

1162
  if (B_d->nnz == 0) return;
Z
zhangjinchao01 已提交
1163

1164
  int nnz = B_d->nnz;
Z
zhangjinchao01 已提交
1165 1166 1167 1168 1169 1170 1171 1172
  int block = 512;
  int grid = DIVUP(nnz, 512);
  KeSMatrixCsrColumnSum<<<grid, block, 0, STREAM_DEFAULT>>>(
      A_d, B_d2->csr_val, B_d2->csr_col, nnz);

  CHECK_SYNC("hl_matrix_csr_column_sum failed");
}

L
liaogang 已提交
1173
void hl_sparse_matrix_add_bias(hl_sparse_matrix_s A_d, real *B_d, real scale) {
Z
zhangjinchao01 已提交
1174 1175 1176 1177 1178 1179 1180
  if (A_d->format == HL_SPARSE_CSR) {
    hl_matrix_csr_add_bias(A_d, B_d, scale);
  } else {
    LOG(FATAL) << "Not support CSC format error!";
  }
}

L
liaogang 已提交
1181
void hl_matrix_csr_add_bias(hl_sparse_matrix_s A_d, real *B_d, real scale) {
Z
zhangjinchao01 已提交
1182 1183 1184 1185 1186 1187 1188 1189 1190
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(B_d);

  hl_csr_matrix A_d2 = (hl_csr_matrix)(A_d->matrix);
  if ((A_d2->csr_val == NULL && A_d->type != HL_NO_VALUE) ||
      A_d2->csr_row == NULL || A_d2->csr_col == NULL) {
    LOG(FATAL) << "parameter A_d is null!";
  }

1191
  if (A_d->nnz == 0) return;
Z
zhangjinchao01 已提交
1192

1193
  int nnz = A_d->nnz;
Z
zhangjinchao01 已提交
1194 1195 1196 1197 1198 1199 1200 1201
  int block = 512;
  int grid = DIVUP(nnz, 512);
  KeSMatrixCsrAddBias<<<grid, block, 0, STREAM_DEFAULT>>>(
      A_d2->csr_val, A_d2->csr_col, B_d, scale, nnz);

  CHECK_SYNC("hl_sparse_matrix_add_bias failed");
}

L
liaogang 已提交
1202 1203 1204 1205 1206 1207
void hl_sparse_matrix_add_dense(hl_sparse_matrix_s A_d,
                                real *B_d,
                                int dimM,
                                int dimN,
                                real alpha,
                                real beta) {
Z
zhangjinchao01 已提交
1208 1209 1210 1211 1212 1213 1214
  if (A_d->format == HL_SPARSE_CSR) {
    hl_matrix_csr_add_dense(A_d, B_d, dimM, dimN, alpha, beta);
  } else {
    LOG(FATAL) << "Not support CSC format error!";
  }
}

L
liaogang 已提交
1215 1216 1217 1218 1219 1220
void hl_matrix_csr_add_dense(hl_sparse_matrix_s A_d,
                             real *B_d,
                             int dimM,
                             int dimN,
                             real alpha,
                             real beta) {
Z
zhangjinchao01 已提交
1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(B_d);

  if (dimM <= 0 || dimN <= 0 || A_d->rows != dimM || A_d->cols != dimN) {
    LOG(FATAL) << "parameter dim error!";
  }

  hl_csr_matrix A_d2 = (hl_csr_matrix)(A_d->matrix);
  if ((A_d2->csr_val == NULL && A_d->type != HL_NO_VALUE) ||
      A_d2->csr_row == NULL || A_d2->csr_col == NULL) {
    LOG(FATAL) << "parameter A_d is null!";
  }

1234
  if (A_d->nnz == 0) return;
Z
zhangjinchao01 已提交
1235

1236
  int gridX = DIVUP((A_d->nnz / dimM), 512);
Z
zhangjinchao01 已提交
1237 1238 1239
  gridX = gridX > 0 ? gridX : 1;
  dim3 block(512, 1);
  dim3 grid(gridX, dimM);
L
liaogang 已提交
1240 1241 1242 1243 1244 1245 1246 1247
  KeSMatrixCsrAddDense<<<grid, block, 0, STREAM_DEFAULT>>>(A_d2->csr_val,
                                                           A_d2->csr_row,
                                                           A_d2->csr_col,
                                                           B_d,
                                                           alpha,
                                                           beta,
                                                           dimM,
                                                           dimN);
Z
zhangjinchao01 已提交
1248 1249 1250 1251

  CHECK_SYNC("hl_sparse_matrix_add_dense failed");
}

L
liaogang 已提交
1252
int *hl_sparse_matrix_get_rows(hl_sparse_matrix_s sMat) {
Z
zhangjinchao01 已提交
1253 1254 1255
  __sparse_get_return__(sMat, row);
}

L
liaogang 已提交
1256
int *hl_sparse_matrix_get_cols(hl_sparse_matrix_s sMat) {
Z
zhangjinchao01 已提交
1257 1258 1259
  __sparse_get_return__(sMat, col);
}

L
liaogang 已提交
1260
real *hl_sparse_matrix_get_value(hl_sparse_matrix_s sMat) {
Z
zhangjinchao01 已提交
1261 1262
  __sparse_get_return__(sMat, val);
}