未验证 提交 b421d7a5 编写于 作者: Q QingshuChen 提交者: GitHub

fix softmax_with_cross_entropy bug for kunlun (#49207)

上级 d4305a26
...@@ -603,7 +603,17 @@ XPUOpMap& get_kl2_ops() { ...@@ -603,7 +603,17 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::BOOL, phi::DataType::BOOL,
phi::DataType::INT8, phi::DataType::INT8,
phi::DataType::UINT8, phi::DataType::UINT8,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"unsqueeze_with_xshape",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
{"warpctc_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"warpctc_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"warpctc", XPUKernelSet({phi::DataType::FLOAT32})}, {"warpctc", XPUKernelSet({phi::DataType::FLOAT32})},
{"where_index", {"where_index",
......
...@@ -89,10 +89,15 @@ void CrossEntropyWithSoftmaxKernel(const Context& dev_ctx, ...@@ -89,10 +89,15 @@ void CrossEntropyWithSoftmaxKernel(const Context& dev_ctx,
static_cast<XPUType>(max_val)); static_cast<XPUType>(max_val));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_v2"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_v2");
if (use_softmax) {
r = xpu::softmax<XPUType>( r = xpu::softmax<XPUType>(
dev_ctx.x_context(), clip_logits_data, softmax_data, logits_dims, axis); dev_ctx.x_context(), clip_logits_data, softmax_data, logits_dims, axis);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax");
} else {
r = xpu::copy<XPUType>(
dev_ctx.x_context(), clip_logits_data, softmax_data, softmax->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
}
// cross_entropy // cross_entropy
if (axis != rank - 1) { if (axis != rank - 1) {
XPUType* trans_softmax = RAII_GUARD.alloc_l3_or_gm<XPUType>(n * d); XPUType* trans_softmax = RAII_GUARD.alloc_l3_or_gm<XPUType>(n * d);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册