Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
8f6c8780
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8f6c8780
编写于
8月 19, 2017
作者:
D
dangqingqing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Replace functor by function.
上级
70285cce
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
14 addition
and
15 deletion
+14
-15
paddle/operators/cross_entropy_op.cu
paddle/operators/cross_entropy_op.cu
+12
-13
paddle/operators/cross_entropy_op.h
paddle/operators/cross_entropy_op.h
+1
-1
python/paddle/v2/framework/tests/op_test_util.py
python/paddle/v2/framework/tests/op_test_util.py
+1
-1
未找到文件。
paddle/operators/cross_entropy_op.cu
浏览文件 @
8f6c8780
...
...
@@ -21,19 +21,18 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
struct
clipping_log
{
__host__
__device__
T
operator
()(
const
T
x
)
{
PADDLE_ASSERT
(
std
::
is_floating_point
<
T
>::
value
);
const
T
kApproInf
=
1e20
;
if
(
x
==
INFINITY
)
{
return
kApproInf
;
}
if
(
x
==
-
INFINITY
)
{
return
-
kApproInf
;
}
return
x
;
__host__
__device__
T
clipping_log
(
const
T
x
)
{
PADDLE_ASSERT
(
std
::
is_floating_point
<
T
>::
value
);
const
T
kApproInf
=
1e20
;
T
v
=
log
(
x
);
if
(
v
==
INFINITY
)
{
return
kApproInf
;
}
};
if
(
v
==
-
INFINITY
)
{
return
-
kApproInf
;
}
return
v
;
}
template
<
typename
T
>
__global__
void
CrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
int
*
label
,
...
...
@@ -43,7 +42,7 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
PADDLE_ASSERT
(
label
[
i
]
>=
0
&&
label
[
i
]
<
D
);
Y
[
i
]
=
-
clipping_log
<
T
>
()
(
X
[
i
*
D
+
label
[
i
]]);
Y
[
i
]
=
-
clipping_log
(
X
[
i
*
D
+
label
[
i
]]);
}
}
...
...
paddle/operators/cross_entropy_op.h
浏览文件 @
8f6c8780
...
...
@@ -21,7 +21,7 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
T
tolerable_value
(
const
T
x
)
{
inline
T
tolerable_value
(
const
T
x
)
{
static_assert
(
std
::
is_floating_point
<
T
>::
value
,
"tolerable_value works only on float, "
"double and double double."
);
...
...
python/paddle/v2/framework/tests/op_test_util.py
浏览文件 @
8f6c8780
...
...
@@ -65,7 +65,7 @@ class OpTestMeta(type):
expect
=
self
.
outputs
[
out_name
]
self
.
assertTrue
(
numpy
.
allclose
(
actual
,
expect
,
atol
=
1e-0
4
),
actual
,
expect
,
atol
=
1e-0
5
),
"output name: "
+
out_name
+
"has diff"
)
obj
.
test_all
=
test_all
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录