Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
2eeaaa7d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2eeaaa7d
编写于
2月 27, 2023
作者:
Y
Yiqun Liu
提交者:
GitHub
2月 27, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add PADDLE_THROW in ToCudaDataType and polish codes. (#50922)
上级
3669868d
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
10 addition
and
15 deletion
+10
-15
paddle/phi/backends/gpu/cuda/cuda_helper.h
paddle/phi/backends/gpu/cuda/cuda_helper.h
+7
-0
paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h
paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h
+3
-15
未找到文件。
paddle/phi/backends/gpu/cuda/cuda_helper.h
浏览文件 @
2eeaaa7d
...
...
@@ -18,7 +18,9 @@
#include <cuda_runtime.h> // NOLINT
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/enforce.h"
namespace
phi
{
namespace
backends
{
...
...
@@ -87,6 +89,11 @@ cudaDataType_t ToCudaDataType() {
}
else
if
(
std
::
is_same
<
T
,
phi
::
dtype
::
bfloat16
>::
value
)
{
return
CUDA_R_16BF
;
#endif
}
else
{
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"DataType %s is unsupported for CUDA."
,
paddle
::
experimental
::
DataTypeToString
(
paddle
::
experimental
::
CppTypeToDataType
<
T
>::
Type
())));
}
}
...
...
paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h
浏览文件 @
2eeaaa7d
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include <cuda_runtime_api.h>
#include "cuda.h" // NOLINT
#include "paddle/phi/backends/gpu/cuda/cuda_helper.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/autotune/cache.h"
#include "paddle/phi/kernels/autotune/gpu_timer.h"
...
...
@@ -27,19 +28,6 @@ namespace funcs {
enum
MatmulImplType
{
kImplWithCublas
=
1
,
kImplWithCublasLt
=
2
};
template
<
typename
T
>
cudaDataType_t
ConvertToCudaDataType
()
{
if
(
std
::
is_same
<
T
,
float
>::
value
)
{
return
CUDA_R_32F
;
}
else
if
(
std
::
is_same
<
T
,
double
>::
value
)
{
return
CUDA_R_64F
;
}
else
if
(
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
)
{
return
CUDA_R_16F
;
}
else
if
(
std
::
is_same
<
T
,
phi
::
dtype
::
bfloat16
>::
value
)
{
return
CUDA_R_16BF
;
}
}
template
<
typename
T
>
cublasComputeType_t
GetCudaComputeType
()
{
if
(
std
::
is_same
<
T
,
double
>::
value
)
{
...
...
@@ -68,8 +56,8 @@ struct MatmulDescriptor {
int64_t
stride_out
=
0
)
{
using
MT
=
typename
phi
::
dtype
::
MPTypeTrait
<
T
>::
Type
;
cudaDataType_t
mat_type
=
Convert
ToCudaDataType
<
T
>
();
cudaDataType_t
scale_type
=
Convert
ToCudaDataType
<
MT
>
();
cudaDataType_t
mat_type
=
phi
::
backends
::
gpu
::
ToCudaDataType
<
T
>
();
cudaDataType_t
scale_type
=
phi
::
backends
::
gpu
::
ToCudaDataType
<
MT
>
();
cublasComputeType_t
compute_type
=
GetCudaComputeType
<
T
>
();
// Create operation desciriptor; see cublasLtMatmulDescAttributes_t for
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录