From b421d7a5ed33ad6f8818f1fd70a6663728712bc5 Mon Sep 17 00:00:00 2001 From: QingshuChen Date: Thu, 22 Dec 2022 10:29:29 +0800 Subject: [PATCH] fix softmax_with_cross_entropy bug for kunlun (#49207) --- paddle/phi/backends/xpu/xpu2_op_list.cc | 10 ++++++++++ paddle/phi/kernels/xpu/cross_entropy_kernel.cc | 13 +++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index aa9714ee36..3a7a0f2fd6 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -603,7 +603,17 @@ XPUOpMap& get_kl2_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, + {"unsqueeze_with_xshape", + XPUKernelSet({phi::DataType::FLOAT64, + phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::BOOL, + phi::DataType::INT8, + phi::DataType::UINT8, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16})}, {"warpctc_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"warpctc", XPUKernelSet({phi::DataType::FLOAT32})}, {"where_index", diff --git a/paddle/phi/kernels/xpu/cross_entropy_kernel.cc b/paddle/phi/kernels/xpu/cross_entropy_kernel.cc index f054c6c445..f1b2257427 100644 --- a/paddle/phi/kernels/xpu/cross_entropy_kernel.cc +++ b/paddle/phi/kernels/xpu/cross_entropy_kernel.cc @@ -89,10 +89,15 @@ void CrossEntropyWithSoftmaxKernel(const Context& dev_ctx, static_cast(max_val)); PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_v2"); - r = xpu::softmax( - dev_ctx.x_context(), clip_logits_data, softmax_data, logits_dims, axis); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax"); - + if (use_softmax) { + r = xpu::softmax( + dev_ctx.x_context(), clip_logits_data, softmax_data, logits_dims, axis); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax"); + } else { + r = xpu::copy( + dev_ctx.x_context(), clip_logits_data, softmax_data, softmax->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); + } // cross_entropy if (axis != rank - 1) { XPUType* trans_softmax = RAII_GUARD.alloc_l3_or_gm(n * d); -- GitLab