Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
000d7511
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
000d7511
编写于
9月 25, 2017
作者:
C
caoying03
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix backward op.
上级
201c2bcf
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
76 addition
and
73 deletion
+76
-73
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+20
-17
paddle/operators/cross_entropy_op.cu
paddle/operators/cross_entropy_op.cu
+27
-25
paddle/operators/cross_entropy_op.h
paddle/operators/cross_entropy_op.h
+25
-28
python/paddle/v2/framework/tests/test_cross_entropy_op.py
python/paddle/v2/framework/tests/test_cross_entropy_op.py
+4
-3
未找到文件。
paddle/operators/cross_entropy_op.cc
浏览文件 @
000d7511
...
@@ -37,13 +37,13 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
...
@@ -37,13 +37,13 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
0
],
label
->
dims
()[
0
],
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
0
],
label
->
dims
()[
0
],
"The 1st dimension of Input(X) and Input(Label) should "
"The 1st dimension of Input(X) and Input(Label) should "
"be equal."
);
"be equal."
);
if
(
ctx
.
Attr
<
bool
>
(
"soft
_l
abel"
))
{
if
(
ctx
.
Attr
<
bool
>
(
"soft
L
abel"
))
{
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
1
],
label
->
dims
()[
1
],
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
1
],
label
->
dims
()[
1
],
"If Attr(soft
_l
abel) == true, the 2nd dimension of "
"If Attr(soft
L
abel) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal."
);
"Input(X) and Input(Label) should be equal."
);
}
else
{
}
else
{
PADDLE_ENFORCE_EQ
(
label
->
dims
()[
1
],
1
,
PADDLE_ENFORCE_EQ
(
label
->
dims
()[
1
],
1
,
"If Attr(soft
_l
abel) == false, the 2nd dimension of "
"If Attr(soft
L
abel) == false, the 2nd dimension of "
"Input(Label) should be 1."
);
"Input(Label) should be 1."
);
}
}
...
@@ -63,6 +63,8 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
...
@@ -63,6 +63,8 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
"Input(Label) should be not null."
);
"Input(Label) should be not null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Y"
)),
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Y"
)),
"Input(Y@GRAD) shoudl be not null."
);
"Input(Y@GRAD) shoudl be not null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
framework
::
GradVarName
(
"X"
)),
"Output(X@GRAD) should be not null."
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
...
@@ -80,13 +82,13 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
...
@@ -80,13 +82,13 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
"be equal."
);
"be equal."
);
PADDLE_ENFORCE_EQ
(
dy
->
dims
()[
1
],
1
,
PADDLE_ENFORCE_EQ
(
dy
->
dims
()[
1
],
1
,
"The 2nd dimension of Input(Y@Grad) should be 1."
);
"The 2nd dimension of Input(Y@Grad) should be 1."
);
if
(
ctx
.
Attr
<
bool
>
(
"soft
_l
abel"
))
{
if
(
ctx
.
Attr
<
bool
>
(
"soft
L
abel"
))
{
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
1
],
label
->
dims
()[
1
],
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
1
],
label
->
dims
()[
1
],
"When Attr(soft
_l
abel) == true, the 2nd dimension of "
"When Attr(soft
L
abel) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal."
);
"Input(X) and Input(Label) should be equal."
);
}
else
{
}
else
{
PADDLE_ENFORCE_EQ
(
label
->
dims
()[
1
],
1
,
PADDLE_ENFORCE_EQ
(
label
->
dims
()[
1
],
1
,
"When Attr(soft
_l
abel) == false, the 2nd dimension of "
"When Attr(soft
L
abel) == false, the 2nd dimension of "
"Input(Label) should be 1."
);
"Input(Label) should be 1."
);
}
}
...
@@ -105,18 +107,19 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -105,18 +107,19 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
"where N is the batch size and D is the number of classes. "
"where N is the batch size and D is the number of classes. "
"This input is a probability computed by the previous operator, "
"This input is a probability computed by the previous operator, "
"which is almost always the result of a softmax operator."
);
"which is almost always the result of a softmax operator."
);
AddInput
(
"Label"
,
AddInput
(
"(Tensor, default Tensor<int>), the ground truth which is "
"Label"
,
"a 1-D or 2-D tensor. "
"(Tensor, default Tensor<int>), the ground truth which is "
"When soft_label is set to 0, `Label` is a Tensor<int> with shape "
"a 2-D tensor. "
"[N x 1]. "
"When softLabel is set to false, `Label` is a Tensor<int> with shape "
"When soft_label is set to 1, `Label` is a Tensor<float/double> "
"[N x 1]. "
"with shape [N x K]."
);
"When softLabel is set to true, `Label` is a Tensor<float/double> "
"with shape [N x K]."
);
AddOutput
(
"Y"
,
AddOutput
(
"Y"
,
"(Tensor, default Tensor<float>), a
1
-D tensor "
"(Tensor, default Tensor<float>), a
2
-D tensor "
"with shape [N x 1]. The cross entropy loss."
);
"with shape [N x 1]. The cross entropy loss."
);
AddAttr
<
bool
>
(
AddAttr
<
bool
>
(
"soft
_l
abel"
,
"soft
L
abel"
,
"(bool, default false), a flag to indicate whether to interpretate "
"(bool, default false), a flag to indicate whether to interpretate "
"the given labels as soft labels."
)
"the given labels as soft labels."
)
.
SetDefault
(
false
);
.
SetDefault
(
false
);
...
@@ -126,12 +129,12 @@ CrossEntropy Operator.
...
@@ -126,12 +129,12 @@ CrossEntropy Operator.
It supports both standard cross-entropy and soft-label cross-entropy loss
It supports both standard cross-entropy and soft-label cross-entropy loss
computation.
computation.
1) One-hot cross-entropy:
1) One-hot cross-entropy:
soft
_label = F
alse, Label[i, 0] indicates the class index for sample i:
soft
Label = f
alse, Label[i, 0] indicates the class index for sample i:
Y[i] = -log(X[i, Label[i]])
Y[i] = -log(X[i, Label[i]])
2) Soft-label cross-entropy:
2) Soft-label cross-entropy:
soft
_label = T
rue, Label[i, j] indicates the soft label of class j
soft
Label = t
rue, Label[i, j] indicates the soft label of class j
for sample i:
for sample i:
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
...
...
paddle/operators/cross_entropy_op.cu
浏览文件 @
000d7511
...
@@ -70,7 +70,7 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
...
@@ -70,7 +70,7 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
// TODO(qingqing): make zero setting a common function.
// TODO(qingqing): make zero setting a common function.
template
<
typename
T
>
template
<
typename
T
>
__global__
void
z
ero
(
T
*
X
,
const
int
N
)
{
__global__
void
Z
ero
(
T
*
X
,
const
int
N
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
X
[
i
]
=
0.0
;
X
[
i
]
=
0.0
;
...
@@ -108,18 +108,17 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
...
@@ -108,18 +108,17 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"This kernel only runs on GPU device."
);
"This kernel only runs on GPU device."
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
y
=
ctx
.
Output
<
Tensor
>
(
"Y
"
);
const
Tensor
*
label
=
ctx
.
Input
<
Tensor
>
(
"Label
"
);
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label
"
);
Tensor
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y
"
);
auto
*
x_data
=
x
->
data
<
T
>
();
const
T
*
x_data
=
x
->
data
<
T
>
();
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
y_data
=
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
y_data
=
y
->
data
<
T
>
();
int
batch_size
=
x
->
dims
()[
0
];
int
batch_size
=
x
->
dims
()[
0
];
int
class_num
=
x
->
dims
()[
1
];
int
class_num
=
x
->
dims
()[
1
];
if
(
ctx
.
Attr
<
bool
>
(
"soft
_l
abel"
))
{
if
(
ctx
.
Attr
<
bool
>
(
"soft
L
abel"
))
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
T
>
();
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
T
>
();
int
block
=
class_num
>
512
?
512
:
pow
(
2
,
int
(
std
::
log2
(
class_num
)));
int
block
=
class_num
>
512
?
512
:
pow
(
2
,
int
(
std
::
log2
(
class_num
)));
...
@@ -148,38 +147,41 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
...
@@ -148,38 +147,41 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"This kernel only runs on GPU device."
);
"This kernel only runs on GPU device."
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
const
Tensor
*
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
dy
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
Tensor
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
T
*
dy_data
=
auto
*
dy_data
=
dy
->
data
<
T
>
();
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
))
->
data
<
T
>
();
auto
*
x_data
=
x
->
data
<
T
>
();
T
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
T
*
x_data
=
x
->
data
<
T
>
();
int
n
=
x
->
dims
()[
0
];
int
batch_size
=
x
->
dims
()[
0
];
int
d
=
x
->
dims
()[
1
];
int
class_num
=
x
->
dims
()[
1
];
int
block
=
512
;
int
block
=
512
;
int
grid
=
(
n
*
d
+
block
-
1
)
/
block
;
int
grid
=
(
batch_size
*
class_num
+
block
-
1
)
/
block
;
zero
<
T
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
if
(
ctx
.
Attr
<
bool
>
(
"softLabel"
))
{
ctx
.
device_context
())
.
stream
()
>>>
(
dx_data
,
n
*
d
);
if
(
ctx
.
Attr
<
bool
>
(
"soft_label"
))
{
auto
*
label_data
=
label
->
data
<
T
>
();
auto
*
label_data
=
label
->
data
<
T
>
();
SoftCrossEntropyGradientKernel
<
T
><<<
SoftCrossEntropyGradientKernel
<
T
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
ctx
.
device_context
())
.
stream
()
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
.
stream
()
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
n
,
d
);
batch_size
,
class_num
);
}
else
{
}
else
{
Zero
<
T
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
dx_data
,
batch_size
*
class_num
);
auto
*
label_data
=
label
->
data
<
int
>
();
auto
*
label_data
=
label
->
data
<
int
>
();
grid
=
(
batch_size
+
block
-
1
)
/
block
;
CrossEntropyGradientKernel
<
T
><<<
CrossEntropyGradientKernel
<
T
><<<
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
grid
,
block
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
ctx
.
device_context
())
.
stream
()
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
.
stream
()
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
n
,
d
);
batch_size
,
class_num
);
}
}
}
}
};
};
...
...
paddle/operators/cross_entropy_op.h
浏览文件 @
000d7511
...
@@ -42,14 +42,14 @@ class CrossEntropyOpKernel : public framework::OpKernel {
...
@@ -42,14 +42,14 @@ class CrossEntropyOpKernel : public framework::OpKernel {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"
It must use CPUPlace
."
);
"
This kernel only runs on CPU
."
);
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
labels
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
const
Tensor
*
labels
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
Tensor
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
Tensor
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
y_data
=
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
const
int
batch_size
=
x
->
dims
()[
0
];
const
int
batch_size
=
x
->
dims
()[
0
];
if
(
ctx
.
Attr
<
bool
>
(
"soft
_l
abel"
))
{
if
(
ctx
.
Attr
<
bool
>
(
"soft
L
abel"
))
{
auto
prob
=
EigenMatrix
<
T
>::
From
(
*
x
);
auto
prob
=
EigenMatrix
<
T
>::
From
(
*
x
);
auto
lbl_mat
=
EigenMatrix
<
T
>::
From
(
*
labels
);
auto
lbl_mat
=
EigenMatrix
<
T
>::
From
(
*
labels
);
auto
loss
=
EigenMatrix
<
T
>::
From
(
*
y
);
auto
loss
=
EigenMatrix
<
T
>::
From
(
*
y
);
...
@@ -60,9 +60,7 @@ class CrossEntropyOpKernel : public framework::OpKernel {
...
@@ -60,9 +60,7 @@ class CrossEntropyOpKernel : public framework::OpKernel {
.
reshape
(
Eigen
::
DSizes
<
int
,
2
>
(
batch_size
,
1
)));
.
reshape
(
Eigen
::
DSizes
<
int
,
2
>
(
batch_size
,
1
)));
}
else
{
}
else
{
const
int
class_num
=
x
->
dims
()[
1
];
const
int
class_num
=
x
->
dims
()[
1
];
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
x_data
=
x
->
data
<
T
>
();
T
*
y_data
=
y
->
data
<
T
>
();
const
int
*
label_data
=
labels
->
data
<
int
>
();
const
int
*
label_data
=
labels
->
data
<
int
>
();
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
...
@@ -78,33 +76,32 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel {
...
@@ -78,33 +76,32 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"It must use CPUPlace."
);
"This kernel only runs on CPU."
);
const
Tensor
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
dy
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
const
Tensor
*
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
dy
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
Tensor
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
T
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
dy_data
=
dy
->
data
<
T
>
();
auto
*
x_data
=
x
->
data
<
T
>
();
int
batch_size
=
x
->
dims
()[
0
];
int
class_num
=
x
->
dims
()[
1
];
int
class_num
=
x
->
dims
()[
1
];
if
(
ctx
.
Attr
<
bool
>
(
"softLabel"
))
{
// TODO(qingqing): make zero setting an common function.
auto
x_mat
=
EigenMatrix
<
T
>::
From
(
*
x
);
if
(
ctx
.
Attr
<
bool
>
(
"soft_label"
))
{
auto
dy_mat
=
EigenMatrix
<
T
>::
From
(
*
dy
);
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
T
>
();
auto
lbl_mat
=
EigenMatrix
<
T
>::
From
(
*
label
);
int
index
=
0
;
auto
dx_mat
=
EigenMatrix
<
T
>::
From
(
*
dx
);
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
j
=
0
;
j
<
class_num
;
++
j
)
{
dx_mat
.
device
(
ctx
.
GetEigenDevice
<
platform
::
CPUPlace
>
())
=
dx_data
[
index
]
=
-
label_data
[
index
]
*
dy_data
[
i
]
/
x_data
[
index
];
-
(
lbl_mat
*
dy_mat
.
broadcast
(
Eigen
::
DSizes
<
int
,
2
>
(
1
,
class_num
))
/
index
++
;
x_mat
);
}
}
}
else
{
}
else
{
auto
*
label_data
=
label
->
data
<
int
>
();
int
batch_size
=
x
->
dims
()[
0
];
const
T
*
dy_data
=
dy
->
data
<
T
>
();
const
T
*
x_data
=
x
->
data
<
T
>
();
const
int
*
label_data
=
label
->
data
<
int
>
();
// TODO(qingqing): make zero setting a common function.
memset
(
dx_data
,
0
,
sizeof
(
T
)
*
batch_size
*
class_num
);
memset
(
dx_data
,
0
,
sizeof
(
T
)
*
batch_size
*
class_num
);
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
PADDLE_ASSERT
(
label_data
[
i
]
>=
0
||
label_data
[
i
]
<
class_num
);
PADDLE_ASSERT
(
label_data
[
i
]
>=
0
||
label_data
[
i
]
<
class_num
);
int
index
=
i
*
class_num
+
label_data
[
i
];
int
index
=
i
*
class_num
+
label_data
[
i
];
...
...
python/paddle/v2/framework/tests/test_cross_entropy_op.py
浏览文件 @
000d7511
...
@@ -21,7 +21,7 @@ class TestCrossEntropyOp1(OpTest):
...
@@ -21,7 +21,7 @@ class TestCrossEntropyOp1(OpTest):
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
"soft
_l
abel"
:
False
}
self
.
attrs
=
{
"soft
L
abel"
:
False
}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
@@ -49,7 +49,7 @@ class TestCrossEntropyOp2(OpTest):
...
@@ -49,7 +49,7 @@ class TestCrossEntropyOp2(OpTest):
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
"soft
_l
abel"
:
True
}
self
.
attrs
=
{
"soft
L
abel"
:
True
}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
@@ -73,6 +73,7 @@ class TestCrossEntropyOp3(OpTest):
...
@@ -73,6 +73,7 @@ class TestCrossEntropyOp3(OpTest):
0
,
class_num
,
(
batch_size
),
dtype
=
"int32"
)
0
,
class_num
,
(
batch_size
),
dtype
=
"int32"
)
label
=
np
.
zeros
(
X
.
shape
)
label
=
np
.
zeros
(
X
.
shape
)
label
[
np
.
arange
(
batch_size
),
label_index
]
=
1
label
[
np
.
arange
(
batch_size
),
label_index
]
=
1
cross_entropy
=
np
.
asmatrix
(
cross_entropy
=
np
.
asmatrix
(
[[
-
np
.
log
(
X
[
i
][
label_index
[
i
]])]
for
i
in
range
(
X
.
shape
[
0
])],
[[
-
np
.
log
(
X
[
i
][
label_index
[
i
]])]
for
i
in
range
(
X
.
shape
[
0
])],
dtype
=
"float32"
)
dtype
=
"float32"
)
...
@@ -81,7 +82,7 @@ class TestCrossEntropyOp3(OpTest):
...
@@ -81,7 +82,7 @@ class TestCrossEntropyOp3(OpTest):
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
"soft
_l
abel"
:
True
}
self
.
attrs
=
{
"soft
L
abel"
:
True
}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录