Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
141b8dbc
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
141b8dbc
编写于
9月 21, 2017
作者:
C
caoying03
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update the backward kernel.
上级
a3a8a090
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
59 addition
and
47 deletion
+59
-47
paddle/operators/cross_entropy_op.cu
paddle/operators/cross_entropy_op.cu
+20
-16
paddle/operators/cross_entropy_op.h
paddle/operators/cross_entropy_op.h
+39
-31
未找到文件。
paddle/operators/cross_entropy_op.cu
浏览文件 @
141b8dbc
...
...
@@ -28,27 +28,27 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
PADDLE_ASSERT
(
label
[
i
]
>=
0
&&
label
[
i
]
<
D
);
Y
[
i
]
=
-
tolerable_value
(
log
(
X
[
i
*
D
+
label
[
i
]]));
Y
[
i
]
=
-
TolerableValue
<
T
>
()
(
log
(
X
[
i
*
D
+
label
[
i
]]));
}
}
template
<
typename
T
,
int
b
lockSize
>
template
<
typename
T
,
int
B
lockSize
>
__global__
void
SoftCrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
T
*
label
,
const
int
N
,
const
int
D
)
{
int
tid
=
threadIdx
.
x
;
__shared__
T
d_sum
[
b
lockSize
];
__shared__
T
d_sum
[
B
lockSize
];
int
next_idx
=
blockIdx
.
x
*
D
+
tid
;
d_sum
[
tid
]
=
0
;
int
cur_idx
=
tid
;
while
(
cur_idx
<
D
)
{
d_sum
[
tid
]
+=
tolerable_value
(
std
::
log
(
X
[
next_idx
]))
*
label
[
next_idx
];
next_idx
+=
b
lockSize
;
cur_idx
+=
b
lockSize
;
d_sum
[
tid
]
+=
TolerableValue
<
T
>
()
(
std
::
log
(
X
[
next_idx
]))
*
label
[
next_idx
];
next_idx
+=
B
lockSize
;
cur_idx
+=
B
lockSize
;
}
__syncthreads
();
for
(
int
stride
=
b
lockSize
>>
1
;
stride
>
0
;
stride
>>=
1
)
{
for
(
int
stride
=
B
lockSize
>>
1
;
stride
>
0
;
stride
>>=
1
)
{
__syncthreads
();
if
(
tid
<
stride
)
{
next_idx
=
tid
+
stride
;
...
...
@@ -88,13 +88,12 @@ template <typename T>
__global__
void
SoftCrossEntropyGradientKernel
(
T
*
dX
,
const
T
*
dY
,
const
T
*
X
,
const
T
*
label
,
const
int
N
,
const
int
D
)
{
// TOOD(qingqing): optimize for this kernel
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(
int
j
=
0
;
j
<
D
;
++
j
)
{
int
idx
=
i
*
D
+
j
;
dX
[
idx
]
=
-
label
[
idx
]
*
dY
[
i
]
/
X
[
idx
];
}
int
row_ids
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
col_ids
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
ids
=
row_ids
*
D
+
col_ids
;
if
(
ids
<
N
*
D
)
{
dX
[
ids
]
=
-
label
[
ids
]
*
dY
[
row_ids
]
/
X
[
ids
];
}
}
...
...
@@ -103,7 +102,7 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"
It must use GPUPla
ce."
);
"
This kernel only runs on GPU devi
ce."
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
...
...
@@ -136,7 +135,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"
It must use GPUPla
ce."
);
"
This kernel only runs on GPU devi
ce."
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
...
...
@@ -156,6 +155,11 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
// TODO(qingqing): launch kernel on specified stream
// base on ExecutionContext.
if
(
ctx
.
Attr
<
int
>
(
"soft_label"
)
==
1
)
{
int
block_x
=
32
;
int
block_y
=
32
;
dim3
block
(
block_x
,
block_y
);
dim3
grid
((
n
+
block_x
-
1
)
/
block_x
,
(
d
+
block_y
-
1
)
/
block_y
);
auto
*
label_data
=
label
->
data
<
T
>
();
SoftCrossEntropyGradientKernel
<
T
><<<
grid
,
block
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
n
,
d
);
...
...
paddle/operators/cross_entropy_op.h
浏览文件 @
141b8dbc
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/hostdevice.h"
...
...
@@ -20,19 +21,25 @@ namespace paddle {
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
>
HOSTDEVICE
T
tolerable_value
(
const
T
x
)
{
PADDLE_ASSERT
(
std
::
is_floating_point
<
T
>::
value
);
const
T
kApproInf
=
1e20
;
if
(
x
==
INFINITY
)
{
return
kApproInf
;
}
if
(
x
==
-
INFINITY
)
{
return
-
kApproInf
;
struct
TolerableValue
{
HOSTDEVICE
T
operator
()(
const
T
&
x
)
const
{
PADDLE_ASSERT
(
std
::
is_floating_point
<
T
>::
value
);
const
T
kApproInf
=
1e20
;
if
(
x
==
INFINITY
)
{
return
kApproInf
;
}
if
(
x
==
-
INFINITY
)
{
return
-
kApproInf
;
}
return
x
;
}
return
x
;
}
};
template
<
typename
T
>
class
CrossEntropyOpKernel
:
public
framework
::
OpKernel
{
...
...
@@ -40,33 +47,34 @@ class CrossEntropyOpKernel : public framework::OpKernel {
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"It must use CPUPlace."
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
auto
*
x_data
=
x
->
data
<
T
>
();
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
labels
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
Tensor
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
y_data
=
y
->
data
<
T
>
();
int
batch_size
=
x
->
dims
()[
0
];
int
class_num
=
x
->
dims
()[
1
];
const
int
batch_size
=
x
->
dims
()[
0
];
if
(
ctx
.
Attr
<
int
>
(
"soft_label"
)
==
1
)
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
T
>
();
int
index
=
0
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
T
sum
=
static_cast
<
T
>
(
0
);
for
(
int
j
=
0
;
j
<
class_num
;
++
j
)
{
sum
+=
label_data
[
index
]
*
tolerable_value
(
std
::
log
(
x_data
[
index
]));
y_data
[
i
]
=
-
sum
;
index
++
;
}
}
auto
prob
=
EigenMatrix
<
T
>::
From
(
*
x
);
auto
lbl_mat
=
EigenMatrix
<
T
>::
From
(
*
labels
);
auto
loss
=
EigenMatrix
<
T
>::
From
(
*
y
);
// loss.device(ctx.GetEigenDevice<platform::CPUPlace>()) =
// prob.log().unaryExpr(TolerableValue<T>());
loss
.
device
(
ctx
.
GetEigenDevice
<
platform
::
CPUPlace
>
())
=
-
((
lbl_mat
*
prob
.
log
())
.
sum
(
Eigen
::
DSizes
<
int
,
1
>
(
1
))
.
reshape
(
Eigen
::
DSizes
<
int
,
2
>
(
batch_size
,
1
)));
}
else
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
int
>
();
const
int
class_num
=
x
->
dims
()[
1
];
const
T
*
x_data
=
x
->
data
<
T
>
();
T
*
y_data
=
y
->
data
<
T
>
();
const
int
*
label_data
=
labels
->
data
<
int
>
();
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
int
index
=
i
*
class_num
+
label_data
[
i
];
y_data
[
i
]
=
-
tolerable_value
(
std
::
log
(
x_data
[
index
]));
y_data
[
i
]
=
-
TolerableValue
<
T
>
()
(
std
::
log
(
x_data
[
index
]));
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录