Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
26475cd9
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
26475cd9
编写于
8月 15, 2017
作者:
D
dangqingqing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use clipping log in cuda kernel, making it same with CPU.
上级
6f7a8260
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
23 addition
and
7 deletion
+23
-7
paddle/operators/cross_entropy_op.cu
paddle/operators/cross_entropy_op.cu
+17
-2
paddle/operators/cross_entropy_op.h
paddle/operators/cross_entropy_op.h
+2
-1
python/paddle/v2/framework/tests/op_test_util.py
python/paddle/v2/framework/tests/op_test_util.py
+2
-1
python/paddle/v2/framework/tests/test_cross_entropy_op.py
python/paddle/v2/framework/tests/test_cross_entropy_op.py
+2
-3
未找到文件。
paddle/operators/cross_entropy_op.cu
浏览文件 @
26475cd9
...
...
@@ -20,6 +20,21 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
struct
clipping_log
{
__host__
__device__
T
operator
()(
const
T
x
)
{
PADDLE_ASSERT
(
std
::
is_floating_point
<
T
>::
value
);
const
T
kApproInf
=
1e20
;
if
(
x
==
INFINITY
)
{
return
kApproInf
;
}
if
(
x
==
-
INFINITY
)
{
return
-
kApproInf
;
}
return
x
;
}
};
template
<
typename
T
>
__global__
void
CrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
int
*
label
,
const
int
N
,
const
int
D
)
{
...
...
@@ -28,10 +43,11 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
PADDLE_ASSERT
(
label
[
i
]
>=
0
&&
label
[
i
]
<
D
);
Y
[
i
]
=
-
log
(
X
[
i
*
D
+
label
[
i
]]);
Y
[
i
]
=
-
clipping_log
<
T
>
()
(
X
[
i
*
D
+
label
[
i
]]);
}
}
// TODO(qingqing): make zero setting an common function.
template
<
typename
T
>
__global__
void
zero
(
T
*
X
,
const
int
N
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
...
...
@@ -98,7 +114,6 @@ class OnehotCrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
int
D
=
X
->
dims
()[
1
];
int
block
=
512
;
int
grid
=
(
N
*
D
+
block
-
1
)
/
block
;
// TODO(qingqing): make zero an common function.
zero
<
T
><<<
grid
,
block
>>>
(
dXdata
,
N
*
D
);
grid
=
(
N
+
block
-
1
)
/
block
;
...
...
paddle/operators/cross_entropy_op.h
浏览文件 @
26475cd9
...
...
@@ -21,7 +21,7 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
T
tolerable_value
(
T
x
)
{
T
tolerable_value
(
const
T
x
)
{
static_assert
(
std
::
is_floating_point
<
T
>::
value
,
"tolerable_value works only on float, "
"double and double double."
);
...
...
@@ -85,6 +85,7 @@ class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel {
const
int
batch_size
=
X
->
dims
()[
0
];
const
int
class_num
=
X
->
dims
()[
1
];
// TODO(qingqing): make zero setting an common function.
memset
(
dXdata
,
0
,
sizeof
(
T
)
*
batch_size
*
class_num
);
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
int
index
=
i
*
class_num
+
label_data
[
i
];
...
...
python/paddle/v2/framework/tests/op_test_util.py
浏览文件 @
26475cd9
...
...
@@ -64,7 +64,8 @@ class OpTestMeta(type):
actual
=
numpy
.
array
(
scope
.
find_var
(
out_name
).
get_tensor
())
expect
=
self
.
outputs
[
out_name
]
self
.
assertTrue
(
numpy
.
allclose
(
actual
,
expect
),
numpy
.
allclose
(
actual
,
expect
,
atol
=
1e-04
),
"output name: "
+
out_name
+
"has diff"
)
obj
.
test_all
=
test_all
...
...
python/paddle/v2/framework/tests/test_cross_entropy_op.py
浏览文件 @
26475cd9
...
...
@@ -8,9 +8,8 @@ class TestCrossEntropy(unittest.TestCase):
__metaclass__
=
OpTestMeta
def
setUp
(
self
):
# TODO this unit test is not passed
self
.
type
=
"onehot_cross_entropy"
batch_size
=
10
0
batch_size
=
3
0
class_num
=
10
X
=
numpy
.
random
.
random
((
batch_size
,
class_num
)).
astype
(
"float32"
)
label
=
5
*
numpy
.
ones
(
batch_size
).
astype
(
"int32"
)
...
...
@@ -24,7 +23,7 @@ class TestCrossEntropy(unittest.TestCase):
class
CrossEntropyGradOpTest
(
GradientChecker
):
def
test_check_grad
(
self
):
op
=
create_op
(
"onehot_cross_entropy"
)
batch_size
=
10
0
batch_size
=
3
0
class_num
=
10
inputs
=
{
"X"
:
numpy
.
random
.
uniform
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录