Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
faf8ad24
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
faf8ad24
编写于
9月 11, 2018
作者:
B
Bai Yifan
提交者:
GitHub
9月 11, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add ignore_index in cross_entropy op (#13217)
* add ignore index * update api.spec * enhance softmax_with_cross_entropy
上级
94b66bdb
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
148 addition
and
29 deletion
+148
-29
paddle/fluid/API.spec
paddle/fluid/API.spec
+2
-2
paddle/fluid/operators/cross_entropy_op.cc
paddle/fluid/operators/cross_entropy_op.cc
+5
-0
paddle/fluid/operators/cross_entropy_op.h
paddle/fluid/operators/cross_entropy_op.h
+17
-9
paddle/fluid/operators/math/cross_entropy.cc
paddle/fluid/operators/math/cross_entropy.cc
+7
-2
paddle/fluid/operators/math/cross_entropy.cu
paddle/fluid/operators/math/cross_entropy.cu
+10
-5
paddle/fluid/operators/math/cross_entropy.h
paddle/fluid/operators/math/cross_entropy.h
+2
-1
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
+6
-0
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
+7
-4
paddle/fluid/operators/softmax_with_cross_entropy_op.h
paddle/fluid/operators/softmax_with_cross_entropy_op.h
+2
-1
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+17
-5
python/paddle/fluid/tests/unittests/test_cross_entropy_op.py
python/paddle/fluid/tests/unittests/test_cross_entropy_op.py
+29
-0
python/paddle/fluid/tests/unittests/test_layers.py
python/paddle/fluid/tests/unittests/test_layers.py
+9
-0
python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py
...uid/tests/unittests/test_softmax_with_cross_entropy_op.py
+35
-0
未找到文件。
paddle/fluid/API.spec
浏览文件 @
faf8ad24
...
...
@@ -100,7 +100,7 @@ paddle.fluid.layers.gru_unit ArgSpec(args=['input', 'hidden', 'size', 'param_att
paddle.fluid.layers.linear_chain_crf ArgSpec(args=['input', 'label', 'param_attr'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.crf_decoding ArgSpec(args=['input', 'param_attr', 'label'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.cos_sim ArgSpec(args=['X', 'Y'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.cross_entropy ArgSpec(args=['input', 'label', 'soft_label'
], varargs=None, keywords=None, defaults=(False,
))
paddle.fluid.layers.cross_entropy ArgSpec(args=['input', 'label', 'soft_label'
, 'ignore_index'], varargs=None, keywords=None, defaults=(False, -100
))
paddle.fluid.layers.square_error_cost ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.chunk_eval ArgSpec(args=['input', 'label', 'chunk_scheme', 'num_chunk_types', 'excluded_chunk_types'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(3, 1, None, None, None, None))
...
...
@@ -142,7 +142,7 @@ paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 's
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.layer_norm ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None))
paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label'
], varargs=None, keywords=None, defaults=(False,
))
paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label'
, 'ignore_index'], varargs=None, keywords=None, defaults=(False, -100
))
paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1))
...
...
paddle/fluid/operators/cross_entropy_op.cc
浏览文件 @
faf8ad24
...
...
@@ -138,6 +138,11 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false), a flag indicating whether to "
"interpretate the given labels as soft labels."
)
.
SetDefault
(
false
);
AddAttr
<
int
>
(
"ignore_index"
,
"(int, default -100), Specifies a target value that is"
"ignored and does not contribute to the input gradient."
"Only valid if soft_label is set to False"
)
.
SetDefault
(
-
100
);
AddComment
(
R"DOC(
CrossEntropy Operator.
...
...
paddle/fluid/operators/cross_entropy_op.h
浏览文件 @
faf8ad24
...
...
@@ -40,7 +40,7 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> {
math
::
CrossEntropyFunctor
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
&
y_2d
,
&
x_2d
,
&
labels_2d
,
ctx
.
Attr
<
bool
>
(
"soft_label"
));
ctx
.
Attr
<
bool
>
(
"soft_label"
)
,
ctx
.
Attr
<
int
>
(
"ignore_index"
)
);
}
};
...
...
@@ -74,14 +74,20 @@ class XeGradFunctor {
const
T
*
dy
,
// NOLINT
const
T
*
x
,
// NOLINT
const
int64_t
*
label
,
// NOLINT
size_t
num_classes
)
:
dx_
(
dx
),
dy_
(
dy
),
x_
(
x
),
label_
(
label
),
num_classes_
(
num_classes
)
{}
size_t
num_classes
,
size_t
ignore_index
)
:
dx_
(
dx
),
dy_
(
dy
),
x_
(
x
),
label_
(
label
),
num_classes_
(
num_classes
),
ignore_index_
(
ignore_index
)
{}
HOSTDEVICE
void
operator
()(
size_t
sample_id
)
{
auto
x_is_true_offset
=
sample_id
*
num_classes_
+
label_
[
sample_id
];
for
(
size_t
x_offset
=
sample_id
*
num_classes_
;
x_offset
<
(
sample_id
+
1
)
*
num_classes_
;
++
x_offset
)
{
dx_
[
x_offset
]
=
x_offset
!=
x_is_true_offset
dx_
[
x_offset
]
=
(
x_offset
!=
x_is_true_offset
||
label_
[
sample_id
]
==
ignore_index_
)
?
static_cast
<
T
>
(
0
)
:
-
dy_
[
sample_id
]
/
x_
[
x_offset
];
}
...
...
@@ -93,6 +99,7 @@ class XeGradFunctor {
const
T
*
x_
;
const
int64_t
*
label_
;
size_t
num_classes_
;
size_t
ignore_index_
;
};
template
<
typename
DeviceContext
,
typename
T
>
...
...
@@ -109,6 +116,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
// unnecessary to convert tensors to 2-D views.
int
rank
=
x
->
dims
().
size
();
int64_t
class_num
=
x
->
dims
()[
rank
-
1
];
int64_t
ignore_index
=
ctx
.
Attr
<
int
>
(
"ignore_index"
);
if
(
ctx
.
Attr
<
bool
>
(
"soft_label"
))
{
XeSoftlabelGradFunctor
<
T
>
functor
(
dx_data
,
dy
->
data
<
T
>
(),
x
->
data
<
T
>
(),
label
->
data
<
T
>
(),
...
...
@@ -118,9 +126,9 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
static_cast
<
size_t
>
(
dx
->
numel
()));
for_range
(
functor
);
}
else
{
XeGradFunctor
<
T
>
functor
(
dx_data
,
dy
->
data
<
T
>
(),
x
->
data
<
T
>
(),
label
->
data
<
int64_t
>
(),
static_cast
<
size_t
>
(
class_num
));
XeGradFunctor
<
T
>
functor
(
dx_data
,
dy
->
data
<
T
>
(),
x
->
data
<
T
>
(),
label
->
data
<
int64_t
>
(),
static_cast
<
size_t
>
(
class_num
),
static_cast
<
size_t
>
(
ignore_index
));
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
.
template
device_context
<
DeviceContext
>(),
static_cast
<
size_t
>
(
dy
->
numel
()));
...
...
paddle/fluid/operators/math/cross_entropy.cc
浏览文件 @
faf8ad24
...
...
@@ -28,7 +28,8 @@ class CrossEntropyFunctor<platform::CPUDeviceContext, T> {
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
prob
,
const
framework
::
Tensor
*
labels
,
const
bool
softLabel
)
{
const
framework
::
Tensor
*
labels
,
const
bool
softLabel
,
const
int
ignore_index
)
{
const
int
batch_size
=
prob
->
dims
()[
0
];
if
(
softLabel
)
{
auto
in
=
EigenMatrix
<
T
>::
From
(
*
prob
);
...
...
@@ -49,8 +50,12 @@ class CrossEntropyFunctor<platform::CPUDeviceContext, T> {
int
lbl
=
label_data
[
i
];
PADDLE_ENFORCE_GE
(
lbl
,
0
);
PADDLE_ENFORCE_LT
(
lbl
,
class_num
);
PADDLE_ENFORCE
((
lbl
>=
0
&&
lbl
<
class_num
)
||
lbl
==
ignore_index
);
int
index
=
i
*
class_num
+
lbl
;
loss_data
[
i
]
=
-
math
::
TolerableValue
<
T
>
()(
std
::
log
(
prob_data
[
index
]));
loss_data
[
i
]
=
lbl
==
ignore_index
?
0
:
-
math
::
TolerableValue
<
T
>
()(
std
::
log
(
prob_data
[
index
]));
}
}
}
...
...
paddle/fluid/operators/math/cross_entropy.cu
浏览文件 @
faf8ad24
...
...
@@ -23,11 +23,14 @@ namespace math {
namespace
{
template
<
typename
T
>
__global__
void
CrossEntropyKernel
(
T
*
Y
,
const
T
*
X
,
const
int64_t
*
label
,
const
int
N
,
const
int
D
)
{
const
int
N
,
const
int
D
,
const
int
ignore_index
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
N
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
PADDLE_ASSERT
(
label
[
i
]
>=
0
&&
label
[
i
]
<
D
);
Y
[
i
]
=
-
math
::
TolerableValue
<
T
>
()(
log
(
X
[
i
*
D
+
label
[
i
]]));
PADDLE_ASSERT
(
label
[
i
]
>=
0
&&
label
[
i
]
<
D
||
label
[
i
]
==
ignore_index
);
Y
[
i
]
=
ignore_index
==
label
[
i
]
?
0
:
-
math
::
TolerableValue
<
T
>
()(
log
(
X
[
i
*
D
+
label
[
i
]]));
}
}
...
...
@@ -57,7 +60,8 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
prob
,
const
framework
::
Tensor
*
labels
,
bool
softLabel
)
{
const
framework
::
Tensor
*
labels
,
bool
softLabel
,
const
int
ignore_index
)
{
const
T
*
prob_data
=
prob
->
data
<
T
>
();
T
*
loss_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
...
@@ -77,7 +81,8 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
int
block
=
512
;
int
grid
=
(
batch_size
+
block
-
1
)
/
block
;
CrossEntropyKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
loss_data
,
prob_data
,
label_data
,
batch_size
,
class_num
);
loss_data
,
prob_data
,
label_data
,
batch_size
,
class_num
,
ignore_index
);
}
}
};
...
...
paddle/fluid/operators/math/cross_entropy.h
浏览文件 @
faf8ad24
...
...
@@ -38,7 +38,8 @@ class CrossEntropyFunctor {
public:
void
operator
()(
const
DeviceContext
&
context
,
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
prob
,
const
framework
::
Tensor
*
labels
,
const
bool
softLabel
);
const
framework
::
Tensor
*
labels
,
const
bool
softLabel
,
const
int
ignore_index
);
};
}
// namespace math
}
// namespace operators
...
...
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
faf8ad24
...
...
@@ -44,6 +44,12 @@ class SoftmaxWithCrossEntropyOpMaker
"(bool, default: false), A flag to indicate whether to interpretate "
"the given labels as soft labels."
)
.
SetDefault
(
false
);
AddAttr
<
int
>
(
"ignore_index"
,
"(int, default -100), Specifies a target value that is ignored and"
"does not contribute to the input gradient. Only valid if soft_label"
"is set to False"
)
.
SetDefault
(
-
100
);
AddComment
(
R"DOC(
Softmax With Cross Entropy Operator.
...
...
paddle/fluid/operators/softmax_with_cross_entropy_op.cu
浏览文件 @
faf8ad24
...
...
@@ -26,7 +26,8 @@ using Tensor = framework::Tensor;
namespace
{
template
<
typename
T
>
__global__
void
CrossEntropyGrad
(
T
*
logit_grad
,
const
int64_t
*
labels
,
const
int
batch_size
,
const
int
class_num
)
{
const
int
batch_size
,
const
int
class_num
,
const
int
ignore_index
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
batch_size
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
idx
=
i
*
class_num
+
labels
[
i
];
...
...
@@ -260,6 +261,7 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
auto
*
loss_data
=
loss
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
soft_label
=
context
.
Attr
<
bool
>
(
"soft_label"
);
auto
ignore_index
=
context
.
Attr
<
int
>
(
"ignore_index"
);
if
(
soft_label
)
{
int
batch_size
=
logits
->
dims
()[
0
];
int
feature_size
=
logits
->
dims
()[
1
];
...
...
@@ -272,7 +274,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
math
::
SoftmaxCUDNNFunctor
<
T
>
()(
context
.
cuda_device_context
(),
logits
,
softmax
);
math
::
CrossEntropyFunctor
<
platform
::
CUDADeviceContext
,
T
>
()(
context
.
cuda_device_context
(),
loss
,
softmax
,
labels
,
false
);
context
.
cuda_device_context
(),
loss
,
softmax
,
labels
,
false
,
ignore_index
);
}
}
};
...
...
@@ -295,7 +298,7 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
const
int
class_num
=
logit_grad
->
dims
()[
1
];
int
block
=
512
;
auto
stream
=
context
.
cuda_device_context
().
stream
();
auto
ignore_index
=
context
.
Attr
<
int
>
(
"ignore_index"
);
if
(
context
.
Attr
<
bool
>
(
"soft_label"
))
{
int
grid
=
(
batch_size
*
class_num
+
block
-
1
)
/
block
;
const
T
*
label_data
=
labels
->
data
<
T
>
();
...
...
@@ -305,7 +308,7 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
int
grid
=
(
batch_size
+
block
-
1
)
/
block
;
const
int64_t
*
label_data
=
labels
->
data
<
int64_t
>
();
CrossEntropyGrad
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
logit_grad_data
,
label_data
,
batch_size
,
class_num
);
logit_grad_data
,
label_data
,
batch_size
,
class_num
,
ignore_index
);
int
num
=
batch_size
*
class_num
;
grid
=
(
num
+
block
-
1
)
/
block
;
Scale
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
logit_grad_data
,
loss_grad_data
,
num
,
...
...
paddle/fluid/operators/softmax_with_cross_entropy_op.h
浏览文件 @
faf8ad24
...
...
@@ -45,7 +45,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
math
::
SoftmaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
()(
dev_ctx
,
logits
,
softmax
);
math
::
CrossEntropyFunctor
<
platform
::
CPUDeviceContext
,
T
>
()(
dev_ctx
,
loss
,
softmax
,
labels
,
context
.
Attr
<
bool
>
(
"soft_label"
));
dev_ctx
,
loss
,
softmax
,
labels
,
context
.
Attr
<
bool
>
(
"soft_label"
),
context
.
Attr
<
int
>
(
"ignore_index"
));
}
};
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
faf8ad24
...
...
@@ -968,7 +968,7 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
return
out
def
cross_entropy
(
input
,
label
,
soft_label
=
False
):
def
cross_entropy
(
input
,
label
,
soft_label
=
False
,
ignore_index
=-
100
):
"""
**Cross Entropy Layer**
...
...
@@ -1012,7 +1012,10 @@ def cross_entropy(input, label, soft_label=False):
tensor<float/double> with shape [N x D].
soft_label (bool): a flag indicating whether to
interpretate the given labels as soft
labels, default `False`.
labels. Default: `False`.
ignore_index (int): Specifies a target value that is ignored and does
not contribute to the input gradient. Only valid
if soft_label is set to False. Default: -100
Returns:
A 2-D tensor with shape [N x 1], the cross entropy loss.
...
...
@@ -1037,7 +1040,8 @@ def cross_entropy(input, label, soft_label=False):
inputs
=
{
'X'
:
[
input
],
'Label'
:
[
label
]},
outputs
=
{
'Y'
:
[
out
]},
attrs
=
{
"soft_label"
:
soft_label
})
attrs
=
{
"soft_label"
:
soft_label
,
"ignore_index"
:
ignore_index
})
return
out
...
...
@@ -4242,7 +4246,10 @@ def multiplex(inputs, index):
return
out
def
softmax_with_cross_entropy
(
logits
,
label
,
soft_label
=
False
):
def
softmax_with_cross_entropy
(
logits
,
label
,
soft_label
=
False
,
ignore_index
=-
100
):
"""
**Softmax With Cross Entropy Operator.**
...
...
@@ -4284,6 +4291,10 @@ def softmax_with_cross_entropy(logits, label, soft_label=False):
soft_label is set to true, Label is a Tensor<float/double> with
soft_label (bool): A flag to indicate whether to interpretate the given
labels as soft labels. By default, `soft_label` is set to False.
ignore_index (int): Specifies a target value that is ignored and does
not contribute to the input gradient. Only valid
if soft_label is set to False. Default: -100
Returns:
Variable: The cross entropy loss is a 2-D tensor with shape [N x 1].
...
...
@@ -4305,7 +4316,8 @@ def softmax_with_cross_entropy(logits, label, soft_label=False):
'Label'
:
label
},
outputs
=
{
'Softmax'
:
softmax
,
'Loss'
:
loss
},
attrs
=
{
'soft_label'
:
soft_label
})
attrs
=
{
'soft_label'
:
soft_label
,
'ignore_index'
:
ignore_index
})
return
loss
...
...
python/paddle/fluid/tests/unittests/test_cross_entropy_op.py
浏览文件 @
faf8ad24
...
...
@@ -209,5 +209,34 @@ class TestCrossEntropyOp6(OpTest):
[
"X"
],
"Y"
,
max_relative_error
=
0.05
,
numeric_grad_delta
=
0.001
)
class
TestCrossEntropyOp7
(
OpTest
):
"""Test cross-entropy with ignore index.
"""
def
setUp
(
self
):
self
.
op_type
=
"cross_entropy"
batch_size
=
30
class_num
=
10
ignore_index
=
3
X
=
randomize_probability
(
batch_size
,
class_num
,
dtype
=
'float64'
)
label
=
np
.
random
.
randint
(
0
,
class_num
,
(
batch_size
,
1
),
dtype
=
"int64"
)
cross_entropy
=
np
.
asmatrix
(
[[
-
np
.
log
(
X
[
i
][
label
[
i
][
0
]])]
if
label
[
i
][
0
]
!=
ignore_index
else
[
0
]
for
i
in
range
(
X
.
shape
[
0
])],
dtype
=
"float64"
)
self
.
inputs
=
{
"X"
:
X
,
"Label"
:
label
}
self
.
outputs
=
{
"Y"
:
cross_entropy
}
self
.
attrs
=
{
"soft_label"
:
False
,
"ignore_index"
:
ignore_index
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Y"
,
numeric_grad_delta
=
0.001
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_layers.py
浏览文件 @
faf8ad24
...
...
@@ -556,6 +556,15 @@ class TestBook(unittest.TestCase):
out
=
layers
.
sequence_enumerate
(
input
=
x
,
win_size
=
2
,
pad_value
=
0
)
print
(
str
(
program
))
def
test_cross_entropy
(
self
):
program
=
Program
()
with
program_guard
(
program
):
x
=
layers
.
data
(
name
=
"x"
,
shape
=
[
30
,
10
],
dtype
=
"float32"
)
label
=
layers
.
data
(
name
=
"label"
,
shape
=
[
30
,
1
],
dtype
=
"int32"
)
mode
=
'channel'
out
=
layers
.
cross_entropy
(
x
,
label
,
False
,
4
)
self
.
assertIsNotNone
(
out
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py
浏览文件 @
faf8ad24
...
...
@@ -88,5 +88,40 @@ class TestSoftmaxWithCrossEntropyOp2(OpTest):
self
.
check_grad
([
"Logits"
],
"Loss"
)
class
TestSoftmaxWithCrossEntropyOp3
(
OpTest
):
"""
Test softmax with cross entropy operator with ignore_index.
"""
def
setUp
(
self
):
self
.
op_type
=
"softmax_with_cross_entropy"
batch_size
=
41
class_num
=
37
logits
=
np
.
random
.
uniform
(
0.1
,
1.0
,
[
batch_size
,
class_num
]).
astype
(
"float64"
)
softmax
=
np
.
apply_along_axis
(
stable_softmax
,
1
,
logits
)
labels
=
np
.
random
.
randint
(
0
,
class_num
,
[
batch_size
,
1
],
dtype
=
"int64"
)
ignore_index
=
7
cross_entropy
=
np
.
asmatrix
(
[[
-
np
.
log
(
softmax
[
i
][
labels
[
i
][
0
]])]
if
labels
[
i
]
!=
ignore_index
else
[
0
]
for
i
in
range
(
softmax
.
shape
[
0
])],
dtype
=
"float64"
)
self
.
inputs
=
{
"Logits"
:
logits
,
"Label"
:
labels
}
self
.
outputs
=
{
"Softmax"
:
softmax
.
astype
(
"float64"
),
"Loss"
:
cross_entropy
.
astype
(
"float64"
)
}
self
.
attrs
=
{
"ignore_index"
:
ignore_index
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"Logits"
],
"Loss"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录