未验证 提交 b421d7a5 编写于 作者: Q QingshuChen 提交者: GitHub

fix softmax_with_cross_entropy bug for kunlun (#49207)

上级 d4305a26
......@@ -603,7 +603,17 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"unsqueeze_with_xshape",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
{"warpctc_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"warpctc", XPUKernelSet({phi::DataType::FLOAT32})},
{"where_index",
......
......@@ -89,10 +89,15 @@ void CrossEntropyWithSoftmaxKernel(const Context& dev_ctx,
static_cast<XPUType>(max_val));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_v2");
if (use_softmax) {
r = xpu::softmax<XPUType>(
dev_ctx.x_context(), clip_logits_data, softmax_data, logits_dims, axis);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax");
} else {
r = xpu::copy<XPUType>(
dev_ctx.x_context(), clip_logits_data, softmax_data, softmax->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
}
// cross_entropy
if (axis != rank - 1) {
XPUType* trans_softmax = RAII_GUARD.alloc_l3_or_gm<XPUType>(n * d);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册