提交 e63a8b64 编写于 作者: L liujuncheng

Add math op support



Former-commit-id: 4df0bbebe027306da6bd14f41dbe01c0f3f3781e
上级 c9a4c9e2
......@@ -66,6 +66,19 @@ void Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE tra
ldc);
}
template<>
void Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans_a,
enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const half* alpha,
const half* a, const half* b, const half* beta, half* c) {
const float alpha_f = __half2float(*alpha);
const float beta_f = __half2float(*beta);
OF_CUBLAS_CHECK(cublasGemmEx(ctx->cublas_tensor_op_math_handle(), CblasTrans2CublasTrans(trans_a),
CblasTrans2CublasTrans(trans_b), m, n, k, &alpha_f, a, CUDA_R_16F,
(trans_a == CblasNoTrans) ? m : k, b, CUDA_R_16F,
(trans_b == CblasNoTrans) ? k : n, &beta_f, c, CUDA_R_16F, m,
CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP));
}
void HGemmWithFloat(DeviceCtx* ctx, const enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans_a,
enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k,
const float* alpha, const half* a, const half* b, const float* beta, half* c) {
......@@ -176,6 +189,34 @@ void BatchedGemmImpl(DeviceCtx* ctx, const enum CBLAS_ORDER order,
#endif
}
#if CUDA_VERSION >= 9010
template<>
void BatchedGemmImpl(DeviceCtx* ctx, const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b,
int batch_size, int m, int n, int k, const half* alpha, const half* a,
const half* b, const half* beta, half* c, half** buf) {
float alpha_f = __half2float(*alpha);
float beta_f = __half2float(*beta);
int a_stride, b_stride, c_stride;
int lda, ldb, ldc;
cublasOperation_t cublas_trans_a, cublas_trans_b;
half** dev_a_ptrs;
half** dev_b_ptrs;
half** dev_c_ptrs;
std::tie(a_stride, b_stride, c_stride, lda, ldb, ldc, cublas_trans_a, cublas_trans_b, dev_a_ptrs,
dev_b_ptrs, dev_c_ptrs) =
PrepareToCallBatchedGemm<half>(ctx, trans_a, trans_b, batch_size, m, n, k, a, b, c, buf);
OF_CUBLAS_CHECK(cublasGemmBatchedEx(
ctx->cublas_tensor_op_math_handle(), CblasTrans2CublasTrans(trans_a),
CblasTrans2CublasTrans(trans_b), m, n, k, &alpha_f,
reinterpret_cast<const void**>(const_cast<const half**>(dev_a_ptrs)), CUDA_R_16F,
(trans_a == CblasNoTrans) ? m : k,
reinterpret_cast<const void**>(const_cast<const half**>(dev_b_ptrs)), CUDA_R_16F,
(trans_b == CblasNoTrans) ? k : n, &beta_f, reinterpret_cast<void**>(dev_c_ptrs), CUDA_R_16F,
m, batch_size, CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP));
}
#endif
void BatchedHGemmWithFloatImpl(DeviceCtx* ctx, const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE trans_a,
const enum CBLAS_TRANSPOSE trans_b, int batch_size, int m, int n,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册