Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
8580dce3
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8580dce3
编写于
9月 18, 2017
作者:
武
武毅
提交者:
GitHub
9月 18, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine accuracy_op CUDA kernel (#4097)
* refind accuracy_op * follow comments * follow comments
上级
59c48f98
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
35 addition
and
17 deletion
+35
-17
paddle/operators/accuracy_op.cu
paddle/operators/accuracy_op.cu
+25
-13
paddle/platform/cuda_helper.h
paddle/platform/cuda_helper.h
+5
-0
python/paddle/v2/framework/tests/test_accuracy_op.py
python/paddle/v2/framework/tests/test_accuracy_op.py
+5
-4
未找到文件。
paddle/operators/accuracy_op.cu
浏览文件 @
8580dce3
...
@@ -12,26 +12,38 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,26 +12,38 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <thrust/execution_policy.h>
#include <thrust/reduce.h>
#include "paddle/operators/accuracy_op.h"
#include "paddle/operators/accuracy_op.h"
#include "paddle/platform/cuda_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
platform
::
PADDLE_CUDA_NUM_THREADS
;
__global__
void
AccuracySingleKernel
(
const
int
N
,
const
int
D
,
const
int
top_k
,
template
<
int
BlockSize
>
const
int
*
Xdata
,
const
int
*
labelData
,
__global__
void
AccuracyCudaKernel
(
const
int
N
,
const
int
D
,
const
int
*
Xdata
,
float
*
accuracy
)
{
const
int
*
labeldata
,
float
*
accuracy
)
{
int
correct
=
0
;
int
count
=
0
;
for
(
int
row
=
0
;
row
<
N
;
row
++
)
{
__shared__
int
total
[
BlockSize
];
const
int
label
=
labelData
[
row
];
for
(
int
col
=
0
;
col
<
D
;
col
++
)
{
// support only 1 block
const
int
pred
=
Xdata
[
row
*
D
+
col
];
for
(
int
i
=
threadIdx
.
x
;
i
<
(
N
);
i
+=
BlockSize
)
{
if
(
pred
==
label
)
{
for
(
int
j
=
0
;
j
<
D
;
++
j
)
{
++
correct
;
if
(
Xdata
[
i
*
D
+
j
]
==
labeldata
[
i
])
{
++
count
;
break
;
break
;
}
}
}
}
}
}
*
accuracy
=
static_cast
<
float
>
(
correct
)
/
static_cast
<
float
>
(
N
);
total
[
threadIdx
.
x
]
=
count
;
__syncthreads
();
// reduce the count with init value 0, and output accuracy.
int
result
=
thrust
::
reduce
(
thrust
::
device
,
total
,
total
+
BlockSize
,
0
);
if
(
threadIdx
.
x
==
0
)
{
*
accuracy
=
static_cast
<
float
>
(
result
)
/
static_cast
<
float
>
(
N
);
}
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -57,8 +69,8 @@ class AccuracyOpCUDAKernel : public framework::OpKernel {
...
@@ -57,8 +69,8 @@ class AccuracyOpCUDAKernel : public framework::OpKernel {
return
;
return
;
}
}
Accuracy
SingleKernel
<<<
1
,
1
>>>
(
num_samples
,
infer_width
,
1
,
inference_data
,
Accuracy
CudaKernel
<
PADDLE_CUDA_NUM_THREADS
><<<
1
,
PADDLE_CUDA_NUM_THREADS
>>>
(
label_data
,
accuracy_data
);
num_samples
,
infer_width
,
inference_data
,
label_data
,
accuracy_data
);
}
}
};
};
...
...
paddle/platform/cuda_helper.h
浏览文件 @
8580dce3
...
@@ -24,6 +24,11 @@ namespace platform {
...
@@ -24,6 +24,11 @@ namespace platform {
#define USE_CUDA_ATOMIC(op, T) \
#define USE_CUDA_ATOMIC(op, T) \
CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }
CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }
// Default thread count per block(or block size).
// TODO(typhoonzero): need to benchmark against setting this value
// to 1024.
constexpr
int
PADDLE_CUDA_NUM_THREADS
=
512
;
// For atomicAdd.
// For atomicAdd.
USE_CUDA_ATOMIC
(
Add
,
float
);
USE_CUDA_ATOMIC
(
Add
,
float
);
...
...
python/paddle/v2/framework/tests/test_accuracy_op.py
浏览文件 @
8580dce3
...
@@ -6,16 +6,17 @@ from op_test import OpTest
...
@@ -6,16 +6,17 @@ from op_test import OpTest
class
TestAccuracyOp
(
OpTest
):
class
TestAccuracyOp
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"accuracy"
self
.
op_type
=
"accuracy"
infer
=
np
.
random
.
randint
(
0
,
2
,
(
32
,
1
)).
astype
(
"int"
)
n
=
8192
label
=
np
.
random
.
randint
(
0
,
2
,
(
32
,
)).
astype
(
"int"
)
infer
=
np
.
random
.
randint
(
0
,
2
,
(
n
,
1
)).
astype
(
"int"
)
label
=
np
.
random
.
randint
(
0
,
2
,
(
n
,
)).
astype
(
"int"
)
self
.
inputs
=
{
'Inference'
:
infer
,
"Label"
:
label
}
self
.
inputs
=
{
'Inference'
:
infer
,
"Label"
:
label
}
num_correct
=
0
num_correct
=
0
for
rowid
in
xrange
(
32
):
for
rowid
in
xrange
(
n
):
for
ele
in
infer
[
rowid
]:
for
ele
in
infer
[
rowid
]:
if
ele
==
label
[
rowid
]:
if
ele
==
label
[
rowid
]:
num_correct
+=
1
num_correct
+=
1
break
break
self
.
outputs
=
{
'Accuracy'
:
[
num_correct
/
32.0
]}
self
.
outputs
=
{
'Accuracy'
:
[
num_correct
/
float
(
n
)
]}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录