Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
201c2bcf
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
201c2bcf
编写于
9月 23, 2017
作者:
C
caoying03
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
delete redundant codes.
上级
6735585b
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
18 addition
and
77 deletion
+18
-77
paddle/operators/cross_entropy_op.cu
paddle/operators/cross_entropy_op.cu
+10
-45
python/paddle/v2/framework/tests/test_cross_entropy_op.py
python/paddle/v2/framework/tests/test_cross_entropy_op.py
+8
-32
未找到文件。
paddle/operators/cross_entropy_op.cu
浏览文件 @
201c2bcf
...
@@ -42,9 +42,8 @@ __device__ __forceinline__ T sum_single_warp(T val) {
...
@@ -42,9 +42,8 @@ __device__ __forceinline__ T sum_single_warp(T val) {
return
val
;
return
val
;
}
}
// This kernel is called when the class number is less than or equal to 512.
template
<
typename
T
>
template
<
typename
T
>
__global__
void
SoftCrossEntropyKernel
1
(
T
*
Y
,
const
T
*
X
,
const
T
*
label
,
__global__
void
SoftCrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
T
*
label
,
const
int
class_num
)
{
const
int
class_num
)
{
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
extern
__shared__
T
d_sum
[];
extern
__shared__
T
d_sum
[];
...
@@ -69,33 +68,6 @@ __global__ void SoftCrossEntropyKernel1(T* Y, const T* X, const T* label,
...
@@ -69,33 +68,6 @@ __global__ void SoftCrossEntropyKernel1(T* Y, const T* X, const T* label,
if
(
tid
==
0
)
Y
[
blockIdx
.
x
]
=
-
val
;
if
(
tid
==
0
)
Y
[
blockIdx
.
x
]
=
-
val
;
}
}
// This kernel is called when the class number is larger than 512.
template
<
typename
T
,
int
BlockSize
>
__global__
void
SoftCrossEntropyKernel2
(
T
*
Y
,
const
T
*
X
,
const
T
*
label
,
const
int
class_num
)
{
int
tid
=
threadIdx
.
x
;
__shared__
T
d_sum
[
BlockSize
];
int
next_idx
=
blockIdx
.
x
*
class_num
+
tid
;
d_sum
[
tid
]
=
0
;
int
cur_idx
=
tid
;
while
(
cur_idx
<
class_num
)
{
d_sum
[
tid
]
+=
TolerableValue
<
T
>
()(
std
::
log
(
X
[
next_idx
]))
*
label
[
next_idx
];
next_idx
+=
BlockSize
;
cur_idx
+=
BlockSize
;
}
__syncthreads
();
for
(
unsigned
int
stride
=
BlockSize
>>
1
;
stride
>=
32
;
stride
>>=
1
)
{
if
(
tid
<
stride
)
d_sum
[
tid
]
+=
d_sum
[
tid
+
stride
];
__syncthreads
();
}
T
val
=
d_sum
[
tid
];
val
=
sum_single_warp
<
T
>
(
val
);
if
(
tid
==
0
)
Y
[
blockIdx
.
x
]
=
-
val
;
}
// TODO(qingqing): make zero setting a common function.
// TODO(qingqing): make zero setting a common function.
template
<
typename
T
>
template
<
typename
T
>
__global__
void
zero
(
T
*
X
,
const
int
N
)
{
__global__
void
zero
(
T
*
X
,
const
int
N
)
{
...
@@ -146,26 +118,19 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
...
@@ -146,26 +118,19 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
int
batch_size
=
x
->
dims
()[
0
];
int
batch_size
=
x
->
dims
()[
0
];
int
class_num
=
x
->
dims
()[
1
];
int
class_num
=
x
->
dims
()[
1
];
int
block
=
512
;
if
(
ctx
.
Attr
<
bool
>
(
"soft_label"
))
{
if
(
ctx
.
Attr
<
bool
>
(
"soft_label"
))
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
T
>
();
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
T
>
();
if
(
class_num
>
512
)
{
int
block
=
class_num
>
512
?
512
:
pow
(
2
,
int
(
std
::
log2
(
class_num
)));
SoftCrossEntropyKernel2
<
T
,
512
><<<
batch_size
,
block
,
0
,
SoftCrossEntropyKernel
<
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
T
><<<
batch_size
,
block
,
block
*
sizeof
(
T
),
ctx
.
device_context
())
.
stream
()
>>>
(
y_data
,
x_data
,
label_data
,
class_num
);
}
else
{
int
block_size
=
pow
(
2
,
int
(
std
::
log2
(
class_num
)));
SoftCrossEntropyKernel1
<
T
><<<
batch_size
,
block_size
,
block_size
*
sizeof
(
T
),
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
ctx
.
device_context
())
.
stream
()
>>>
(
y_data
,
x_data
,
label_data
,
class_num
);
.
stream
()
>>>
(
y_data
,
x_data
,
label_data
,
class_num
);
}
}
else
{
}
else
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
int
>
();
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
int
>
();
int
block
=
512
;
int
grid
=
(
batch_size
+
block
-
1
)
/
block
;
int
grid
=
(
batch_size
+
block
-
1
)
/
block
;
CrossEntropyKernel
<
T
><<<
CrossEntropyKernel
<
T
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
...
...
python/paddle/v2/framework/tests/test_cross_entropy_op.py
浏览文件 @
201c2bcf
...
@@ -4,19 +4,21 @@ from op_test import OpTest
...
@@ -4,19 +4,21 @@ from op_test import OpTest
class
TestCrossEntropyOp1
(
OpTest
):
class
TestCrossEntropyOp1
(
OpTest
):
"""Test
standard cross-entropy, with index representation of
labels.
"""Test
cross-entropy with discrete one-hot
labels.
"""
"""
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"cross_entropy"
self
.
op_type
=
"cross_entropy"
batch_size
=
30
batch_size
=
30
class_num
=
10
class_num
=
10
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label
=
np
.
random
.
randint
(
0
,
class_num
,
(
batch_size
,
1
),
dtype
=
"int32"
)
label
=
np
.
random
.
randint
(
0
,
class_num
,
(
batch_size
,
1
),
dtype
=
"int32"
)
cross_entropy
=
np
.
asmatrix
(
cross_entropy
=
np
.
asmatrix
(
[[
-
np
.
log
(
X
[
i
][
label
[
i
][
0
]])]
for
i
in
range
(
X
.
shape
[
0
])],
[[
-
np
.
log
(
X
[
i
][
label
[
i
][
0
]])]
for
i
in
range
(
X
.
shape
[
0
])],
dtype
=
"float32"
)
dtype
=
"float32"
)
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
"soft_label"
:
False
}
self
.
attrs
=
{
"soft_label"
:
False
}
...
@@ -29,14 +31,14 @@ class TestCrossEntropyOp1(OpTest):
...
@@ -29,14 +31,14 @@ class TestCrossEntropyOp1(OpTest):
class
TestCrossEntropyOp2
(
OpTest
):
class
TestCrossEntropyOp2
(
OpTest
):
"""Test
soft-label cross-entropy, with vecte
rized soft labels.
"""Test
cross-entropy with vecto
rized soft labels.
"""
"""
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"cross_entropy"
self
.
op_type
=
"cross_entropy"
batch_size
=
5
batch_size
=
5
# this setting tests threads in more than one wrap.
class_num
=
37
class_num
=
37
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label
=
np
.
random
.
uniform
(
0.1
,
1.0
,
label
=
np
.
random
.
uniform
(
0.1
,
1.0
,
...
@@ -44,6 +46,7 @@ class TestCrossEntropyOp2(OpTest):
...
@@ -44,6 +46,7 @@ class TestCrossEntropyOp2(OpTest):
label
/=
label
.
sum
(
axis
=
1
,
keepdims
=
True
)
label
/=
label
.
sum
(
axis
=
1
,
keepdims
=
True
)
cross_entropy
=
(
-
label
*
np
.
log
(
X
)).
sum
(
cross_entropy
=
(
-
label
*
np
.
log
(
X
)).
sum
(
axis
=
1
,
keepdims
=
True
).
astype
(
"float32"
)
axis
=
1
,
keepdims
=
True
).
astype
(
"float32"
)
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
"soft_label"
:
True
}
self
.
attrs
=
{
"soft_label"
:
True
}
...
@@ -56,15 +59,14 @@ class TestCrossEntropyOp2(OpTest):
...
@@ -56,15 +59,14 @@ class TestCrossEntropyOp2(OpTest):
class
TestCrossEntropyOp3
(
OpTest
):
class
TestCrossEntropyOp3
(
OpTest
):
"""Test one-hot cross-entropy, with vecterized one-hot representation of
"""Test cross-entropy with vectorized one-hot representation of labels.
labels.
"""
"""
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"cross_entropy"
self
.
op_type
=
"cross_entropy"
batch_size
=
5
batch_size
=
5
# this setting tests all threads in one wrap.
class_num
=
17
class_num
=
17
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label_index
=
np
.
random
.
randint
(
label_index
=
np
.
random
.
randint
(
...
@@ -76,33 +78,7 @@ class TestCrossEntropyOp3(OpTest):
...
@@ -76,33 +78,7 @@ class TestCrossEntropyOp3(OpTest):
dtype
=
"float32"
)
dtype
=
"float32"
)
cross_entropy2
=
(
-
label
*
np
.
log
(
X
)).
sum
(
cross_entropy2
=
(
-
label
*
np
.
log
(
X
)).
sum
(
axis
=
1
,
keepdims
=
True
).
astype
(
"float32"
)
axis
=
1
,
keepdims
=
True
).
astype
(
"float32"
)
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
"soft_label"
:
True
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Y"
,
max_relative_error
=
0.05
)
class
TestCrossEntropyOp4
(
OpTest
):
"""Test soft-label cross-entropy.
This unittest tests the gpu kernel for layer size excesses 512.
"""
def
setUp
(
self
):
self
.
op_type
=
"cross_entropy"
batch_size
=
2
class_num
=
517
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label
/=
label
.
sum
(
axis
=
1
,
keepdims
=
True
)
cross_entropy
=
(
-
label
*
np
.
log
(
X
)).
sum
(
axis
=
1
,
keepdims
=
True
).
astype
(
"float32"
)
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
"soft_label"
:
True
}
self
.
attrs
=
{
"soft_label"
:
True
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录