提交 e427cc49 编写于 作者: S Shiyuan Shang-Guan 提交者: Will Zhang

cublas_template (#152)

* add cblas_template

* update cblas_gemm

* add CBLAS_ORDER

* update blas

* update dot product

* add cublas_template
上级 578540b1
......@@ -2,48 +2,128 @@
namespace oneflow {
// level 1 vector and vector
// dot product
template<>
void cublas_gemm<float>(
const cublasHandle_t& cublas_handle, const cublasOperation_t cuTransA,
const cublasOperation_t cuTransB, const int M, const int N, const int K,
const float* alpha, const float* A, const int lda, const float* B,
const int ldb, const float* beta, float* C, const int ldc) {
CHECK_EQ(cublasSgemm(
cublas_handle, cuTransA, cuTransB, M, N, K, alpha, A, lda, B,
ldb, beta, C, ldc),
void cublas_dot<float>(
cublasHandle_t handle, int n, const float* x, int incx, const float* y,
int incy, float* result) {
CHECK_EQ(cublasSdot(handle, n, x, incx, y, incy, result),
CUBLAS_STATUS_SUCCESS);
}
template<>
void cublas_gemm<double>(
const cublasHandle_t& cublas_handle, const cublasOperation_t cuTransA,
const cublasOperation_t cuTransB, const int M, const int N, const int K,
const double* alpha, const double* A, const int lda, const double* B,
const int ldb, const double* beta, double* C, const int ldc) {
CHECK_EQ(cublasDgemm(
cublas_handle, cuTransA, cuTransB, M, N, K, alpha, A, lda, B,
ldb, beta, C, ldc),
void cublas_dot<double>(
cublasHandle_t handle, int n, const double* x, int incx, const double* y,
int incy, double* result) {
CHECK_EQ(cublasDdot(handle, n, x, incx, y, incy, result),
CUBLAS_STATUS_SUCCESS);
}
// swap x and y
template<>
void cublas_swap<float>(
cublasHandle_t handle, int n, float* x, int incx, float* y, int incy) {
CHECK_EQ(cublasSswap(handle, n, x, incx, y, incy), CUBLAS_STATUS_SUCCESS);
}
template<>
void cublas_swap<double>(
cublasHandle_t handle, int n, double* x, int incx, double* y, int incy) {
CHECK_EQ(cublasDswap(handle, n, x, incx, y, incy), CUBLAS_STATUS_SUCCESS);
}
// copy x into y
template<>
void cublas_copy<float>(
cublasHandle_t handle, int n, const float* x, int incx,
float* y, int incy) {
CHECK_EQ(cublasScopy(handle, n, x, incx, y, incy), CUBLAS_STATUS_SUCCESS);
}
template<>
void cublas_copy<double>(
cublasHandle_t handle, int n, const double* x, int incx,
double* y, int incy) {
CHECK_EQ(cublasDcopy(handle, n, x, incx, y, incy), CUBLAS_STATUS_SUCCESS);
}
// y = a*x + y
template<>
void cublas_axpy<float>(
cublasHandle_t handle, int n,
const float *alpha,
const float *x, int incx,
float *y, int incy) {
cublasHandle_t handle, int n, const float* alpha, const float* x, int incx,
float* y, int incy) {
CHECK_EQ(cublasSaxpy(handle, n, alpha, x, incx, y, incy),
CUBLAS_STATUS_SUCCESS);
}
template<>
void cublas_axpy<double>(
cublasHandle_t handle, int n,
const double *alpha,
const double *x, int incx,
double *y, int incy) {
cublasHandle_t handle, int n, const double* alpha, const double* x, int incx,
double* y, int incy) {
CHECK_EQ(cublasDaxpy(handle, n, alpha, x, incx, y, incy),
CUBLAS_STATUS_SUCCESS);
CUBLAS_STATUS_SUCCESS);
}
// x = a*x
template<>
void cublas_scal<float>(
cublasHandle_t handle, int n, const float* alpha, float* x, int incx) {
CHECK_EQ(cublasSscal(handle, n, alpha, x, incx), CUBLAS_STATUS_SUCCESS);
}
template<>
void cublas_scal<double>(
cublasHandle_t handle, int n, const double* alpha, double* x, int incx) {
CHECK_EQ(cublasDscal(handle, n, alpha, x, incx), CUBLAS_STATUS_SUCCESS);
}
// level 2 matrix and vector
// matrix vector multiply
template<>
void cublas_gemv<float>(
cublasHandle_t handle, cublasOperation_t trans, int m, int n,
const float* alpha, const float* A, int lda, const float* x, int incx,
const float* beta, float* y, int incy) {
CHECK_EQ(cublasSgemv(
handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy),
CUBLAS_STATUS_SUCCESS);
}
template<>
void cublas_gemv<double>(
cublasHandle_t handle, cublasOperation_t trans, int m, int n,
const double* alpha, const double* A, int lda, const double* x, int incx,
const double* beta, double* y, int incy) {
CHECK_EQ(cublasDgemv(
handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy),
CUBLAS_STATUS_SUCCESS);
}
// level 3 matrix and matrix
// matrix matrix multiply
template<>
void cublas_gemm<float>(
cublasHandle_t handle, cublasOperation_t cuTransA,
cublasOperation_t cuTransB, int m, int n, int k,
const float* alpha, const float* A, int lda,
const float* B, int ldb, const float* beta, float* C, int ldc) {
CHECK_EQ(cublasSgemm(
handle, cuTransA, cuTransB, m, n, k, alpha, A, lda, B, ldb, beta,
C, ldc),
CUBLAS_STATUS_SUCCESS);
}
template<>
void cublas_gemm<double>(
cublasHandle_t handle, cublasOperation_t cuTransA,
cublasOperation_t cuTransB, int m, int n, int k,
const double* alpha, const double* A, int lda,
const double* B, int ldb, const double* beta, double* C, int ldc) {
CHECK_EQ(cublasDgemm(
handle, cuTransA, cuTransB, m, n, k, alpha, A, lda, B, ldb, beta,
C, ldc),
CUBLAS_STATUS_SUCCESS);
}
} // namespace oneflow
......@@ -5,20 +5,59 @@
namespace oneflow {
// level 1 vector and vector
// dot product
template<typename floating_point_type>
void cublas_gemm(
const cublasHandle_t& cublas_handle, const cublasOperation_t cuTransA,
const cublasOperation_t cuTransB, const int M, const int N, const int K,
const floating_point_type* alpha, const floating_point_type* A,
const int lda, const floating_point_type* B, const int ldb,
const floating_point_type* beta, floating_point_type* C, const int ldc);
void cublas_dot(
cublasHandle_t handle, int n,
const floating_point_type* x, int incx,
const floating_point_type* y, int incy, floating_point_type* result);
// swap x and y
template<typename floating_point_type>
void cublas_swap(
cublasHandle_t handle, int n,
floating_point_type* x, int incx, floating_point_type* y, int incy);
// copy x into y
template<typename floating_point_type>
void cublas_copy(
cublasHandle_t handle, int n,
const floating_point_type* x, int incx,
floating_point_type* y, int incy);
// y = a*x + y
template<typename floating_point_type>
void cublas_axpy(
cublasHandle_t handle, int n,
const floating_point_type *alpha,
const floating_point_type *x, int incx,
floating_point_type *y, int incy);
const floating_point_type* alpha,
const floating_point_type* x, int incx,
floating_point_type* y, int incy);
// x = a*x
template<typename floating_point_type>
void cublas_scal(
cublasHandle_t handle, int n,
const floating_point_type* alpha, floating_point_type* x, int incx);
// level 2 matrix and vector
// matrix vector multiply
template<typename floating_point_type>
void cublas_gemv(
cublasHandle_t handle, cublasOperation_t trans, int m, int n,
const floating_point_type* alpha, const floating_point_type* A, int lda,
const floating_point_type* x, int incx, const floating_point_type* beta,
floating_point_type* y, int incy);
// level 3 matrix and matrix
// matrix matrix multiply
template<typename floating_point_type>
void cublas_gemm(
cublasHandle_t handle, cublasOperation_t cuTransA,
cublasOperation_t cuTransB, int m, int n, int k,
const floating_point_type* alpha, const floating_point_type* A, int lda,
const floating_point_type* B, int ldb,
const floating_point_type* beta, floating_point_type* C, int ldc);
} // namespace oneflow
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册