Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
79def5e6
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
79def5e6
编写于
9月 28, 2017
作者:
Q
qijun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine CrossEntropyFunctor
上级
c634a848
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
19 addition
and
34 deletion
+19
-34
paddle/operators/cross_entropy_op.cu
paddle/operators/cross_entropy_op.cu
+1
-13
paddle/operators/cross_entropy_op.h
paddle/operators/cross_entropy_op.h
+3
-3
paddle/operators/math/cross_entropy.cc
paddle/operators/math/cross_entropy.cc
+3
-3
paddle/operators/math/cross_entropy.cu
paddle/operators/math/cross_entropy.cu
+9
-11
paddle/operators/math/cross_entropy.h
paddle/operators/math/cross_entropy.h
+1
-3
paddle/operators/softmax_with_cross_entropy_op.h
paddle/operators/softmax_with_cross_entropy_op.h
+2
-1
未找到文件。
paddle/operators/cross_entropy_op.cu
浏览文件 @
79def5e6
...
...
@@ -18,14 +18,6 @@ namespace paddle {
namespace
operators
{
namespace
{
// TODO(qingqing): make zero setting a common function.
template
<
typename
T
>
__global__
void
Zero
(
T
*
X
,
const
int
N
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
X
[
i
]
=
0.0
;
}
}
template
<
typename
T
>
__global__
void
CrossEntropyGradientKernel
(
T
*
dX
,
const
T
*
dY
,
const
T
*
X
,
...
...
@@ -99,11 +91,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
.
stream
()
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
batch_size
,
class_num
);
}
else
{
Zero
<
T
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
dx_data
,
batch_size
*
class_num
);
math
::
SetConstant
<
platform
::
GPUPlace
,
T
>
(
ctx
.
device_context
(),
dx
,
0
);
auto
*
label_data
=
label
->
data
<
int
>
();
grid
=
(
batch_size
+
block
-
1
)
/
block
;
CrossEntropyGradientKernel
<
T
><<<
...
...
paddle/operators/cross_entropy_op.h
浏览文件 @
79def5e6
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/cross_entropy.h"
#include "paddle/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -37,7 +38,7 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> {
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
math
::
CrossEntropyFunctor
<
platform
::
CPUPlace
,
T
>
()(
ctx
,
y
,
x
,
labels
,
ctx
.
Attr
<
bool
>
(
"softLabel"
));
ctx
.
device_context
()
,
y
,
x
,
labels
,
ctx
.
Attr
<
bool
>
(
"softLabel"
));
}
};
...
...
@@ -69,8 +70,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
const
T
*
x_data
=
x
->
data
<
T
>
();
const
int
*
label_data
=
label
->
data
<
int
>
();
// TODO(qingqing): make zero setting a common function.
memset
(
dx_data
,
0
,
sizeof
(
T
)
*
batch_size
*
class_num
);
math
::
SetConstant
<
platform
::
CPUPlace
,
T
>
(
ctx
.
device_context
(),
dx
,
0
);
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
PADDLE_ASSERT
(
label_data
[
i
]
>=
0
||
label_data
[
i
]
<
class_num
);
...
...
paddle/operators/math/cross_entropy.cc
浏览文件 @
79def5e6
...
...
@@ -26,8 +26,8 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template
<
typename
T
>
class
CrossEntropyFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
prob
,
void
operator
()(
const
platform
::
DeviceContext
&
ctx
,
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
prob
,
const
framework
::
Tensor
*
labels
,
const
bool
softLabel
)
{
const
int
batch_size
=
prob
->
dims
()[
0
];
if
(
softLabel
)
{
...
...
@@ -35,7 +35,7 @@ class CrossEntropyFunctor<platform::CPUPlace, T> {
auto
lbl
=
EigenMatrix
<
T
>::
From
(
*
labels
);
auto
loss
=
EigenMatrix
<
T
>::
From
(
*
out
);
loss
.
device
(
ctx
.
GetEigenDevice
<
platform
::
CPUPlace
>
())
=
loss
.
device
(
*
ctx
.
GetEigenDevice
<
platform
::
CPUPlace
>
())
=
-
((
lbl
*
in
.
log
().
unaryExpr
(
math
::
TolerableValue
<
T
>
()))
.
sum
(
Eigen
::
DSizes
<
int
,
1
>
(
1
))
.
reshape
(
Eigen
::
DSizes
<
int
,
2
>
(
batch_size
,
1
)));
...
...
paddle/operators/math/cross_entropy.cu
浏览文件 @
79def5e6
...
...
@@ -74,8 +74,8 @@ using Tensor = framework::Tensor;
template
<
typename
T
>
class
CrossEntropyFunctor
<
platform
::
GPUPlace
,
T
>
{
public:
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
prob
,
void
operator
()(
const
framework
::
DeviceContext
&
ctx
,
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
prob
,
const
framework
::
Tensor
*
labels
,
bool
softLabel
)
{
const
T
*
prob_data
=
prob
->
data
<
T
>
();
T
*
loss_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
...
@@ -87,20 +87,18 @@ class CrossEntropyFunctor<platform::GPUPlace, T> {
const
T
*
label_data
=
labels
->
data
<
T
>
();
int
block
=
class_num
>
512
?
512
:
pow
(
2
,
int
(
std
::
log2
(
class_num
)));
SoftCrossEntropyKernel
<
T
><<<
batch_size
,
block
,
block
*
sizeof
(
T
),
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
loss_data
,
prob_data
,
label_data
,
class_num
);
SoftCrossEntropyKernel
<
T
><<<
batch_size
,
block
,
block
*
sizeof
(
T
),
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
()
>>>
(
loss_data
,
prob_data
,
label_data
,
class_num
);
}
else
{
const
int
*
label_data
=
labels
->
data
<
int
>
();
int
block
=
512
;
int
grid
=
(
batch_size
+
block
-
1
)
/
block
;
CrossEntropyKernel
<
T
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
loss_data
,
prob_data
,
label_data
,
batch_size
,
class_num
);
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
()
>>>
(
loss_data
,
prob_data
,
label_data
,
batch_size
,
class_num
);
}
}
};
...
...
paddle/operators/math/cross_entropy.h
浏览文件 @
79def5e6
...
...
@@ -37,9 +37,7 @@ struct TolerableValue {
template
<
typename
Place
,
typename
T
>
class
CrossEntropyFunctor
{
public:
// (TODO caoying) it is much better to use DeviceContext as the first
// parameter.
void
operator
()(
const
framework
::
ExecutionContext
&
context
,
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
prob
,
const
framework
::
Tensor
*
labels
,
const
bool
softLabel
);
};
...
...
paddle/operators/softmax_with_cross_entropy_op.h
浏览文件 @
79def5e6
...
...
@@ -42,7 +42,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
math
::
SoftmaxFunctor
<
platform
::
CPUPlace
,
T
>
()(
context
,
logits
,
softmax
);
math
::
CrossEntropyFunctor
<
platform
::
CPUPlace
,
T
>
()(
context
,
loss
,
softmax
,
labels
,
context
.
Attr
<
bool
>
(
"softLabel"
));
context
.
device_context
(),
loss
,
softmax
,
labels
,
context
.
Attr
<
bool
>
(
"softLabel"
));
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录