Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
e97b8987
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e97b8987
编写于
11月 15, 2017
作者:
D
dzhwinter
提交者:
GitHub
11月 15, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"fix accuracy kernel bug" (#5673)
* "fix accuracy kernel bug" * "relauch ci"
上级
f95c291b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
21 addition
and
10 deletion
+21
-10
paddle/operators/accuracy_op.cu
paddle/operators/accuracy_op.cu
+13
-10
paddle/platform/gpu_info.cc
paddle/platform/gpu_info.cc
+5
-0
paddle/platform/gpu_info.h
paddle/platform/gpu_info.h
+3
-0
未找到文件。
paddle/operators/accuracy_op.cu
浏览文件 @
e97b8987
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <thrust/reduce.h>
#include "paddle/operators/accuracy_op.h"
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/gpu_info.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -73,26 +74,28 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
int
num_samples
=
static_cast
<
int
>
(
inference
->
dims
()[
0
]);
size_t
infer_width
=
inference
->
dims
()[
1
];
PADDLE_ENFORCE
(
cudaMemset
(
accuracy_data
,
0
,
sizeof
(
float
))
);
// cudaMemset((void**)&correct_data, 0, sizeof(float)
);
auto
stream
=
ctx
.
cuda_device_context
().
stream
(
);
platform
::
GpuMemsetAsync
(
accuracy_data
,
0
,
sizeof
(
float
),
stream
);
if
(
num_samples
==
0
)
{
return
;
}
cudaMemcpy
(
total_data
,
&
num_samples
,
sizeof
(
int
),
cudaMemcpyHostToDevice
);
platform
::
GpuMemcpyAsync
(
total_data
,
&
num_samples
,
sizeof
(
int
),
cudaMemcpyHostToDevice
,
stream
);
AccuracyCudaKernel
<
PADDLE_CUDA_NUM_THREADS
><<<
1
,
PADDLE_CUDA_NUM_THREADS
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
AccuracyCudaKernel
<
PADDLE_CUDA_NUM_THREADS
><<<
1
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
num_samples
,
infer_width
,
indices_data
,
label_data
,
correct_data
,
accuracy_data
);
int
d_num_samples
,
d_num_correct
;
float
d_accuracy
;
cudaMemcpy
(
&
d_num_correct
,
correct_data
,
sizeof
(
int
),
cudaMemcpyDeviceToHost
);
cudaMemcpy
(
&
d_num_samples
,
total_data
,
sizeof
(
int
),
cudaMemcpyDeviceToHost
);
cudaMemcpy
(
&
d_accuracy
,
accuracy_data
,
sizeof
(
float
),
cudaMemcpyDeviceToHost
);
platform
::
GpuMemcpyAsync
(
&
d_num_correct
,
correct_data
,
sizeof
(
int
),
cudaMemcpyDeviceToHost
,
stream
);
platform
::
GpuMemcpyAsync
(
&
d_num_samples
,
total_data
,
sizeof
(
int
),
cudaMemcpyDeviceToHost
,
stream
);
platform
::
GpuMemcpyAsync
(
&
d_accuracy
,
accuracy_data
,
sizeof
(
float
),
cudaMemcpyDeviceToHost
,
stream
);
}
};
...
...
paddle/platform/gpu_info.cc
浏览文件 @
e97b8987
...
...
@@ -109,5 +109,10 @@ void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device,
cudaMemcpyPeerAsync
(
dst
,
dst_device
,
src
,
src_device
,
count
,
stream
),
"cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeer"
);
}
void
GpuMemsetAsync
(
void
*
dst
,
int
value
,
size_t
count
,
cudaStream_t
stream
)
{
PADDLE_ENFORCE
(
cudaMemsetAsync
(
dst
,
value
,
count
,
stream
),
"cudaMemsetAsync failed in paddle::platform::GpuMemsetAsync"
);
}
}
// namespace platform
}
// namespace paddle
paddle/platform/gpu_info.h
浏览文件 @
e97b8987
...
...
@@ -60,6 +60,9 @@ void GpuMemcpySync(void *dst, const void *src, size_t count,
void
GpuMemcpyPeer
(
void
*
dst
,
int
dst_device
,
const
void
*
src
,
int
src_device
,
size_t
count
,
cudaStream_t
stream
);
//! Set memory dst with value count size asynchronously
void
GpuMemsetAsync
(
void
*
dst
,
int
value
,
size_t
count
,
cudaStream_t
stream
);
}
// namespace platform
}
// namespace paddle
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录