Matrix.cpp 1.7 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10
#include "PaddleCAPI.h"
#include "PaddleCAPIPrivate.h"
#include "hl_cuda.h"

#define cast(v) paddle::capi::cast<paddle::capi::CMatrix>(v)
extern "C" {
int PDMatCreate(PD_Matrix* mat, uint64_t height, uint64_t width, bool useGpu) {
  auto ptr = new paddle::capi::CMatrix();
  ptr->mat = paddle::Matrix::create(height, width, false, useGpu);
  *mat = ptr;
Y
Yu Yang 已提交
11
  return kPD_NO_ERROR;
Y
Yu Yang 已提交
12 13 14 15 16
}

int PDMatCreateNone(PD_Matrix* mat) {
  auto ptr = new paddle::capi::CMatrix();
  *mat = ptr;
Y
Yu Yang 已提交
17
  return kPD_NO_ERROR;
Y
Yu Yang 已提交
18 19 20
}

int PDMatDestroy(PD_Matrix mat) {
Y
Yu Yang 已提交
21
  if (mat == nullptr) return kPD_NULLPTR;
Y
Yu Yang 已提交
22 23
  auto ptr = cast(mat);
  delete ptr;
Y
Yu Yang 已提交
24
  return kPD_NO_ERROR;
Y
Yu Yang 已提交
25 26 27
}

int PDMatCopyToRow(PD_Matrix mat, uint64_t rowID, pd_real* rowArray) {
Y
Yu Yang 已提交
28
  if (mat == nullptr) return kPD_NULLPTR;
Y
Yu Yang 已提交
29
  auto ptr = cast(mat);
Y
Yu Yang 已提交
30 31
  if (ptr->mat == nullptr) return kPD_NULLPTR;
  if (rowID >= ptr->mat->getHeight()) return kPD_OUT_OF_RANGE;
Y
Yu Yang 已提交
32 33 34 35 36 37 38
  paddle::real* buf = ptr->mat->getRowBuf(rowID);
  size_t width = ptr->mat->getWidth();
#ifndef PADDLE_ONLY_CPU
  hl_memcpy(buf, rowArray, sizeof(paddle::real) * width);
#else
  std::copy(rowArray, rowArray + width, buf);
#endif
Y
Yu Yang 已提交
39
  return kPD_NO_ERROR;
Y
Yu Yang 已提交
40 41 42
}

int PDMatGetRow(PD_Matrix mat, uint64_t rowID, pd_real** rawRowBuffer) {
Y
Yu Yang 已提交
43
  if (mat == nullptr) return kPD_NULLPTR;
Y
Yu Yang 已提交
44
  auto ptr = cast(mat);
Y
Yu Yang 已提交
45 46
  if (ptr->mat == nullptr) return kPD_NULLPTR;
  if (rowID >= ptr->mat->getHeight()) return kPD_OUT_OF_RANGE;
Y
Yu Yang 已提交
47
  *rawRowBuffer = ptr->mat->getRowBuf(rowID);
Y
Yu Yang 已提交
48
  return kPD_NO_ERROR;
Y
Yu Yang 已提交
49 50 51
}

int PDMatGetShape(PD_Matrix mat, uint64_t* height, uint64_t* width) {
Y
Yu Yang 已提交
52
  if (mat == nullptr) return kPD_NULLPTR;
Y
Yu Yang 已提交
53 54 55 56 57 58
  if (height != nullptr) {
    *height = cast(mat)->mat->getHeight();
  }
  if (width != nullptr) {
    *width = cast(mat)->mat->getWidth();
  }
Y
Yu Yang 已提交
59
  return kPD_NO_ERROR;
Y
Yu Yang 已提交
60 61
}
}