Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
e63a8b64
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
e63a8b64
编写于
12月 03, 2020
作者:
L
liujuncheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add math op support
Former-commit-id: 4df0bbebe027306da6bd14f41dbe01c0f3f3781e
上级
c9a4c9e2
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
41 addition
and
0 deletion
+41
-0
oneflow/core/kernel/util/cuda_blas_interface.cu
oneflow/core/kernel/util/cuda_blas_interface.cu
+41
-0
未找到文件。
oneflow/core/kernel/util/cuda_blas_interface.cu
浏览文件 @
e63a8b64
...
...
@@ -66,6 +66,19 @@ void Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE tra
ldc
);
}
template
<
>
void
Gemm
(
DeviceCtx
*
ctx
,
const
enum
CBLAS_ORDER
order
,
enum
CBLAS_TRANSPOSE
trans_a
,
enum
CBLAS_TRANSPOSE
trans_b
,
const
int
m
,
const
int
n
,
const
int
k
,
const
half
*
alpha
,
const
half
*
a
,
const
half
*
b
,
const
half
*
beta
,
half
*
c
)
{
const
float
alpha_f
=
__half2float
(
*
alpha
);
const
float
beta_f
=
__half2float
(
*
beta
);
OF_CUBLAS_CHECK
(
cublasGemmEx
(
ctx
->
cublas_tensor_op_math_handle
(),
CblasTrans2CublasTrans
(
trans_a
),
CblasTrans2CublasTrans
(
trans_b
),
m
,
n
,
k
,
&
alpha_f
,
a
,
CUDA_R_16F
,
(
trans_a
==
CblasNoTrans
)
?
m
:
k
,
b
,
CUDA_R_16F
,
(
trans_b
==
CblasNoTrans
)
?
k
:
n
,
&
beta_f
,
c
,
CUDA_R_16F
,
m
,
CUDA_R_32F
,
CUBLAS_GEMM_DFALT_TENSOR_OP
));
}
void
HGemmWithFloat
(
DeviceCtx
*
ctx
,
const
enum
CBLAS_ORDER
order
,
enum
CBLAS_TRANSPOSE
trans_a
,
enum
CBLAS_TRANSPOSE
trans_b
,
const
int
m
,
const
int
n
,
const
int
k
,
const
float
*
alpha
,
const
half
*
a
,
const
half
*
b
,
const
float
*
beta
,
half
*
c
)
{
...
...
@@ -176,6 +189,34 @@ void BatchedGemmImpl(DeviceCtx* ctx, const enum CBLAS_ORDER order,
#endif
}
#if CUDA_VERSION >= 9010
template
<
>
void
BatchedGemmImpl
(
DeviceCtx
*
ctx
,
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
trans_a
,
const
enum
CBLAS_TRANSPOSE
trans_b
,
int
batch_size
,
int
m
,
int
n
,
int
k
,
const
half
*
alpha
,
const
half
*
a
,
const
half
*
b
,
const
half
*
beta
,
half
*
c
,
half
**
buf
)
{
float
alpha_f
=
__half2float
(
*
alpha
);
float
beta_f
=
__half2float
(
*
beta
);
int
a_stride
,
b_stride
,
c_stride
;
int
lda
,
ldb
,
ldc
;
cublasOperation_t
cublas_trans_a
,
cublas_trans_b
;
half
**
dev_a_ptrs
;
half
**
dev_b_ptrs
;
half
**
dev_c_ptrs
;
std
::
tie
(
a_stride
,
b_stride
,
c_stride
,
lda
,
ldb
,
ldc
,
cublas_trans_a
,
cublas_trans_b
,
dev_a_ptrs
,
dev_b_ptrs
,
dev_c_ptrs
)
=
PrepareToCallBatchedGemm
<
half
>
(
ctx
,
trans_a
,
trans_b
,
batch_size
,
m
,
n
,
k
,
a
,
b
,
c
,
buf
);
OF_CUBLAS_CHECK
(
cublasGemmBatchedEx
(
ctx
->
cublas_tensor_op_math_handle
(),
CblasTrans2CublasTrans
(
trans_a
),
CblasTrans2CublasTrans
(
trans_b
),
m
,
n
,
k
,
&
alpha_f
,
reinterpret_cast
<
const
void
**>
(
const_cast
<
const
half
**>
(
dev_a_ptrs
)),
CUDA_R_16F
,
(
trans_a
==
CblasNoTrans
)
?
m
:
k
,
reinterpret_cast
<
const
void
**>
(
const_cast
<
const
half
**>
(
dev_b_ptrs
)),
CUDA_R_16F
,
(
trans_b
==
CblasNoTrans
)
?
k
:
n
,
&
beta_f
,
reinterpret_cast
<
void
**>
(
dev_c_ptrs
),
CUDA_R_16F
,
m
,
batch_size
,
CUDA_R_32F
,
CUBLAS_GEMM_DFALT_TENSOR_OP
));
}
#endif
void
BatchedHGemmWithFloatImpl
(
DeviceCtx
*
ctx
,
const
enum
CBLAS_ORDER
order
,
const
enum
CBLAS_TRANSPOSE
trans_a
,
const
enum
CBLAS_TRANSPOSE
trans_b
,
int
batch_size
,
int
m
,
int
n
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录