From 649948a6f681d7397b2f2df396fa24ed8e8d8d69 Mon Sep 17 00:00:00 2001 From: zhangyikun02 <48021248+zhangyk0314@users.noreply.github.com> Date: Tue, 29 Mar 2022 10:46:06 +0800 Subject: [PATCH] softmax_with_cross_entropy support fp16 on xpu, test=kunlun (#40869) --- .../softmax_with_cross_entropy_op_xpu.cc | 53 ++++++++++++------- .../fluid/platform/device/xpu/xpu2_op_list.h | 3 +- 2 files changed, 36 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc index d9149b85c6..b5514525f5 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc @@ -28,6 +28,8 @@ namespace operators { template class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + public: void Compute(const framework::ExecutionContext& context) const override { PADDLE_ENFORCE_EQ( @@ -48,6 +50,10 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel { std::vector logits_dims = phi::vectorize(logits->dims()); const bool soft_label = context.Attr("soft_label"); + auto logits_data = reinterpret_cast(logits->data()); + auto softmax_data = reinterpret_cast(softmax->data()); + auto loss_data = reinterpret_cast(loss->data()); + // softmax auto& dev_ctx = context.template device_context(); @@ -55,32 +61,41 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel { if (platform::get_xpu_version(context.GetPlace().GetDeviceId()) == phi::backends::xpu::XPUVersion::XPU2 && soft_label) { - r = xpu::soft_softmax_with_cross_entropy( - dev_ctx.x_context(), logits->data(), labels->data(), - softmax->data(), loss->data(), n, d); + auto labels_data = reinterpret_cast(labels->data()); + r = xpu::soft_softmax_with_cross_entropy( + dev_ctx.x_context(), logits_data, labels_data, softmax_data, + loss_data, n, d); PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_softmax_with_cross_entropy"); return; } xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); int len = logits->numel(); - T* clip_logits_data = RAII_GUARD.alloc_l3_or_gm(len); - PADDLE_ENFORCE_XDNN_NOT_NULL(clip_logits_data); + T* clip_logits = RAII_GUARD.alloc_l3_or_gm(len); + PADDLE_ENFORCE_XDNN_NOT_NULL(clip_logits); + XPUType* clip_logits_data = reinterpret_cast(clip_logits); + + float max_val = 1e20; + float min_val = -1e20; + if (std::is_same::value) { + max_val = 65504; + min_val = -65504; + } - r = xpu::clip_v2(dev_ctx.x_context(), logits->data(), - clip_logits_data, len, static_cast(-1e20), - static_cast(1e20)); + r = xpu::clip_v2( + dev_ctx.x_context(), logits_data, clip_logits_data, len, + static_cast(min_val), static_cast(max_val)); PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_v2"); - r = xpu::softmax(dev_ctx.x_context(), clip_logits_data, - softmax->data(), logits_dims, axis); + r = xpu::softmax(dev_ctx.x_context(), clip_logits_data, + softmax_data, logits_dims, axis); PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax"); // cross_entropy if (soft_label) { - r = xpu::soft_cross_entropy( - dev_ctx.x_context(), softmax->data(), labels->data(), - loss->data(), n, d); + auto labels_data = reinterpret_cast(labels->data()); + r = xpu::soft_cross_entropy(dev_ctx.x_context(), softmax_data, + labels_data, loss_data, n, d); PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_cross_entropy"); } else { auto ignore_index = context.Attr("ignore_index"); @@ -92,10 +107,9 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel { labels_int32.data(), labels->numel()); PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_v2"); - r = xpu::hard_cross_entropy( - dev_ctx.x_context(), softmax->data(), - labels_int32.data(), loss->data(), nullptr, n, d, - ignore_index); + r = xpu::hard_cross_entropy( + dev_ctx.x_context(), softmax_data, labels_int32.data(), + loss_data, nullptr, n, d, ignore_index); PADDLE_ENFORCE_XDNN_SUCCESS(r, "hard_cross_entropy"); } } @@ -167,8 +181,9 @@ class SoftmaxWithCrossEntropyGradXPUKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL(softmax_with_cross_entropy, - ops::SoftmaxWithCrossEntropyXPUKernel); +REGISTER_OP_XPU_KERNEL( + softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyXPUKernel, + ops::SoftmaxWithCrossEntropyXPUKernel); REGISTER_OP_XPU_KERNEL( softmax_with_cross_entropy_grad, ops::SoftmaxWithCrossEntropyGradXPUKernel, diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 57d6c5e119..3feb33e4ac 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -321,7 +321,8 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, {"softmax_with_cross_entropy", - XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, {"softplus", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"softplus_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, -- GitLab