Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7d653216
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7d653216
编写于
9月 26, 2017
作者:
C
Cao Ying
提交者:
GitHub
9月 26, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #4237 from lcy-seso/optimize_cross_entropy_kernel
optimize cross entropy kernel.
上级
1c0a4c90
000d7511
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
242 addition
and
171 deletion
+242
-171
paddle/operators/accuracy_op.cu
paddle/operators/accuracy_op.cu
+6
-2
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+51
-32
paddle/operators/cross_entropy_op.cu
paddle/operators/cross_entropy_op.cu
+92
-55
paddle/operators/cross_entropy_op.h
paddle/operators/cross_entropy_op.h
+56
-58
paddle/operators/lookup_table_op.cu
paddle/operators/lookup_table_op.cu
+8
-3
paddle/operators/top_k_op.cu
paddle/operators/top_k_op.cu
+6
-4
python/paddle/v2/framework/tests/test_cross_entropy_op.py
python/paddle/v2/framework/tests/test_cross_entropy_op.py
+23
-17
未找到文件。
paddle/operators/accuracy_op.cu
浏览文件 @
7d653216
...
...
@@ -69,8 +69,12 @@ class AccuracyOpCUDAKernel : public framework::OpKernel {
return
;
}
AccuracyCudaKernel
<
PADDLE_CUDA_NUM_THREADS
><<<
1
,
PADDLE_CUDA_NUM_THREADS
>>>
(
num_samples
,
infer_width
,
inference_data
,
label_data
,
accuracy_data
);
AccuracyCudaKernel
<
PADDLE_CUDA_NUM_THREADS
><<<
1
,
PADDLE_CUDA_NUM_THREADS
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
num_samples
,
infer_width
,
inference_data
,
label_data
,
accuracy_data
);
}
};
...
...
paddle/operators/cross_entropy_op.cc
浏览文件 @
7d653216
...
...
@@ -23,27 +23,28 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X)
must not be
null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X)
should be not
null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Label"
),
"Input(Label) must not be null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
"Y"
),
"Output(Y) must not be null."
);
"Input(Label) should be not null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
"Y"
),
"Output(Y) should be not null."
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
PADDLE_ENFORCE_EQ
(
x
->
dims
().
size
(),
2
,
"Input(X)'s rank
must
be 2."
);
PADDLE_ENFORCE_EQ
(
x
->
dims
().
size
(),
2
,
"Input(X)'s rank
should
be 2."
);
PADDLE_ENFORCE_EQ
(
label
->
dims
().
size
(),
2
,
"Input(Label)'s rank
must
be 2."
);
"Input(Label)'s rank
should
be 2."
);
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
0
],
label
->
dims
()[
0
],
"The 1st dimension of Input(X) and Input(Label)
must
"
"The 1st dimension of Input(X) and Input(Label)
should
"
"be equal."
);
if
(
ctx
.
Attr
<
bool
>
(
"soft
_l
abel"
))
{
if
(
ctx
.
Attr
<
bool
>
(
"soft
L
abel"
))
{
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
1
],
label
->
dims
()[
1
],
"If Attr(soft
_label) == true, T
he 2nd dimension of "
"Input(X) and Input(Label)
must
be equal."
);
"If Attr(soft
Label) == true, t
he 2nd dimension of "
"Input(X) and Input(Label)
should
be equal."
);
}
else
{
PADDLE_ENFORCE_EQ
(
label
->
dims
()[
1
],
1
,
"If Attr(soft
_label) == false, T
he 2nd dimension of "
"Input(Label)
must
be 1."
);
"If Attr(soft
Label) == false, t
he 2nd dimension of "
"Input(Label)
should
be 1."
);
}
ctx
.
Output
<
Tensor
>
(
"Y"
)
->
Resize
({
x
->
dims
()[
0
],
1
});
...
...
@@ -57,35 +58,38 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X)
must not be
null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X)
should be not
null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Label"
),
"Input(Label)
must not be
null."
);
"Input(Label)
should be not
null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Y"
)),
"Input(Y@GRAD) must not be null."
);
"Input(Y@GRAD) shoudl be not null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
framework
::
GradVarName
(
"X"
)),
"Output(X@GRAD) should be not null."
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
dy
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
PADDLE_ENFORCE_EQ
(
x
->
dims
().
size
(),
2
,
"Input(X)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
dy
->
dims
().
size
(),
2
,
"Input(Y@Grad)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
x
->
dims
().
size
(),
2
,
"Input(X)'s rank should be 2."
);
PADDLE_ENFORCE_EQ
(
dy
->
dims
().
size
(),
2
,
"Input(Y@Grad)'s rank should be 2."
);
PADDLE_ENFORCE_EQ
(
label
->
dims
().
size
(),
2
,
"Input(Label)'s rank
must
be 2."
);
"Input(Label)'s rank
should
be 2."
);
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
0
],
label
->
dims
()[
0
],
"The 1st dimension of Input(X) and Input(Label)
must
"
"The 1st dimension of Input(X) and Input(Label)
should
"
"be equal."
);
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
0
],
dy
->
dims
()[
0
],
"The 1st dimension of Input(X) and Input(Y@Grad)
must
"
"The 1st dimension of Input(X) and Input(Y@Grad)
should
"
"be equal."
);
PADDLE_ENFORCE_EQ
(
dy
->
dims
()[
1
],
1
,
"The 2nd dimension of Input(Y@Grad)
must
be 1."
);
if
(
ctx
.
Attr
<
bool
>
(
"soft
_l
abel"
))
{
"The 2nd dimension of Input(Y@Grad)
should
be 1."
);
if
(
ctx
.
Attr
<
bool
>
(
"soft
L
abel"
))
{
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
1
],
label
->
dims
()[
1
],
"
If Attr(soft_label) == true, T
he 2nd dimension of "
"Input(X) and Input(Label)
must
be equal."
);
"
When Attr(softLabel) == true, t
he 2nd dimension of "
"Input(X) and Input(Label)
should
be equal."
);
}
else
{
PADDLE_ENFORCE_EQ
(
label
->
dims
()[
1
],
1
,
"
If Attr(soft_label) == false, T
he 2nd dimension of "
"Input(Label)
must
be 1."
);
"
When Attr(softLabel) == false, t
he 2nd dimension of "
"Input(Label)
should
be 1."
);
}
auto
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
...
...
@@ -98,24 +102,39 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
CrossEntropyOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The first input of CrossEntropyOp"
);
AddInput
(
"Label"
,
"The second input of CrossEntropyOp"
);
AddOutput
(
"Y"
,
"The output of CrossEntropyOp"
);
AddAttr
<
bool
>
(
"soft_label"
,
"Is soft label. Default zero."
)
AddInput
(
"X"
,
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
"where N is the batch size and D is the number of classes. "
"This input is a probability computed by the previous operator, "
"which is almost always the result of a softmax operator."
);
AddInput
(
"Label"
,
"(Tensor, default Tensor<int>), the ground truth which is "
"a 2-D tensor. "
"When softLabel is set to false, `Label` is a Tensor<int> with shape "
"[N x 1]. "
"When softLabel is set to true, `Label` is a Tensor<float/double> "
"with shape [N x K]."
);
AddOutput
(
"Y"
,
"(Tensor, default Tensor<float>), a 2-D tensor "
"with shape [N x 1]. The cross entropy loss."
);
AddAttr
<
bool
>
(
"softLabel"
,
"(bool, default false), a flag to indicate whether to interpretate "
"the given labels as soft labels."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
CrossEntropy Operator.
It supports both standard cross-entropy and soft-label cross-entropy loss
computation.
1) One-hot cross-entropy:
soft
_label = F
alse, Label[i, 0] indicates the class index for sample i:
soft
Label = f
alse, Label[i, 0] indicates the class index for sample i:
Y[i] = -log(X[i, Label[i]])
2) Soft-label cross-entropy:
soft
_label = T
rue, Label[i, j] indicates the soft label of class j
soft
Label = t
rue, Label[i, j] indicates the soft label of class j
for sample i:
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
...
...
paddle/operators/cross_entropy_op.cu
浏览文件 @
7d653216
...
...
@@ -28,26 +28,49 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
PADDLE_ASSERT
(
label
[
i
]
>=
0
&&
label
[
i
]
<
D
);
Y
[
i
]
=
-
tolerable_value
(
log
(
X
[
i
*
D
+
label
[
i
]]));
Y
[
i
]
=
-
TolerableValue
<
T
>
()
(
log
(
X
[
i
*
D
+
label
[
i
]]));
}
}
template
<
typename
T
>
__device__
__forceinline__
T
sum_single_warp
(
T
val
)
{
val
+=
__shfl_down
(
val
,
16
);
val
+=
__shfl_down
(
val
,
8
);
val
+=
__shfl_down
(
val
,
4
);
val
+=
__shfl_down
(
val
,
2
);
val
+=
__shfl_down
(
val
,
1
);
return
val
;
}
template
<
typename
T
>
__global__
void
SoftCrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
T
*
label
,
const
int
N
,
const
int
D
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
sum
=
static_cast
<
T
>
(
0
);
for
(
int
j
=
0
;
j
<
D
;
j
++
)
{
sum
+=
label
[
i
*
D
+
j
]
*
tolerable_value
(
log
(
X
[
i
*
D
+
j
]));
}
Y
[
i
]
=
-
sum
;
const
int
class_num
)
{
int
tid
=
threadIdx
.
x
;
extern
__shared__
T
d_sum
[];
d_sum
[
tid
]
=
0
;
int
cur_idx
=
tid
;
int
next_idx
=
blockIdx
.
x
*
class_num
+
tid
;
while
(
cur_idx
<
class_num
)
{
d_sum
[
tid
]
+=
TolerableValue
<
T
>
()(
std
::
log
(
X
[
next_idx
]))
*
label
[
next_idx
];
next_idx
+=
blockDim
.
x
;
cur_idx
+=
blockDim
.
x
;
}
__syncthreads
();
for
(
unsigned
int
stride
=
blockDim
.
x
>>
1
;
stride
>=
32
;
stride
>>=
1
)
{
if
(
tid
<
stride
)
d_sum
[
tid
]
+=
d_sum
[
tid
+
stride
];
__syncthreads
();
}
T
val
=
d_sum
[
tid
];
val
=
sum_single_warp
<
T
>
(
val
);
if
(
tid
==
0
)
Y
[
blockIdx
.
x
]
=
-
val
;
}
// TODO(qingqing): make zero setting a
n
common function.
// TODO(qingqing): make zero setting a common function.
template
<
typename
T
>
__global__
void
z
ero
(
T
*
X
,
const
int
N
)
{
__global__
void
Z
ero
(
T
*
X
,
const
int
N
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
X
[
i
]
=
0.0
;
...
...
@@ -71,13 +94,10 @@ template <typename T>
__global__
void
SoftCrossEntropyGradientKernel
(
T
*
dX
,
const
T
*
dY
,
const
T
*
X
,
const
T
*
label
,
const
int
N
,
const
int
D
)
{
// TOOD(qingqing): optimize for this kernel
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(
int
j
=
0
;
j
<
D
;
++
j
)
{
int
idx
=
i
*
D
+
j
;
dX
[
idx
]
=
-
label
[
idx
]
*
dY
[
i
]
/
X
[
idx
];
}
int
ids
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
ids
<
N
*
D
)
{
int
row_ids
=
ids
/
D
;
dX
[
ids
]
=
-
label
[
ids
]
*
dY
[
row_ids
]
/
X
[
ids
];
}
}
...
...
@@ -86,29 +106,36 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"
It must use GPUPla
ce."
);
"
This kernel only runs on GPU devi
ce."
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
y
=
ctx
.
Output
<
Tensor
>
(
"Y
"
);
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label
"
);
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
label
=
ctx
.
Input
<
Tensor
>
(
"Label
"
);
Tensor
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y
"
);
auto
*
x_data
=
x
->
data
<
T
>
();
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
y_data
=
y
->
data
<
T
>
();
const
T
*
x_data
=
x
->
data
<
T
>
();
T
*
y_data
=
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int
n
=
x
->
dims
()[
0
];
int
d
=
x
->
dims
()[
1
];
int
block
=
512
;
int
grid
=
(
n
+
block
-
1
)
/
block
;
// TODO(qingqing) launch kernel on specified stream
// base on ExecutionContext.
if
(
ctx
.
Attr
<
bool
>
(
"soft_label"
))
{
int
batch_size
=
x
->
dims
()[
0
];
int
class_num
=
x
->
dims
()[
1
];
if
(
ctx
.
Attr
<
bool
>
(
"softLabel"
))
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
T
>
();
SoftCrossEntropyKernel
<
T
><<<
grid
,
block
>>>
(
y_data
,
x_data
,
label_data
,
n
,
d
);
int
block
=
class_num
>
512
?
512
:
pow
(
2
,
int
(
std
::
log2
(
class_num
)));
SoftCrossEntropyKernel
<
T
><<<
batch_size
,
block
,
block
*
sizeof
(
T
),
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
y_data
,
x_data
,
label_data
,
class_num
);
}
else
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
int
>
();
CrossEntropyKernel
<
T
><<<
grid
,
block
>>>
(
y_data
,
x_data
,
label_data
,
n
,
d
);
int
block
=
512
;
int
grid
=
(
batch_size
+
block
-
1
)
/
block
;
CrossEntropyKernel
<
T
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
y_data
,
x_data
,
label_data
,
batch_size
,
class_num
);
}
}
};
...
...
@@ -118,33 +145,43 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use GPUPlace."
);
"This kernel only runs on GPU device."
);
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
Tensor
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
)
);
auto
dy
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
const
T
*
dy_data
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
))
->
data
<
T
>
(
);
T
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
(
));
const
T
*
x_data
=
x
->
data
<
T
>
(
);
auto
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
dy_data
=
dy
->
data
<
T
>
();
auto
*
x_data
=
x
->
data
<
T
>
();
int
batch_size
=
x
->
dims
()[
0
];
int
class_num
=
x
->
dims
()[
1
];
int
n
=
x
->
dims
()[
0
];
int
d
=
x
->
dims
()[
1
];
int
block
=
512
;
int
grid
=
(
n
*
d
+
block
-
1
)
/
block
;
zero
<
T
><<<
grid
,
block
>>>
(
dx_data
,
n
*
d
);
grid
=
(
n
+
block
-
1
)
/
block
;
// TODO(qingqing): launch kernel on specified stream
// base on ExecutionContext.
if
(
ctx
.
Attr
<
bool
>
(
"soft_label"
))
{
int
grid
=
(
batch_size
*
class_num
+
block
-
1
)
/
block
;
if
(
ctx
.
Attr
<
bool
>
(
"softLabel"
))
{
auto
*
label_data
=
label
->
data
<
T
>
();
SoftCrossEntropyGradientKernel
<
T
><<<
grid
,
block
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
n
,
d
);
SoftCrossEntropyGradientKernel
<
T
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
batch_size
,
class_num
);
}
else
{
Zero
<
T
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
dx_data
,
batch_size
*
class_num
);
auto
*
label_data
=
label
->
data
<
int
>
();
CrossEntropyGradientKernel
<
T
><<<
grid
,
block
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
n
,
d
);
grid
=
(
batch_size
+
block
-
1
)
/
block
;
CrossEntropyGradientKernel
<
T
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
batch_size
,
class_num
);
}
}
};
...
...
paddle/operators/cross_entropy_op.h
浏览文件 @
7d653216
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/hostdevice.h"
...
...
@@ -20,53 +21,51 @@ namespace paddle {
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
>
HOSTDEVICE
T
tolerable_value
(
const
T
x
)
{
PADDLE_ASSERT
(
std
::
is_floating_point
<
T
>::
value
);
const
T
kApproInf
=
1e20
;
if
(
x
==
INFINITY
)
{
return
kApproInf
;
struct
TolerableValue
{
HOSTDEVICE
T
operator
()(
const
T
&
x
)
const
{
PADDLE_ASSERT
(
std
::
is_floating_point
<
T
>::
value
);
const
T
kApproInf
=
1e20
;
if
(
x
==
INFINITY
)
return
kApproInf
;
if
(
x
==
-
INFINITY
)
return
-
kApproInf
;
return
x
;
}
if
(
x
==
-
INFINITY
)
{
return
-
kApproInf
;
}
return
x
;
}
};
template
<
typename
T
>
class
CrossEntropyOpKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"It must use CPUPlace."
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
auto
*
x_data
=
x
->
data
<
T
>
();
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
y_data
=
y
->
data
<
T
>
();
int
batch_size
=
x
->
dims
()[
0
];
int
class_num
=
x
->
dims
()[
1
];
if
(
ctx
.
Attr
<
bool
>
(
"soft_label"
))
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
T
>
();
int
index
=
0
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
T
sum
=
static_cast
<
T
>
(
0
);
for
(
int
j
=
0
;
j
<
class_num
;
++
j
)
{
sum
+=
label_data
[
index
]
*
tolerable_value
(
std
::
log
(
x_data
[
index
]));
y_data
[
i
]
=
-
sum
;
index
++
;
}
}
"This kernel only runs on CPU."
);
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
labels
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
Tensor
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
T
*
y_data
=
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
int
batch_size
=
x
->
dims
()[
0
];
if
(
ctx
.
Attr
<
bool
>
(
"softLabel"
))
{
auto
prob
=
EigenMatrix
<
T
>::
From
(
*
x
);
auto
lbl_mat
=
EigenMatrix
<
T
>::
From
(
*
labels
);
auto
loss
=
EigenMatrix
<
T
>::
From
(
*
y
);
loss
.
device
(
ctx
.
GetEigenDevice
<
platform
::
CPUPlace
>
())
=
-
((
lbl_mat
*
prob
.
log
().
unaryExpr
(
TolerableValue
<
T
>
()))
.
sum
(
Eigen
::
DSizes
<
int
,
1
>
(
1
))
.
reshape
(
Eigen
::
DSizes
<
int
,
2
>
(
batch_size
,
1
)));
}
else
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
int
>
();
const
int
class_num
=
x
->
dims
()[
1
];
const
T
*
x_data
=
x
->
data
<
T
>
();
const
int
*
label_data
=
labels
->
data
<
int
>
();
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
int
index
=
i
*
class_num
+
label_data
[
i
];
y_data
[
i
]
=
-
tolerable_value
(
std
::
log
(
x_data
[
index
]));
y_data
[
i
]
=
-
TolerableValue
<
T
>
()
(
std
::
log
(
x_data
[
index
]));
}
}
}
...
...
@@ -77,33 +76,32 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"
It must use CPUPlace
."
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
)
);
auto
dy
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y
"
));
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
"
This kernel only runs on CPU
."
);
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
dy
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
)
);
const
Tensor
*
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
Tensor
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X
"
));
T
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()
);
auto
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
dy_data
=
dy
->
data
<
T
>
();
auto
*
x_data
=
x
->
data
<
T
>
();
int
batch_size
=
x
->
dims
()[
0
];
int
class_num
=
x
->
dims
()[
1
];
// TODO(qingqing): make zero setting an common function.
if
(
ctx
.
Attr
<
bool
>
(
"soft_label"
))
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
T
>
();
int
index
=
0
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
j
=
0
;
j
<
class_num
;
++
j
)
{
dx_data
[
index
]
=
-
label_data
[
index
]
*
dy_data
[
i
]
/
x_data
[
index
];
index
++
;
}
}
if
(
ctx
.
Attr
<
bool
>
(
"softLabel"
))
{
auto
x_mat
=
EigenMatrix
<
T
>::
From
(
*
x
);
auto
dy_mat
=
EigenMatrix
<
T
>::
From
(
*
dy
);
auto
lbl_mat
=
EigenMatrix
<
T
>::
From
(
*
label
);
auto
dx_mat
=
EigenMatrix
<
T
>::
From
(
*
dx
);
dx_mat
.
device
(
ctx
.
GetEigenDevice
<
platform
::
CPUPlace
>
())
=
-
(
lbl_mat
*
dy_mat
.
broadcast
(
Eigen
::
DSizes
<
int
,
2
>
(
1
,
class_num
))
/
x_mat
);
}
else
{
auto
*
label_data
=
label
->
data
<
int
>
();
int
batch_size
=
x
->
dims
()[
0
];
const
T
*
dy_data
=
dy
->
data
<
T
>
();
const
T
*
x_data
=
x
->
data
<
T
>
();
const
int
*
label_data
=
label
->
data
<
int
>
();
// TODO(qingqing): make zero setting a common function.
memset
(
dx_data
,
0
,
sizeof
(
T
)
*
batch_size
*
class_num
);
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
PADDLE_ASSERT
(
label_data
[
i
]
>=
0
||
label_data
[
i
]
<
class_num
);
int
index
=
i
*
class_num
+
label_data
[
i
];
...
...
paddle/operators/lookup_table_op.cu
浏览文件 @
7d653216
...
...
@@ -77,7 +77,10 @@ class LookupTableCUDAKernel : public framework::OpKernel {
dim3
threads
(
128
,
8
);
dim3
grids
(
8
,
1
);
LookupTable
<
T
,
128
,
8
,
8
><<<
grids
,
threads
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
);
LookupTable
<
T
,
128
,
8
,
8
><<<
grids
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
.
device_context
())
.
stream
()
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
);
}
};
...
...
@@ -102,8 +105,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel {
dim3
threads
(
128
,
8
);
dim3
grids
(
8
,
1
);
LookupTableGrad
<
T
,
128
,
8
,
8
><<<
grids
,
threads
>>>
(
d_table
,
d_output
,
ids
,
N
,
K
,
D
);
LookupTableGrad
<
T
,
128
,
8
,
8
><<<
grids
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
.
device_context
())
.
stream
()
>>>
(
d_table
,
d_output
,
ids
,
N
,
K
,
D
);
}
};
...
...
paddle/operators/top_k_op.cu
浏览文件 @
7d653216
...
...
@@ -301,14 +301,16 @@ class TopkOpCUDAKernel : public framework::OpKernel {
// NOTE: pass lds and dim same to input width.
// NOTE: old matrix implementation of stride is different to eigen.
// TODO(typhoonzero): launch kernel on specified stream.
// TODO(typhoonzero): refine this kernel.
dim3
threads
(
256
,
1
);
dim3
grid
(
input_height
,
1
);
KeMatrixTopK
<
T
,
5
,
256
><<<
grid
,
threads
>>>
(
output_data
,
output
->
dims
()[
1
],
indices_data
,
input_data
,
input_width
,
input_width
,
int
(
k
));
KeMatrixTopK
<
T
,
5
,
256
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
output_data
,
output
->
dims
()[
1
],
indices_data
,
input_data
,
input_width
,
input_width
,
int
(
k
));
}
};
...
...
python/paddle/v2/framework/tests/test_cross_entropy_op.py
浏览文件 @
7d653216
...
...
@@ -4,22 +4,24 @@ from op_test import OpTest
class
TestCrossEntropyOp1
(
OpTest
):
"""Test
standard cross-entropy, with index representation of
labels.
"""Test
cross-entropy with discrete one-hot
labels.
"""
def
setUp
(
self
):
self
.
op_type
=
"cross_entropy"
batch_size
=
30
class_num
=
10
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label
=
np
.
random
.
randint
(
0
,
class_num
,
(
batch_size
,
1
),
dtype
=
"int32"
)
cross_entropy
=
np
.
asmatrix
(
[[
-
np
.
log
(
X
[
i
][
label
[
i
][
0
]])]
for
i
in
range
(
X
.
shape
[
0
])],
dtype
=
"float32"
)
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
'soft_label'
:
False
}
self
.
attrs
=
{
"softLabel"
:
False
}
def
test_check_output
(
self
):
self
.
check_output
()
...
...
@@ -29,13 +31,14 @@ class TestCrossEntropyOp1(OpTest):
class
TestCrossEntropyOp2
(
OpTest
):
"""Test
soft-label cross-entropy, with vecte
rized soft labels.
"""Test
cross-entropy with vecto
rized soft labels.
"""
def
setUp
(
self
):
self
.
op_type
=
"cross_entropy"
batch_size
=
10
class_num
=
5
batch_size
=
5
class_num
=
37
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label
=
np
.
random
.
uniform
(
0.1
,
1.0
,
...
...
@@ -43,46 +46,49 @@ class TestCrossEntropyOp2(OpTest):
label
/=
label
.
sum
(
axis
=
1
,
keepdims
=
True
)
cross_entropy
=
(
-
label
*
np
.
log
(
X
)).
sum
(
axis
=
1
,
keepdims
=
True
).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
X
,
'Label'
:
label
}
self
.
outputs
=
{
'Y'
:
cross_entropy
}
self
.
attrs
=
{
'soft_label'
:
True
}
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
"softLabel"
:
True
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Y'
)
self
.
check_grad
([
"X"
],
"Y"
,
max_relative_error
=
0.05
)
class
TestCrossEntropyOp3
(
OpTest
):
"""Test one-hot cross-entropy, with vecterized one-hot representation of
labels.
"""Test cross-entropy with vectorized one-hot representation of labels.
"""
def
setUp
(
self
):
self
.
op_type
=
"cross_entropy"
batch_size
=
30
class_num
=
10
batch_size
=
5
class_num
=
17
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label_index
=
np
.
random
.
randint
(
0
,
class_num
,
(
batch_size
),
dtype
=
"int32"
)
label
=
np
.
zeros
(
X
.
shape
)
label
[
np
.
arange
(
batch_size
),
label_index
]
=
1
cross_entropy
=
np
.
asmatrix
(
[[
-
np
.
log
(
X
[
i
][
label_index
[
i
]])]
for
i
in
range
(
X
.
shape
[
0
])],
dtype
=
"float32"
)
cross_entropy2
=
(
-
label
*
np
.
log
(
X
)).
sum
(
axis
=
1
,
keepdims
=
True
).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
X
,
'Label'
:
label
}
self
.
outputs
=
{
'Y'
:
cross_entropy
}
self
.
attrs
=
{
'soft_label'
:
True
}
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
"softLabel"
:
True
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Y'
)
self
.
check_grad
([
"X"
],
"Y"
,
max_relative_error
=
0.05
)
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录