未验证 提交 1d549400 编写于 作者: L Lucas 提交者: GitHub

[Bug Fixs] fix bugs when using cast<int64_t, int32_t> in xpu/cross_entropy...

[Bug Fixs] fix bugs when using cast<int64_t, int32_t> in xpu/cross_entropy kernels, *test=kunlun (#53325)
上级 3ec12c2b
...@@ -54,21 +54,32 @@ void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx, ...@@ -54,21 +54,32 @@ void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx,
d); d);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_softmax_with_cross_entropy_grad"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_softmax_with_cross_entropy_grad");
} else { } else {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); const int* labels_int_ptr = nullptr;
int* labels_int_ptr_l3 = if (labels.dtype() == DataType::INT32) {
RAII_GUARD.alloc_l3_or_gm<int32_t>(labels.numel()); labels_int_ptr = labels.data<int32_t>();
PADDLE_ENFORCE_XDNN_NOT_NULL(labels_int_ptr_l3); } else if (labels.dtype() == DataType::INT64) {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
r = xpu::cast<int64_t, int32_t>(dev_ctx.x_context(), int* labels_int_ptr_l3 =
labels.data<int64_t>(), RAII_GUARD.alloc_l3_or_gm<int32_t>(labels.numel());
labels_int_ptr_l3, PADDLE_ENFORCE_XDNN_NOT_NULL(labels_int_ptr_l3);
labels.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); r = xpu::cast<int64_t, int32_t>(dev_ctx.x_context(),
labels.data<int64_t>(),
labels_int_ptr_l3,
labels.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
labels_int_ptr = labels_int_ptr_l3;
} else {
// TODO(lilujia): other data types should be handled
errors::Unimplemented(
("cross_entropy does not support data types other than int32 and "
"int64"));
}
r = xpu::hard_softmax_with_cross_entropy_grad<XPUType, int>( r = xpu::hard_softmax_with_cross_entropy_grad<XPUType, int>(
dev_ctx.x_context(), dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(loss_grad.data<T>()), reinterpret_cast<const XPUType*>(loss_grad.data<T>()),
labels_int_ptr_l3, labels_int_ptr,
reinterpret_cast<const XPUType*>(softmax.data<T>()), reinterpret_cast<const XPUType*>(softmax.data<T>()),
reinterpret_cast<XPUType*>(logit_grad->data<T>()), reinterpret_cast<XPUType*>(logit_grad->data<T>()),
ignore_index, ignore_index,
...@@ -113,19 +124,31 @@ void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx, ...@@ -113,19 +124,31 @@ void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx,
t); t);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_softmax_with_cross_entropy_grad"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_softmax_with_cross_entropy_grad");
} else { } else {
int* labels_int_ptr_l3 = const int* labels_int_ptr = nullptr;
RAII_GUARD.alloc_l3_or_gm<int32_t>(labels.numel()); if (labels.dtype() == DataType::INT32) {
PADDLE_ENFORCE_XDNN_NOT_NULL(labels_int_ptr_l3); labels_int_ptr = labels.data<int32_t>();
} else if (labels.dtype() == DataType::INT64) {
r = xpu::cast<int64_t, int32_t>(dev_ctx.x_context(), int* labels_int_ptr_l3 =
labels.data<int64_t>(), RAII_GUARD.alloc_l3_or_gm<int32_t>(labels.numel());
labels_int_ptr_l3, PADDLE_ENFORCE_XDNN_NOT_NULL(labels_int_ptr_l3);
labels.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); r = xpu::cast<int64_t, int32_t>(dev_ctx.x_context(),
labels.data<int64_t>(),
labels_int_ptr_l3,
labels.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
labels_int_ptr = labels_int_ptr_l3;
} else {
// TODO(lilujia): other data types should be handled
errors::Unimplemented(
("cross_entropy does not support data types other than int32 and "
"int64"));
}
r = xpu::hard_softmax_with_cross_entropy_grad<XPUType, int>( r = xpu::hard_softmax_with_cross_entropy_grad<XPUType, int>(
dev_ctx.x_context(), dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(loss_grad.data<T>()), reinterpret_cast<const XPUType*>(loss_grad.data<T>()),
labels_int_ptr_l3, labels_int_ptr,
trans_softmax, trans_softmax,
trans_logit, trans_logit,
ignore_index, ignore_index,
......
...@@ -133,20 +133,32 @@ void CrossEntropyWithSoftmaxKernel(const Context& dev_ctx, ...@@ -133,20 +133,32 @@ void CrossEntropyWithSoftmaxKernel(const Context& dev_ctx,
axis == rank - 1 ? d : t); axis == rank - 1 ? d : t);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_cross_entropy"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_cross_entropy");
} else { } else {
DenseTensor labels_int32; const int* labels_int_ptr = nullptr;
int* labels_int_ptr_l3 = RAII_GUARD.alloc_l3_or_gm<int32_t>(labels.numel()); if (labels.dtype() == DataType::INT32) {
PADDLE_ENFORCE_XDNN_NOT_NULL(labels_int_ptr_l3); labels_int_ptr = labels.data<int32_t>();
} else if (labels.dtype() == DataType::INT64) {
r = xpu::cast<int64_t, int32_t>(dev_ctx.x_context(), xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
labels.data<int64_t>(), int* labels_int_ptr_l3 =
labels_int_ptr_l3, RAII_GUARD.alloc_l3_or_gm<int32_t>(labels.numel());
labels.numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(labels_int_ptr_l3);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
r = xpu::cast<int64_t, int32_t>(dev_ctx.x_context(),
labels.data<int64_t>(),
labels_int_ptr_l3,
labels.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
labels_int_ptr = labels_int_ptr_l3;
} else {
// TODO(lilujia): other data types should be handled
errors::Unimplemented(
("cross_entropy does not support data types other than int32 and "
"int64"));
}
r = xpu::hard_cross_entropy<XPUType, int32_t>( r = xpu::hard_cross_entropy<XPUType, int32_t>(
dev_ctx.x_context(), dev_ctx.x_context(),
softmax_data, softmax_data,
labels_int_ptr_l3, labels_int_ptr,
loss_data, loss_data,
nullptr, nullptr,
axis == rank - 1 ? n : n * d / t, axis == rank - 1 ? n : n * d / t,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册