未验证 提交 a71d8fdb 编写于 作者: K Kaipeng Deng 提交者: GitHub

Softmax_cross_entropy op add axis (#16806)

* add attr axis infershape. test=develop

* add CUDA kernel. test=develop

* fix unittest. test=develop

* fix unittest for soft_label. test=develop

* fix fp16 unittest. test=develop

* remove comment code. test=develop

* refine test for axis. test=develop

* add python api. test=develop

* fix doc. test=develop

* fix fp16 unittest. test=develop

* fix ngraph test. test=develop

* fix ENFORCE for test_imperative_transformer. test=develop

* fit for ngraph test. test=develop

* fix after rebase develop. test=develop

* fix doc. test=develop

* fix API.spec. test=develop

* fix test_layers. test=develop

* fix format. test=develop
上级 c2e20e2a
...@@ -133,7 +133,7 @@ paddle.fluid.layers.multiplex (ArgSpec(args=['inputs', 'index'], varargs=None, k ...@@ -133,7 +133,7 @@ paddle.fluid.layers.multiplex (ArgSpec(args=['inputs', 'index'], varargs=None, k
paddle.fluid.layers.layer_norm (ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None)), ('document', 'de6a906950bae9f3c245cb744d22b94e')) paddle.fluid.layers.layer_norm (ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None)), ('document', 'de6a906950bae9f3c245cb744d22b94e'))
paddle.fluid.layers.group_norm (ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None)), ('document', '419c3a24a83cc89219a029cf4092788b')) paddle.fluid.layers.group_norm (ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None)), ('document', '419c3a24a83cc89219a029cf4092788b'))
paddle.fluid.layers.spectral_norm (ArgSpec(args=['weight', 'dim', 'power_iters', 'eps', 'name'], varargs=None, keywords=None, defaults=(0, 1, 1e-12, None)), ('document', '3f536aafba30d793287b52d231baff1b')) paddle.fluid.layers.spectral_norm (ArgSpec(args=['weight', 'dim', 'power_iters', 'eps', 'name'], varargs=None, keywords=None, defaults=(0, 1, 1e-12, None)), ('document', '3f536aafba30d793287b52d231baff1b'))
paddle.fluid.layers.softmax_with_cross_entropy (ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode', 'return_softmax'], varargs=None, keywords=None, defaults=(False, -100, True, False)), ('document', 'bce1b75e3d95b75cacd1099655cbb3c3')) paddle.fluid.layers.softmax_with_cross_entropy (ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode', 'return_softmax', 'axis'], varargs=None, keywords=None, defaults=(False, -100, True, False, -1)), ('document', '8b074f9c56b4233a2b65d03254eb309e'))
paddle.fluid.layers.smooth_l1 (ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'c6b175d253c55baf4b9c0eca9b1dda88')) paddle.fluid.layers.smooth_l1 (ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'c6b175d253c55baf4b9c0eca9b1dda88'))
paddle.fluid.layers.one_hot (ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None), ('document', '960fc799549c202da1e85d626cb2c962')) paddle.fluid.layers.one_hot (ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None), ('document', '960fc799549c202da1e85d626cb2c962'))
paddle.fluid.layers.autoincreased_step_counter (ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1)), ('document', '67afefa80b6cc38801bd5b631fed8a4a')) paddle.fluid.layers.autoincreased_step_counter (ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1)), ('document', '67afefa80b6cc38801bd5b631fed8a4a'))
......
...@@ -39,9 +39,10 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> { ...@@ -39,9 +39,10 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> {
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1); Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
Tensor y_2d = framework::ReshapeToMatrix(*y, rank - 1); Tensor y_2d = framework::ReshapeToMatrix(*y, rank - 1);
int axis_dim = x->dims()[rank - 1];
math::CrossEntropyFunctor<DeviceContext, T>()( math::CrossEntropyFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), &y_2d, &x_2d, &labels_2d, ctx.template device_context<DeviceContext>(), &y_2d, &x_2d, &labels_2d,
ctx.Attr<bool>("soft_label"), ctx.Attr<int>("ignore_index")); ctx.Attr<bool>("soft_label"), ctx.Attr<int>("ignore_index"), axis_dim);
} }
}; };
......
...@@ -29,8 +29,13 @@ class CrossEntropyFunctor<platform::CPUDeviceContext, T> { ...@@ -29,8 +29,13 @@ class CrossEntropyFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, framework::Tensor* out, void operator()(const platform::CPUDeviceContext& ctx, framework::Tensor* out,
const framework::Tensor* prob, const framework::Tensor* prob,
const framework::Tensor* labels, const bool softLabel, const framework::Tensor* labels, const bool softLabel,
const int ignore_index) { const int ignore_index, const int axis_dim) {
const int batch_size = prob->dims()[0]; const int batch_size = prob->dims()[0];
const int num_classes = prob->dims()[1];
const int num_remain = num_classes / axis_dim;
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
if (softLabel) { if (softLabel) {
auto in = EigenMatrix<T>::From(*prob); auto in = EigenMatrix<T>::From(*prob);
auto lbl = EigenMatrix<T>::From(*labels); auto lbl = EigenMatrix<T>::From(*labels);
...@@ -38,24 +43,24 @@ class CrossEntropyFunctor<platform::CPUDeviceContext, T> { ...@@ -38,24 +43,24 @@ class CrossEntropyFunctor<platform::CPUDeviceContext, T> {
loss.device(*ctx.eigen_device()) = loss.device(*ctx.eigen_device()) =
-((lbl * in.log().unaryExpr(math::TolerableValue<T>())) -((lbl * in.log().unaryExpr(math::TolerableValue<T>()))
.sum(Eigen::DSizes<int, 1>(1)) .reshape(batch_axis_remain)
.reshape(Eigen::DSizes<int, 2>(batch_size, 1))); .sum(Eigen::DSizes<int, 1>(1)));
} else { } else {
const int class_num = prob->dims()[1];
const T* prob_data = prob->data<T>(); const T* prob_data = prob->data<T>();
T* loss_data = out->data<T>(); T* loss_data = out->data<T>();
const int64_t* label_data = labels->data<int64_t>(); const int64_t* label_data = labels->data<int64_t>();
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
int lbl = label_data[i]; for (int j = 0; j < num_remain; j++) {
PADDLE_ENFORCE_GE(lbl, 0); int lbl = label_data[i * num_remain + j];
PADDLE_ENFORCE_LT(lbl, class_num); PADDLE_ENFORCE((lbl >= 0 && lbl < axis_dim) || lbl == ignore_index);
PADDLE_ENFORCE((lbl >= 0 && lbl < class_num) || lbl == ignore_index); int index = i * num_classes + lbl * num_remain + j;
int index = i * class_num + lbl; int loss_idx = i * num_remain + j;
loss_data[i] = loss_data[loss_idx] =
lbl == ignore_index lbl == ignore_index
? 0 ? 0
: -math::TolerableValue<T>()(std::log(prob_data[index])); : -math::TolerableValue<T>()(std::log(prob_data[index]));
}
} }
} }
} }
......
...@@ -57,8 +57,8 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> { ...@@ -57,8 +57,8 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
public: public:
void operator()(const platform::CUDADeviceContext& ctx, void operator()(const platform::CUDADeviceContext& ctx,
framework::Tensor* out, const framework::Tensor* prob, framework::Tensor* out, const framework::Tensor* prob,
const framework::Tensor* labels, bool softLabel, const framework::Tensor* labels, const bool softLabel,
const int ignore_index) { const int ignore_index, const int axis_dim) {
const T* prob_data = prob->data<T>(); const T* prob_data = prob->data<T>();
T* loss_data = out->mutable_data<T>(ctx.GetPlace()); T* loss_data = out->mutable_data<T>(ctx.GetPlace());
......
...@@ -60,7 +60,7 @@ class CrossEntropyFunctor { ...@@ -60,7 +60,7 @@ class CrossEntropyFunctor {
void operator()(const DeviceContext& context, framework::Tensor* out, void operator()(const DeviceContext& context, framework::Tensor* out,
const framework::Tensor* prob, const framework::Tensor* prob,
const framework::Tensor* labels, const bool softLabel, const framework::Tensor* labels, const bool softLabel,
const int ignore_index); const int ignore_index, const int axis_dim);
}; };
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -26,23 +26,28 @@ class SoftmaxWithCrossEntropyOpMaker ...@@ -26,23 +26,28 @@ class SoftmaxWithCrossEntropyOpMaker
public: public:
void Make() override { void Make() override {
AddInput("Logits", AddInput("Logits",
"(Tensor, default: Tensor<float>), The unscaled log probabilities " "(Tensor, default: Tensor<float>), The input tensor of unscaled "
"which is a 2-D tensor with shape [N x K]. N is the batch_size, " "log probabilities, whose dimension :attr:`axis` should be scaled "
"and K is the class number."); "by softmax.");
AddInput("Label", AddInput(
"(Tensor) The ground truth which is a 2-D tensor. If soft_label " "Label",
"is set to false, Label is a Tensor<int64> with shape [N x 1]. If " "(Tensor) The input tesnor of groud truth label. If :attr:`soft_label` "
"soft_label is set to true, Label is a Tensor<float/double> with " "is set to false, Label is a Tensor<int64> in same shape with "
"shape [N x K]."); "Input(Logits) except the shape in dimension :attr:`axis` as 1. If "
"soft_label is set to true, Label is a Tensor<float/double> in same "
"shape with Input(Logits).");
AddOutput( AddOutput(
"Softmax", "Softmax",
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x K]. " "(Tensor, default: Tensor<float>), A tensor in same shape with "
"Input(Logits). "
"The outputs value of softmax activation by given the input batch, " "The outputs value of softmax activation by given the input batch, "
"which will be used in backward calculation.") "which will be used in backward calculation.")
.AsIntermediate(); .AsIntermediate();
AddOutput("Loss", AddOutput("Loss",
"(Tensor, default: Tensor<float>), A 2-D tensor. The cross " "(Tensor, default: Tensor<float>), A tensor in same shape with "
"entropy loss with shape [N x 1]."); "Input(Logits) "
"except the shape in dimension :attr:`axis` as 1. The cross "
"entropy loss.");
AddAttr<bool>( AddAttr<bool>(
"soft_label", "soft_label",
"(bool, default: false), A flag to indicate whether to interpretate " "(bool, default: false), A flag to indicate whether to interpretate "
...@@ -60,6 +65,10 @@ class SoftmaxWithCrossEntropyOpMaker ...@@ -60,6 +65,10 @@ class SoftmaxWithCrossEntropyOpMaker
"does not contribute to the input gradient. Only valid if soft_label" "does not contribute to the input gradient. Only valid if soft_label"
"is set to False") "is set to False")
.SetDefault(-100); .SetDefault(-100);
AddAttr<int>("axis",
"The dimension index of Input(Logits) to perform softmax,"
"default -1 for last dimension")
.SetDefault(-1);
AddComment(R"DOC( AddComment(R"DOC(
Softmax With Cross Entropy Operator. Softmax With Cross Entropy Operator.
...@@ -107,38 +116,53 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { ...@@ -107,38 +116,53 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
"Output(Softmax) should be not null."); "Output(Softmax) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) should be not null."); PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) should be not null.");
auto axis = ctx->Attrs().Get<int>("axis");
auto logits_dims = ctx->GetInputDim("Logits"); auto logits_dims = ctx->GetInputDim("Logits");
auto labels_dims = ctx->GetInputDim("Label"); auto labels_dims = ctx->GetInputDim("Label");
auto logits_rank = logits_dims.size();
PADDLE_ENFORCE(axis >= -logits_rank && axis < logits_rank,
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(Logits).");
axis = CanonicalAxis(axis, logits_rank);
for (int i = 0; i < logits_rank; i++) {
if (i != axis) {
if (ctx->IsRuntime() || (logits_dims[i] > 0 && labels_dims[i] > 0)) {
PADDLE_ENFORCE_EQ(
logits_dims[i], labels_dims[i],
"Input(Logits) and Input(Label) should in same shape in "
"dimensions except axis.");
}
}
}
int rank = logits_dims.size(); auto numeric_stable_mode = ctx->Attrs().Get<bool>("numeric_stable_mode");
PADDLE_ENFORCE_EQ( if (axis != logits_rank - 1) {
rank, labels_dims.size(), PADDLE_ENFORCE(
"Input(logits) and Input(Label) shall have the same rank."); numeric_stable_mode,
bool check = ctx->IsRuntime() || (framework::product(logits_dims) > 0 && "Attr(axis) can only be -1 when not in numeric_stable_mode.");
framework::product(labels_dims) > 0);
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(logits_dims, 0, rank - 1),
framework::slice_ddim(labels_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension.");
} }
if (ctx->Attrs().Get<bool>("soft_label")) { bool soft_label = ctx->Attrs().Get<bool>("soft_label");
if (check) { if (soft_label) {
PADDLE_ENFORCE_EQ(logits_dims[rank - 1], labels_dims[rank - 1], if (ctx->IsRuntime() ||
"If Attr(soft_label) == true, the last dimension of " (logits_dims[axis] > 0 && labels_dims[axis] > 0)) {
PADDLE_ENFORCE_EQ(logits_dims[axis], labels_dims[axis],
"If Attr(soft_label) == true, the axis dimension of "
"Input(X) and Input(Label) should be equal."); "Input(X) and Input(Label) should be equal.");
} }
} else { } else {
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL, if (ctx->IsRuntime() || labels_dims[axis] > 0) {
"If Attr(softLabel) == false, the last dimension of " PADDLE_ENFORCE_EQ(labels_dims[axis], 1UL,
"Input(Label) should be 1."); "If Attr(soft_label) == false, the axis dimension of "
"Input(Label) should be 1.");
}
} }
ctx->SetOutputDim("Softmax", logits_dims); ctx->SetOutputDim("Softmax", logits_dims);
auto loss_dims = logits_dims;
loss_dims[rank - 1] = 1; logits_dims[axis] = 1;
ctx->SetOutputDim("Loss", loss_dims); ctx->SetOutputDim("Loss", logits_dims);
ctx->ShareLoD("Logits", /*->*/ "Softmax"); ctx->ShareLoD("Logits", /*->*/ "Softmax");
ctx->ShareLoD("Logits", /*->*/ "Loss"); ctx->ShareLoD("Logits", /*->*/ "Loss");
...@@ -165,36 +189,40 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { ...@@ -165,36 +189,40 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")), PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
"Output(Logits@Grad) should be not null."); "Output(Logits@Grad) should be not null.");
auto axis = ctx->Attrs().Get<int>("axis");
auto softmax_dims = ctx->GetInputDim("Softmax"); auto softmax_dims = ctx->GetInputDim("Softmax");
auto labels_dims = ctx->GetInputDim("Label"); auto labels_dims = ctx->GetInputDim("Label");
auto softmax_rank = softmax_dims.size();
int rank = softmax_dims.size(); PADDLE_ENFORCE(axis >= -softmax_rank && axis < softmax_rank,
PADDLE_ENFORCE_EQ( "Attr(axis) value should be in range [-R, R-1], "
rank, labels_dims.size(), "R is the rank of Input(Logits).");
"Input(logits) and Input(Label) shall have the same rank.");
bool check = true; axis = CanonicalAxis(axis, softmax_rank);
if ((!ctx->IsRuntime()) && (framework::product(softmax_dims) <= 0 || for (int i = 0; i < softmax_rank; i++) {
framework::product(labels_dims) <= 0)) { if (i != axis) {
check = false; if (ctx->IsRuntime() || (softmax_dims[i] > 0 && labels_dims[i] > 0)) {
} PADDLE_ENFORCE_EQ(
if (check) { softmax_dims[i], labels_dims[i],
PADDLE_ENFORCE_EQ( "Input(Logits) and Input(Label) should in same shape in "
framework::slice_ddim(softmax_dims, 0, rank - 1), "dimensions except axis.");
framework::slice_ddim(labels_dims, 0, rank - 1), }
"Input(Softmax) and Input(Label) shall have the same shape " }
"except the last dimension.");
} }
if (ctx->Attrs().Get<bool>("soft_label")) { bool soft_label = ctx->Attrs().Get<bool>("soft_label");
if (check) { if (soft_label) {
PADDLE_ENFORCE_EQ(softmax_dims[rank - 1], labels_dims[rank - 1], if (ctx->IsRuntime() ||
"If Attr(soft_label) == true, the last dimension of " (softmax_dims[axis] > 0 && labels_dims[axis] > 0)) {
"Input( Softmax) and Input(Label) should be equal."); PADDLE_ENFORCE_EQ(softmax_dims[axis], labels_dims[axis],
"If Attr(soft_label) == true, the axis dimension of "
"Input(X) and Input(Label) should be equal.");
} }
} else { } else {
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL, if (ctx->IsRuntime() || labels_dims[axis] > 0) {
"If Attr(softLabel) == false, the last dimension of " PADDLE_ENFORCE_EQ(labels_dims[axis], 1UL,
"Input(Label) should be 1."); "If Attr(soft_label) == false, the axis dimension of "
"Input(Label) should be 1.");
}
} }
ctx->SetOutputDim(framework::GradVarName("Logits"), ctx->SetOutputDim(framework::GradVarName("Logits"),
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/softmax_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -36,26 +37,30 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> { ...@@ -36,26 +37,30 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
const Tensor* labels = context.Input<Tensor>("Label"); const Tensor* labels = context.Input<Tensor>("Label");
Tensor* softmax = context.Output<Tensor>("Softmax"); Tensor* softmax = context.Output<Tensor>("Softmax");
Tensor* loss = context.Output<Tensor>("Loss"); Tensor* loss = context.Output<Tensor>("Loss");
const bool soft_label = context.Attr<bool>("soft_label");
const int rank = logits->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = logits->dims()[axis];
softmax->mutable_data<T>(context.GetPlace()); softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace()); loss->mutable_data<T>(context.GetPlace());
// reshape to 2D tensor const int n = SizeToAxis(axis, logits->dims());
int rank = logits->dims().size(); const int d = SizeFromAxis(axis, logits->dims());
Tensor logits_2d = framework::ReshapeToMatrix(*logits, rank - 1); Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1); logits_2d.ShareDataWith(*logits).Resize({n, d});
Tensor loss_2d = framework::ReshapeToMatrix(*loss, rank - 1); softmax_2d.ShareDataWith(*softmax).Resize({n, d});
Tensor softmax_2d = framework::ReshapeToMatrix(*softmax, rank - 1); labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
loss_2d.ShareDataWith(*loss).Resize({n, d / axis_dim});
int axis_dim = logits->dims()[rank - 1];
auto& dev_ctx = auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>(); context.template device_context<platform::CPUDeviceContext>();
math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()( math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
dev_ctx, axis_dim, &logits_2d, &softmax_2d); dev_ctx, axis_dim, &logits_2d, &softmax_2d);
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()( math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
dev_ctx, &loss_2d, &softmax_2d, &labels_2d, dev_ctx, &loss_2d, &softmax_2d, &labels_2d, soft_label,
context.Attr<bool>("soft_label"), context.Attr<int>("ignore_index")); context.Attr<int>("ignore_index"), axis_dim);
} }
}; };
...@@ -75,34 +80,43 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> { ...@@ -75,34 +80,43 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
context.device_context(), logit_grad); context.device_context(), logit_grad);
} }
int rank = logit_grad->dims().size(); const bool soft_label = context.Attr<bool>("soft_label");
const int class_num = logit_grad->dims()[rank - 1];
// reshape to 2d const int rank = logit_grad->dims().size();
Tensor logit_grad_2d = framework::ReshapeToMatrix(*logit_grad, rank - 1); const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
Tensor out_grad_2d = framework::ReshapeToMatrix(*out_grad, rank - 1); int axis_dim = logit_grad->dims()[axis];
const int n = SizeToAxis(axis, logit_grad->dims());
const int d = SizeFromAxis(axis, logit_grad->dims());
Tensor logit_grad_2d, labels_2d, out_grad_2d;
logit_grad_2d.ShareDataWith(*logit_grad).Resize({n, d});
labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n});
out_grad_2d.ShareDataWith(*out_grad).Resize({n, d / axis_dim});
auto out_grad_mat = EigenMatrix<T>::From(out_grad_2d); auto out_grad_mat = EigenMatrix<T>::From(out_grad_2d);
auto logit_grad_mat = EigenMatrix<T>::From(logit_grad_2d); auto logit_grad_mat = EigenMatrix<T>::From(logit_grad_2d);
auto& place = *context.template device_context<platform::CPUDeviceContext>() auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device(); .eigen_device();
if (context.Attr<bool>("soft_label")) { if (soft_label) {
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
auto lbl_mat = EigenMatrix<T>::From(labels_2d); auto lbl_mat = EigenMatrix<T>::From(labels_2d);
logit_grad_mat.device(place) = logit_grad_mat.device(place) =
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num)) * out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)) *
(logit_grad_mat - lbl_mat); (logit_grad_mat - lbl_mat);
} else { } else {
logit_grad_mat.device(place) = logit_grad_mat.device(place) =
logit_grad_mat * logit_grad_mat *
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num)); out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim));
const int batch_size = logit_grad_2d.dims()[0];
const int64_t* label_data = labels->data<int64_t>(); const int64_t* label_data = labels->data<int64_t>();
T* logit_grad_data = logit_grad->data<T>(); T* logit_grad_data = logit_grad->data<T>();
const T* out_grad_data = out_grad->data<T>(); const T* out_grad_data = out_grad->data<T>();
for (int i = 0; i < batch_size; ++i) { const int remain = d / axis_dim;
logit_grad_data[i * class_num + label_data[i]] -= out_grad_data[i]; for (int i = 0; i < n; ++i) {
for (int j = 0; j < remain; j++) {
int idx = i * remain + j;
logit_grad_data[i * d + label_data[idx] * remain + j] -=
out_grad_data[idx];
}
} }
} }
} }
......
...@@ -6099,22 +6099,24 @@ def softmax_with_cross_entropy(logits, ...@@ -6099,22 +6099,24 @@ def softmax_with_cross_entropy(logits,
soft_label=False, soft_label=False,
ignore_index=kIgnoreIndex, ignore_index=kIgnoreIndex,
numeric_stable_mode=True, numeric_stable_mode=True,
return_softmax=False): return_softmax=False,
axis=-1):
""" """
**Softmax With Cross Entropy Operator.** **Softmax With Cross Entropy Operator.**
Cross entropy loss with softmax is used as the output layer extensively. This Cross entropy loss with softmax is used as the output layer extensively. This
operator computes the softmax normalized values for each row of the input operator computes the softmax normalized values for dimension :attr:`axis` of
tensor, after which cross-entropy loss is computed. This provides a more the input tensor, after which cross-entropy loss is computed. This provides
numerically stable gradient. a more numerically stable gradient.
Because this operator performs a softmax on logits internally, it expects Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results. softmax operator since that would produce incorrect results.
When the attribute soft_label is set false, this operators expects mutually When the attribute :attr:`soft_label` is set :attr:`False`, this operators
exclusive hard labels, each sample in a batch is in exactly one class with a expects mutually exclusive hard labels, each sample in a batch is in exactly
probability of 1.0. Each sample in the batch will have a single label. one class with a probability of 1.0. Each sample in the batch will have a
single label.
The equation is as follows: The equation is as follows:
...@@ -6133,7 +6135,8 @@ def softmax_with_cross_entropy(logits, ...@@ -6133,7 +6135,8 @@ def softmax_with_cross_entropy(logits,
\\left(\\text{logit}_i - \\log\\left(\\sum_{i=0}^{K} \\left(\\text{logit}_i - \\log\\left(\\sum_{i=0}^{K}
\\exp(\\text{logit}_i)\\right)\\right), j = 1,...,K \\exp(\\text{logit}_i)\\right)\\right), j = 1,...,K
3) If numeric_stable_mode is True, softmax is calculated first by: 3) If :attr:`numeric_stable_mode` is :attr:`True`, softmax is calculated
first by:
.. math:: .. math::
...@@ -6146,32 +6149,39 @@ def softmax_with_cross_entropy(logits, ...@@ -6146,32 +6149,39 @@ def softmax_with_cross_entropy(logits,
and then cross entropy loss is calculated by softmax and label. and then cross entropy loss is calculated by softmax and label.
Args: Args:
logits (Variable): The unscaled log probabilities, which is a 2-D tensor logits (Variable): The input tensor of unscaled log probabilities.
with shape [N x K]. N is the batch_size, and K is the class number. label (Variable): The ground truth tensor. If :attr:`soft_label`
label (Variable): The ground truth which is a 2-D tensor. If soft_label is set to :attr:`True`, Label is a Tensor<float/double> in the
is set to false, Label is a Tensor<int64> with shape [N x 1]. If same shape with :attr:`logits`. If :attr:`soft_label` is set to
soft_label is set to true, Label is a Tensor<float/double> with :attr:`True`, Label is a Tensor<int64> in the same shape with
:attr:`logits` expect shape in dimension :attr:`axis` as 1.
soft_label (bool): A flag to indicate whether to interpretate the given soft_label (bool): A flag to indicate whether to interpretate the given
labels as soft labels. By default, `soft_label` is set to False. labels as soft labels. Default False.
ignore_index (int): Specifies a target value that is ignored and does ignore_index (int): Specifies a target value that is ignored and does
not contribute to the input gradient. Only valid not contribute to the input gradient. Only valid
if soft_label is set to False. Default: kIgnoreIndex if :attr:`soft_label` is set to :attr:`False`.
Default: kIgnoreIndex
numeric_stable_mode (bool): A flag to indicate whether to use a more numeric_stable_mode (bool): A flag to indicate whether to use a more
numerically stable algorithm. Only valid numerically stable algorithm. Only valid
when soft_label is False and GPU is used. when :attr:`soft_label` is :attr:`False`
When soft_label is True or CPU is used, and GPU is used. When :attr:`soft_label`
the algorithm is always numerically stable. is :attr:`True` or CPU is used, the
algorithm is always numerically stable.
Note that the speed may be slower when use Note that the speed may be slower when use
stable algorithm. Default: True stable algorithm. Default: True
return_softmax (bool): A flag indicating whether to return the softmax return_softmax (bool): A flag indicating whether to return the softmax
along with the cross entropy loss. Default: False along with the cross entropy loss. Default: False
axis (int): The index of dimension to perform softmax calculations. It
should be in range :math:`[-1, rank - 1]`, while :math:`rank`
is the rank of input :attr:`logits`. Default: -1.
Returns: Returns:
Variable or Tuple of two Variables: Return the cross entropy loss if \ Variable or Tuple of two Variables: Return the cross entropy loss if \
`return_softmax` is False, otherwise the tuple \ `return_softmax` is False, otherwise the tuple \
(loss, softmax), where the cross entropy loss is \ (loss, softmax), softmax is in the same shape \
a 2-D tensor with shape [N x 1], and softmax is a \ with input logits and cross entropy loss is in \
2-D tensor with shape [N x K]. the same shape with input logits except shape \
in dimension :attr:`axis` as 1.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -6194,7 +6204,8 @@ def softmax_with_cross_entropy(logits, ...@@ -6194,7 +6204,8 @@ def softmax_with_cross_entropy(logits,
attrs={ attrs={
'soft_label': soft_label, 'soft_label': soft_label,
'ignore_index': ignore_index, 'ignore_index': ignore_index,
'numeric_stable_mode': numeric_stable_mode 'numeric_stable_mode': numeric_stable_mode,
'axis': axis
}) })
if return_softmax: if return_softmax:
......
...@@ -1221,10 +1221,25 @@ class TestBook(LayerTest): ...@@ -1221,10 +1221,25 @@ class TestBook(LayerTest):
y = self._get_data(name='label', shape=[1], dtype='int64') y = self._get_data(name='label', shape=[1], dtype='int64')
loss, softmax = layers.softmax_with_cross_entropy( loss, softmax = layers.softmax_with_cross_entropy(
x, y, return_softmax=True) x, y, return_softmax=True)
return (loss) self.assertIsNotNone(loss)
return (softmax) self.assertIsNotNone(softmax)
loss = layers.softmax_with_cross_entropy(x, y) loss = layers.softmax_with_cross_entropy(x, y)
return (loss) self.assertIsNotNone(loss)
x1 = self._get_data(name='x1', shape=[16, 32, 64], dtype='float32')
y1 = self._get_data(name='label1', shape=[1, 32, 64], dtype='int64')
y2 = self._get_data(name='label2', shape=[16, 1, 64], dtype='int64')
y3 = self._get_data(name='label3', shape=[16, 32, 1], dtype='int64')
loss1 = layers.softmax_with_cross_entropy(x1, y1, axis=1)
loss2 = layers.softmax_with_cross_entropy(x1, y2, axis=2)
loss3 = layers.softmax_with_cross_entropy(x1, y3, axis=3)
loss4 = layers.softmax_with_cross_entropy(x1, y3, axis=-1)
self.assertIsNotNone(loss1)
self.assertIsNotNone(loss2)
self.assertIsNotNone(loss3)
self.assertIsNotNone(loss4)
return (loss4)
def make_smooth_l1(self): def make_smooth_l1(self):
with program_guard(fluid.default_main_program(), with program_guard(fluid.default_main_program(),
......
...@@ -21,37 +21,70 @@ from op_test import OpTest ...@@ -21,37 +21,70 @@ from op_test import OpTest
from test_softmax_op import stable_softmax from test_softmax_op import stable_softmax
def cross_entropy(softmax, label, soft_label, axis, ignore_index=-1):
if soft_label:
return (-label * np.log(softmax)).sum(axis=axis, keepdims=True)
shape = softmax.shape
axis %= len(shape)
n = int(np.prod(shape[:axis]))
axis_dim = shape[axis]
remain = int(np.prod(shape[axis + 1:]))
softmax_reshape = softmax.reshape((n, axis_dim, remain))
label_reshape = label.reshape((n, 1, remain))
result = np.zeros_like(label_reshape, dtype=softmax.dtype)
for i in range(n):
for j in range(remain):
lbl = label_reshape[i, 0, j]
if lbl != ignore_index:
result[i, 0, j] -= np.log(softmax_reshape[i, lbl, j])
return result.reshape(label.shape)
class TestSoftmaxWithCrossEntropyOp(OpTest): class TestSoftmaxWithCrossEntropyOp(OpTest):
""" """
Test softmax with cross entropy operator with discreate one-hot labels. Test softmax with cross entropy operator with discreate one-hot labels.
""" """
def initParams(self): def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = False self.numeric_stable_mode = False
self.soft_label = False
self.dtype = np.float64 self.dtype = np.float64
self.axis = -1
self.ignore_index = -1
self.shape = [41, 37]
def setUp(self): def setUp(self):
self.initParams() self.initParams()
self.op_type = "softmax_with_cross_entropy"
batch_size = 41
class_num = 37
logits = np.random.uniform(0.1, 1.0, logits = np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype)
[batch_size, class_num]).astype(self.dtype) softmax = np.apply_along_axis(stable_softmax, self.axis, logits)
softmax = np.apply_along_axis(stable_softmax, 1, logits)
labels = np.random.randint(0, class_num, [batch_size, 1], dtype="int64") if self.soft_label:
labels = np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype)
labels /= np.sum(labels, axis=self.axis, keepdims=True)
else:
axis_dim = self.shape[self.axis]
self.shape[self.axis] = 1
labels = np.random.randint(0, axis_dim, self.shape, dtype="int64")
cross_entropy = np.asmatrix( loss = cross_entropy(softmax, labels, self.soft_label, self.axis,
[[-np.log(softmax[i][labels[i][0]])] self.ignore_index)
for i in range(softmax.shape[0])],
dtype=self.dtype)
self.inputs = {"Logits": logits, "Label": labels} self.inputs = {"Logits": logits, "Label": labels}
self.outputs = { self.outputs = {
"Softmax": softmax.astype(self.dtype), "Softmax": softmax.astype(self.dtype),
"Loss": cross_entropy.astype(self.dtype) "Loss": loss.astype(self.dtype)
} }
self.attrs = {"numeric_stable_mode": self.numeric_stable_mode} self.attrs = {
"numeric_stable_mode": self.numeric_stable_mode,
"soft_label": self.soft_label,
}
if self.ignore_index >= 0:
self.attrs['ignore_index'] = self.ignore_index
if self.axis != -1:
self.attrs['axis'] = self.axis
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -62,30 +95,38 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): ...@@ -62,30 +95,38 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
class TestSoftmaxWithCrossEntropyOpNoCudnn(TestSoftmaxWithCrossEntropyOp): class TestSoftmaxWithCrossEntropyOpNoCudnn(TestSoftmaxWithCrossEntropyOp):
def initParams(self): def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.axis = -1
self.ignore_index = -1
self.dtype = np.float64
class TestSoftmaxWithCrossEntropyOpFp16(TestSoftmaxWithCrossEntropyOp): class TestSoftmaxWithCrossEntropyOpFp16(TestSoftmaxWithCrossEntropyOp):
def initParams(self): def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = False self.numeric_stable_mode = False
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.axis = -1
self.ignore_index = -1
self.dtype = np.float16 self.dtype = np.float16
def setUp(self): def setUp(self):
self.initParams() self.initParams()
self.op_type = "softmax_with_cross_entropy" self.op_type = "softmax_with_cross_entropy"
batch_size = 41
class_num = 37
# NOTE: numpy float16 have very low accuracy, use float32 for numpy check. # NOTE: numpy float16 have very low accuracy, use float32 for numpy check.
logits = np.random.uniform(0.1, 1.0, logits = np.random.uniform(0.1, 1.0, self.shape).astype(np.float32)
[batch_size, class_num]).astype(np.float32) softmax = np.apply_along_axis(stable_softmax, self.axis, logits)
softmax = np.apply_along_axis(stable_softmax, 1, logits)
labels = np.random.randint(0, class_num, [batch_size, 1], dtype="int64")
cross_entropy = np.asmatrix( axis_dim = self.shape[self.axis]
[[-np.log(softmax[i][labels[i][0]])] self.shape[self.axis] = 1
for i in range(softmax.shape[0])], labels = np.random.randint(0, axis_dim, self.shape, dtype="int64")
dtype=np.float32)
loss = cross_entropy(softmax, labels, self.soft_label, self.axis)
self.inputs = { self.inputs = {
"Logits": logits.astype(self.dtype).view(np.uint16), "Logits": logits.astype(self.dtype).view(np.uint16),
...@@ -93,9 +134,14 @@ class TestSoftmaxWithCrossEntropyOpFp16(TestSoftmaxWithCrossEntropyOp): ...@@ -93,9 +134,14 @@ class TestSoftmaxWithCrossEntropyOpFp16(TestSoftmaxWithCrossEntropyOp):
} }
self.outputs = { self.outputs = {
"Softmax": softmax.astype(self.dtype), "Softmax": softmax.astype(self.dtype),
"Loss": cross_entropy.astype(self.dtype) "Loss": loss.astype(self.dtype)
}
self.attrs = {
"numeric_stable_mode": self.numeric_stable_mode,
"soft_label": self.soft_label,
} }
self.attrs = {"numeric_stable_mode": self.numeric_stable_mode} if self.axis != -1:
self.attrs['axis'] = self.axis
def test_check_output(self): def test_check_output(self):
self.check_output(atol=1e-2) self.check_output(atol=1e-2)
...@@ -107,39 +153,31 @@ class TestSoftmaxWithCrossEntropyOpFp16(TestSoftmaxWithCrossEntropyOp): ...@@ -107,39 +153,31 @@ class TestSoftmaxWithCrossEntropyOpFp16(TestSoftmaxWithCrossEntropyOp):
class TestSoftmaxWithCrossEntropyOpNoCudnnFp16( class TestSoftmaxWithCrossEntropyOpNoCudnnFp16(
TestSoftmaxWithCrossEntropyOpFp16): TestSoftmaxWithCrossEntropyOpFp16):
def initParams(self): def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.axis = -1
self.ignore_index = -1
self.dtype = np.float16 self.dtype = np.float16
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["Logits"], "Loss", max_relative_error=0.1) self.check_grad(["Logits"], "Loss", max_relative_error=0.1)
class TestSoftmaxWithCrossEntropyOp2(OpTest): class TestSoftmaxWithCrossEntropyOp2(TestSoftmaxWithCrossEntropyOp):
""" """
Test softmax with cross entropy operator with soft labels. Test softmax with cross entropy operator with soft labels.
""" """
def setUp(self): def initParams(self):
self.op_type = "softmax_with_cross_entropy" self.op_type = "softmax_with_cross_entropy"
batch_size = 41 self.numeric_stable_mode = True
class_num = 37 self.soft_label = True
self.dtype = np.float64
logits = np.random.uniform(0.1, 1.0, self.axis = -1
[batch_size, class_num]).astype("float64") self.ignore_index = -1
softmax = np.apply_along_axis(stable_softmax, 1, logits) self.shape = [41, 37]
labels = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float64")
labels /= np.sum(labels, axis=1, keepdims=True)
cross_entropy = (-labels * np.log(softmax)).sum(
axis=1, keepdims=True).astype("float64")
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {
"Softmax": softmax.astype("float64"),
"Loss": cross_entropy.astype("float64")
}
self.attrs = {"soft_label": True}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -148,190 +186,226 @@ class TestSoftmaxWithCrossEntropyOp2(OpTest): ...@@ -148,190 +186,226 @@ class TestSoftmaxWithCrossEntropyOp2(OpTest):
self.check_grad(["Logits"], "Loss") self.check_grad(["Logits"], "Loss")
class TestSoftmaxWithCrossEntropyOp3(OpTest): class TestSoftmaxWithCrossEntropyOp3(TestSoftmaxWithCrossEntropyOp):
""" """
Test softmax with cross entropy operator with ignore_index. Test softmax with cross entropy operator with ignore_index.
""" """
def initParams(self): def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = False self.numeric_stable_mode = False
self.soft_label = False
self.shape = [41, 37]
self.ignore_index = 5
self.axis = -1
self.dtype = np.float64
def setUp(self):
self.initParams() class TestSoftmaxWithCrossEntropyOp3NoCudnn(TestSoftmaxWithCrossEntropyOp3):
def initParams(self):
self.op_type = "softmax_with_cross_entropy" self.op_type = "softmax_with_cross_entropy"
batch_size = 41 self.numeric_stable_mode = True
class_num = 37 self.soft_label = False
self.shape = [3, 5, 7, 11]
logits = np.random.uniform(0.1, 1.0, self.ignore_index = 4
[batch_size, class_num]).astype("float64") self.axis = -1
softmax = np.apply_along_axis(stable_softmax, 1, logits) self.dtype = np.float64
labels = np.random.randint(0, class_num, [batch_size, 1], dtype="int64")
ignore_index = 7
cross_entropy = np.asmatrix(
[[-np.log(softmax[i][labels[i][0]])]
if labels[i] != ignore_index else [0]
for i in range(softmax.shape[0])],
dtype="float64")
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {
"Softmax": softmax.astype("float64"),
"Loss": cross_entropy.astype("float64")
}
self.attrs = {
"ignore_index": ignore_index,
"numeric_stable_mode": self.numeric_stable_mode
}
def test_check_output(self): class TestSoftmaxWithCrossEntropyOpAxis1(TestSoftmaxWithCrossEntropyOp):
self.check_output() """
Test softmax with cross entropy operator with discreate one-hot labels.
Given axis != -1
"""
def test_check_grad(self): def initParams(self):
self.check_grad(["Logits"], "Loss") self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float64
self.axis = 0
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
class TestSoftmaxWithCrossEntropyOp3NoCudnn(TestSoftmaxWithCrossEntropyOp3): class TestSoftmaxWithCrossEntropyOpAxis2(TestSoftmaxWithCrossEntropyOp):
"""
Test softmax with cross entropy operator with discreate one-hot labels.
Given axis != -1
"""
def initParams(self): def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float64
self.axis = 1
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
class TestSoftmaxWithCrossEntropyOp5(OpTest): class TestSoftmaxWithCrossEntropyOpAxis3(TestSoftmaxWithCrossEntropyOp):
""" """
Test softmax with cross entropy operator with ignore_index. Test softmax with cross entropy operator with discreate one-hot labels.
Given axis != -1
""" """
def initParams(self): def initParams(self):
self.numeric_stable_mode = False self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float64
self.axis = 2
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
def setUp(self):
self.initParams() class TestSoftmaxWithCrossEntropyOpAxis4(TestSoftmaxWithCrossEntropyOp):
"""
Test softmax with cross entropy operator with discreate one-hot labels.
Given axis != -1
"""
def initParams(self):
self.op_type = "softmax_with_cross_entropy" self.op_type = "softmax_with_cross_entropy"
batch_size = [6, 10] self.numeric_stable_mode = True
class_num = 47 self.soft_label = False
self.dtype = np.float64
logits = np.random.uniform( self.axis = 3
0.1, 1.0, tuple(batch_size + [class_num])).astype("float64") self.ignore_index = -1
softmax = np.apply_along_axis(stable_softmax, 2, logits) self.shape = [3, 5, 7, 11]
labels = np.random.randint(
0, class_num, tuple(batch_size + [1]), dtype="int64")
ignore_index = 7
softmax_2d = np.reshape(softmax, [-1, class_num])
labels_2d = np.reshape(labels, [-1, 1])
cross_entropy = np.asmatrix(
[[-np.log(softmax_2d[i][labels_2d[i][0]])]
if labels_2d[i] != ignore_index else [0]
for i in range(softmax_2d.shape[0])],
dtype="float64")
cross_entropy = np.reshape(cross_entropy, batch_size)
output_shape = tuple(batch_size + [1])
output_res = cross_entropy.astype("float64")
output_res = np.expand_dims(output_res, axis=2)
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {
"Softmax": softmax.astype("float64"),
"Loss": output_res,
}
self.attrs = {
"ignore_index": ignore_index,
"numeric_stable_mode": self.numeric_stable_mode
}
def test_check_output(self):
self.check_output()
def test_check_grad(self): class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis1(
self.check_grad(["Logits"], "Loss") TestSoftmaxWithCrossEntropyOpNoCudnnFp16):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.axis = 0
self.ignore_index = -1
self.dtype = np.float16
class TestSoftmaxWithCrossEntropyOp5NoCudnn(TestSoftmaxWithCrossEntropyOp5): class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis2(
TestSoftmaxWithCrossEntropyOpNoCudnnFp16):
def initParams(self): def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.axis = 1
self.ignore_index = -1
self.dtype = np.float16
class TestSoftmaxWithCrossEntropyOp6(OpTest): class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis3(
""" TestSoftmaxWithCrossEntropyOpNoCudnnFp16):
Test softmax with cross entropy operator with soft labels. def initParams(self):
""" self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.axis = 2
self.ignore_index = -1
self.dtype = np.float16
def setUp(self):
class TestSoftmaxWithCrossEntropyOpSoftLabelAxis1(
TestSoftmaxWithCrossEntropyOp2):
def initParams(self):
self.op_type = "softmax_with_cross_entropy" self.op_type = "softmax_with_cross_entropy"
batch_size = [6, 10] self.numeric_stable_mode = True
class_num = 37 self.soft_label = True
self.shape = [3, 5, 7, 11]
self.axis = 0
self.ignore_index = -1
self.dtype = np.float64
logits = np.random.uniform(
0.1, 1.0, tuple(batch_size + [class_num])).astype("float64")
softmax = np.apply_along_axis(stable_softmax, 2, logits)
labels = np.random.uniform(
0.1, 1.0, tuple(batch_size + [class_num])).astype("float64")
labels /= np.sum(labels, axis=2, keepdims=True)
cross_entropy = (-labels * np.log(softmax)).sum( class TestSoftmaxWithCrossEntropyOpSoftLabelAxis2(
axis=2, keepdims=True).astype("float64") TestSoftmaxWithCrossEntropyOp2):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = True
self.shape = [3, 5, 7, 11]
self.axis = 1
self.ignore_index = -1
self.dtype = np.float64
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {
"Softmax": softmax.astype("float64"),
"Loss": cross_entropy.astype("float64")
}
self.attrs = {"soft_label": True}
def test_check_output(self): class TestSoftmaxWithCrossEntropyOpSoftLabelAxis3(
self.check_output() TestSoftmaxWithCrossEntropyOp2):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = True
self.shape = [3, 5, 7, 11]
self.axis = 2
self.ignore_index = -1
self.dtype = np.float64
def test_check_grad(self):
self.check_grad(["Logits"], "Loss") class TestSoftmaxWithCrossEntropyOpSoftLabelAxis4(
TestSoftmaxWithCrossEntropyOp2):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = True
self.shape = [3, 5, 7, 11]
self.axis = 3
self.ignore_index = -1
self.dtype = np.float64
class TestSoftmaxWithCrossEntropyOpFp16_2(TestSoftmaxWithCrossEntropyOp): class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis1(
TestSoftmaxWithCrossEntropyOp3):
def initParams(self): def initParams(self):
self.numeric_stable_mode = False self.op_type = "softmax_with_cross_entropy"
self.dtype = np.float16 self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.ignore_index = 1
self.axis = 0
self.dtype = np.float64
def setUp(self):
self.initParams() class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis2(
TestSoftmaxWithCrossEntropyOp3):
def initParams(self):
self.op_type = "softmax_with_cross_entropy" self.op_type = "softmax_with_cross_entropy"
batch_size = [64, 10] self.numeric_stable_mode = True
class_num = 37 self.soft_label = False
self.shape = [3, 5, 7, 11]
self.ignore_index = 0
self.axis = 1
self.dtype = np.float64
# NOTE: numpy float16 have very low accuracy, use float32 for numpy check.
logits = np.random.uniform(
0.1, 1.0, tuple(batch_size + [class_num])).astype(np.float32)
softmax = np.apply_along_axis(stable_softmax, 2, logits)
labels = np.random.randint(
0, class_num, tuple(batch_size + [1]), dtype="int64")
softmax_2d = np.reshape(softmax, [-1, class_num])
labels_2d = np.reshape(labels, [-1, 1])
cross_entropy = np.asmatrix(
[[-np.log(softmax_2d[i][labels_2d[i][0]])]
for i in range(softmax_2d.shape[0])],
dtype=np.float32)
cross_entropy = np.reshape(cross_entropy, batch_size)
output_shape = tuple(batch_size + [1])
output_res = cross_entropy.astype(self.dtype)
output_res = np.expand_dims(output_res, axis=2)
self.inputs = {"Logits": logits, "Label": labels}
self.inputs = { class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis3(
"Logits": logits.astype(self.dtype).view(np.uint16), TestSoftmaxWithCrossEntropyOp3):
"Label": labels def initParams(self):
} self.op_type = "softmax_with_cross_entropy"
self.outputs = { self.numeric_stable_mode = True
"Softmax": softmax.astype(self.dtype), self.soft_label = False
"Loss": output_res, self.shape = [3, 5, 7, 11]
} self.ignore_index = 3
self.attrs = {"numeric_stable_mode": self.numeric_stable_mode} self.axis = 2
self.dtype = np.float64
def test_check_output(self):
self.check_output(atol=1e-2)
def test_check_grad(self): class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis4(
self.check_grad(["Logits"], "Loss", max_relative_error=0.1) TestSoftmaxWithCrossEntropyOp3):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.ignore_index = 3
self.axis = 3
self.dtype = np.float64
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册