diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc index fbaf76d4e7cd89ea75a271dc4c5c658ea910808a..0c2d39e7519ef473f01de5671f0035d7acde6dd4 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc @@ -44,6 +44,19 @@ class SoftmaxWithCrossEntropyOpMaker "The outputs value of softmax activation by given the input batch, " "which will be used in backward calculation.") .AsIntermediate(); +#ifdef PADDLE_WITH_ASCEND_CL + AddOutput( + "Backprop", + "(Tensor, default: Tensor), A tensor in same shape with " + "Input(Logits). " + "The intermediate value used for backward calculation. The calculation " + "is :" + "exp(logits -max_logits) / sum(exp(logits - max_logits)) - labels, " + "where labels is ont-hot." + "Currently, the tensor is generated and used in npu kernel only. ") + .AsIntermediate() + .AsDispensable(); +#endif AddOutput("Loss", "(Tensor, default: Tensor), A tensor in same shape with " "Input(Logits) " @@ -181,7 +194,10 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { } ctx->SetOutputDim("Softmax", logits_dims); - +#ifdef PADDLE_WITH_ASCEND_CL + ctx->SetOutputDim("Backprop", logits_dims); + ctx->ShareLoD("Logits", /*->*/ "Backprop"); +#endif logits_dims[axis] = 1; ctx->SetOutputDim("Loss", logits_dims); @@ -285,6 +301,9 @@ class SoftmaxGradMaker : public framework::SingleGradOpMaker { grad_op->SetType("softmax_with_cross_entropy_grad"); grad_op->SetInput("Label", this->Input("Label")); grad_op->SetInput("Softmax", this->Output("Softmax")); +#ifdef PADDLE_WITH_ASCEND_CL + grad_op->SetInput("Backprop", this->Output("Backprop")); +#endif grad_op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss")); grad_op->SetOutput(framework::GradVarName("Logits"), this->InputGrad("Logits")); @@ -317,9 +336,29 @@ REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy, REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad, ops::SoftmaxWithCrossEntropyGradKernel, ops::SoftmaxWithCrossEntropyGradKernel); + REGISTER_OP_VERSION(softmax_with_cross_entropy) +#ifdef PADDLE_WITH_ASCEND_CL + .AddCheckpoint( + R"ROC( + Add a new attribute [use_softmax] )ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "use_softmax", "A flag to indicate whether to do softmax", true)) + .AddCheckpoint( + R"ROC( + Add a new dispensable/intermediate output [backprop] )ROC", + paddle::framework::compatible::OpVersionDesc().NewOutput( + "Backprop", + "The intermediate value used for backward calculation. The " + "calculation is :" + "exp(logits -max_logits) / sum(exp(logits - max_logits)) - labels, " + "where labels is ont-hot." + "Currently, the tensor is generated and used in npu kernel " + "only. ")); +#else .AddCheckpoint( R"ROC( Add a new attribute [use_softmax] )ROC", paddle::framework::compatible::OpVersionDesc().NewAttr( "use_softmax", "A flag to indicate whether to do softmax", true)); +#endif diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc index 9921248d1ca1d652cd7505a50b7a2ec4c46afc9e..639fc6fcc2e79b265e6fda48303db6603ef12401 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc @@ -32,81 +32,53 @@ class SoftmaxWithCrossEntropyNPUKernel : public framework::OpKernel { auto* labels = ctx.Input("Label"); auto* softmax = ctx.Output("Softmax"); auto* loss = ctx.Output("Loss"); + auto* backprop = ctx.Output("Backprop"); + auto soft_label = ctx.Attr("soft_label"); + PADDLE_ENFORCE_EQ(soft_label, false, + platform::errors::Unimplemented( + "soft_label=True is not supported in " + "the npu kernel of softmax_with_cross_entropy.")); - int cls_num = logits->dims()[1]; const int rank = logits->dims().size(); const int axis = CanonicalAxis(ctx.Attr("axis"), rank); - std::vector axes; - for (auto i = axis; i < logits->dims().size(); ++i) { - axes.push_back(i); - } + const int n = SizeToAxis(axis, logits->dims()); + const int d = SizeFromAxis(axis, logits->dims()); + + PADDLE_ENFORCE_EQ( + labels->numel(), n, + platform::errors::Unimplemented( + "The size of labels should be equal to SizeToAxis of logits," + "but got size of labels is %d and SizeToAxis is %d.", + labels->numel(), n)); + + loss->mutable_data(ctx.GetPlace()); + backprop->mutable_data(ctx.GetPlace()); + softmax->mutable_data(ctx.GetPlace()); + + Tensor logits_2d, labels_1d, loss_1d, backprop_2d, softmax_2d; + logits_2d.ShareDataWith(*logits).Resize({n, d}); + labels_1d.ShareDataWith(*labels).Resize({n}); + loss_1d.ShareDataWith(*loss).Resize({n}); + backprop_2d.ShareDataWith(*backprop).Resize({n, d}); + softmax_2d.ShareDataWith(*softmax).Resize({n, d}); auto stream = ctx.template device_context() .stream(); - // softmax - softmax->mutable_data(ctx.GetPlace()); + std::vector axes; + for (auto i = axis; i < logits->dims().size(); ++i) { + axes.push_back(i); + } const auto& runner_softmax = NpuOpRunner("SoftmaxV2", {*logits}, {*softmax}, {{"axes", axes}}); runner_softmax.Run(stream); - // cast label from int64/int32 to int32 - Tensor tmp_labels(framework::proto::VarType::INT32); - if (labels->type() != framework::proto::VarType::INT32) { - tmp_labels.Resize(labels->dims()); - tmp_labels.mutable_data(ctx.GetPlace(), framework::proto::VarType::INT32); - auto dst_dtype = ConvertToNpuDtype(framework::proto::VarType::INT32); - const auto& runner_cast_label = - NpuOpRunner("Cast", {*labels}, {tmp_labels}, - {{"dst_type", static_cast(dst_dtype)}}); - runner_cast_label.Run(stream); - labels = &tmp_labels; - } - - // on and off - Tensor on_tensor(framework::proto::VarType::INT32); - on_tensor.mutable_data({1}, ctx.GetPlace()); - FillNpuTensorWithConstant(&on_tensor, static_cast(1)); - Tensor off_tensor(framework::proto::VarType::INT32); - off_tensor.mutable_data({1}, ctx.GetPlace()); - FillNpuTensorWithConstant(&off_tensor, static_cast(0)); - - // one_hot - Tensor tmp_onehot(on_tensor.type()); - tmp_onehot.Resize(logits->dims()); - tmp_onehot.mutable_data(ctx.GetPlace()); - - const auto& runner_onehot = - NpuOpRunner("OneHotD", {*labels, on_tensor, off_tensor}, {tmp_onehot}, - {{"axis", -1}, {"depth", cls_num}}); - runner_onehot.Run(stream); - - // cast one_hot from int32 to T - Tensor cast_onehot(logits->type()); - cast_onehot.Resize(tmp_onehot.dims()); - cast_onehot.mutable_data(ctx.GetPlace()); - auto dst_dtype = ConvertToNpuDtype(logits->type()); - const auto& runner_cast_onehot = - NpuOpRunner("Cast", {tmp_onehot}, {cast_onehot}, - {{"dst_type", static_cast(dst_dtype)}}); - runner_cast_onehot.Run(stream); - - // SoftmaxCrossEntropyWithLogits - Tensor backprop(logits->type()); - backprop.Resize(logits->dims()); - backprop.mutable_data(ctx.GetPlace()); - - loss->mutable_data(ctx.GetPlace()); - - // SoftmaxCrossEntropyWithLogits requires loss to be of shape [batch_size] - auto loss_dims = loss->dims(); - loss->Resize({loss_dims[0]}); + // SparseSoftmaxCrossEntropyWithLogits const auto& runner_s = - NpuOpRunner("SoftmaxCrossEntropyWithLogits", {*logits, cast_onehot}, - {*loss, backprop}, {}); + NpuOpRunner("SparseSoftmaxCrossEntropyWithLogits", + {logits_2d, labels_1d}, {loss_1d, backprop_2d}, {}); runner_s.Run(stream); - loss->Resize(loss_dims); } }; @@ -114,70 +86,32 @@ template class SoftmaxWithCrossEntropyGradNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* labels = ctx.Input("Label"); - auto* softmax = ctx.Input("Softmax"); + auto* backprop = ctx.Input("Backprop"); auto* loss_grad = ctx.Input(framework::GradVarName("Loss")); auto* logits_grad = ctx.Output(framework::GradVarName("Logits")); - int cls_num = softmax->dims()[1]; + PADDLE_ENFORCE_NOT_NULL(backprop, + platform::errors::PreconditionNotMet( + "backprop should not be null in NPU kernel of " + "softmax_with_cross_entropy_grad.")); + logits_grad->mutable_data(ctx.GetPlace()); + + const int rank = logits_grad->dims().size(); + const int axis = CanonicalAxis(ctx.Attr("axis"), rank); + const int n = SizeToAxis(axis, logits_grad->dims()); + const int d = SizeFromAxis(axis, logits_grad->dims()); + + Tensor logits_grad_2d, loss_grad_1d, backprop_2d; + + logits_grad_2d.ShareDataWith(*logits_grad).Resize({n, d}); + loss_grad_1d.ShareDataWith(*loss_grad).Resize({n}); + backprop_2d.ShareDataWith(*backprop).Resize({n, d}); auto stream = ctx.template device_context() .stream(); - - // cast label from int64/int32 to int32 - Tensor tmp_labels(framework::proto::VarType::INT32); - if (labels->type() != framework::proto::VarType::INT32) { - tmp_labels.Resize(labels->dims()); - tmp_labels.mutable_data(ctx.GetPlace(), framework::proto::VarType::INT32); - auto dst_dtype = ConvertToNpuDtype(framework::proto::VarType::INT32); - const auto& runner_cast_label = - NpuOpRunner("Cast", {*labels}, {tmp_labels}, - {{"dst_type", static_cast(dst_dtype)}}); - runner_cast_label.Run(stream); - labels = &tmp_labels; - } - - // on and off - Tensor on_tensor(framework::proto::VarType::INT32); - on_tensor.mutable_data({1}, ctx.GetPlace()); - FillNpuTensorWithConstant(&on_tensor, static_cast(1)); - Tensor off_tensor(framework::proto::VarType::INT32); - off_tensor.mutable_data({1}, ctx.GetPlace()); - FillNpuTensorWithConstant(&off_tensor, static_cast(0)); - - // one_hot - Tensor tmp_onehot(on_tensor.type()); - tmp_onehot.Resize(softmax->dims()); - tmp_onehot.mutable_data(ctx.GetPlace()); - - const auto& runner_onehot = - NpuOpRunner("OneHotD", {*labels, on_tensor, off_tensor}, {tmp_onehot}, - {{"axis", -1}, {"depth", cls_num}}); - runner_onehot.Run(stream); - - // cast one_hot from int32 to T - Tensor cast_onehot(softmax->type()); - cast_onehot.Resize(tmp_onehot.dims()); - cast_onehot.mutable_data(ctx.GetPlace()); - auto dst_dtype = ConvertToNpuDtype(softmax->type()); - const auto& runner_cast_onehot = - NpuOpRunner("Cast", {tmp_onehot}, {cast_onehot}, - {{"dst_type", static_cast(dst_dtype)}}); - runner_cast_onehot.Run(stream); - - // sub - Tensor tmp_sub(softmax->type()); - tmp_sub.Resize(softmax->dims()); - tmp_sub.mutable_data(ctx.GetPlace()); - const auto& runner_sub = - NpuOpRunner("Sub", {*softmax, cast_onehot}, {tmp_sub}, {}); - - runner_sub.Run(stream); - // mul - logits_grad->mutable_data(ctx.GetPlace()); const auto& runner_mul = - NpuOpRunner("Mul", {*loss_grad, tmp_sub}, {*logits_grad}, {}); + NpuOpRunner("Mul", {*loss_grad, *backprop}, {*logits_grad}, {}); runner_mul.Run(stream); } }; diff --git a/python/paddle/fluid/layers/loss.py b/python/paddle/fluid/layers/loss.py index c3f25dc53c12c42f882187267c7860e6727c1f51..d150cc7a9aee9960068738bc0ba98a444eba1d6e 100644 --- a/python/paddle/fluid/layers/loss.py +++ b/python/paddle/fluid/layers/loss.py @@ -26,6 +26,7 @@ from ..data_feeder import check_variable_and_dtype, check_type from ..param_attr import ParamAttr from ..initializer import NumpyArrayInitializer, Constant from .. import core +import warnings __all__ = [ 'center_loss', @@ -1258,10 +1259,16 @@ def softmax_with_cross_entropy(logits, print(out) """ if in_dygraph_mode(): - softmax, loss = core.ops.softmax_with_cross_entropy( - logits, label, 'soft_label', soft_label, 'ignore_index', - ignore_index, 'numeric_stable_mode', numeric_stable_mode, 'axis', - axis) + if core.is_compiled_with_npu(): + softmax, backprop, loss = core.ops.softmax_with_cross_entropy( + logits, label, 'soft_label', soft_label, 'ignore_index', + ignore_index, 'numeric_stable_mode', numeric_stable_mode, + 'axis', axis) + else: + softmax, loss = core.ops.softmax_with_cross_entropy( + logits, label, 'soft_label', soft_label, 'ignore_index', + ignore_index, 'numeric_stable_mode', numeric_stable_mode, + 'axis', axis) if not return_softmax: return loss else: @@ -1276,12 +1283,16 @@ def softmax_with_cross_entropy(logits, helper = LayerHelper('softmax_with_cross_entropy', **locals()) softmax = helper.create_variable_for_type_inference(dtype=logits.dtype) loss = helper.create_variable_for_type_inference(dtype=logits.dtype) + + outputs = {'Softmax': softmax, 'Loss': loss} + if core.is_compiled_with_npu(): + backprop = helper.create_variable_for_type_inference(dtype=logits.dtype) + outputs['Backprop'] = backprop helper.append_op( type='softmax_with_cross_entropy', inputs={'Logits': logits, 'Label': label}, - outputs={'Softmax': softmax, - 'Loss': loss}, + outputs=outputs, attrs=attrs) if return_softmax: diff --git a/python/paddle/fluid/tests/unittests/npu/test_softmax_with_cross_entropy_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_softmax_with_cross_entropy_op_npu.py index 1b48268b0e77e6804d3a26bd58918a4c484d3732..2ee089360e6dd2f62aabdd25179e7e9410b365e4 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_softmax_with_cross_entropy_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_softmax_with_cross_entropy_op_npu.py @@ -68,8 +68,11 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): loss = cross_entropy(softmax, labels, self.soft_label, self.axis, self.ignore_index) + one_hot_label = np.eye(axis_dim)[labels.reshape(-1)] + self.inputs = {"Logits": logits, "Label": labels} self.outputs = { + "Backprop": (softmax - one_hot_label).astype(self.dtype), "Softmax": softmax.astype(self.dtype), "Loss": loss.astype(self.dtype) } @@ -85,12 +88,16 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): def test_check_output(self): self.check_output_with_place(self.place, check_dygraph=False) - # TODO(ascendrc): Add grad test - # def test_check_grad(self): - # if self.dtype == np.float16: - # return - # self.check_grad(['X'], 'Out') - # + def test_check_grad(self): + if self.dtype == np.float16: + return + # fp32 has low precision, cpu and npu both need to relax the max_relative_error if using fp32 + self.check_grad_with_place( + self.place, ['Logits'], + 'Loss', + check_dygraph=False, + numeric_grad_delta=0.001, + max_relative_error=0.5) @unittest.skipIf(not paddle.is_compiled_with_npu(), diff --git a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py index f81011717040a35375bdb5bed87392c997f5ab29..32ac4f412a8f5af26cd77114ecb229226ae2ac63 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py @@ -30,4 +30,5 @@ no_check_set_white_list = [ 'cudnn_lstm', 'rnn', 'fusion_lstm', + 'softmax_with_cross_entropy', ]