Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
6735585b
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
6735585b
编写于
9月 22, 2017
作者:
C
caoying03
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix cpu kernel with soft labels.
上级
30bfaab3
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
194 addition
and
97 deletion
+194
-97
paddle/operators/accuracy_op.cu
paddle/operators/accuracy_op.cu
+6
-2
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+44
-28
paddle/operators/cross_entropy_op.cu
paddle/operators/cross_entropy_op.cu
+88
-40
paddle/operators/cross_entropy_op.h
paddle/operators/cross_entropy_op.h
+3
-10
paddle/operators/lookup_table_op.cu
paddle/operators/lookup_table_op.cu
+8
-3
paddle/operators/top_k_op.cu
paddle/operators/top_k_op.cu
+6
-4
python/paddle/v2/framework/tests/test_cross_entropy_op.py
python/paddle/v2/framework/tests/test_cross_entropy_op.py
+39
-10
未找到文件。
paddle/operators/accuracy_op.cu
浏览文件 @
6735585b
...
...
@@ -69,8 +69,12 @@ class AccuracyOpCUDAKernel : public framework::OpKernel {
return
;
}
AccuracyCudaKernel
<
PADDLE_CUDA_NUM_THREADS
><<<
1
,
PADDLE_CUDA_NUM_THREADS
>>>
(
num_samples
,
infer_width
,
inference_data
,
label_data
,
accuracy_data
);
AccuracyCudaKernel
<
PADDLE_CUDA_NUM_THREADS
><<<
1
,
PADDLE_CUDA_NUM_THREADS
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
num_samples
,
infer_width
,
inference_data
,
label_data
,
accuracy_data
);
}
};
...
...
paddle/operators/cross_entropy_op.cc
浏览文件 @
6735585b
...
...
@@ -23,27 +23,28 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X)
must not be
null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X)
should be not
null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Label"
),
"Input(Label) must not be null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
"Y"
),
"Output(Y) must not be null."
);
"Input(Label) should be not null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
"Y"
),
"Output(Y) should be not null."
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
PADDLE_ENFORCE_EQ
(
x
->
dims
().
size
(),
2
,
"Input(X)'s rank
must
be 2."
);
PADDLE_ENFORCE_EQ
(
x
->
dims
().
size
(),
2
,
"Input(X)'s rank
should
be 2."
);
PADDLE_ENFORCE_EQ
(
label
->
dims
().
size
(),
2
,
"Input(Label)'s rank
must
be 2."
);
"Input(Label)'s rank
should
be 2."
);
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
0
],
label
->
dims
()[
0
],
"The 1st dimension of Input(X) and Input(Label)
must
"
"The 1st dimension of Input(X) and Input(Label)
should
"
"be equal."
);
if
(
ctx
.
Attr
<
bool
>
(
"soft_label"
))
{
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
1
],
label
->
dims
()[
1
],
"If Attr(soft_label) == true,
T
he 2nd dimension of "
"Input(X) and Input(Label)
must
be equal."
);
"If Attr(soft_label) == true,
t
he 2nd dimension of "
"Input(X) and Input(Label)
should
be equal."
);
}
else
{
PADDLE_ENFORCE_EQ
(
label
->
dims
()[
1
],
1
,
"If Attr(soft_label) == false,
T
he 2nd dimension of "
"Input(Label)
must
be 1."
);
"If Attr(soft_label) == false,
t
he 2nd dimension of "
"Input(Label)
should
be 1."
);
}
ctx
.
Output
<
Tensor
>
(
"Y"
)
->
Resize
({
x
->
dims
()[
0
],
1
});
...
...
@@ -57,35 +58,36 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X)
must not be
null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X)
should be not
null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Label"
),
"Input(Label)
must not be
null."
);
"Input(Label)
should be not
null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Y"
)),
"Input(Y@GRAD)
must not be
null."
);
"Input(Y@GRAD)
shoudl be not
null."
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
dy
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
PADDLE_ENFORCE_EQ
(
x
->
dims
().
size
(),
2
,
"Input(X)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
dy
->
dims
().
size
(),
2
,
"Input(Y@Grad)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
x
->
dims
().
size
(),
2
,
"Input(X)'s rank should be 2."
);
PADDLE_ENFORCE_EQ
(
dy
->
dims
().
size
(),
2
,
"Input(Y@Grad)'s rank should be 2."
);
PADDLE_ENFORCE_EQ
(
label
->
dims
().
size
(),
2
,
"Input(Label)'s rank
must
be 2."
);
"Input(Label)'s rank
should
be 2."
);
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
0
],
label
->
dims
()[
0
],
"The 1st dimension of Input(X) and Input(Label)
must
"
"The 1st dimension of Input(X) and Input(Label)
should
"
"be equal."
);
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
0
],
dy
->
dims
()[
0
],
"The 1st dimension of Input(X) and Input(Y@Grad)
must
"
"The 1st dimension of Input(X) and Input(Y@Grad)
should
"
"be equal."
);
PADDLE_ENFORCE_EQ
(
dy
->
dims
()[
1
],
1
,
"The 2nd dimension of Input(Y@Grad)
must
be 1."
);
"The 2nd dimension of Input(Y@Grad)
should
be 1."
);
if
(
ctx
.
Attr
<
bool
>
(
"soft_label"
))
{
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
1
],
label
->
dims
()[
1
],
"
If Attr(soft_label) == true, T
he 2nd dimension of "
"Input(X) and Input(Label)
must
be equal."
);
"
When Attr(soft_label) == true, t
he 2nd dimension of "
"Input(X) and Input(Label)
should
be equal."
);
}
else
{
PADDLE_ENFORCE_EQ
(
label
->
dims
()[
1
],
1
,
"
If Attr(soft_label) == false, T
he 2nd dimension of "
"Input(Label)
must
be 1."
);
"
When Attr(soft_label) == false, t
he 2nd dimension of "
"Input(Label)
should
be 1."
);
}
auto
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
...
...
@@ -98,12 +100,26 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
CrossEntropyOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The first input of CrossEntropyOp"
);
AddInput
(
"Label"
,
"The second input of CrossEntropyOp"
);
AddOutput
(
"Y"
,
"The output of CrossEntropyOp"
);
AddAttr
<
bool
>
(
"soft_label"
,
"Is soft label. Default zero."
)
AddInput
(
"X"
,
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
"where N is the batch size and D is the number of classes. "
"This input is a probability computed by the previous operator, "
"which is almost always the result of a softmax operator."
);
AddInput
(
"Label"
,
"(Tensor, default Tensor<int>), the ground truth which is "
"a 1-D or 2-D tensor. "
"When soft_label is set to 0, `Label` is a Tensor<int> with shape "
"[N x 1]. "
"When soft_label is set to 1, `Label` is a Tensor<float/double> "
"with shape [N x K]."
);
AddOutput
(
"Y"
,
"(Tensor, default Tensor<float>), a 1-D tensor "
"with shape [N x 1]. The cross entropy loss."
);
AddAttr
<
bool
>
(
"soft_label"
,
"(bool, default false), a flag to indicate whether to interpretate "
"the given labels as soft labels."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
CrossEntropy Operator.
...
...
paddle/operators/cross_entropy_op.cu
浏览文件 @
6735585b
...
...
@@ -32,37 +32,71 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
}
}
template
<
typename
T
>
__device__
__forceinline__
T
sum_single_warp
(
T
val
)
{
val
+=
__shfl_down
(
val
,
16
);
val
+=
__shfl_down
(
val
,
8
);
val
+=
__shfl_down
(
val
,
4
);
val
+=
__shfl_down
(
val
,
2
);
val
+=
__shfl_down
(
val
,
1
);
return
val
;
}
// This kernel is called when the class number is less than or equal to 512.
template
<
typename
T
>
__global__
void
SoftCrossEntropyKernel1
(
T
*
Y
,
const
T
*
X
,
const
T
*
label
,
const
int
class_num
)
{
int
tid
=
threadIdx
.
x
;
extern
__shared__
T
d_sum
[];
d_sum
[
tid
]
=
0
;
int
cur_idx
=
tid
;
int
next_idx
=
blockIdx
.
x
*
class_num
+
tid
;
while
(
cur_idx
<
class_num
)
{
d_sum
[
tid
]
+=
TolerableValue
<
T
>
()(
std
::
log
(
X
[
next_idx
]))
*
label
[
next_idx
];
next_idx
+=
blockDim
.
x
;
cur_idx
+=
blockDim
.
x
;
}
__syncthreads
();
for
(
unsigned
int
stride
=
blockDim
.
x
>>
1
;
stride
>=
32
;
stride
>>=
1
)
{
if
(
tid
<
stride
)
d_sum
[
tid
]
+=
d_sum
[
tid
+
stride
];
__syncthreads
();
}
T
val
=
d_sum
[
tid
];
val
=
sum_single_warp
<
T
>
(
val
);
if
(
tid
==
0
)
Y
[
blockIdx
.
x
]
=
-
val
;
}
// This kernel is called when the class number is larger than 512.
template
<
typename
T
,
int
BlockSize
>
__global__
void
SoftCrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
T
*
label
,
const
int
N
,
const
int
D
)
{
__global__
void
SoftCrossEntropyKernel
2
(
T
*
Y
,
const
T
*
X
,
const
T
*
label
,
const
int
class_num
)
{
int
tid
=
threadIdx
.
x
;
__shared__
T
d_sum
[
BlockSize
];
int
next_idx
=
blockIdx
.
x
*
D
+
tid
;
int
next_idx
=
blockIdx
.
x
*
class_num
+
tid
;
d_sum
[
tid
]
=
0
;
int
cur_idx
=
tid
;
while
(
cur_idx
<
D
)
{
while
(
cur_idx
<
class_num
)
{
d_sum
[
tid
]
+=
TolerableValue
<
T
>
()(
std
::
log
(
X
[
next_idx
]))
*
label
[
next_idx
];
next_idx
+=
BlockSize
;
cur_idx
+=
BlockSize
;
}
__syncthreads
();
for
(
int
stride
=
BlockSize
>>
1
;
stride
>
0
;
stride
>>=
1
)
{
for
(
unsigned
int
stride
=
BlockSize
>>
1
;
stride
>=
32
;
stride
>>=
1
)
{
if
(
tid
<
stride
)
d_sum
[
tid
]
+=
d_sum
[
tid
+
stride
];
__syncthreads
();
if
(
tid
<
stride
)
{
next_idx
=
tid
+
stride
;
d_sum
[
tid
]
+=
d_sum
[
next_idx
];
}
}
__syncthreads
();
if
(
tid
==
0
)
{
Y
[
blockIdx
.
x
]
=
-
d_sum
[
0
]
;
}
T
val
=
d_sum
[
tid
];
val
=
sum_single_warp
<
T
>
(
val
)
;
if
(
tid
==
0
)
Y
[
blockIdx
.
x
]
=
-
val
;
}
// TODO(qingqing): make zero setting a
n
common function.
// TODO(qingqing): make zero setting a common function.
template
<
typename
T
>
__global__
void
zero
(
T
*
X
,
const
int
N
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
...
...
@@ -88,11 +122,9 @@ template <typename T>
__global__
void
SoftCrossEntropyGradientKernel
(
T
*
dX
,
const
T
*
dY
,
const
T
*
X
,
const
T
*
label
,
const
int
N
,
const
int
D
)
{
int
row_ids
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
col_ids
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
ids
=
row_ids
*
D
+
col_ids
;
int
ids
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
ids
<
N
*
D
)
{
int
row_ids
=
ids
/
D
;
dX
[
ids
]
=
-
label
[
ids
]
*
dY
[
row_ids
]
/
X
[
ids
];
}
}
...
...
@@ -112,20 +144,34 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
y_data
=
y
->
data
<
T
>
();
int
n
=
x
->
dims
()[
0
];
int
d
=
x
->
dims
()[
1
];
int
batch_size
=
x
->
dims
()[
0
];
int
class_num
=
x
->
dims
()[
1
];
int
block
=
512
;
int
grid
=
(
n
+
block
-
1
)
/
block
;
// TODO(qingqing) launch kernel on specified stream
// base on ExecutionContext.
if
(
ctx
.
Attr
<
bool
>
(
"soft_label"
))
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
T
>
();
grid
=
d
;
SoftCrossEntropyKernel
<
T
,
512
><<<
grid
,
block
>>>
(
y_data
,
x_data
,
label_data
,
n
,
d
);
if
(
class_num
>
512
)
{
SoftCrossEntropyKernel2
<
T
,
512
><<<
batch_size
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
y_data
,
x_data
,
label_data
,
class_num
);
}
else
{
int
block_size
=
pow
(
2
,
int
(
std
::
log2
(
class_num
)));
SoftCrossEntropyKernel1
<
T
><<<
batch_size
,
block_size
,
block_size
*
sizeof
(
T
),
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
y_data
,
x_data
,
label_data
,
class_num
);
}
}
else
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
int
>
();
CrossEntropyKernel
<
T
><<<
grid
,
block
>>>
(
y_data
,
x_data
,
label_data
,
n
,
d
);
int
grid
=
(
batch_size
+
block
-
1
)
/
block
;
CrossEntropyKernel
<
T
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
y_data
,
x_data
,
label_data
,
batch_size
,
class_num
);
}
}
};
...
...
@@ -148,25 +194,27 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
int
n
=
x
->
dims
()[
0
];
int
d
=
x
->
dims
()[
1
];
int
block
=
512
;
int
grid
=
(
n
*
d
+
block
-
1
)
/
block
;
zero
<
T
><<<
grid
,
block
>>>
(
dx_data
,
n
*
d
);
grid
=
(
n
+
block
-
1
)
/
block
;
// TODO(qingqing): launch kernel on specified stream
// base on ExecutionContext.
zero
<
T
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
dx_data
,
n
*
d
);
if
(
ctx
.
Attr
<
bool
>
(
"soft_label"
))
{
int
block_x
=
32
;
int
block_y
=
32
;
dim3
block
(
block_x
,
block_y
);
dim3
grid
((
n
+
block_x
-
1
)
/
block_x
,
(
d
+
block_y
-
1
)
/
block_y
);
auto
*
label_data
=
label
->
data
<
T
>
();
SoftCrossEntropyGradientKernel
<
T
><<<
grid
,
block
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
n
,
d
);
SoftCrossEntropyGradientKernel
<
T
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
n
,
d
);
}
else
{
auto
*
label_data
=
label
->
data
<
int
>
();
CrossEntropyGradientKernel
<
T
><<<
grid
,
block
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
n
,
d
);
CrossEntropyGradientKernel
<
T
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
n
,
d
);
}
}
};
...
...
paddle/operators/cross_entropy_op.h
浏览文件 @
6735585b
...
...
@@ -31,12 +31,8 @@ struct TolerableValue {
PADDLE_ASSERT
(
std
::
is_floating_point
<
T
>::
value
);
const
T
kApproInf
=
1e20
;
if
(
x
==
INFINITY
)
{
return
kApproInf
;
}
if
(
x
==
-
INFINITY
)
{
return
-
kApproInf
;
}
if
(
x
==
INFINITY
)
return
kApproInf
;
if
(
x
==
-
INFINITY
)
return
-
kApproInf
;
return
x
;
}
};
...
...
@@ -58,11 +54,8 @@ class CrossEntropyOpKernel : public framework::OpKernel {
auto
lbl_mat
=
EigenMatrix
<
T
>::
From
(
*
labels
);
auto
loss
=
EigenMatrix
<
T
>::
From
(
*
y
);
// loss.device(ctx.GetEigenDevice<platform::CPUPlace>()) =
// prob.log().unaryExpr(TolerableValue<T>());
loss
.
device
(
ctx
.
GetEigenDevice
<
platform
::
CPUPlace
>
())
=
-
((
lbl_mat
*
prob
.
log
())
-
((
lbl_mat
*
prob
.
log
()
.
unaryExpr
(
TolerableValue
<
T
>
())
)
.
sum
(
Eigen
::
DSizes
<
int
,
1
>
(
1
))
.
reshape
(
Eigen
::
DSizes
<
int
,
2
>
(
batch_size
,
1
)));
}
else
{
...
...
paddle/operators/lookup_table_op.cu
浏览文件 @
6735585b
...
...
@@ -77,7 +77,10 @@ class LookupTableCUDAKernel : public framework::OpKernel {
dim3
threads
(
128
,
8
);
dim3
grids
(
8
,
1
);
LookupTable
<
T
,
128
,
8
,
8
><<<
grids
,
threads
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
);
LookupTable
<
T
,
128
,
8
,
8
><<<
grids
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
.
device_context
())
.
stream
()
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
);
}
};
...
...
@@ -102,8 +105,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel {
dim3
threads
(
128
,
8
);
dim3
grids
(
8
,
1
);
LookupTableGrad
<
T
,
128
,
8
,
8
><<<
grids
,
threads
>>>
(
d_table
,
d_output
,
ids
,
N
,
K
,
D
);
LookupTableGrad
<
T
,
128
,
8
,
8
><<<
grids
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
.
device_context
())
.
stream
()
>>>
(
d_table
,
d_output
,
ids
,
N
,
K
,
D
);
}
};
...
...
paddle/operators/top_k_op.cu
浏览文件 @
6735585b
...
...
@@ -301,14 +301,16 @@ class TopkOpCUDAKernel : public framework::OpKernel {
// NOTE: pass lds and dim same to input width.
// NOTE: old matrix implementation of stride is different to eigen.
// TODO(typhoonzero): launch kernel on specified stream.
// TODO(typhoonzero): refine this kernel.
dim3
threads
(
256
,
1
);
dim3
grid
(
input_height
,
1
);
KeMatrixTopK
<
T
,
5
,
256
><<<
grid
,
threads
>>>
(
output_data
,
output
->
dims
()[
1
],
indices_data
,
input_data
,
input_width
,
input_width
,
int
(
k
));
KeMatrixTopK
<
T
,
5
,
256
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
output_data
,
output
->
dims
()[
1
],
indices_data
,
input_data
,
input_width
,
input_width
,
int
(
k
));
}
};
...
...
python/paddle/v2/framework/tests/test_cross_entropy_op.py
浏览文件 @
6735585b
...
...
@@ -19,7 +19,7 @@ class TestCrossEntropyOp1(OpTest):
dtype
=
"float32"
)
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
'soft_label'
:
False
}
self
.
attrs
=
{
"soft_label"
:
False
}
def
test_check_output
(
self
):
self
.
check_output
()
...
...
@@ -34,7 +34,8 @@ class TestCrossEntropyOp2(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"cross_entropy"
batch_size
=
13
batch_size
=
5
# this setting tests threads in more than one wrap.
class_num
=
37
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
...
...
@@ -43,9 +44,9 @@ class TestCrossEntropyOp2(OpTest):
label
/=
label
.
sum
(
axis
=
1
,
keepdims
=
True
)
cross_entropy
=
(
-
label
*
np
.
log
(
X
)).
sum
(
axis
=
1
,
keepdims
=
True
).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
X
,
'Label'
:
label
}
self
.
outputs
=
{
'Y'
:
cross_entropy
}
self
.
attrs
=
{
'soft_label'
:
True
}
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
"soft_label"
:
True
}
def
test_check_output
(
self
):
self
.
check_output
()
...
...
@@ -61,8 +62,9 @@ class TestCrossEntropyOp3(OpTest):
def
setUp
(
self
):
self
.
op_type
=
"cross_entropy"
batch_size
=
13
class_num
=
37
batch_size
=
5
# this setting tests all threads in one wrap.
class_num
=
17
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label_index
=
np
.
random
.
randint
(
...
...
@@ -74,9 +76,36 @@ class TestCrossEntropyOp3(OpTest):
dtype
=
"float32"
)
cross_entropy2
=
(
-
label
*
np
.
log
(
X
)).
sum
(
axis
=
1
,
keepdims
=
True
).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
X
,
'Label'
:
label
}
self
.
outputs
=
{
'Y'
:
cross_entropy
}
self
.
attrs
=
{
'soft_label'
:
True
}
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
"soft_label"
:
True
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Y"
,
max_relative_error
=
0.05
)
class
TestCrossEntropyOp4
(
OpTest
):
"""Test soft-label cross-entropy.
This unittest tests the gpu kernel for layer size excesses 512.
"""
def
setUp
(
self
):
self
.
op_type
=
"cross_entropy"
batch_size
=
2
class_num
=
517
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label
/=
label
.
sum
(
axis
=
1
,
keepdims
=
True
)
cross_entropy
=
(
-
label
*
np
.
log
(
X
)).
sum
(
axis
=
1
,
keepdims
=
True
).
astype
(
"float32"
)
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
"soft_label"
:
True
}
def
test_check_output
(
self
):
self
.
check_output
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录