Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
eaa3fd45
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
eaa3fd45
编写于
2月 09, 2022
作者:
S
sneaxiy
提交者:
GitHub
2月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add more int type support for softmax_with_cross_entropy (#39409)
上级
8d87b3bc
变更
6
展开全部
显示空白变更内容
内联
并排
Showing
6 changed file
with
307 addition
and
154 deletion
+307
-154
paddle/fluid/operators/math/cross_entropy.cc
paddle/fluid/operators/math/cross_entropy.cc
+67
-36
paddle/fluid/operators/math/cross_entropy.cu
paddle/fluid/operators/math/cross_entropy.cu
+48
-12
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
+99
-85
paddle/fluid/operators/softmax_with_cross_entropy_op.h
paddle/fluid/operators/softmax_with_cross_entropy_op.h
+64
-19
python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py
...uid/tests/unittests/test_softmax_with_cross_entropy_op.py
+27
-1
python/paddle/nn/functional/loss.py
python/paddle/nn/functional/loss.py
+2
-1
未找到文件。
paddle/fluid/operators/math/cross_entropy.cc
浏览文件 @
eaa3fd45
...
@@ -30,59 +30,90 @@ template <typename T, int MajorType = Eigen::RowMajor,
...
@@ -30,59 +30,90 @@ template <typename T, int MajorType = Eigen::RowMajor,
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
>
template
<
typename
T
>
class
CrossEntropyFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
struct
HardLabelCrossEntropyCPUFunctorImpl
{
public:
HardLabelCrossEntropyCPUFunctorImpl
(
framework
::
Tensor
*
out
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
prob
,
const
framework
::
Tensor
*
prob
,
const
framework
::
Tensor
*
labels
,
const
bool
softLabel
,
const
framework
::
Tensor
*
labels
,
const
int
ignore_index
,
const
int
axis_dim
)
{
const
int
ignore_index
,
const
int
batch_size
=
prob
->
dims
()[
0
];
const
int
axis_dim
)
const
int
num_classes
=
prob
->
dims
()[
1
];
:
out_
(
out
),
const
int
num_remain
=
num_classes
/
axis_dim
;
prob_
(
prob
),
labels_
(
labels
),
ignore_index_
(
ignore_index
),
axis_dim_
(
axis_dim
)
{}
Eigen
::
DSizes
<
int
,
3
>
batch_axis_remain
(
batch_size
,
axis_dim
,
num_remain
);
template
<
typename
U
>
void
apply
()
const
{
const
int
batch_size
=
prob_
->
dims
()[
0
];
const
int
num_classes
=
prob_
->
dims
()[
1
];
const
int
num_remain
=
num_classes
/
axis_dim_
;
if
(
softLabel
)
{
const
T
*
prob_data
=
prob_
->
template
data
<
T
>();
auto
in
=
EigenMatrix
<
T
>::
From
(
*
prob
);
T
*
loss_data
=
out_
->
template
data
<
T
>();
auto
lbl
=
EigenMatrix
<
T
>::
From
(
*
labels
);
auto
loss
=
EigenMatrix
<
T
>::
From
(
*
out
);
loss
.
device
(
*
ctx
.
eigen_device
())
=
const
auto
*
label_data
=
labels_
->
template
data
<
U
>();
-
((
lbl
*
in
.
log
().
unaryExpr
(
math
::
TolerableValue
<
T
>
()))
.
reshape
(
batch_axis_remain
)
.
sum
(
Eigen
::
DSizes
<
int
,
1
>
(
1
)));
}
else
{
const
T
*
prob_data
=
prob
->
data
<
T
>
();
T
*
loss_data
=
out
->
data
<
T
>
();
const
int64_t
*
label_data
=
labels
->
data
<
int64_t
>
();
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
j
=
0
;
j
<
num_remain
;
j
++
)
{
for
(
int
j
=
0
;
j
<
num_remain
;
j
++
)
{
int
lbl
=
label_data
[
i
*
num_remain
+
j
]
;
int
lbl
=
static_cast
<
int
>
(
label_data
[
i
*
num_remain
+
j
])
;
if
(
lbl
!=
ignore_index
)
{
if
(
lbl
!=
ignore_index_
)
{
PADDLE_ENFORCE_GE
(
lbl
,
0
,
PADDLE_ENFORCE_GE
(
lbl
,
0
,
platform
::
errors
::
OutOfRange
(
platform
::
errors
::
OutOfRange
(
"label value should >= 0 when label "
"label value should >= 0 when label "
"value(%f) not equal to ignore_index(%f)"
,
"value(%f) not equal to ignore_index(%f)"
,
lbl
,
ignore_index
));
lbl
,
ignore_index_
));
PADDLE_ENFORCE_LT
(
PADDLE_ENFORCE_LT
(
lbl
,
axis_dim
,
lbl
,
axis_dim_
,
platform
::
errors
::
OutOfRange
(
platform
::
errors
::
OutOfRange
(
"label value should less than the shape of axis dimension "
"label value should less than the shape of axis dimension "
"when label value(%f) not equal to ignore_index(%f), But "
"when label value(%f) not equal to ignore_index(%f), But "
"received label value as %ld and shape of axis dimension "
"received label value as %ld and shape of axis dimension "
"is %d"
,
"is %d"
,
lbl
,
ignore_index
,
lbl
,
axis_dim
));
lbl
,
ignore_index_
,
lbl
,
axis_dim_
));
}
}
int
index
=
i
*
num_classes
+
lbl
*
num_remain
+
j
;
int
index
=
i
*
num_classes
+
lbl
*
num_remain
+
j
;
int
loss_idx
=
i
*
num_remain
+
j
;
int
loss_idx
=
i
*
num_remain
+
j
;
loss_data
[
loss_idx
]
=
loss_data
[
loss_idx
]
=
lbl
==
ignore_index
lbl
==
ignore_index_
?
0
?
0
:
-
math
::
TolerableValue
<
T
>
()(
std
::
log
(
prob_data
[
index
]));
:
-
math
::
TolerableValue
<
T
>
()(
std
::
log
(
prob_data
[
index
]));
}
}
}
}
}
}
private:
framework
::
Tensor
*
out_
;
const
framework
::
Tensor
*
prob_
;
const
framework
::
Tensor
*
labels_
;
const
int
ignore_index_
;
const
int
axis_dim_
;
};
template
<
typename
T
>
class
CrossEntropyFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
prob
,
const
framework
::
Tensor
*
labels
,
const
bool
softLabel
,
const
int
ignore_index
,
const
int
axis_dim
)
{
if
(
softLabel
)
{
const
int
batch_size
=
prob
->
dims
()[
0
];
const
int
num_classes
=
prob
->
dims
()[
1
];
const
int
num_remain
=
num_classes
/
axis_dim
;
Eigen
::
DSizes
<
int
,
3
>
batch_axis_remain
(
batch_size
,
axis_dim
,
num_remain
);
auto
in
=
EigenMatrix
<
T
>::
From
(
*
prob
);
auto
lbl
=
EigenMatrix
<
T
>::
From
(
*
labels
);
auto
loss
=
EigenMatrix
<
T
>::
From
(
*
out
);
loss
.
device
(
*
ctx
.
eigen_device
())
=
-
((
lbl
*
in
.
log
().
unaryExpr
(
math
::
TolerableValue
<
T
>
()))
.
reshape
(
batch_axis_remain
)
.
sum
(
Eigen
::
DSizes
<
int
,
1
>
(
1
)));
}
else
{
HardLabelCrossEntropyCPUFunctorImpl
<
T
>
functor_impl
(
out
,
prob
,
labels
,
ignore_index
,
axis_dim
);
framework
::
VisitIntDataType
(
labels
->
type
(),
functor_impl
);
}
}
}
};
};
...
...
paddle/fluid/operators/math/cross_entropy.cu
浏览文件 @
eaa3fd45
...
@@ -21,18 +21,19 @@ namespace paddle {
...
@@ -21,18 +21,19 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
template
<
typename
T
>
template
<
typename
T
,
typename
LabelT
>
__global__
void
CrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
int64_t
*
label
,
__global__
void
CrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
LabelT
*
label
,
const
int
N
,
const
int
D
,
const
int
N
,
const
int
D
,
const
int
ignore_index
)
{
const
int
ignore_index
)
{
CUDA_KERNEL_LOOP
(
i
,
N
)
{
CUDA_KERNEL_LOOP
(
i
,
N
)
{
PADDLE_ENFORCE
(
label
[
i
]
>=
0
&&
label
[
i
]
<
D
||
label
[
i
]
==
ignore_index
,
auto
lbl
=
static_cast
<
int64_t
>
(
label
[
i
]);
PADDLE_ENFORCE
(
lbl
>=
0
&&
lbl
<
D
||
lbl
==
ignore_index
,
"The value of label[%d] expected >= 0 and < %ld, or == %ld, "
"The value of label[%d] expected >= 0 and < %ld, or == %ld, "
"but got %ld. Please check input value."
,
"but got %ld. Please check input value."
,
i
,
D
,
ignore_index
,
l
abel
[
i
]
);
i
,
D
,
ignore_index
,
l
bl
);
Y
[
i
]
=
ignore_index
==
l
abel
[
i
]
Y
[
i
]
=
ignore_index
==
l
bl
?
static_cast
<
T
>
(
0
)
?
static_cast
<
T
>
(
0
)
:
-
math
::
TolerableValue
<
T
>
()(
real_log
(
X
[
i
*
D
+
l
abel
[
i
]
]));
:
-
math
::
TolerableValue
<
T
>
()(
real_log
(
X
[
i
*
D
+
l
bl
]));
}
}
}
}
...
@@ -54,6 +55,43 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
...
@@ -54,6 +55,43 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
}
}
}
}
template
<
typename
T
>
struct
HardLabelCrossEntropyCUDAFunctorImpl
{
public:
HardLabelCrossEntropyCUDAFunctorImpl
(
T
*
loss_data
,
const
T
*
prob_data
,
const
void
*
label_data
,
const
int
batch_size
,
const
int
class_num
,
const
int
ignore_index
,
const
int
block_size
,
gpuStream_t
stream
)
:
loss_data_
(
loss_data
),
prob_data_
(
prob_data
),
label_data_
(
label_data
),
batch_size_
(
batch_size
),
class_num_
(
class_num
),
ignore_index_
(
ignore_index
),
block_size_
(
block_size
),
stream_
(
stream
)
{}
template
<
typename
U
>
void
apply
()
const
{
int
grid_size
=
(
batch_size_
+
block_size_
-
1
)
/
block_size_
;
CrossEntropyKernel
<
T
,
U
><<<
grid_size
,
block_size_
,
0
,
stream_
>>>
(
loss_data_
,
prob_data_
,
static_cast
<
const
U
*>
(
label_data_
),
batch_size_
,
class_num_
,
ignore_index_
);
}
private:
T
*
loss_data_
;
const
T
*
prob_data_
;
const
void
*
label_data_
;
const
int
batch_size_
;
const
int
class_num_
;
const
int
ignore_index_
;
const
int
block_size_
;
gpuStream_t
stream_
;
};
template
<
typename
T
>
template
<
typename
T
>
class
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
class
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
public:
public:
...
@@ -81,12 +119,10 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
...
@@ -81,12 +119,10 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
SoftCrossEntropyKernel
<
T
><<<
batch_size
,
block
,
0
,
ctx
.
stream
()
>>>
(
SoftCrossEntropyKernel
<
T
><<<
batch_size
,
block
,
0
,
ctx
.
stream
()
>>>
(
loss_data
,
prob_data
,
label_data
,
class_num
);
loss_data
,
prob_data
,
label_data
,
class_num
);
}
else
{
}
else
{
const
int64_t
*
label_data
=
labels
->
data
<
int64_t
>
();
HardLabelCrossEntropyCUDAFunctorImpl
<
T
>
functor
(
int
block
=
kMaxBlockDim
;
loss_data
,
prob_data
,
labels
->
data
(),
batch_size
,
class_num
,
int
grid
=
(
batch_size
+
block
-
1
)
/
block
;
ignore_index
,
kMaxBlockDim
,
ctx
.
stream
());
CrossEntropyKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
framework
::
VisitDataType
(
labels
->
type
(),
functor
);
loss_data
,
prob_data
,
label_data
,
batch_size
,
class_num
,
ignore_index
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
浏览文件 @
eaa3fd45
此差异已折叠。
点击以展开。
paddle/fluid/operators/softmax_with_cross_entropy_op.h
浏览文件 @
eaa3fd45
...
@@ -24,6 +24,48 @@ namespace operators {
...
@@ -24,6 +24,48 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
typename
Visitor
>
struct
SoftmaxWithCrossEntropyFunctor
{
public:
SoftmaxWithCrossEntropyFunctor
(
const
framework
::
ExecutionContext
&
context
,
const
framework
::
Tensor
&
labels
,
const
bool
soft_label
,
const
Visitor
&
visitor
)
:
context_
(
context
),
labels_
(
labels
),
soft_label_
(
soft_label
),
visitor_
(
visitor
)
{}
template
<
typename
U
>
void
apply
()
const
{
visitor_
.
template
Apply
<
U
>(
context_
,
labels_
,
soft_label_
);
}
private:
const
framework
::
ExecutionContext
&
context_
;
const
framework
::
Tensor
&
labels_
;
const
bool
soft_label_
;
const
Visitor
&
visitor_
;
};
template
<
typename
T
,
typename
Visitor
>
static
void
RunSoftmaxWithCrossEntropyFunctor
(
const
framework
::
ExecutionContext
&
context
,
const
Visitor
&
visitor
)
{
const
auto
*
labels
=
context
.
Input
<
framework
::
Tensor
>
(
"Label"
);
const
bool
soft_label
=
context
.
Attr
<
bool
>
(
"soft_label"
);
SoftmaxWithCrossEntropyFunctor
<
T
,
Visitor
>
functor
(
context
,
*
labels
,
soft_label
,
visitor
);
auto
dtype
=
labels
->
type
();
if
(
soft_label
)
{
PADDLE_ENFORCE_EQ
(
dtype
,
framework
::
DataTypeTrait
<
T
>::
DataType
(),
platform
::
errors
::
InvalidArgument
(
"The Input(Label) should be with the "
"same data type as Input(Logits)."
));
functor
.
template
apply
<
T
>();
}
else
{
framework
::
VisitIntDataType
(
dtype
,
functor
);
}
}
template
<
typename
T
>
template
<
typename
T
>
class
SoftmaxWithCrossEntropyKernel
:
public
framework
::
OpKernel
<
T
>
{
class
SoftmaxWithCrossEntropyKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -32,14 +74,14 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
...
@@ -32,14 +74,14 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
platform
::
is_cpu_place
(
context
.
GetPlace
()),
true
,
platform
::
is_cpu_place
(
context
.
GetPlace
()),
true
,
platform
::
errors
::
Unimplemented
(
"This kernel only runs on CPU."
));
platform
::
errors
::
Unimplemented
(
"This kernel only runs on CPU."
));
const
bool
use_softmax
=
context
.
Attr
<
bool
>
(
"use_softmax"
);
const
bool
use_softmax
=
context
.
Attr
<
bool
>
(
"use_softmax"
);
const
Tensor
*
labels
=
context
.
Input
<
Tensor
>
(
"Label"
);
const
bool
soft_label
=
context
.
Attr
<
bool
>
(
"soft_label"
);
// do not with softmax op, and input is softmax
// do not with softmax op, and input is softmax
if
(
!
use_softmax
)
{
if
(
!
use_softmax
)
{
const
Tensor
*
softmax
=
context
.
Input
<
Tensor
>
(
"Logits"
);
const
Tensor
*
softmax
=
context
.
Input
<
Tensor
>
(
"Logits"
);
const
Tensor
*
labels
=
context
.
Input
<
Tensor
>
(
"Label"
);
Tensor
*
softmax_out
=
context
.
Output
<
Tensor
>
(
"Softmax"
);
Tensor
*
softmax_out
=
context
.
Output
<
Tensor
>
(
"Softmax"
);
Tensor
*
loss
=
context
.
Output
<
Tensor
>
(
"Loss"
);
Tensor
*
loss
=
context
.
Output
<
Tensor
>
(
"Loss"
);
const
bool
soft_label
=
context
.
Attr
<
bool
>
(
"soft_label"
);
const
int
rank
=
softmax
->
dims
().
size
();
const
int
rank
=
softmax
->
dims
().
size
();
const
int
axis
=
CanonicalAxis
(
context
.
Attr
<
int
>
(
"axis"
),
rank
);
const
int
axis
=
CanonicalAxis
(
context
.
Attr
<
int
>
(
"axis"
),
rank
);
int
axis_dim
=
softmax
->
dims
()[
axis
];
int
axis_dim
=
softmax
->
dims
()[
axis
];
...
@@ -86,10 +128,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
...
@@ -86,10 +128,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
}
}
const
Tensor
*
logits
=
context
.
Input
<
Tensor
>
(
"Logits"
);
const
Tensor
*
logits
=
context
.
Input
<
Tensor
>
(
"Logits"
);
const
Tensor
*
labels
=
context
.
Input
<
Tensor
>
(
"Label"
);
Tensor
*
softmax
=
context
.
Output
<
Tensor
>
(
"Softmax"
);
Tensor
*
softmax
=
context
.
Output
<
Tensor
>
(
"Softmax"
);
Tensor
*
loss
=
context
.
Output
<
Tensor
>
(
"Loss"
);
Tensor
*
loss
=
context
.
Output
<
Tensor
>
(
"Loss"
);
const
bool
soft_label
=
context
.
Attr
<
bool
>
(
"soft_label"
);
const
int
rank
=
logits
->
dims
().
size
();
const
int
rank
=
logits
->
dims
().
size
();
const
int
axis
=
CanonicalAxis
(
context
.
Attr
<
int
>
(
"axis"
),
rank
);
const
int
axis
=
CanonicalAxis
(
context
.
Attr
<
int
>
(
"axis"
),
rank
);
...
@@ -132,9 +172,14 @@ template <typename T>
...
@@ -132,9 +172,14 @@ template <typename T>
class
SoftmaxWithCrossEntropyGradKernel
:
public
framework
::
OpKernel
<
T
>
{
class
SoftmaxWithCrossEntropyGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
RunSoftmaxWithCrossEntropyFunctor
<
T
>
(
context
,
*
this
);
}
template
<
typename
LabelT
>
static
void
Apply
(
const
framework
::
ExecutionContext
&
context
,
const
framework
::
Tensor
&
labels
,
const
bool
soft_label
)
{
const
Tensor
*
out_grad
=
const
Tensor
*
out_grad
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Loss"
));
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Loss"
));
const
Tensor
*
labels
=
context
.
Input
<
Tensor
>
(
"Label"
);
Tensor
*
logit_grad
=
Tensor
*
logit_grad
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Logits"
));
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Logits"
));
const
Tensor
*
softmax
=
context
.
Input
<
Tensor
>
(
"Softmax"
);
const
Tensor
*
softmax
=
context
.
Input
<
Tensor
>
(
"Softmax"
);
...
@@ -143,7 +188,6 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
...
@@ -143,7 +188,6 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
framework
::
TensorCopy
(
*
softmax
,
context
.
GetPlace
(),
framework
::
TensorCopy
(
*
softmax
,
context
.
GetPlace
(),
context
.
device_context
(),
logit_grad
);
context
.
device_context
(),
logit_grad
);
}
}
const
bool
soft_label
=
context
.
Attr
<
bool
>
(
"soft_label"
);
auto
ignore_index
=
context
.
Attr
<
int
>
(
"ignore_index"
);
auto
ignore_index
=
context
.
Attr
<
int
>
(
"ignore_index"
);
const
int
rank
=
logit_grad
->
dims
().
size
();
const
int
rank
=
logit_grad
->
dims
().
size
();
...
@@ -166,7 +210,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
...
@@ -166,7 +210,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
const
int
d
=
SizeFromAxis
(
axis
,
logit_grad
->
dims
());
const
int
d
=
SizeFromAxis
(
axis
,
logit_grad
->
dims
());
Tensor
logit_grad_2d
,
labels_2d
,
out_grad_2d
;
Tensor
logit_grad_2d
,
labels_2d
,
out_grad_2d
;
logit_grad_2d
.
ShareDataWith
(
*
logit_grad
).
Resize
({
n
,
d
});
logit_grad_2d
.
ShareDataWith
(
*
logit_grad
).
Resize
({
n
,
d
});
labels_2d
.
ShareDataWith
(
*
labels
).
Resize
({
n
,
labels
->
numel
()
/
n
});
labels_2d
.
ShareDataWith
(
labels
).
Resize
({
n
,
labels
.
numel
()
/
n
});
out_grad_2d
.
ShareDataWith
(
*
out_grad
).
Resize
({
n
,
d
/
axis_dim
});
out_grad_2d
.
ShareDataWith
(
*
out_grad
).
Resize
({
n
,
d
/
axis_dim
});
auto
out_grad_mat
=
framework
::
EigenMatrix
<
T
>::
From
(
out_grad_2d
);
auto
out_grad_mat
=
framework
::
EigenMatrix
<
T
>::
From
(
out_grad_2d
);
auto
logit_grad_mat
=
framework
::
EigenMatrix
<
T
>::
From
(
logit_grad_2d
);
auto
logit_grad_mat
=
framework
::
EigenMatrix
<
T
>::
From
(
logit_grad_2d
);
...
@@ -183,23 +227,24 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
...
@@ -183,23 +227,24 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
logit_grad_mat
;
logit_grad_mat
;
}
else
{
}
else
{
// use_softmax step2
// use_softmax step2
const
int64_t
*
label_data
=
labels
->
data
<
int64_t
>
();
const
auto
*
label_data
=
labels
.
template
data
<
LabelT
>();
T
*
logit_grad_data
=
logit_grad
->
data
<
T
>
();
T
*
logit_grad_data
=
logit_grad
->
template
data
<
T
>();
const
T
*
out_grad_data
=
out_grad
->
data
<
T
>
();
const
T
*
out_grad_data
=
out_grad
->
template
data
<
T
>();
const
int
remain
=
d
/
axis_dim
;
const
int
remain
=
d
/
axis_dim
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
// for each sample_1_dim
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
// for each sample_1_dim
for
(
int
j
=
0
;
j
<
remain
;
j
++
)
{
// for each sample_other_dims
for
(
int
j
=
0
;
j
<
remain
;
j
++
)
{
// for each sample_other_dims
int
idx
=
i
*
remain
+
j
;
// this sample's label_idx. for 1d case,
int
idx
=
i
*
remain
+
j
;
// this sample's label_idx. for 1d case,
// remain=1 and j=0, so, idx = i
// remain=1 and j=0, so, idx = i
if
(
label_data
[
idx
]
==
ignore_index
)
{
auto
lbl
=
static_cast
<
int64_t
>
(
label_data
[
idx
]);
if
(
lbl
==
ignore_index
)
{
for
(
int
k
=
0
;
k
<
axis_dim
;
++
k
)
{
// for each class id's label
for
(
int
k
=
0
;
k
<
axis_dim
;
++
k
)
{
// for each class id's label
logit_grad_data
[
i
*
d
+
k
*
remain
+
j
]
=
0
;
logit_grad_data
[
i
*
d
+
k
*
remain
+
j
]
=
0
;
}
}
}
else
{
}
else
{
// only for this sample's label_idx, the label is 1, others is 0,
// only for this sample's label_idx, the label is 1, others is 0,
// so, only compute this label_idx's class
// so, only compute this label_idx's class
logit_grad_data
[
i
*
d
+
l
abel_data
[
idx
]
*
remain
+
j
]
=
logit_grad_data
[
i
*
d
+
l
bl
*
remain
+
j
]
=
(
-
1
/
logit_grad_data
[
i
*
d
+
l
abel_data
[
idx
]
*
remain
+
j
])
*
(
-
1
/
logit_grad_data
[
i
*
d
+
l
bl
*
remain
+
j
])
*
out_grad_data
[
idx
];
out_grad_data
[
idx
];
for
(
int
k
=
0
;
k
<
axis_dim
;
++
k
)
{
// for each class id's label
for
(
int
k
=
0
;
k
<
axis_dim
;
++
k
)
{
// for each class id's label
if
(
k
!=
if
(
k
!=
...
@@ -233,15 +278,16 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
...
@@ -233,15 +278,16 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
logit_grad_mat
*
// element_wise multiply
logit_grad_mat
*
// element_wise multiply
out_grad_mat
.
broadcast
(
Eigen
::
DSizes
<
int
,
2
>
(
1
,
axis_dim
));
out_grad_mat
.
broadcast
(
Eigen
::
DSizes
<
int
,
2
>
(
1
,
axis_dim
));
const
int64_t
*
label_data
=
labels
->
data
<
int64_t
>
();
const
auto
*
label_data
=
labels
.
template
data
<
LabelT
>();
T
*
logit_grad_data
=
logit_grad
->
data
<
T
>
();
T
*
logit_grad_data
=
logit_grad
->
template
data
<
T
>();
const
T
*
out_grad_data
=
out_grad
->
data
<
T
>
();
const
T
*
out_grad_data
=
out_grad
->
template
data
<
T
>();
const
int
remain
=
d
/
axis_dim
;
const
int
remain
=
d
/
axis_dim
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
// for each sample_1_dim
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
// for each sample_1_dim
for
(
int
j
=
0
;
j
<
remain
;
j
++
)
{
// for each sample_other_dims
for
(
int
j
=
0
;
j
<
remain
;
j
++
)
{
// for each sample_other_dims
int
idx
=
i
*
remain
+
j
;
// this sample's label_idx. for 1d case,
int
idx
=
i
*
remain
+
j
;
// this sample's label_idx. for 1d case,
// remain=1 and j=0, so, idx = i
// remain=1 and j=0, so, idx = i
if
(
label_data
[
idx
]
==
ignore_index
)
{
auto
lbl
=
static_cast
<
int64_t
>
(
label_data
[
idx
]);
if
(
lbl
==
ignore_index
)
{
for
(
int
k
=
0
;
k
<
axis_dim
;
++
k
)
{
// for each class id's label
for
(
int
k
=
0
;
k
<
axis_dim
;
++
k
)
{
// for each class id's label
logit_grad_data
[
i
*
d
+
k
*
remain
+
j
]
=
0
;
logit_grad_data
[
i
*
d
+
k
*
remain
+
j
]
=
0
;
}
}
...
@@ -258,8 +304,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
...
@@ -258,8 +304,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
// out_grad_data[idx]
// out_grad_data[idx]
// means: dy/dp * dy= ( p - y ) * dy
// means: dy/dp * dy= ( p - y ) * dy
logit_grad_data
[
i
*
d
+
label_data
[
idx
]
*
remain
+
j
]
-=
logit_grad_data
[
i
*
d
+
lbl
*
remain
+
j
]
-=
out_grad_data
[
idx
];
out_grad_data
[
idx
];
}
}
}
}
}
}
...
...
python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py
浏览文件 @
eaa3fd45
...
@@ -16,6 +16,7 @@ from __future__ import print_function
...
@@ -16,6 +16,7 @@ from __future__ import print_function
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle.fluid.core
as
core
import
paddle.fluid.core
as
core
from
op_test
import
OpTest
from
op_test
import
OpTest
...
@@ -58,6 +59,9 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
...
@@ -58,6 +59,9 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
self
.
shape
=
[
41
,
37
]
self
.
shape
=
[
41
,
37
]
self
.
use_softmax
=
True
self
.
use_softmax
=
True
def
hard_label_dtype
(
self
):
return
"int64"
def
setUp
(
self
):
def
setUp
(
self
):
self
.
initParams
()
self
.
initParams
()
...
@@ -72,7 +76,8 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
...
@@ -72,7 +76,8 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
else
:
else
:
axis_dim
=
self
.
shape
[
self
.
axis
]
axis_dim
=
self
.
shape
[
self
.
axis
]
self
.
shape
[
self
.
axis
]
=
1
self
.
shape
[
self
.
axis
]
=
1
labels
=
np
.
random
.
randint
(
0
,
axis_dim
,
self
.
shape
,
dtype
=
"int64"
)
labels
=
np
.
random
.
randint
(
0
,
axis_dim
,
self
.
shape
,
dtype
=
self
.
hard_label_dtype
())
loss
=
cross_entropy
(
softmax
,
labels
,
self
.
soft_label
,
self
.
axis
,
loss
=
cross_entropy
(
softmax
,
labels
,
self
.
soft_label
,
self
.
axis
,
self
.
ignore_index
)
self
.
ignore_index
)
...
@@ -107,6 +112,26 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
...
@@ -107,6 +112,26 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
self
.
check_grad
([
"Logits"
],
"Loss"
,
numeric_grad_delta
=
0.001
)
self
.
check_grad
([
"Logits"
],
"Loss"
,
numeric_grad_delta
=
0.001
)
class
TestSoftmaxWithCrossEntropyOpInt32
(
TestSoftmaxWithCrossEntropyOp
):
def
hard_label_dtype
(
self
):
return
"int32"
class
TestSoftmaxWithCrossEntropyOpInt16
(
TestSoftmaxWithCrossEntropyOp
):
def
hard_label_dtype
(
self
):
return
"int16"
class
TestSoftmaxWithCrossEntropyOpInt8
(
TestSoftmaxWithCrossEntropyOp
):
def
hard_label_dtype
(
self
):
return
"int8"
class
TestSoftmaxWithCrossEntropyOpUInt8
(
TestSoftmaxWithCrossEntropyOp
):
def
hard_label_dtype
(
self
):
return
"uint8"
class
TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_1D
(
class
TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_1D
(
TestSoftmaxWithCrossEntropyOp
):
TestSoftmaxWithCrossEntropyOp
):
def
initParams
(
self
):
def
initParams
(
self
):
...
@@ -711,4 +736,5 @@ class TestSoftmaxWithCrossEntropyOpBoundary1(TestSoftmaxWithCrossEntropyOp):
...
@@ -711,4 +736,5 @@ class TestSoftmaxWithCrossEntropyOpBoundary1(TestSoftmaxWithCrossEntropyOp):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
paddle
.
enable_static
()
unittest
.
main
()
unittest
.
main
()
python/paddle/nn/functional/loss.py
浏览文件 @
eaa3fd45
...
@@ -1783,7 +1783,8 @@ def cross_entropy(input,
...
@@ -1783,7 +1783,8 @@ def cross_entropy(input,
fluid
.
data_feeder
.
check_variable_and_dtype
(
fluid
.
data_feeder
.
check_variable_and_dtype
(
input
,
'input'
,
[
'float32'
,
'float64'
],
'softmax_cross_entropy'
)
input
,
'input'
,
[
'float32'
,
'float64'
],
'softmax_cross_entropy'
)
fluid
.
data_feeder
.
check_variable_and_dtype
(
fluid
.
data_feeder
.
check_variable_and_dtype
(
label
,
'label'
,
[
'int32'
,
'int64'
,
'float32'
,
'float64'
],
label
,
'label'
,
[
'uint8'
,
'int8'
,
'int16'
,
'int32'
,
'int64'
,
'float32'
,
'float64'
],
'softmax_cross_entropy'
)
'softmax_cross_entropy'
)
attrs
=
{
attrs
=
{
'soft_label'
:
soft_label
,
'soft_label'
:
soft_label
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录