Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
693de9f0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
693de9f0
编写于
12月 07, 2022
作者:
W
WangZhen
提交者:
GitHub
12月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix accuracy fp16 kernel return fp32 tensor error (#48803)
上级
93b7ccf5
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
6 addition
and
6 deletion
+6
-6
paddle/phi/kernels/gpu/accuracy_kernel.cu
paddle/phi/kernels/gpu/accuracy_kernel.cu
+6
-6
未找到文件。
paddle/phi/kernels/gpu/accuracy_kernel.cu
浏览文件 @
693de9f0
...
...
@@ -26,13 +26,13 @@
namespace
phi
{
using
phi
::
PADDLE_CUDA_NUM_THREADS
;
template
<
int
BlockSize
>
template
<
int
BlockSize
,
typename
T
>
__global__
void
AccuracyCudaKernel
(
const
int
N
,
const
int
D
,
const
int64_t
*
Xdata
,
const
int64_t
*
labeldata
,
int
*
correct_data
,
float
*
accuracy
,
T
*
accuracy
,
int
*
total_data
)
{
int
count
=
0
;
__shared__
int
total
[
BlockSize
];
...
...
@@ -64,7 +64,7 @@ __global__ void AccuracyCudaKernel(const int N,
#endif
if
(
threadIdx
.
x
==
0
)
{
*
correct_data
=
result
;
*
accuracy
=
static_cast
<
float
>
(
result
)
/
static_cast
<
float
>
(
N
);
*
accuracy
=
static_cast
<
T
>
(
result
)
/
static_cast
<
T
>
(
N
);
*
total_data
=
N
;
}
}
...
...
@@ -84,18 +84,18 @@ void AccuracyRawKernel(const Context& dev_ctx,
int
*
correct_data
=
dev_ctx
.
template
Alloc
<
int
>(
correct
);
int
*
total_data
=
dev_ctx
.
template
Alloc
<
int
>(
total
);
float
*
accuracy_data
=
dev_ctx
.
template
Alloc
<
float
>(
accuracy
);
T
*
accuracy_data
=
dev_ctx
.
template
Alloc
<
T
>(
accuracy
);
int
num_samples
=
static_cast
<
int
>
(
inference
.
dims
()[
0
]);
size_t
infer_width
=
inference
.
dims
()[
1
];
auto
stream
=
dev_ctx
.
stream
();
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
accuracy_data
,
0
,
sizeof
(
float
),
stream
);
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
accuracy_data
,
0
,
sizeof
(
T
),
stream
);
if
(
num_samples
==
0
)
{
return
;
}
AccuracyCudaKernel
<
PADDLE_CUDA_NUM_THREADS
>
AccuracyCudaKernel
<
PADDLE_CUDA_NUM_THREADS
,
T
>
<<<
1
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
num_samples
,
infer_width
,
indices_data
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录