未验证 提交 ff825238 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] use SparseSoftmaxCrossEntropyWithLogits in npu kernel of softmax_with_cross_entropy (#32858)

* use SparseSoftmaxCrossEntropyWithLogits

* fix

* test_slice

* revert test_slice

* add backprob for npu kernel

* fix typo

* fix ut

* fix ut

* refine comments

* return softmax
上级 28521e0f
......@@ -44,6 +44,19 @@ class SoftmaxWithCrossEntropyOpMaker
"The outputs value of softmax activation by given the input batch, "
"which will be used in backward calculation.")
.AsIntermediate();
#ifdef PADDLE_WITH_ASCEND_CL
AddOutput(
"Backprop",
"(Tensor, default: Tensor<float>), A tensor in same shape with "
"Input(Logits). "
"The intermediate value used for backward calculation. The calculation "
"is :"
"exp(logits -max_logits) / sum(exp(logits - max_logits)) - labels, "
"where labels is ont-hot."
"Currently, the tensor is generated and used in npu kernel only. ")
.AsIntermediate()
.AsDispensable();
#endif
AddOutput("Loss",
"(Tensor, default: Tensor<float>), A tensor in same shape with "
"Input(Logits) "
......@@ -181,7 +194,10 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
}
ctx->SetOutputDim("Softmax", logits_dims);
#ifdef PADDLE_WITH_ASCEND_CL
ctx->SetOutputDim("Backprop", logits_dims);
ctx->ShareLoD("Logits", /*->*/ "Backprop");
#endif
logits_dims[axis] = 1;
ctx->SetOutputDim("Loss", logits_dims);
......@@ -285,6 +301,9 @@ class SoftmaxGradMaker : public framework::SingleGradOpMaker<T> {
grad_op->SetType("softmax_with_cross_entropy_grad");
grad_op->SetInput("Label", this->Input("Label"));
grad_op->SetInput("Softmax", this->Output("Softmax"));
#ifdef PADDLE_WITH_ASCEND_CL
grad_op->SetInput("Backprop", this->Output("Backprop"));
#endif
grad_op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss"));
grad_op->SetOutput(framework::GradVarName("Logits"),
this->InputGrad("Logits"));
......@@ -317,9 +336,29 @@ REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyGradKernel<float>,
ops::SoftmaxWithCrossEntropyGradKernel<double>);
REGISTER_OP_VERSION(softmax_with_cross_entropy)
#ifdef PADDLE_WITH_ASCEND_CL
.AddCheckpoint(
R"ROC(
Add a new attribute [use_softmax] )ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"use_softmax", "A flag to indicate whether to do softmax", true))
.AddCheckpoint(
R"ROC(
Add a new dispensable/intermediate output [backprop] )ROC",
paddle::framework::compatible::OpVersionDesc().NewOutput(
"Backprop",
"The intermediate value used for backward calculation. The "
"calculation is :"
"exp(logits -max_logits) / sum(exp(logits - max_logits)) - labels, "
"where labels is ont-hot."
"Currently, the tensor is generated and used in npu kernel "
"only. "));
#else
.AddCheckpoint(
R"ROC(
Add a new attribute [use_softmax] )ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"use_softmax", "A flag to indicate whether to do softmax", true));
#endif
......@@ -32,81 +32,53 @@ class SoftmaxWithCrossEntropyNPUKernel : public framework::OpKernel<T> {
auto* labels = ctx.Input<Tensor>("Label");
auto* softmax = ctx.Output<Tensor>("Softmax");
auto* loss = ctx.Output<Tensor>("Loss");
auto* backprop = ctx.Output<Tensor>("Backprop");
auto soft_label = ctx.Attr<bool>("soft_label");
PADDLE_ENFORCE_EQ(soft_label, false,
platform::errors::Unimplemented(
"soft_label=True is not supported in "
"the npu kernel of softmax_with_cross_entropy."));
int cls_num = logits->dims()[1];
const int rank = logits->dims().size();
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank);
std::vector<int> axes;
for (auto i = axis; i < logits->dims().size(); ++i) {
axes.push_back(i);
}
const int n = SizeToAxis(axis, logits->dims());
const int d = SizeFromAxis(axis, logits->dims());
PADDLE_ENFORCE_EQ(
labels->numel(), n,
platform::errors::Unimplemented(
"The size of labels should be equal to SizeToAxis of logits,"
"but got size of labels is %d and SizeToAxis is %d.",
labels->numel(), n));
loss->mutable_data<T>(ctx.GetPlace());
backprop->mutable_data<T>(ctx.GetPlace());
softmax->mutable_data<T>(ctx.GetPlace());
Tensor logits_2d, labels_1d, loss_1d, backprop_2d, softmax_2d;
logits_2d.ShareDataWith(*logits).Resize({n, d});
labels_1d.ShareDataWith(*labels).Resize({n});
loss_1d.ShareDataWith(*loss).Resize({n});
backprop_2d.ShareDataWith(*backprop).Resize({n, d});
softmax_2d.ShareDataWith(*softmax).Resize({n, d});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
// softmax
softmax->mutable_data<T>(ctx.GetPlace());
std::vector<int> axes;
for (auto i = axis; i < logits->dims().size(); ++i) {
axes.push_back(i);
}
const auto& runner_softmax =
NpuOpRunner("SoftmaxV2", {*logits}, {*softmax}, {{"axes", axes}});
runner_softmax.Run(stream);
// cast label from int64/int32 to int32
Tensor tmp_labels(framework::proto::VarType::INT32);
if (labels->type() != framework::proto::VarType::INT32) {
tmp_labels.Resize(labels->dims());
tmp_labels.mutable_data(ctx.GetPlace(), framework::proto::VarType::INT32);
auto dst_dtype = ConvertToNpuDtype(framework::proto::VarType::INT32);
const auto& runner_cast_label =
NpuOpRunner("Cast", {*labels}, {tmp_labels},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_label.Run(stream);
labels = &tmp_labels;
}
// on and off
Tensor on_tensor(framework::proto::VarType::INT32);
on_tensor.mutable_data<int>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<int>(&on_tensor, static_cast<int>(1));
Tensor off_tensor(framework::proto::VarType::INT32);
off_tensor.mutable_data<int>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<int>(&off_tensor, static_cast<int>(0));
// one_hot
Tensor tmp_onehot(on_tensor.type());
tmp_onehot.Resize(logits->dims());
tmp_onehot.mutable_data<int>(ctx.GetPlace());
const auto& runner_onehot =
NpuOpRunner("OneHotD", {*labels, on_tensor, off_tensor}, {tmp_onehot},
{{"axis", -1}, {"depth", cls_num}});
runner_onehot.Run(stream);
// cast one_hot from int32 to T
Tensor cast_onehot(logits->type());
cast_onehot.Resize(tmp_onehot.dims());
cast_onehot.mutable_data<T>(ctx.GetPlace());
auto dst_dtype = ConvertToNpuDtype(logits->type());
const auto& runner_cast_onehot =
NpuOpRunner("Cast", {tmp_onehot}, {cast_onehot},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_onehot.Run(stream);
// SoftmaxCrossEntropyWithLogits
Tensor backprop(logits->type());
backprop.Resize(logits->dims());
backprop.mutable_data<T>(ctx.GetPlace());
loss->mutable_data<T>(ctx.GetPlace());
// SoftmaxCrossEntropyWithLogits requires loss to be of shape [batch_size]
auto loss_dims = loss->dims();
loss->Resize({loss_dims[0]});
// SparseSoftmaxCrossEntropyWithLogits
const auto& runner_s =
NpuOpRunner("SoftmaxCrossEntropyWithLogits", {*logits, cast_onehot},
{*loss, backprop}, {});
NpuOpRunner("SparseSoftmaxCrossEntropyWithLogits",
{logits_2d, labels_1d}, {loss_1d, backprop_2d}, {});
runner_s.Run(stream);
loss->Resize(loss_dims);
}
};
......@@ -114,70 +86,32 @@ template <typename DeviceContext, typename T>
class SoftmaxWithCrossEntropyGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* labels = ctx.Input<Tensor>("Label");
auto* softmax = ctx.Input<Tensor>("Softmax");
auto* backprop = ctx.Input<Tensor>("Backprop");
auto* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
auto* logits_grad = ctx.Output<Tensor>(framework::GradVarName("Logits"));
int cls_num = softmax->dims()[1];
PADDLE_ENFORCE_NOT_NULL(backprop,
platform::errors::PreconditionNotMet(
"backprop should not be null in NPU kernel of "
"softmax_with_cross_entropy_grad."));
logits_grad->mutable_data<T>(ctx.GetPlace());
const int rank = logits_grad->dims().size();
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank);
const int n = SizeToAxis(axis, logits_grad->dims());
const int d = SizeFromAxis(axis, logits_grad->dims());
Tensor logits_grad_2d, loss_grad_1d, backprop_2d;
logits_grad_2d.ShareDataWith(*logits_grad).Resize({n, d});
loss_grad_1d.ShareDataWith(*loss_grad).Resize({n});
backprop_2d.ShareDataWith(*backprop).Resize({n, d});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
// cast label from int64/int32 to int32
Tensor tmp_labels(framework::proto::VarType::INT32);
if (labels->type() != framework::proto::VarType::INT32) {
tmp_labels.Resize(labels->dims());
tmp_labels.mutable_data(ctx.GetPlace(), framework::proto::VarType::INT32);
auto dst_dtype = ConvertToNpuDtype(framework::proto::VarType::INT32);
const auto& runner_cast_label =
NpuOpRunner("Cast", {*labels}, {tmp_labels},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_label.Run(stream);
labels = &tmp_labels;
}
// on and off
Tensor on_tensor(framework::proto::VarType::INT32);
on_tensor.mutable_data<int>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<int>(&on_tensor, static_cast<int>(1));
Tensor off_tensor(framework::proto::VarType::INT32);
off_tensor.mutable_data<int>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<int>(&off_tensor, static_cast<int>(0));
// one_hot
Tensor tmp_onehot(on_tensor.type());
tmp_onehot.Resize(softmax->dims());
tmp_onehot.mutable_data<int>(ctx.GetPlace());
const auto& runner_onehot =
NpuOpRunner("OneHotD", {*labels, on_tensor, off_tensor}, {tmp_onehot},
{{"axis", -1}, {"depth", cls_num}});
runner_onehot.Run(stream);
// cast one_hot from int32 to T
Tensor cast_onehot(softmax->type());
cast_onehot.Resize(tmp_onehot.dims());
cast_onehot.mutable_data<T>(ctx.GetPlace());
auto dst_dtype = ConvertToNpuDtype(softmax->type());
const auto& runner_cast_onehot =
NpuOpRunner("Cast", {tmp_onehot}, {cast_onehot},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_onehot.Run(stream);
// sub
Tensor tmp_sub(softmax->type());
tmp_sub.Resize(softmax->dims());
tmp_sub.mutable_data<T>(ctx.GetPlace());
const auto& runner_sub =
NpuOpRunner("Sub", {*softmax, cast_onehot}, {tmp_sub}, {});
runner_sub.Run(stream);
// mul
logits_grad->mutable_data<T>(ctx.GetPlace());
const auto& runner_mul =
NpuOpRunner("Mul", {*loss_grad, tmp_sub}, {*logits_grad}, {});
NpuOpRunner("Mul", {*loss_grad, *backprop}, {*logits_grad}, {});
runner_mul.Run(stream);
}
};
......
......@@ -26,6 +26,7 @@ from ..data_feeder import check_variable_and_dtype, check_type
from ..param_attr import ParamAttr
from ..initializer import NumpyArrayInitializer, Constant
from .. import core
import warnings
__all__ = [
'center_loss',
......@@ -1258,10 +1259,16 @@ def softmax_with_cross_entropy(logits,
print(out)
"""
if in_dygraph_mode():
softmax, loss = core.ops.softmax_with_cross_entropy(
logits, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', numeric_stable_mode, 'axis',
axis)
if core.is_compiled_with_npu():
softmax, backprop, loss = core.ops.softmax_with_cross_entropy(
logits, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', numeric_stable_mode,
'axis', axis)
else:
softmax, loss = core.ops.softmax_with_cross_entropy(
logits, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', numeric_stable_mode,
'axis', axis)
if not return_softmax:
return loss
else:
......@@ -1276,12 +1283,16 @@ def softmax_with_cross_entropy(logits,
helper = LayerHelper('softmax_with_cross_entropy', **locals())
softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
outputs = {'Softmax': softmax, 'Loss': loss}
if core.is_compiled_with_npu():
backprop = helper.create_variable_for_type_inference(dtype=logits.dtype)
outputs['Backprop'] = backprop
helper.append_op(
type='softmax_with_cross_entropy',
inputs={'Logits': logits,
'Label': label},
outputs={'Softmax': softmax,
'Loss': loss},
outputs=outputs,
attrs=attrs)
if return_softmax:
......
......@@ -68,8 +68,11 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
loss = cross_entropy(softmax, labels, self.soft_label, self.axis,
self.ignore_index)
one_hot_label = np.eye(axis_dim)[labels.reshape(-1)]
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {
"Backprop": (softmax - one_hot_label).astype(self.dtype),
"Softmax": softmax.astype(self.dtype),
"Loss": loss.astype(self.dtype)
}
......@@ -85,12 +88,16 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False)
# TODO(ascendrc): Add grad test
# def test_check_grad(self):
# if self.dtype == np.float16:
# return
# self.check_grad(['X'], 'Out')
#
def test_check_grad(self):
if self.dtype == np.float16:
return
# fp32 has low precision, cpu and npu both need to relax the max_relative_error if using fp32
self.check_grad_with_place(
self.place, ['Logits'],
'Loss',
check_dygraph=False,
numeric_grad_delta=0.001,
max_relative_error=0.5)
@unittest.skipIf(not paddle.is_compiled_with_npu(),
......
......@@ -30,4 +30,5 @@ no_check_set_white_list = [
'cudnn_lstm',
'rnn',
'fusion_lstm',
'softmax_with_cross_entropy',
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册