未验证 提交 a71d8fdb 编写于 作者: K Kaipeng Deng 提交者: GitHub

Softmax_cross_entropy op add axis (#16806)

* add attr axis infershape. test=develop

* add CUDA kernel. test=develop

* fix unittest. test=develop

* fix unittest for soft_label. test=develop

* fix fp16 unittest. test=develop

* remove comment code. test=develop

* refine test for axis. test=develop

* add python api. test=develop

* fix doc. test=develop

* fix fp16 unittest. test=develop

* fix ngraph test. test=develop

* fix ENFORCE for test_imperative_transformer. test=develop

* fit for ngraph test. test=develop

* fix after rebase develop. test=develop

* fix doc. test=develop

* fix API.spec. test=develop

* fix test_layers. test=develop

* fix format. test=develop
上级 c2e20e2a
......@@ -133,7 +133,7 @@ paddle.fluid.layers.multiplex (ArgSpec(args=['inputs', 'index'], varargs=None, k
paddle.fluid.layers.layer_norm (ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None)), ('document', 'de6a906950bae9f3c245cb744d22b94e'))
paddle.fluid.layers.group_norm (ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None)), ('document', '419c3a24a83cc89219a029cf4092788b'))
paddle.fluid.layers.spectral_norm (ArgSpec(args=['weight', 'dim', 'power_iters', 'eps', 'name'], varargs=None, keywords=None, defaults=(0, 1, 1e-12, None)), ('document', '3f536aafba30d793287b52d231baff1b'))
paddle.fluid.layers.softmax_with_cross_entropy (ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode', 'return_softmax'], varargs=None, keywords=None, defaults=(False, -100, True, False)), ('document', 'bce1b75e3d95b75cacd1099655cbb3c3'))
paddle.fluid.layers.softmax_with_cross_entropy (ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode', 'return_softmax', 'axis'], varargs=None, keywords=None, defaults=(False, -100, True, False, -1)), ('document', '8b074f9c56b4233a2b65d03254eb309e'))
paddle.fluid.layers.smooth_l1 (ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'c6b175d253c55baf4b9c0eca9b1dda88'))
paddle.fluid.layers.one_hot (ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None), ('document', '960fc799549c202da1e85d626cb2c962'))
paddle.fluid.layers.autoincreased_step_counter (ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1)), ('document', '67afefa80b6cc38801bd5b631fed8a4a'))
......
......@@ -39,9 +39,10 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> {
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
Tensor y_2d = framework::ReshapeToMatrix(*y, rank - 1);
int axis_dim = x->dims()[rank - 1];
math::CrossEntropyFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), &y_2d, &x_2d, &labels_2d,
ctx.Attr<bool>("soft_label"), ctx.Attr<int>("ignore_index"));
ctx.Attr<bool>("soft_label"), ctx.Attr<int>("ignore_index"), axis_dim);
}
};
......
......@@ -29,8 +29,13 @@ class CrossEntropyFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, framework::Tensor* out,
const framework::Tensor* prob,
const framework::Tensor* labels, const bool softLabel,
const int ignore_index) {
const int ignore_index, const int axis_dim) {
const int batch_size = prob->dims()[0];
const int num_classes = prob->dims()[1];
const int num_remain = num_classes / axis_dim;
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
if (softLabel) {
auto in = EigenMatrix<T>::From(*prob);
auto lbl = EigenMatrix<T>::From(*labels);
......@@ -38,24 +43,24 @@ class CrossEntropyFunctor<platform::CPUDeviceContext, T> {
loss.device(*ctx.eigen_device()) =
-((lbl * in.log().unaryExpr(math::TolerableValue<T>()))
.sum(Eigen::DSizes<int, 1>(1))
.reshape(Eigen::DSizes<int, 2>(batch_size, 1)));
.reshape(batch_axis_remain)
.sum(Eigen::DSizes<int, 1>(1)));
} else {
const int class_num = prob->dims()[1];
const T* prob_data = prob->data<T>();
T* loss_data = out->data<T>();
const int64_t* label_data = labels->data<int64_t>();
for (int i = 0; i < batch_size; ++i) {
int lbl = label_data[i];
PADDLE_ENFORCE_GE(lbl, 0);
PADDLE_ENFORCE_LT(lbl, class_num);
PADDLE_ENFORCE((lbl >= 0 && lbl < class_num) || lbl == ignore_index);
int index = i * class_num + lbl;
loss_data[i] =
lbl == ignore_index
? 0
: -math::TolerableValue<T>()(std::log(prob_data[index]));
for (int j = 0; j < num_remain; j++) {
int lbl = label_data[i * num_remain + j];
PADDLE_ENFORCE((lbl >= 0 && lbl < axis_dim) || lbl == ignore_index);
int index = i * num_classes + lbl * num_remain + j;
int loss_idx = i * num_remain + j;
loss_data[loss_idx] =
lbl == ignore_index
? 0
: -math::TolerableValue<T>()(std::log(prob_data[index]));
}
}
}
}
......
......@@ -57,8 +57,8 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& ctx,
framework::Tensor* out, const framework::Tensor* prob,
const framework::Tensor* labels, bool softLabel,
const int ignore_index) {
const framework::Tensor* labels, const bool softLabel,
const int ignore_index, const int axis_dim) {
const T* prob_data = prob->data<T>();
T* loss_data = out->mutable_data<T>(ctx.GetPlace());
......
......@@ -60,7 +60,7 @@ class CrossEntropyFunctor {
void operator()(const DeviceContext& context, framework::Tensor* out,
const framework::Tensor* prob,
const framework::Tensor* labels, const bool softLabel,
const int ignore_index);
const int ignore_index, const int axis_dim);
};
} // namespace math
} // namespace operators
......
......@@ -26,23 +26,28 @@ class SoftmaxWithCrossEntropyOpMaker
public:
void Make() override {
AddInput("Logits",
"(Tensor, default: Tensor<float>), The unscaled log probabilities "
"which is a 2-D tensor with shape [N x K]. N is the batch_size, "
"and K is the class number.");
AddInput("Label",
"(Tensor) The ground truth which is a 2-D tensor. If soft_label "
"is set to false, Label is a Tensor<int64> with shape [N x 1]. If "
"soft_label is set to true, Label is a Tensor<float/double> with "
"shape [N x K].");
"(Tensor, default: Tensor<float>), The input tensor of unscaled "
"log probabilities, whose dimension :attr:`axis` should be scaled "
"by softmax.");
AddInput(
"Label",
"(Tensor) The input tesnor of groud truth label. If :attr:`soft_label` "
"is set to false, Label is a Tensor<int64> in same shape with "
"Input(Logits) except the shape in dimension :attr:`axis` as 1. If "
"soft_label is set to true, Label is a Tensor<float/double> in same "
"shape with Input(Logits).");
AddOutput(
"Softmax",
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x K]. "
"(Tensor, default: Tensor<float>), A tensor in same shape with "
"Input(Logits). "
"The outputs value of softmax activation by given the input batch, "
"which will be used in backward calculation.")
.AsIntermediate();
AddOutput("Loss",
"(Tensor, default: Tensor<float>), A 2-D tensor. The cross "
"entropy loss with shape [N x 1].");
"(Tensor, default: Tensor<float>), A tensor in same shape with "
"Input(Logits) "
"except the shape in dimension :attr:`axis` as 1. The cross "
"entropy loss.");
AddAttr<bool>(
"soft_label",
"(bool, default: false), A flag to indicate whether to interpretate "
......@@ -60,6 +65,10 @@ class SoftmaxWithCrossEntropyOpMaker
"does not contribute to the input gradient. Only valid if soft_label"
"is set to False")
.SetDefault(-100);
AddAttr<int>("axis",
"The dimension index of Input(Logits) to perform softmax,"
"default -1 for last dimension")
.SetDefault(-1);
AddComment(R"DOC(
Softmax With Cross Entropy Operator.
......@@ -107,38 +116,53 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
"Output(Softmax) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) should be not null.");
auto axis = ctx->Attrs().Get<int>("axis");
auto logits_dims = ctx->GetInputDim("Logits");
auto labels_dims = ctx->GetInputDim("Label");
auto logits_rank = logits_dims.size();
PADDLE_ENFORCE(axis >= -logits_rank && axis < logits_rank,
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(Logits).");
axis = CanonicalAxis(axis, logits_rank);
for (int i = 0; i < logits_rank; i++) {
if (i != axis) {
if (ctx->IsRuntime() || (logits_dims[i] > 0 && labels_dims[i] > 0)) {
PADDLE_ENFORCE_EQ(
logits_dims[i], labels_dims[i],
"Input(Logits) and Input(Label) should in same shape in "
"dimensions except axis.");
}
}
}
int rank = logits_dims.size();
PADDLE_ENFORCE_EQ(
rank, labels_dims.size(),
"Input(logits) and Input(Label) shall have the same rank.");
bool check = ctx->IsRuntime() || (framework::product(logits_dims) > 0 &&
framework::product(labels_dims) > 0);
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(logits_dims, 0, rank - 1),
framework::slice_ddim(labels_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension.");
auto numeric_stable_mode = ctx->Attrs().Get<bool>("numeric_stable_mode");
if (axis != logits_rank - 1) {
PADDLE_ENFORCE(
numeric_stable_mode,
"Attr(axis) can only be -1 when not in numeric_stable_mode.");
}
if (ctx->Attrs().Get<bool>("soft_label")) {
if (check) {
PADDLE_ENFORCE_EQ(logits_dims[rank - 1], labels_dims[rank - 1],
"If Attr(soft_label) == true, the last dimension of "
bool soft_label = ctx->Attrs().Get<bool>("soft_label");
if (soft_label) {
if (ctx->IsRuntime() ||
(logits_dims[axis] > 0 && labels_dims[axis] > 0)) {
PADDLE_ENFORCE_EQ(logits_dims[axis], labels_dims[axis],
"If Attr(soft_label) == true, the axis dimension of "
"Input(X) and Input(Label) should be equal.");
}
} else {
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL,
"If Attr(softLabel) == false, the last dimension of "
"Input(Label) should be 1.");
if (ctx->IsRuntime() || labels_dims[axis] > 0) {
PADDLE_ENFORCE_EQ(labels_dims[axis], 1UL,
"If Attr(soft_label) == false, the axis dimension of "
"Input(Label) should be 1.");
}
}
ctx->SetOutputDim("Softmax", logits_dims);
auto loss_dims = logits_dims;
loss_dims[rank - 1] = 1;
ctx->SetOutputDim("Loss", loss_dims);
logits_dims[axis] = 1;
ctx->SetOutputDim("Loss", logits_dims);
ctx->ShareLoD("Logits", /*->*/ "Softmax");
ctx->ShareLoD("Logits", /*->*/ "Loss");
......@@ -165,36 +189,40 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
"Output(Logits@Grad) should be not null.");
auto axis = ctx->Attrs().Get<int>("axis");
auto softmax_dims = ctx->GetInputDim("Softmax");
auto labels_dims = ctx->GetInputDim("Label");
int rank = softmax_dims.size();
PADDLE_ENFORCE_EQ(
rank, labels_dims.size(),
"Input(logits) and Input(Label) shall have the same rank.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(softmax_dims) <= 0 ||
framework::product(labels_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(
framework::slice_ddim(softmax_dims, 0, rank - 1),
framework::slice_ddim(labels_dims, 0, rank - 1),
"Input(Softmax) and Input(Label) shall have the same shape "
"except the last dimension.");
auto softmax_rank = softmax_dims.size();
PADDLE_ENFORCE(axis >= -softmax_rank && axis < softmax_rank,
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(Logits).");
axis = CanonicalAxis(axis, softmax_rank);
for (int i = 0; i < softmax_rank; i++) {
if (i != axis) {
if (ctx->IsRuntime() || (softmax_dims[i] > 0 && labels_dims[i] > 0)) {
PADDLE_ENFORCE_EQ(
softmax_dims[i], labels_dims[i],
"Input(Logits) and Input(Label) should in same shape in "
"dimensions except axis.");
}
}
}
if (ctx->Attrs().Get<bool>("soft_label")) {
if (check) {
PADDLE_ENFORCE_EQ(softmax_dims[rank - 1], labels_dims[rank - 1],
"If Attr(soft_label) == true, the last dimension of "
"Input( Softmax) and Input(Label) should be equal.");
bool soft_label = ctx->Attrs().Get<bool>("soft_label");
if (soft_label) {
if (ctx->IsRuntime() ||
(softmax_dims[axis] > 0 && labels_dims[axis] > 0)) {
PADDLE_ENFORCE_EQ(softmax_dims[axis], labels_dims[axis],
"If Attr(soft_label) == true, the axis dimension of "
"Input(X) and Input(Label) should be equal.");
}
} else {
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL,
"If Attr(softLabel) == false, the last dimension of "
"Input(Label) should be 1.");
if (ctx->IsRuntime() || labels_dims[axis] > 0) {
PADDLE_ENFORCE_EQ(labels_dims[axis], 1UL,
"If Attr(soft_label) == false, the axis dimension of "
"Input(Label) should be 1.");
}
}
ctx->SetOutputDim(framework::GradVarName("Logits"),
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/softmax_op.h"
namespace paddle {
namespace operators {
......@@ -36,26 +37,30 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
const Tensor* labels = context.Input<Tensor>("Label");
Tensor* softmax = context.Output<Tensor>("Softmax");
Tensor* loss = context.Output<Tensor>("Loss");
const bool soft_label = context.Attr<bool>("soft_label");
const int rank = logits->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = logits->dims()[axis];
softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
// reshape to 2D tensor
int rank = logits->dims().size();
Tensor logits_2d = framework::ReshapeToMatrix(*logits, rank - 1);
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
Tensor loss_2d = framework::ReshapeToMatrix(*loss, rank - 1);
Tensor softmax_2d = framework::ReshapeToMatrix(*softmax, rank - 1);
int axis_dim = logits->dims()[rank - 1];
const int n = SizeToAxis(axis, logits->dims());
const int d = SizeFromAxis(axis, logits->dims());
Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
logits_2d.ShareDataWith(*logits).Resize({n, d});
softmax_2d.ShareDataWith(*softmax).Resize({n, d});
labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
loss_2d.ShareDataWith(*loss).Resize({n, d / axis_dim});
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
dev_ctx, axis_dim, &logits_2d, &softmax_2d);
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
dev_ctx, &loss_2d, &softmax_2d, &labels_2d,
context.Attr<bool>("soft_label"), context.Attr<int>("ignore_index"));
dev_ctx, &loss_2d, &softmax_2d, &labels_2d, soft_label,
context.Attr<int>("ignore_index"), axis_dim);
}
};
......@@ -75,34 +80,43 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
context.device_context(), logit_grad);
}
int rank = logit_grad->dims().size();
const int class_num = logit_grad->dims()[rank - 1];
// reshape to 2d
Tensor logit_grad_2d = framework::ReshapeToMatrix(*logit_grad, rank - 1);
Tensor out_grad_2d = framework::ReshapeToMatrix(*out_grad, rank - 1);
const bool soft_label = context.Attr<bool>("soft_label");
const int rank = logit_grad->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = logit_grad->dims()[axis];
const int n = SizeToAxis(axis, logit_grad->dims());
const int d = SizeFromAxis(axis, logit_grad->dims());
Tensor logit_grad_2d, labels_2d, out_grad_2d;
logit_grad_2d.ShareDataWith(*logit_grad).Resize({n, d});
labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
out_grad_2d.ShareDataWith(*out_grad).Resize({n, d / axis_dim});
auto out_grad_mat = EigenMatrix<T>::From(out_grad_2d);
auto logit_grad_mat = EigenMatrix<T>::From(logit_grad_2d);
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
if (context.Attr<bool>("soft_label")) {
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
if (soft_label) {
auto lbl_mat = EigenMatrix<T>::From(labels_2d);
logit_grad_mat.device(place) =
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num)) *
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
(logit_grad_mat - lbl_mat);
} else {
logit_grad_mat.device(place) =
logit_grad_mat *
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num));
const int batch_size = logit_grad_2d.dims()[0];
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim));
const int64_t* label_data = labels->data<int64_t>();
T* logit_grad_data = logit_grad->data<T>();
const T* out_grad_data = out_grad->data<T>();
for (int i = 0; i < batch_size; ++i) {
logit_grad_data[i * class_num + label_data[i]] -= out_grad_data[i];
const int remain = d / axis_dim;
for (int i = 0; i < n; ++i) {
for (int j = 0; j < remain; j++) {
int idx = i * remain + j;
logit_grad_data[i * d + label_data[idx] * remain + j] -=
out_grad_data[idx];
}
}
}
}
......
......@@ -6099,22 +6099,24 @@ def softmax_with_cross_entropy(logits,
soft_label=False,
ignore_index=kIgnoreIndex,
numeric_stable_mode=True,
return_softmax=False):
return_softmax=False,
axis=-1):
"""
**Softmax With Cross Entropy Operator.**
Cross entropy loss with softmax is used as the output layer extensively. This
operator computes the softmax normalized values for each row of the input
tensor, after which cross-entropy loss is computed. This provides a more
numerically stable gradient.
operator computes the softmax normalized values for dimension :attr:`axis` of
the input tensor, after which cross-entropy loss is computed. This provides
a more numerically stable gradient.
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
When the attribute soft_label is set false, this operators expects mutually
exclusive hard labels, each sample in a batch is in exactly one class with a
probability of 1.0. Each sample in the batch will have a single label.
When the attribute :attr:`soft_label` is set :attr:`False`, this operators
expects mutually exclusive hard labels, each sample in a batch is in exactly
one class with a probability of 1.0. Each sample in the batch will have a
single label.
The equation is as follows:
......@@ -6133,7 +6135,8 @@ def softmax_with_cross_entropy(logits,
\\left(\\text{logit}_i - \\log\\left(\\sum_{i=0}^{K}
\\exp(\\text{logit}_i)\\right)\\right), j = 1,...,K
3) If numeric_stable_mode is True, softmax is calculated first by:
3) If :attr:`numeric_stable_mode` is :attr:`True`, softmax is calculated
first by:
.. math::
......@@ -6146,32 +6149,39 @@ def softmax_with_cross_entropy(logits,
and then cross entropy loss is calculated by softmax and label.
Args:
logits (Variable): The unscaled log probabilities, which is a 2-D tensor
with shape [N x K]. N is the batch_size, and K is the class number.
label (Variable): The ground truth which is a 2-D tensor. If soft_label
is set to false, Label is a Tensor<int64> with shape [N x 1]. If
soft_label is set to true, Label is a Tensor<float/double> with
logits (Variable): The input tensor of unscaled log probabilities.
label (Variable): The ground truth tensor. If :attr:`soft_label`
is set to :attr:`True`, Label is a Tensor<float/double> in the
same shape with :attr:`logits`. If :attr:`soft_label` is set to
:attr:`True`, Label is a Tensor<int64> in the same shape with
:attr:`logits` expect shape in dimension :attr:`axis` as 1.
soft_label (bool): A flag to indicate whether to interpretate the given
labels as soft labels. By default, `soft_label` is set to False.
labels as soft labels. Default False.
ignore_index (int): Specifies a target value that is ignored and does
not contribute to the input gradient. Only valid
if soft_label is set to False. Default: kIgnoreIndex
if :attr:`soft_label` is set to :attr:`False`.
Default: kIgnoreIndex
numeric_stable_mode (bool): A flag to indicate whether to use a more
numerically stable algorithm. Only valid
when soft_label is False and GPU is used.
When soft_label is True or CPU is used,
the algorithm is always numerically stable.
when :attr:`soft_label` is :attr:`False`
and GPU is used. When :attr:`soft_label`
is :attr:`True` or CPU is used, the
algorithm is always numerically stable.
Note that the speed may be slower when use
stable algorithm. Default: True
return_softmax (bool): A flag indicating whether to return the softmax
along with the cross entropy loss. Default: False
axis (int): The index of dimension to perform softmax calculations. It
should be in range :math:`[-1, rank - 1]`, while :math:`rank`
is the rank of input :attr:`logits`. Default: -1.
Returns:
Variable or Tuple of two Variables: Return the cross entropy loss if \
`return_softmax` is False, otherwise the tuple \
(loss, softmax), where the cross entropy loss is \
a 2-D tensor with shape [N x 1], and softmax is a \
2-D tensor with shape [N x K].
(loss, softmax), softmax is in the same shape \
with input logits and cross entropy loss is in \
the same shape with input logits except shape \
in dimension :attr:`axis` as 1.
Examples:
.. code-block:: python
......@@ -6194,7 +6204,8 @@ def softmax_with_cross_entropy(logits,
attrs={
'soft_label': soft_label,
'ignore_index': ignore_index,
'numeric_stable_mode': numeric_stable_mode
'numeric_stable_mode': numeric_stable_mode,
'axis': axis
})
if return_softmax:
......
......@@ -1221,10 +1221,25 @@ class TestBook(LayerTest):
y = self._get_data(name='label', shape=[1], dtype='int64')
loss, softmax = layers.softmax_with_cross_entropy(
x, y, return_softmax=True)
return (loss)
return (softmax)
self.assertIsNotNone(loss)
self.assertIsNotNone(softmax)
loss = layers.softmax_with_cross_entropy(x, y)
return (loss)
self.assertIsNotNone(loss)
x1 = self._get_data(name='x1', shape=[16, 32, 64], dtype='float32')
y1 = self._get_data(name='label1', shape=[1, 32, 64], dtype='int64')
y2 = self._get_data(name='label2', shape=[16, 1, 64], dtype='int64')
y3 = self._get_data(name='label3', shape=[16, 32, 1], dtype='int64')
loss1 = layers.softmax_with_cross_entropy(x1, y1, axis=1)
loss2 = layers.softmax_with_cross_entropy(x1, y2, axis=2)
loss3 = layers.softmax_with_cross_entropy(x1, y3, axis=3)
loss4 = layers.softmax_with_cross_entropy(x1, y3, axis=-1)
self.assertIsNotNone(loss1)
self.assertIsNotNone(loss2)
self.assertIsNotNone(loss3)
self.assertIsNotNone(loss4)
return (loss4)
def make_smooth_l1(self):
with program_guard(fluid.default_main_program(),
......
......@@ -21,37 +21,70 @@ from op_test import OpTest
from test_softmax_op import stable_softmax
def cross_entropy(softmax, label, soft_label, axis, ignore_index=-1):
if soft_label:
return (-label * np.log(softmax)).sum(axis=axis, keepdims=True)
shape = softmax.shape
axis %= len(shape)
n = int(np.prod(shape[:axis]))
axis_dim = shape[axis]
remain = int(np.prod(shape[axis + 1:]))
softmax_reshape = softmax.reshape((n, axis_dim, remain))
label_reshape = label.reshape((n, 1, remain))
result = np.zeros_like(label_reshape, dtype=softmax.dtype)
for i in range(n):
for j in range(remain):
lbl = label_reshape[i, 0, j]
if lbl != ignore_index:
result[i, 0, j] -= np.log(softmax_reshape[i, lbl, j])
return result.reshape(label.shape)
class TestSoftmaxWithCrossEntropyOp(OpTest):
"""
Test softmax with cross entropy operator with discreate one-hot labels.
"""
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = False
self.soft_label = False
self.dtype = np.float64
self.axis = -1
self.ignore_index = -1
self.shape = [41, 37]
def setUp(self):
self.initParams()
self.op_type = "softmax_with_cross_entropy"
batch_size = 41
class_num = 37
logits = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype(self.dtype)
softmax = np.apply_along_axis(stable_softmax, 1, logits)
labels = np.random.randint(0, class_num, [batch_size, 1], dtype="int64")
logits = np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype)
softmax = np.apply_along_axis(stable_softmax, self.axis, logits)
if self.soft_label:
labels = np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype)
labels /= np.sum(labels, axis=self.axis, keepdims=True)
else:
axis_dim = self.shape[self.axis]
self.shape[self.axis] = 1
labels = np.random.randint(0, axis_dim, self.shape, dtype="int64")
cross_entropy = np.asmatrix(
[[-np.log(softmax[i][labels[i][0]])]
for i in range(softmax.shape[0])],
dtype=self.dtype)
loss = cross_entropy(softmax, labels, self.soft_label, self.axis,
self.ignore_index)
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {
"Softmax": softmax.astype(self.dtype),
"Loss": cross_entropy.astype(self.dtype)
"Loss": loss.astype(self.dtype)
}
self.attrs = {"numeric_stable_mode": self.numeric_stable_mode}
self.attrs = {
"numeric_stable_mode": self.numeric_stable_mode,
"soft_label": self.soft_label,
}
if self.ignore_index >= 0:
self.attrs['ignore_index'] = self.ignore_index
if self.axis != -1:
self.attrs['axis'] = self.axis
def test_check_output(self):
self.check_output()
......@@ -62,30 +95,38 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
class TestSoftmaxWithCrossEntropyOpNoCudnn(TestSoftmaxWithCrossEntropyOp):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.axis = -1
self.ignore_index = -1
self.dtype = np.float64
class TestSoftmaxWithCrossEntropyOpFp16(TestSoftmaxWithCrossEntropyOp):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = False
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.axis = -1
self.ignore_index = -1
self.dtype = np.float16
def setUp(self):
self.initParams()
self.op_type = "softmax_with_cross_entropy"
batch_size = 41
class_num = 37
# NOTE: numpy float16 have very low accuracy, use float32 for numpy check.
logits = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype(np.float32)
softmax = np.apply_along_axis(stable_softmax, 1, logits)
labels = np.random.randint(0, class_num, [batch_size, 1], dtype="int64")
logits = np.random.uniform(0.1, 1.0, self.shape).astype(np.float32)
softmax = np.apply_along_axis(stable_softmax, self.axis, logits)
cross_entropy = np.asmatrix(
[[-np.log(softmax[i][labels[i][0]])]
for i in range(softmax.shape[0])],
dtype=np.float32)
axis_dim = self.shape[self.axis]
self.shape[self.axis] = 1
labels = np.random.randint(0, axis_dim, self.shape, dtype="int64")
loss = cross_entropy(softmax, labels, self.soft_label, self.axis)
self.inputs = {
"Logits": logits.astype(self.dtype).view(np.uint16),
......@@ -93,9 +134,14 @@ class TestSoftmaxWithCrossEntropyOpFp16(TestSoftmaxWithCrossEntropyOp):
}
self.outputs = {
"Softmax": softmax.astype(self.dtype),
"Loss": cross_entropy.astype(self.dtype)
"Loss": loss.astype(self.dtype)
}
self.attrs = {
"numeric_stable_mode": self.numeric_stable_mode,
"soft_label": self.soft_label,
}
self.attrs = {"numeric_stable_mode": self.numeric_stable_mode}
if self.axis != -1:
self.attrs['axis'] = self.axis
def test_check_output(self):
self.check_output(atol=1e-2)
......@@ -107,39 +153,31 @@ class TestSoftmaxWithCrossEntropyOpFp16(TestSoftmaxWithCrossEntropyOp):
class TestSoftmaxWithCrossEntropyOpNoCudnnFp16(
TestSoftmaxWithCrossEntropyOpFp16):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.axis = -1
self.ignore_index = -1
self.dtype = np.float16
def test_check_grad(self):
self.check_grad(["Logits"], "Loss", max_relative_error=0.1)
class TestSoftmaxWithCrossEntropyOp2(OpTest):
class TestSoftmaxWithCrossEntropyOp2(TestSoftmaxWithCrossEntropyOp):
"""
Test softmax with cross entropy operator with soft labels.
"""
def setUp(self):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
batch_size = 41
class_num = 37
logits = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float64")
softmax = np.apply_along_axis(stable_softmax, 1, logits)
labels = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float64")
labels /= np.sum(labels, axis=1, keepdims=True)
cross_entropy = (-labels * np.log(softmax)).sum(
axis=1, keepdims=True).astype("float64")
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {
"Softmax": softmax.astype("float64"),
"Loss": cross_entropy.astype("float64")
}
self.attrs = {"soft_label": True}
self.numeric_stable_mode = True
self.soft_label = True
self.dtype = np.float64
self.axis = -1
self.ignore_index = -1
self.shape = [41, 37]
def test_check_output(self):
self.check_output()
......@@ -148,190 +186,226 @@ class TestSoftmaxWithCrossEntropyOp2(OpTest):
self.check_grad(["Logits"], "Loss")
class TestSoftmaxWithCrossEntropyOp3(OpTest):
class TestSoftmaxWithCrossEntropyOp3(TestSoftmaxWithCrossEntropyOp):
"""
Test softmax with cross entropy operator with ignore_index.
"""
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = False
self.soft_label = False
self.shape = [41, 37]
self.ignore_index = 5
self.axis = -1
self.dtype = np.float64
def setUp(self):
self.initParams()
class TestSoftmaxWithCrossEntropyOp3NoCudnn(TestSoftmaxWithCrossEntropyOp3):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
batch_size = 41
class_num = 37
logits = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float64")
softmax = np.apply_along_axis(stable_softmax, 1, logits)
labels = np.random.randint(0, class_num, [batch_size, 1], dtype="int64")
ignore_index = 7
cross_entropy = np.asmatrix(
[[-np.log(softmax[i][labels[i][0]])]
if labels[i] != ignore_index else [0]
for i in range(softmax.shape[0])],
dtype="float64")
self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.ignore_index = 4
self.axis = -1
self.dtype = np.float64
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {
"Softmax": softmax.astype("float64"),
"Loss": cross_entropy.astype("float64")
}
self.attrs = {
"ignore_index": ignore_index,
"numeric_stable_mode": self.numeric_stable_mode
}
def test_check_output(self):
self.check_output()
class TestSoftmaxWithCrossEntropyOpAxis1(TestSoftmaxWithCrossEntropyOp):
"""
Test softmax with cross entropy operator with discreate one-hot labels.
Given axis != -1
"""
def test_check_grad(self):
self.check_grad(["Logits"], "Loss")
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float64
self.axis = 0
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
class TestSoftmaxWithCrossEntropyOp3NoCudnn(TestSoftmaxWithCrossEntropyOp3):
class TestSoftmaxWithCrossEntropyOpAxis2(TestSoftmaxWithCrossEntropyOp):
"""
Test softmax with cross entropy operator with discreate one-hot labels.
Given axis != -1
"""
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float64
self.axis = 1
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
class TestSoftmaxWithCrossEntropyOp5(OpTest):
class TestSoftmaxWithCrossEntropyOpAxis3(TestSoftmaxWithCrossEntropyOp):
"""
Test softmax with cross entropy operator with ignore_index.
Test softmax with cross entropy operator with discreate one-hot labels.
Given axis != -1
"""
def initParams(self):
self.numeric_stable_mode = False
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float64
self.axis = 2
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
def setUp(self):
self.initParams()
class TestSoftmaxWithCrossEntropyOpAxis4(TestSoftmaxWithCrossEntropyOp):
"""
Test softmax with cross entropy operator with discreate one-hot labels.
Given axis != -1
"""
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
batch_size = [6, 10]
class_num = 47
logits = np.random.uniform(
0.1, 1.0, tuple(batch_size + [class_num])).astype("float64")
softmax = np.apply_along_axis(stable_softmax, 2, logits)
labels = np.random.randint(
0, class_num, tuple(batch_size + [1]), dtype="int64")
ignore_index = 7
softmax_2d = np.reshape(softmax, [-1, class_num])
labels_2d = np.reshape(labels, [-1, 1])
cross_entropy = np.asmatrix(
[[-np.log(softmax_2d[i][labels_2d[i][0]])]
if labels_2d[i] != ignore_index else [0]
for i in range(softmax_2d.shape[0])],
dtype="float64")
cross_entropy = np.reshape(cross_entropy, batch_size)
output_shape = tuple(batch_size + [1])
output_res = cross_entropy.astype("float64")
output_res = np.expand_dims(output_res, axis=2)
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {
"Softmax": softmax.astype("float64"),
"Loss": output_res,
}
self.attrs = {
"ignore_index": ignore_index,
"numeric_stable_mode": self.numeric_stable_mode
}
self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float64
self.axis = 3
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["Logits"], "Loss")
class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis1(
TestSoftmaxWithCrossEntropyOpNoCudnnFp16):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.axis = 0
self.ignore_index = -1
self.dtype = np.float16
class TestSoftmaxWithCrossEntropyOp5NoCudnn(TestSoftmaxWithCrossEntropyOp5):
class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis2(
TestSoftmaxWithCrossEntropyOpNoCudnnFp16):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.axis = 1
self.ignore_index = -1
self.dtype = np.float16
class TestSoftmaxWithCrossEntropyOp6(OpTest):
"""
Test softmax with cross entropy operator with soft labels.
"""
class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis3(
TestSoftmaxWithCrossEntropyOpNoCudnnFp16):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.axis = 2
self.ignore_index = -1
self.dtype = np.float16
def setUp(self):
class TestSoftmaxWithCrossEntropyOpSoftLabelAxis1(
TestSoftmaxWithCrossEntropyOp2):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
batch_size = [6, 10]
class_num = 37
self.numeric_stable_mode = True
self.soft_label = True
self.shape = [3, 5, 7, 11]
self.axis = 0
self.ignore_index = -1
self.dtype = np.float64
logits = np.random.uniform(
0.1, 1.0, tuple(batch_size + [class_num])).astype("float64")
softmax = np.apply_along_axis(stable_softmax, 2, logits)
labels = np.random.uniform(
0.1, 1.0, tuple(batch_size + [class_num])).astype("float64")
labels /= np.sum(labels, axis=2, keepdims=True)
cross_entropy = (-labels * np.log(softmax)).sum(
axis=2, keepdims=True).astype("float64")
class TestSoftmaxWithCrossEntropyOpSoftLabelAxis2(
TestSoftmaxWithCrossEntropyOp2):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = True
self.shape = [3, 5, 7, 11]
self.axis = 1
self.ignore_index = -1
self.dtype = np.float64
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {
"Softmax": softmax.astype("float64"),
"Loss": cross_entropy.astype("float64")
}
self.attrs = {"soft_label": True}
def test_check_output(self):
self.check_output()
class TestSoftmaxWithCrossEntropyOpSoftLabelAxis3(
TestSoftmaxWithCrossEntropyOp2):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = True
self.shape = [3, 5, 7, 11]
self.axis = 2
self.ignore_index = -1
self.dtype = np.float64
def test_check_grad(self):
self.check_grad(["Logits"], "Loss")
class TestSoftmaxWithCrossEntropyOpSoftLabelAxis4(
TestSoftmaxWithCrossEntropyOp2):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = True
self.shape = [3, 5, 7, 11]
self.axis = 3
self.ignore_index = -1
self.dtype = np.float64
class TestSoftmaxWithCrossEntropyOpFp16_2(TestSoftmaxWithCrossEntropyOp):
class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis1(
TestSoftmaxWithCrossEntropyOp3):
def initParams(self):
self.numeric_stable_mode = False
self.dtype = np.float16
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.ignore_index = 1
self.axis = 0
self.dtype = np.float64
def setUp(self):
self.initParams()
class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis2(
TestSoftmaxWithCrossEntropyOp3):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
batch_size = [64, 10]
class_num = 37
self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.ignore_index = 0
self.axis = 1
self.dtype = np.float64
# NOTE: numpy float16 have very low accuracy, use float32 for numpy check.
logits = np.random.uniform(
0.1, 1.0, tuple(batch_size + [class_num])).astype(np.float32)
softmax = np.apply_along_axis(stable_softmax, 2, logits)
labels = np.random.randint(
0, class_num, tuple(batch_size + [1]), dtype="int64")
softmax_2d = np.reshape(softmax, [-1, class_num])
labels_2d = np.reshape(labels, [-1, 1])
cross_entropy = np.asmatrix(
[[-np.log(softmax_2d[i][labels_2d[i][0]])]
for i in range(softmax_2d.shape[0])],
dtype=np.float32)
cross_entropy = np.reshape(cross_entropy, batch_size)
output_shape = tuple(batch_size + [1])
output_res = cross_entropy.astype(self.dtype)
output_res = np.expand_dims(output_res, axis=2)
self.inputs = {"Logits": logits, "Label": labels}
self.inputs = {
"Logits": logits.astype(self.dtype).view(np.uint16),
"Label": labels
}
self.outputs = {
"Softmax": softmax.astype(self.dtype),
"Loss": output_res,
}
self.attrs = {"numeric_stable_mode": self.numeric_stable_mode}
class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis3(
TestSoftmaxWithCrossEntropyOp3):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.ignore_index = 3
self.axis = 2
self.dtype = np.float64
def test_check_output(self):
self.check_output(atol=1e-2)
def test_check_grad(self):
self.check_grad(["Logits"], "Loss", max_relative_error=0.1)
class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis4(
TestSoftmaxWithCrossEntropyOp3):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.ignore_index = 3
self.axis = 3
self.dtype = np.float64
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册