未验证 提交 1d549400 编写于 作者: L Lucas 提交者: GitHub

[Bug Fixs] fix bugs when using cast<int64_t, int32_t> in xpu/cross_entropy...

[Bug Fixs] fix bugs when using cast<int64_t, int32_t> in xpu/cross_entropy kernels, *test=kunlun (#53325)
上级 3ec12c2b
......@@ -54,21 +54,32 @@ void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx,
d);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_softmax_with_cross_entropy_grad");
} else {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int* labels_int_ptr_l3 =
RAII_GUARD.alloc_l3_or_gm<int32_t>(labels.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(labels_int_ptr_l3);
r = xpu::cast<int64_t, int32_t>(dev_ctx.x_context(),
labels.data<int64_t>(),
labels_int_ptr_l3,
labels.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
const int* labels_int_ptr = nullptr;
if (labels.dtype() == DataType::INT32) {
labels_int_ptr = labels.data<int32_t>();
} else if (labels.dtype() == DataType::INT64) {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int* labels_int_ptr_l3 =
RAII_GUARD.alloc_l3_or_gm<int32_t>(labels.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(labels_int_ptr_l3);
r = xpu::cast<int64_t, int32_t>(dev_ctx.x_context(),
labels.data<int64_t>(),
labels_int_ptr_l3,
labels.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
labels_int_ptr = labels_int_ptr_l3;
} else {
// TODO(lilujia): other data types should be handled
errors::Unimplemented(
("cross_entropy does not support data types other than int32 and "
"int64"));
}
r = xpu::hard_softmax_with_cross_entropy_grad<XPUType, int>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(loss_grad.data<T>()),
labels_int_ptr_l3,
labels_int_ptr,
reinterpret_cast<const XPUType*>(softmax.data<T>()),
reinterpret_cast<XPUType*>(logit_grad->data<T>()),
ignore_index,
......@@ -113,19 +124,31 @@ void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx,
t);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_softmax_with_cross_entropy_grad");
} else {
int* labels_int_ptr_l3 =
RAII_GUARD.alloc_l3_or_gm<int32_t>(labels.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(labels_int_ptr_l3);
r = xpu::cast<int64_t, int32_t>(dev_ctx.x_context(),
labels.data<int64_t>(),
labels_int_ptr_l3,
labels.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
const int* labels_int_ptr = nullptr;
if (labels.dtype() == DataType::INT32) {
labels_int_ptr = labels.data<int32_t>();
} else if (labels.dtype() == DataType::INT64) {
int* labels_int_ptr_l3 =
RAII_GUARD.alloc_l3_or_gm<int32_t>(labels.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(labels_int_ptr_l3);
r = xpu::cast<int64_t, int32_t>(dev_ctx.x_context(),
labels.data<int64_t>(),
labels_int_ptr_l3,
labels.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
labels_int_ptr = labels_int_ptr_l3;
} else {
// TODO(lilujia): other data types should be handled
errors::Unimplemented(
("cross_entropy does not support data types other than int32 and "
"int64"));
}
r = xpu::hard_softmax_with_cross_entropy_grad<XPUType, int>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(loss_grad.data<T>()),
labels_int_ptr_l3,
labels_int_ptr,
trans_softmax,
trans_logit,
ignore_index,
......
......@@ -133,20 +133,32 @@ void CrossEntropyWithSoftmaxKernel(const Context& dev_ctx,
axis == rank - 1 ? d : t);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_cross_entropy");
} else {
DenseTensor labels_int32;
int* labels_int_ptr_l3 = RAII_GUARD.alloc_l3_or_gm<int32_t>(labels.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(labels_int_ptr_l3);
r = xpu::cast<int64_t, int32_t>(dev_ctx.x_context(),
labels.data<int64_t>(),
labels_int_ptr_l3,
labels.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
const int* labels_int_ptr = nullptr;
if (labels.dtype() == DataType::INT32) {
labels_int_ptr = labels.data<int32_t>();
} else if (labels.dtype() == DataType::INT64) {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int* labels_int_ptr_l3 =
RAII_GUARD.alloc_l3_or_gm<int32_t>(labels.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(labels_int_ptr_l3);
r = xpu::cast<int64_t, int32_t>(dev_ctx.x_context(),
labels.data<int64_t>(),
labels_int_ptr_l3,
labels.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
labels_int_ptr = labels_int_ptr_l3;
} else {
// TODO(lilujia): other data types should be handled
errors::Unimplemented(
("cross_entropy does not support data types other than int32 and "
"int64"));
}
r = xpu::hard_cross_entropy<XPUType, int32_t>(
dev_ctx.x_context(),
softmax_data,
labels_int_ptr_l3,
labels_int_ptr,
loss_data,
nullptr,
axis == rank - 1 ? n : n * d / t,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册