Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
5b42d2b2
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5b42d2b2
编写于
9月 19, 2017
作者:
Q
qingqing01
提交者:
GitHub
9月 19, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #4081 from xinghai-sun/soft_label_cross_entropy
Add soft-label support for cross-entropy operator.
上级
de8aaf6c
19de8ae1
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
515 addition
and
117 deletion
+515
-117
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+147
-0
paddle/operators/cross_entropy_op.cu
paddle/operators/cross_entropy_op.cu
+158
-0
paddle/operators/cross_entropy_op.h
paddle/operators/cross_entropy_op.h
+117
-0
paddle/operators/onehot_cross_entropy_op.cc
paddle/operators/onehot_cross_entropy_op.cc
+0
-85
python/paddle/v2/framework/tests/test_cross_entropy_op.py
python/paddle/v2/framework/tests/test_cross_entropy_op.py
+89
-0
python/paddle/v2/framework/tests/test_mnist.py
python/paddle/v2/framework/tests/test_mnist.py
+4
-2
python/paddle/v2/framework/tests/test_onehot_cross_entropy_op.py
...paddle/v2/framework/tests/test_onehot_cross_entropy_op.py
+0
-30
未找到文件。
paddle/operators/cross_entropy_op.cc
0 → 100644
浏览文件 @
5b42d2b2
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/cross_entropy_op.h"
namespace
paddle
{
namespace
operators
{
using
framework
::
LoDTensor
;
class
CrossEntropyOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Label"
),
"Input(Label) must not be null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
"Y"
),
"Output(Y) must not be null."
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
PADDLE_ENFORCE_EQ
(
x
->
dims
().
size
(),
2
,
"Input(X)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
label
->
dims
().
size
(),
2
,
"Input(Label)'s rank must be 2."
);
// TODO(xinghai-sun): remove this check after swtiching to bool
PADDLE_ENFORCE
(
ctx
.
Attr
<
int
>
(
"soft_label"
)
==
0
||
ctx
.
Attr
<
int
>
(
"soft_label"
)
==
1
);
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
0
],
label
->
dims
()[
0
],
"The 1st dimension of Input(X) and Input(Label) must "
"be equal."
);
if
(
ctx
.
Attr
<
int
>
(
"soft_label"
)
==
1
)
{
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
1
],
label
->
dims
()[
1
],
"If Attr(soft_label) == 1, The 2nd dimension of "
"Input(X) and Input(Label) must be equal."
);
}
else
{
PADDLE_ENFORCE_EQ
(
label
->
dims
()[
1
],
1
,
"If Attr(soft_label) == 0, The 2nd dimension of "
"Input(Label) must be 1."
);
}
ctx
.
Output
<
LoDTensor
>
(
"Y"
)
->
Resize
({
x
->
dims
()[
0
],
1
});
}
};
class
CrossEntropyGradientOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Label"
),
"Input(Label) must not be null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Y"
)),
"Input(Y@GRAD) must not be null."
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
dy
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
PADDLE_ENFORCE_EQ
(
x
->
dims
().
size
(),
2
,
"Input(X)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
dy
->
dims
().
size
(),
2
,
"Input(Y@Grad)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
label
->
dims
().
size
(),
2
,
"Input(Label)'s rank must be 2."
);
// TODO(xinghai-sun): remove this check after swtiching to bool
PADDLE_ENFORCE
(
ctx
.
Attr
<
int
>
(
"soft_label"
)
==
0
||
ctx
.
Attr
<
int
>
(
"soft_label"
)
==
1
);
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
0
],
label
->
dims
()[
0
],
"The 1st dimension of Input(X) and Input(Label) must "
"be equal."
);
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
0
],
dy
->
dims
()[
0
],
"The 1st dimension of Input(X) and Input(Y@Grad) must "
"be equal."
);
PADDLE_ENFORCE_EQ
(
dy
->
dims
()[
1
],
1
,
"The 2nd dimension of Input(Y@Grad) must be 1."
);
if
(
ctx
.
Attr
<
int
>
(
"soft_label"
)
==
1
)
{
PADDLE_ENFORCE_EQ
(
x
->
dims
()[
1
],
label
->
dims
()[
1
],
"If Attr(soft_label) == 1, The 2nd dimension of "
"Input(X) and Input(Label) must be equal."
);
}
else
{
PADDLE_ENFORCE_EQ
(
label
->
dims
()[
1
],
1
,
"If Attr(soft_label) == 0, The 2nd dimension of "
"Input(Label) must be 1."
);
}
auto
dx
=
ctx
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
dx
->
Resize
(
x
->
dims
());
}
};
class
CrossEntropyOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
CrossEntropyOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The first input of CrossEntropyOp"
);
AddInput
(
"Label"
,
"The second input of CrossEntropyOp"
);
AddOutput
(
"Y"
,
"The output of CrossEntropyOp"
);
AddAttr
<
int
>
(
"soft_label"
,
"Is soft label. Default zero."
).
SetDefault
(
0
);
AddComment
(
R"DOC(
CrossEntropy Operator.
It supports both standard cross-entropy and soft-label cross-entropy loss
computation.
1) One-hot cross-entropy:
soft_label = 0, Label[i, 0] indicates the class index for sample i:
Y[i] = -log(X[i, Label[i]])
2) Soft-label cross-entropy:
soft_label = 1, Label[i, j] indicates the soft label of class j
for sample i:
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
Please make sure that in this case the summuation of each row of Label
equals one.
3) One-hot cross-entropy with vecterized Input(Label):
As a special case of 2), when each row of Input(Label) has only one
non-zero element (equals 1), soft-label cross-entropy degenerates to a
one-hot cross-entropy with one-hot label representation.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
cross_entropy
,
ops
::
CrossEntropyOp
,
ops
::
CrossEntropyOpMaker
,
cross_entropy_grad
,
ops
::
CrossEntropyGradientOp
);
REGISTER_OP_CPU_KERNEL
(
cross_entropy
,
ops
::
CrossEntropyOpKernel
<
float
>
);
REGISTER_OP_CPU_KERNEL
(
cross_entropy_grad
,
ops
::
CrossEntropyGradientOpKernel
<
float
>
);
paddle/operators/
onehot_
cross_entropy_op.cu
→
paddle/operators/cross_entropy_op.cu
浏览文件 @
5b42d2b2
...
...
@@ -13,27 +13,13 @@
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/hostdevice.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
__host__
__device__
T
clipping_log
(
const
T
x
)
{
PADDLE_ASSERT
(
std
::
is_floating_point
<
T
>::
value
);
const
T
kApproInf
=
1e20
;
T
v
=
log
(
x
);
if
(
v
==
INFINITY
)
{
return
kApproInf
;
}
if
(
v
==
-
INFINITY
)
{
return
-
kApproInf
;
}
return
v
;
}
template
<
typename
T
>
__global__
void
CrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
int
*
label
,
const
int
N
,
const
int
D
)
{
...
...
@@ -42,7 +28,20 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
PADDLE_ASSERT
(
label
[
i
]
>=
0
&&
label
[
i
]
<
D
);
Y
[
i
]
=
-
clipping_log
(
X
[
i
*
D
+
label
[
i
]]);
Y
[
i
]
=
-
tolerable_value
(
log
(
X
[
i
*
D
+
label
[
i
]]));
}
}
template
<
typename
T
>
__global__
void
SoftCrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
T
*
label
,
const
int
N
,
const
int
D
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
sum
=
static_cast
<
T
>
(
0
);
for
(
int
j
=
0
;
j
<
D
;
j
++
)
{
sum
+=
label
[
i
*
D
+
j
]
*
tolerable_value
(
log
(
X
[
i
*
D
+
j
]));
}
Y
[
i
]
=
-
sum
;
}
}
...
...
@@ -69,57 +68,84 @@ __global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
}
template
<
typename
T
>
class
OnehotCrossEntropyOpCUDAKernel
:
public
framework
::
OpKernel
{
__global__
void
SoftCrossEntropyGradientKernel
(
T
*
dX
,
const
T
*
dY
,
const
T
*
X
,
const
T
*
label
,
const
int
N
,
const
int
D
)
{
// TOOD(qingqing): optimize for this kernel
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(
int
j
=
0
;
j
<
D
;
++
j
)
{
int
idx
=
i
*
D
+
j
;
dX
[
idx
]
=
-
label
[
idx
]
*
dY
[
i
]
/
X
[
idx
];
}
}
}
template
<
typename
T
>
class
CrossEntropyOpCUDAKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use GPUPlace."
);
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
T
*
Xdata
=
X
->
data
<
T
>
();
const
int
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"label"
)
->
data
<
int
>
();
auto
Y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
Y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
Ydata
=
Y
->
data
<
T
>
();
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
auto
label
=
ctx
.
Input
<
Tensor
>
(
"Label"
);
auto
*
x_data
=
x
->
data
<
T
>
();
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
y_data
=
y
->
data
<
T
>
();
int
N
=
X
->
dims
()[
0
];
int
D
=
X
->
dims
()[
1
];
int
n
=
x
->
dims
()[
0
];
int
d
=
x
->
dims
()[
1
];
int
block
=
512
;
int
grid
=
(
N
+
block
-
1
)
/
block
;
int
grid
=
(
n
+
block
-
1
)
/
block
;
// TODO(qingqing) launch kernel on specified stream
// base on ExecutionContext.
CrossEntropyKernel
<
T
><<<
grid
,
block
>>>
(
Ydata
,
Xdata
,
label_data
,
N
,
D
);
if
(
ctx
.
Attr
<
int
>
(
"soft_label"
)
==
1
)
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
T
>
();
SoftCrossEntropyKernel
<
T
><<<
grid
,
block
>>>
(
y_data
,
x_data
,
label_data
,
n
,
d
);
}
else
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
int
>
();
CrossEntropyKernel
<
T
><<<
grid
,
block
>>>
(
y_data
,
x_data
,
label_data
,
n
,
d
);
}
}
};
template
<
typename
T
>
class
Onehot
CrossEntropyGradientOpCUDAKernel
:
public
framework
::
OpKernel
{
class
CrossEntropyGradientOpCUDAKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"It must use GPUPlace."
);
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
d
X
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
d
Y
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
label
=
ctx
.
Input
<
Tensor
>
(
"
l
abel"
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
d
x
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
d
y
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
label
=
ctx
.
Input
<
Tensor
>
(
"
L
abel"
);
auto
*
dXdata
=
dX
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
());
auto
*
dYdata
=
dY
->
template
data
<
T
>();
auto
*
Xdata
=
X
->
template
data
<
T
>();
auto
*
label_data
=
label
->
data
<
int
>
();
auto
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
dy_data
=
dy
->
data
<
T
>
();
auto
*
x_data
=
x
->
data
<
T
>
();
int
N
=
X
->
dims
()[
0
];
int
D
=
X
->
dims
()[
1
];
int
n
=
x
->
dims
()[
0
];
int
d
=
x
->
dims
()[
1
];
int
block
=
512
;
int
grid
=
(
N
*
D
+
block
-
1
)
/
block
;
zero
<
T
><<<
grid
,
block
>>>
(
dXdata
,
N
*
D
);
grid
=
(
N
+
block
-
1
)
/
block
;
int
grid
=
(
n
*
d
+
block
-
1
)
/
block
;
zero
<
T
><<<
grid
,
block
>>>
(
dx_data
,
n
*
d
);
grid
=
(
n
+
block
-
1
)
/
block
;
// TODO(qingqing): launch kernel on specified stream
// base on ExecutionContext.
CrossEntropyGradientKernel
<
T
><<<
grid
,
block
>>>
(
dXdata
,
dYdata
,
Xdata
,
label_data
,
N
,
D
);
if
(
ctx
.
Attr
<
int
>
(
"soft_label"
)
==
1
)
{
auto
*
label_data
=
label
->
data
<
T
>
();
SoftCrossEntropyGradientKernel
<
T
><<<
grid
,
block
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
n
,
d
);
}
else
{
auto
*
label_data
=
label
->
data
<
int
>
();
CrossEntropyGradientKernel
<
T
><<<
grid
,
block
>>>
(
dx_data
,
dy_data
,
x_data
,
label_data
,
n
,
d
);
}
}
};
...
...
@@ -127,7 +153,6 @@ class OnehotCrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
onehot_cross_entropy
,
ops
::
OnehotCrossEntropyOpCUDAKernel
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
onehot_cross_entropy_grad
,
ops
::
OnehotCrossEntropyGradientOpCUDAKernel
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
cross_entropy
,
ops
::
CrossEntropyOpCUDAKernel
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
cross_entropy_grad
,
ops
::
CrossEntropyGradientOpCUDAKernel
<
float
>
);
paddle/operators/
onehot_
cross_entropy_op.h
→
paddle/operators/cross_entropy_op.h
浏览文件 @
5b42d2b2
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/platform/hostdevice.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -21,75 +22,93 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
inline
T
tolerable_value
(
const
T
x
)
{
static_assert
(
std
::
is_floating_point
<
T
>::
value
,
"tolerable_value works only on float, "
"double and double double."
);
HOSTDEVICE
T
tolerable_value
(
const
T
x
)
{
PADDLE_ASSERT
(
std
::
is_floating_point
<
T
>::
value
);
const
T
kApproInf
=
1e20
;
if
(
x
==
INFINITY
)
{
return
kApproInf
;
}
if
(
x
==
-
INFINITY
)
{
return
-
kApproInf
;
}
return
x
;
}
template
<
typename
T
>
class
Onehot
CrossEntropyOpKernel
:
public
framework
::
OpKernel
{
class
CrossEntropyOpKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"It must use CPUPlace."
);
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
T
*
Xdata
=
X
->
data
<
T
>
();
const
int
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"label"
)
->
data
<
int
>
();
auto
Y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
Y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
x_data
=
x
->
data
<
T
>
();
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
y_data
=
y
->
data
<
T
>
();
T
*
Ydata
=
Y
->
data
<
T
>
();
int
batch_size
=
X
->
dims
()[
0
];
int
class_num
=
X
->
dims
()[
1
];
int
batch_size
=
x
->
dims
()[
0
];
int
class_num
=
x
->
dims
()[
1
];
if
(
ctx
.
Attr
<
int
>
(
"soft_label"
)
==
1
)
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
T
>
();
int
index
=
0
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
T
sum
=
static_cast
<
T
>
(
0
);
for
(
int
j
=
0
;
j
<
class_num
;
++
j
)
{
sum
+=
label_data
[
index
]
*
tolerable_value
(
std
::
log
(
x_data
[
index
]));
y_data
[
i
]
=
-
sum
;
index
++
;
}
}
}
else
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
int
>
();
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
int
index
=
i
*
class_num
+
label_data
[
i
];
Ydata
[
i
]
=
-
tolerable_value
(
std
::
log
(
Xdata
[
index
]));
y_data
[
i
]
=
-
tolerable_value
(
std
::
log
(
x_data
[
index
]));
}
}
}
};
template
<
typename
T
>
class
Onehot
CrossEntropyGradientOpKernel
:
public
framework
::
OpKernel
{
class
CrossEntropyGradientOpKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"It must use CPUPlace."
);
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
d
X
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
d
Y
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
label
=
ctx
.
Input
<
Tensor
>
(
"
l
abel"
);
auto
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
d
x
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
d
y
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
label
=
ctx
.
Input
<
Tensor
>
(
"
L
abel"
);
auto
*
dXdata
=
dX
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
());
auto
*
dYdata
=
dY
->
template
data
<
T
>();
auto
*
Xdata
=
X
->
template
data
<
T
>();
auto
*
label_data
=
label
->
data
<
int
>
();
auto
*
dx_data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
dy_data
=
dy
->
data
<
T
>
();
auto
*
x_data
=
x
->
data
<
T
>
();
const
int
batch_size
=
X
->
dims
()[
0
];
const
int
class_num
=
X
->
dims
()[
1
];
int
batch_size
=
x
->
dims
()[
0
];
int
class_num
=
x
->
dims
()[
1
];
// TODO(qingqing): make zero setting an common function.
memset
(
dXdata
,
0
,
sizeof
(
T
)
*
batch_size
*
class_num
);
if
(
ctx
.
Attr
<
int
>
(
"soft_label"
)
==
1
)
{
auto
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"Label"
)
->
data
<
T
>
();
int
index
=
0
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
j
=
0
;
j
<
class_num
;
++
j
)
{
dx_data
[
index
]
=
-
label_data
[
index
]
*
dy_data
[
i
]
/
x_data
[
index
];
index
++
;
}
}
}
else
{
auto
*
label_data
=
label
->
data
<
int
>
();
memset
(
dx_data
,
0
,
sizeof
(
T
)
*
batch_size
*
class_num
);
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
PADDLE_ASSERT
(
label_data
[
i
]
>=
0
||
label_data
[
i
]
<
class_num
);
int
index
=
i
*
class_num
+
label_data
[
i
];
dXdata
[
index
]
=
-
tolerable_value
(
dYdata
[
i
]
/
Xdata
[
index
]);
dx_data
[
index
]
=
-
dy_data
[
i
]
/
x_data
[
index
];
}
}
}
};
...
...
paddle/operators/onehot_cross_entropy_op.cc
已删除
100644 → 0
浏览文件 @
de8aaf6c
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/onehot_cross_entropy_op.h"
namespace
paddle
{
namespace
operators
{
class
OnehotCrossEntropyOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X) of OnehotCrossEntropyOp should not be null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"label"
),
"Input(label) of OnehotCrossEntropyOp should not be null."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
"Y"
),
"Output(Y) of OnehotCrossEntropyOp should not be null."
);
auto
*
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
label
=
ctx
.
Input
<
Tensor
>
(
"label"
);
PADDLE_ENFORCE_EQ
(
X
->
dims
().
size
(),
2
,
"X's dimension must be 2."
);
PADDLE_ENFORCE_EQ
(
label
->
dims
().
size
(),
1
,
"label's dimension must be 1."
);
PADDLE_ENFORCE_EQ
(
X
->
dims
()[
0
],
label
->
dims
()[
0
]);
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Y"
)
->
Resize
({
X
->
dims
()[
0
],
1
});
}
};
class
OnehotCrossEntropyGradientOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
auto
dX
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
dX
->
Resize
(
X
->
dims
());
}
};
class
OnehotCrossEntropyOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
OnehotCrossEntropyOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The first input of OnehotCrossEntropyOp"
);
AddInput
(
"label"
,
"The second input of OnehotCrossEntropyOp"
);
AddOutput
(
"Y"
,
"The output of OnehotCrossEntropyOp"
);
AddComment
(
R"DOC(
OnehotCrossEntropy Operator.
Y[i] = -log(X[i][j])
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
onehot_cross_entropy
,
ops
::
OnehotCrossEntropyOp
,
ops
::
OnehotCrossEntropyOpMaker
,
onehot_cross_entropy_grad
,
ops
::
OnehotCrossEntropyGradientOp
);
REGISTER_OP_CPU_KERNEL
(
onehot_cross_entropy
,
ops
::
OnehotCrossEntropyOpKernel
<
float
>
);
REGISTER_OP_CPU_KERNEL
(
onehot_cross_entropy_grad
,
ops
::
OnehotCrossEntropyGradientOpKernel
<
float
>
);
python/paddle/v2/framework/tests/test_cross_entropy_op.py
0 → 100644
浏览文件 @
5b42d2b2
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
class
TestCrossEntropyOp1
(
OpTest
):
"""Test standard cross-entropy, with index representation of labels.
"""
def
setUp
(
self
):
self
.
op_type
=
"cross_entropy"
batch_size
=
30
class_num
=
10
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label
=
np
.
random
.
randint
(
0
,
class_num
,
(
batch_size
,
1
),
dtype
=
"int32"
)
cross_entropy
=
np
.
asmatrix
(
[[
-
np
.
log
(
X
[
i
][
label
[
i
][
0
]])]
for
i
in
range
(
X
.
shape
[
0
])],
dtype
=
"float32"
)
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
'soft_label'
:
0
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Y"
)
class
TestCrossEntropyOp2
(
OpTest
):
"""Test soft-label cross-entropy, with vecterized soft labels.
"""
def
setUp
(
self
):
self
.
op_type
=
"cross_entropy"
batch_size
=
10
class_num
=
5
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label
/=
label
.
sum
(
axis
=
1
,
keepdims
=
True
)
cross_entropy
=
(
-
label
*
np
.
log
(
X
)).
sum
(
axis
=
1
,
keepdims
=
True
).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
X
,
'Label'
:
label
}
self
.
outputs
=
{
'Y'
:
cross_entropy
}
self
.
attrs
=
{
'soft_label'
:
1
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Y'
)
class
TestCrossEntropyOp3
(
OpTest
):
"""Test one-hot cross-entropy, with vecterized one-hot representation of
labels.
"""
def
setUp
(
self
):
self
.
op_type
=
"cross_entropy"
batch_size
=
30
class_num
=
10
X
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
label_index
=
np
.
random
.
randint
(
0
,
class_num
,
(
batch_size
),
dtype
=
"int32"
)
label
=
np
.
zeros
(
X
.
shape
)
label
[
np
.
arange
(
batch_size
),
label_index
]
=
1
cross_entropy
=
np
.
asmatrix
(
[[
-
np
.
log
(
X
[
i
][
label_index
[
i
]])]
for
i
in
range
(
X
.
shape
[
0
])],
dtype
=
"float32"
)
cross_entropy2
=
(
-
label
*
np
.
log
(
X
)).
sum
(
axis
=
1
,
keepdims
=
True
).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
X
,
'Label'
:
label
}
self
.
outputs
=
{
'Y'
:
cross_entropy
}
self
.
attrs
=
{
'soft_label'
:
1
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Y'
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/v2/framework/tests/test_mnist.py
浏览文件 @
5b42d2b2
...
...
@@ -128,7 +128,7 @@ def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None):
def
cross_entropy_layer
(
net
,
input
,
label
):
cost_name
=
"cross_entropy_%d"
%
uniq_id
()
cross_entropy_op
=
Operator
(
"
onehot_cross_entropy"
,
X
=
input
,
l
abel
=
label
,
Y
=
cost_name
)
"
cross_entropy"
,
X
=
input
,
L
abel
=
label
,
Y
=
cost_name
)
net
.
append_op
(
cross_entropy_op
)
scope
.
new_var
(
cost_name
)
net
.
infer_shape
(
scope
)
...
...
@@ -181,7 +181,7 @@ def error_rate(predict, label):
images
=
data_layer
(
name
=
"pixel"
,
dims
=
[
BATCH_SIZE
,
784
])
labels
=
data_layer
(
name
=
"label"
,
dims
=
[
BATCH_SIZE
])
labels
=
data_layer
(
name
=
"label"
,
dims
=
[
BATCH_SIZE
,
1
])
fc1
=
fc_layer
(
net
=
forward_net
,
input
=
images
,
size
=
100
,
act
=
"sigmoid"
)
fc2
=
fc_layer
(
net
=
forward_net
,
input
=
fc1
,
size
=
100
,
act
=
"sigmoid"
)
predict
=
fc_layer
(
net
=
forward_net
,
input
=
fc2
,
size
=
10
,
act
=
"softmax"
)
...
...
@@ -215,6 +215,7 @@ def test(cost_name):
for
data
in
test_reader
():
image_data
=
numpy
.
array
(
map
(
lambda
x
:
x
[
0
],
data
)).
astype
(
"float32"
)
label_data
=
numpy
.
array
(
map
(
lambda
x
:
x
[
1
],
data
)).
astype
(
"int32"
)
label_data
=
numpy
.
expand_dims
(
label_data
,
axis
=
1
)
feed_data
(
images
,
image_data
)
feed_data
(
labels
,
label_data
)
...
...
@@ -235,6 +236,7 @@ for pass_id in range(PASS_NUM):
for
data
in
train_reader
():
image_data
=
numpy
.
array
(
map
(
lambda
x
:
x
[
0
],
data
)).
astype
(
"float32"
)
label_data
=
numpy
.
array
(
map
(
lambda
x
:
x
[
1
],
data
)).
astype
(
"int32"
)
label_data
=
numpy
.
expand_dims
(
label_data
,
axis
=
1
)
feed_data
(
images
,
image_data
)
feed_data
(
labels
,
label_data
)
...
...
python/paddle/v2/framework/tests/test_onehot_cross_entropy_op.py
已删除
100644 → 0
浏览文件 @
de8aaf6c
import
unittest
import
numpy
from
op_test
import
OpTest
class
TestOnehotCrossEntropyOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"onehot_cross_entropy"
batch_size
=
30
class_num
=
10
X
=
numpy
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float32"
)
labels
=
numpy
.
random
.
randint
(
0
,
class_num
,
batch_size
,
dtype
=
"int32"
)
cross_entropy
=
numpy
.
asmatrix
(
[[
-
numpy
.
log
(
X
[
i
][
labels
[
i
]])]
for
i
in
range
(
X
.
shape
[
0
])],
dtype
=
"float32"
)
self
.
inputs
=
{
"X"
:
X
,
"label"
:
labels
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Y"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录