Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
6d60352e
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6d60352e
编写于
9月 13, 2017
作者:
X
Xinghai Sun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add soft-label support for cross-entropy operator.
上级
0f42e564
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
205 addition
and
99 deletion
+205
-99
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+38
-26
paddle/operators/cross_entropy_op.cu
paddle/operators/cross_entropy_op.cu
+81
-38
paddle/operators/cross_entropy_op.h
paddle/operators/cross_entropy_op.h
+61
-31
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+1
-1
python/paddle/v2/framework/tests/test_cross_entropy_op.py
python/paddle/v2/framework/tests/test_cross_entropy_op.py
+23
-2
python/paddle/v2/framework/tests/test_mnist.py
python/paddle/v2/framework/tests/test_mnist.py
+1
-1
未找到文件。
paddle/operators/cross_entropy_op.cc
浏览文件 @
6d60352e
...
...
@@ -17,48 +17,62 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
class
Onehot
CrossEntropyOp
:
public
framework
::
OperatorWithKernel
{
class
CrossEntropyOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
auto
*
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
label
=
ctx
.
Input
<
Tensor
>
(
"
l
abel"
);
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
label
=
ctx
.
Input
<
Tensor
>
(
"
L
abel"
);
PADDLE_ENFORCE_EQ
(
X
->
dims
().
size
(),
2
,
"X's dimension must be 2."
);
PADDLE_ENFORCE_EQ
(
label
->
dims
().
size
(),
1
,
"label's dimension must be 1."
);
PADDLE_ENFORCE_EQ
(
X
->
dims
()[
0
],
label
->
dims
()[
0
]);
ctx
.
Output
<
Tensor
>
(
"Y"
)
->
Resize
({
X
->
dims
()[
0
]});
PADDLE_ENFORCE_EQ
(
x
->
dims
().
size
(),
2
,
"X's rank must be 2."
);
PADDLE_ASSERT
(
label
->
dims
().
size
()
==
1
||
label
->
dims
().
size
()
==
2
);
if
(
label
->
dims
().
size
()
==
2
)
{
// soft cross entropy
PADDLE_ENFORCE_EQ
(
x
->
dims
(),
label
->
dims
());
}
else
{
// normal cross entropy
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
0
],
label
->
dims
()[
0
]);
}
ctx
.
Output
<
Tensor
>
(
"Y"
)
->
Resize
({
x
->
dims
()[
0
]});
}
};
class
Onehot
CrossEntropyGradientOp
:
public
framework
::
OperatorWithKernel
{
class
CrossEntropyGradientOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
auto
d
X
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
d
x
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
d
X
->
Resize
(
X
->
dims
());
d
x
->
Resize
(
x
->
dims
());
}
};
class
Onehot
CrossEntropyOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
CrossEntropyOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
Onehot
CrossEntropyOpMaker
(
framework
::
OpProto
*
proto
,
CrossEntropyOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The first input of
Onehot
CrossEntropyOp"
);
AddInput
(
"
label"
,
"The second input of Onehot
CrossEntropyOp"
);
AddOutput
(
"Y"
,
"The output of
Onehot
CrossEntropyOp"
);
AddInput
(
"X"
,
"The first input of CrossEntropyOp"
);
AddInput
(
"
Label"
,
"The second input of
CrossEntropyOp"
);
AddOutput
(
"Y"
,
"The output of CrossEntropyOp"
);
AddComment
(
R"DOC(
Onehot
CrossEntropy Operator.
CrossEntropy Operator.
Y[i] = -log(X[i][j])
The second input (Label tensor) supports two kinds of shapes:
1) Rank(Label) = 1, Label[i] indicates the class index for sample i:
Y[i] = -log(X[i, Label[i]])
2) Rank(Label) = 2, Label[i, j] indicates the soft label of class j
for sample i:
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
Please make sure that in this case the summuation of each row of Label
equals one. If each row of Label has only one non-zero element (equals 1),
it degenerates to a standard one-hot representation.
)DOC"
);
}
};
...
...
@@ -66,10 +80,8 @@ OnehotCrossEntropy Operator.
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
onehot_cross_entropy
,
ops
::
OnehotCrossEntropyOp
,
ops
::
OnehotCrossEntropyOpMaker
,
onehot_cross_entropy_grad
,
ops
::
OnehotCrossEntropyGradientOp
);
REGISTER_OP_CPU_KERNEL
(
onehot_cross_entropy
,
ops
::
OnehotCrossEntropyOpKernel
<
float
>
);
REGISTER_OP_CPU_KERNEL
(
onehot_cross_entropy_grad
,
ops
::
OnehotCrossEntropyGradientOpKernel
<
float
>
);
REGISTER_OP
(
cross_entropy
,
ops
::
CrossEntropyOp
,
ops
::
CrossEntropyOpMaker
,
cross_entropy_grad
,
ops
::
CrossEntropyGradientOp
);
REGISTER_OP_CPU_KERNEL
(
cross_entropy
,
ops
::
CrossEntropyOpKernel
<
float
>
);
REGISTER_OP_CPU_KERNEL
(
cross_entropy_grad
,
ops
::
CrossEntropyGradientOpKernel
<
float
>
);
paddle/operators/cross_entropy_op.cu
浏览文件 @
6d60352e
...
...
@@ -21,17 +21,16 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
__host__
__device__
T
clipping_log
(
const
T
x
)
{
__host__
__device__
T
tolerable_value
(
const
T
x
)
{
PADDLE_ASSERT
(
std
::
is_floating_point
<
T
>::
value
);
const
T
kApproInf
=
1e20
;
T
v
=
log
(
x
);
if
(
v
==
INFINITY
)
{
if
(
x
==
INFINITY
)
{
return
kApproInf
;
}
if
(
v
==
-
INFINITY
)
{
if
(
x
==
-
INFINITY
)
{
return
-
kApproInf
;
}
return
v
;
return
x
;
}
template
<
typename
T
>
...
...
@@ -42,7 +41,20 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
PADDLE_ASSERT
(
label
[
i
]
>=
0
&&
label
[
i
]
<
D
);
Y
[
i
]
=
-
clipping_log
(
X
[
i
*
D
+
label
[
i
]]);
Y
[
i
]
=
-
tolerable_value
(
log
(
X
[
i
*
D
+
label
[
i
]]));
}
}
template
<
typename
T
>
__global__
void
SoftCrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
T
*
label
,
const
int
N
,
const
int
D
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
sum
=
static_cast
<
T
>
(
0
);
for
(
int
j
=
0
;
j
<
D
;
j
++
)
{
sum
+=
label
[
i
*
D
+
j
]
*
log
(
X
[
i
*
D
+
j
]);
}
Y
[
i
]
=
-
tolerable_value
(
sum
);
}
}
...
...
@@ -69,57 +81,89 @@ __global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
}
template
<
typename
T
>
class
OnehotCrossEntropyOpCUDAKernel
:
public
framework
::
OpKernel
{
__global__
void
SoftCrossEntropyGradientKernel
(
T
*
dX
,
const
T
*
dY
,
const
T
*
X
,
const
T
*
label
,
const
int
N
,
const
int
D
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(
int
j
=
0
;
j
<
D
;
++
j
)
{
int
idx
=
i
*
D
+
j
;
dX
[
idx
]
=
-
label
[
idx
]
*
dY
[
i
]
/
X
[
idx
];
}
}
}
template
<
typename
T
>
class
CrossEntropyOpCUDAKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use GPUPlace."
);
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
T
*
Xdata
=
X
->
data
<
T
>
();
const
int
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"label"
)
->
data
<
int
>
();
auto
Y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
Y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
Ydata
=
Y
->
data
<
T
>
();
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
*
x_data
=
x
->
data
<
T
>
();
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
y_data
=
y
->
data
<
T
>
();
int
N
=
X
->
dims
()[
0
];
int
D
=
X
->
dims
()[
1
];
int
n
=
x
->
dims
()[
0
];
int
d
=
x
->
dims
()[
1
];
int
block
=
512
;
int
grid
=
(
N
+
block
-
1
)
/
block
;
int
grid
=
(
n
+
block
-
1
)
/
block
;
// TODO(qingqing) launch kernel on specified stream
// base on ExecutionContext.
CrossEntropyKernel
<
T
><<<
grid
,
block
>>>
(
Ydata
,
Xdata
,
label_data
,
N
,
D
);
int
label_rank
=
label
->
dims
().
size
();
if
(
label_rank
==
2
)
{
// soft cross entropy
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
T
>
();
SoftCrossEntropyKernel
<
T
><<<
grid
,
block
>>>
(
y_data
,
x_data
,
label_data
,
n
,
d
);
}
else
{
// normal cross entropy
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
int
>
();
CrossEntropyKernel
<
T
><<<
grid
,
block
>>>
(
y_data
,
x_data
,
label_data
,
n
,
d
);
}
}
};
template
<
typename
T
>
class
Onehot
CrossEntropyGradientOpCUDAKernel
:
public
framework
::
OpKernel
{
class
CrossEntropyGradientOpCUDAKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use GPUPlace."
);
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
d
X
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
d
Y
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
label
=
ctx
.
Input
<
Tensor
>
(
"
l
abel"
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
d
x
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
d
y
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
label
=
ctx
.
Input
<
Tensor
>
(
"
L
abel"
);
auto
*
dXdata
=
dX
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
());
auto
*
dYdata
=
dY
->
template
data
<
T
>();
auto
*
Xdata
=
X
->
template
data
<
T
>();
auto
*
label_data
=
label
->
data
<
int
>
();
auto
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
dy_data
=
dy
->
data
<
T
>
();
auto
*
x_data
=
x
->
data
<
T
>
();
int
N
=
X
->
dims
()[
0
];
int
D
=
X
->
dims
()[
1
];
int
n
=
x
->
dims
()[
0
];
int
d
=
x
->
dims
()[
1
];
int
block
=
512
;
int
grid
=
(
N
*
D
+
block
-
1
)
/
block
;
zero
<
T
><<<
grid
,
block
>>>
(
dXdata
,
N
*
D
);
grid
=
(
N
+
block
-
1
)
/
block
;
int
grid
=
(
n
*
d
+
block
-
1
)
/
block
;
zero
<
T
><<<
grid
,
block
>>>
(
dx_data
,
n
*
d
);
grid
=
(
n
+
block
-
1
)
/
block
;
// TODO(qingqing): launch kernel on specified stream
// base on ExecutionContext.
CrossEntropyGradientKernel
<
T
><<<
grid
,
block
>>>
(
dXdata
,
dYdata
,
Xdata
,
label_data
,
N
,
D
);
int
label_rank
=
label
->
dims
().
size
();
if
(
label_rank
==
2
)
{
// soft cross entropy
auto
*
label_data
=
label
->
data
<
T
>
();
SoftCrossEntropyGradientKernel
<
T
><<<
grid
,
block
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
n
,
d
);
}
else
{
// normal cross entropy
auto
*
label_data
=
label
->
data
<
int
>
();
CrossEntropyGradientKernel
<
T
><<<
grid
,
block
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
n
,
d
);
}
}
};
...
...
@@ -127,7 +171,6 @@ class OnehotCrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
onehot_cross_entropy
,
ops
::
OnehotCrossEntropyOpCUDAKernel
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
onehot_cross_entropy_grad
,
ops
::
OnehotCrossEntropyGradientOpCUDAKernel
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
cross_entropy
,
ops
::
CrossEntropyOpCUDAKernel
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
cross_entropy_grad
,
ops
::
CrossEntropyGradientOpCUDAKernel
<
float
>
);
paddle/operators/cross_entropy_op.h
浏览文件 @
6d60352e
...
...
@@ -40,56 +40,86 @@ inline T tolerable_value(const T x) {
}
template
<
typename
T
>
class
Onehot
CrossEntropyOpKernel
:
public
framework
::
OpKernel
{
class
CrossEntropyOpKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"It must use CPUPlace."
);
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
T
*
Xdata
=
X
->
data
<
T
>
();
const
int
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"label"
)
->
data
<
int
>
();
auto
Y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
Y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
x_data
=
x
->
data
<
T
>
();
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
y_data
=
y
->
data
<
T
>
();
T
*
Ydata
=
Y
->
data
<
T
>
();
int
batch_size
=
X
->
dims
()[
0
];
int
class_num
=
X
->
dims
()[
1
];
int
batch_size
=
x
->
dims
()[
0
];
int
class_num
=
x
->
dims
()[
1
];
int
label_rank
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
dims
().
size
();
if
(
label_rank
==
2
)
{
// soft cross entropy
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
T
>
();
int
index
=
0
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
T
sum
=
static_cast
<
T
>
(
0
);
for
(
int
j
=
0
;
j
<
class_num
;
++
j
)
{
sum
+=
label_data
[
index
]
*
std
::
log
(
x_data
[
index
]);
y_data
[
i
]
=
-
tolerable_value
(
sum
);
index
++
;
}
}
}
else
{
// normal cross entropy
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
int
>
();
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
int
index
=
i
*
class_num
+
label_data
[
i
];
Ydata
[
i
]
=
-
tolerable_value
(
std
::
log
(
Xdata
[
index
]));
y_data
[
i
]
=
-
tolerable_value
(
std
::
log
(
x_data
[
index
]));
}
}
}
};
template
<
typename
T
>
class
Onehot
CrossEntropyGradientOpKernel
:
public
framework
::
OpKernel
{
class
CrossEntropyGradientOpKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"It must use CPUPlace."
);
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
d
X
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
d
Y
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
label
=
ctx
.
Input
<
Tensor
>
(
"
l
abel"
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
d
x
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
d
y
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
label
=
ctx
.
Input
<
Tensor
>
(
"
L
abel"
);
auto
*
dXdata
=
dX
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
());
auto
*
dYdata
=
dY
->
template
data
<
T
>();
auto
*
Xdata
=
X
->
template
data
<
T
>();
auto
*
label_data
=
label
->
data
<
int
>
();
auto
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
dy_data
=
dy
->
data
<
T
>
();
auto
*
x_data
=
x
->
data
<
T
>
();
const
int
batch_size
=
X
->
dims
()[
0
];
const
int
class_num
=
X
->
dims
()[
1
];
int
batch_size
=
x
->
dims
()[
0
];
int
class_num
=
x
->
dims
()[
1
];
int
label_rank
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
dims
().
size
();
// TODO(qingqing): make zero setting an common function.
memset
(
dXdata
,
0
,
sizeof
(
T
)
*
batch_size
*
class_num
);
if
(
label_rank
==
2
)
{
// soft cross entropy
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
T
>
();
int
index
=
0
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
j
=
0
;
j
<
class_num
;
++
j
)
{
dx_data
[
index
]
=
-
label_data
[
index
]
*
dy_data
[
i
]
/
x_data
[
index
];
index
++
;
}
}
}
else
{
// normal cross entropy
auto
*
label_data
=
label
->
data
<
int
>
();
memset
(
dx_data
,
0
,
sizeof
(
T
)
*
batch_size
*
class_num
);
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
PADDLE_ASSERT
(
label_data
[
i
]
>=
0
||
label_data
[
i
]
<
class_num
);
int
index
=
i
*
class_num
+
label_data
[
i
];
dXdata
[
index
]
=
-
tolerable_value
(
dYdata
[
i
]
/
Xdata
[
index
]);
dx_data
[
index
]
=
-
dy_data
[
i
]
/
x_data
[
index
];
}
}
}
};
...
...
paddle/pybind/pybind.cc
浏览文件 @
6d60352e
...
...
@@ -32,7 +32,7 @@ limitations under the License. */
namespace
py
=
pybind11
;
USE_OP
(
add
);
USE_OP
(
onehot_
cross_entropy
);
USE_OP
(
cross_entropy
);
USE_OP
(
sgd
);
USE_OP
(
mul
);
USE_OP
(
mean
);
...
...
python/paddle/v2/framework/tests/test_cross_entropy_op.py
浏览文件 @
6d60352e
...
...
@@ -5,13 +5,13 @@ from op_test import OpTest
class
TestCrossEntropy
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"
onehot_
cross_entropy"
self
.
op_type
=
"cross_entropy"
batch_size
=
30
class_num
=
10
X
=
numpy
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label
=
(
class_num
/
2
)
*
numpy
.
ones
(
batch_size
).
astype
(
"int32"
)
self
.
inputs
=
{
'X'
:
X
,
'
l
abel'
:
label
}
self
.
inputs
=
{
'X'
:
X
,
'
L
abel'
:
label
}
Y
=
[]
for
i
in
range
(
0
,
batch_size
):
Y
.
append
(
-
numpy
.
log
(
X
[
i
][
label
[
i
]]))
...
...
@@ -24,5 +24,26 @@ class TestCrossEntropy(OpTest):
self
.
check_grad
([
'X'
],
'Y'
)
class
TestCrossEntropySoftLabel
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"cross_entropy"
batch_size
=
30
class_num
=
10
X
=
numpy
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label
=
numpy
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label
/=
label
.
sum
(
axis
=
1
,
keepdims
=
True
)
self
.
inputs
=
{
'X'
:
X
,
'Label'
:
label
}
Y
=
(
-
label
*
numpy
.
log
(
X
)).
sum
(
axis
=
1
)
self
.
outputs
=
{
'Y'
:
numpy
.
array
(
Y
).
astype
(
"float32"
)}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Y'
,
max_relative_error
=
0.05
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/v2/framework/tests/test_mnist.py
浏览文件 @
6d60352e
...
...
@@ -128,7 +128,7 @@ def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None):
def
cross_entropy_layer
(
net
,
input
,
label
):
cost_name
=
"cross_entropy_%d"
%
uniq_id
()
cross_entropy_op
=
Operator
(
"
onehot_
cross_entropy"
,
X
=
input
,
label
=
label
,
Y
=
cost_name
)
"cross_entropy"
,
X
=
input
,
label
=
label
,
Y
=
cost_name
)
net
.
append_op
(
cross_entropy_op
)
scope
.
new_var
(
cost_name
)
net
.
infer_shape
(
scope
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录