Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
a3a8a090
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a3a8a090
编写于
9月 20, 2017
作者:
C
caoying03
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize cross entropy kernel by using reduce.
上级
414a7a1e
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
41 addition
and
22 deletion
+41
-22
paddle/operators/cross_entropy_op.cu
paddle/operators/cross_entropy_op.cu
+27
-9
python/paddle/v2/framework/tests/test_cross_entropy_op.py
python/paddle/v2/framework/tests/test_cross_entropy_op.py
+14
-13
未找到文件。
paddle/operators/cross_entropy_op.cu
浏览文件 @
a3a8a090
...
...
@@ -32,16 +32,33 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
}
}
template
<
typename
T
>
template
<
typename
T
,
int
blockSize
>
__global__
void
SoftCrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
T
*
label
,
const
int
N
,
const
int
D
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
sum
=
static_cast
<
T
>
(
0
);
for
(
int
j
=
0
;
j
<
D
;
j
++
)
{
sum
+=
label
[
i
*
D
+
j
]
*
tolerable_value
(
log
(
X
[
i
*
D
+
j
]));
int
tid
=
threadIdx
.
x
;
__shared__
T
d_sum
[
blockSize
];
int
next_idx
=
blockIdx
.
x
*
D
+
tid
;
d_sum
[
tid
]
=
0
;
int
cur_idx
=
tid
;
while
(
cur_idx
<
D
)
{
d_sum
[
tid
]
+=
tolerable_value
(
std
::
log
(
X
[
next_idx
]))
*
label
[
next_idx
];
next_idx
+=
blockSize
;
cur_idx
+=
blockSize
;
}
__syncthreads
();
for
(
int
stride
=
blockSize
>>
1
;
stride
>
0
;
stride
>>=
1
)
{
__syncthreads
();
if
(
tid
<
stride
)
{
next_idx
=
tid
+
stride
;
d_sum
[
tid
]
+=
d_sum
[
next_idx
];
}
Y
[
i
]
=
-
sum
;
}
__syncthreads
();
if
(
tid
==
0
)
{
Y
[
blockIdx
.
x
]
=
-
d_sum
[
0
];
}
}
...
...
@@ -104,8 +121,9 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
// base on ExecutionContext.
if
(
ctx
.
Attr
<
int
>
(
"soft_label"
)
==
1
)
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
T
>
();
SoftCrossEntropyKernel
<
T
><<<
grid
,
block
>>>
(
y_data
,
x_data
,
label_data
,
n
,
d
);
grid
=
d
;
SoftCrossEntropyKernel
<
T
,
512
><<<
grid
,
block
>>>
(
y_data
,
x_data
,
label_data
,
n
,
d
);
}
else
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
int
>
();
CrossEntropyKernel
<
T
><<<
grid
,
block
>>>
(
y_data
,
x_data
,
label_data
,
n
,
d
);
...
...
python/paddle/v2/framework/tests/test_cross_entropy_op.py
浏览文件 @
a3a8a090
...
...
@@ -19,7 +19,7 @@ class TestCrossEntropyOp1(OpTest):
dtype
=
"float32"
)
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
'soft_label'
:
0
}
self
.
attrs
=
{
"soft_label"
:
0
}
def
test_check_output
(
self
):
self
.
check_output
()
...
...
@@ -34,8 +34,8 @@ class TestCrossEntropyOp2(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"cross_entropy"
batch_size
=
1
0
class_num
=
5
batch_size
=
1
3
class_num
=
37
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label
=
np
.
random
.
uniform
(
0.1
,
1.0
,
...
...
@@ -43,15 +43,16 @@ class TestCrossEntropyOp2(OpTest):
label
/=
label
.
sum
(
axis
=
1
,
keepdims
=
True
)
cross_entropy
=
(
-
label
*
np
.
log
(
X
)).
sum
(
axis
=
1
,
keepdims
=
True
).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
X
,
'Label'
:
label
}
self
.
outputs
=
{
'Y'
:
cross_entropy
}
self
.
attrs
=
{
'soft_label'
:
1
}
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
"soft_label"
:
1
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Y'
)
self
.
check_grad
([
"X"
],
"Y"
,
max_relative_error
=
0.05
)
class
TestCrossEntropyOp3
(
OpTest
):
...
...
@@ -61,8 +62,8 @@ class TestCrossEntropyOp3(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"cross_entropy"
batch_size
=
30
class_num
=
10
batch_size
=
13
class_num
=
37
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label_index
=
np
.
random
.
randint
(
...
...
@@ -74,15 +75,15 @@ class TestCrossEntropyOp3(OpTest):
dtype
=
"float32"
)
cross_entropy2
=
(
-
label
*
np
.
log
(
X
)).
sum
(
axis
=
1
,
keepdims
=
True
).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
X
,
'Label'
:
label
}
self
.
outputs
=
{
'Y'
:
cross_entropy
}
self
.
attrs
=
{
'soft_label'
:
1
}
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
"soft_label"
:
1
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Y'
)
self
.
check_grad
([
"X"
],
"Y"
,
max_relative_error
=
0.05
)
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录