未验证 提交 689bcad5 编写于 作者: R RuohengMa 提交者: GitHub

[XPU kernel] fix warpctc issue (#55950)

* [XPU kernel] fix warpctc issue

* fix issue

* temporal hack to circumvent depthwise_conv2d precision issue

* reset test case
上级 986ccbca
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -67,42 +67,113 @@ void WarpctcKernel(const Context& dev_ctx,
"greater than zero "
"but received %d. ",
sequence_width));
PADDLE_ENFORCE_GE(blank,
0,
phi::errors::InvalidArgument("Input(blank) should be "
"equal or greater than zero "
"but received %d. ",
blank));
PADDLE_ENFORCE_LT(blank,
sequence_width,
phi::errors::InvalidArgument("Input(blank) should be "
"less than %d "
"but received %d. ",
sequence_width,
blank));
int lm_workspace = (max_sequence_length + 1) *
(2 * max_target_seq_length + sequence_width + 1) *
sizeof(T) +
(7 * max_target_seq_length + 3) * sizeof(int);
PADDLE_ENFORCE_LE(lm_workspace,
6144,
auto logits_length_dtype = logits_length.get_ptr()->dtype();
auto labels_length_dtype = labels_length.get_ptr()->dtype();
PADDLE_ENFORCE_EQ(
logits_length_dtype == labels_length_dtype,
true,
phi::errors::InvalidArgument("The data type of Input(logits_length) and "
"Input(labels_length) should be equal. "));
PADDLE_ENFORCE_EQ(logits_length_dtype == DataType::INT32 ||
logits_length_dtype == DataType::INT64,
true,
phi::errors::InvalidArgument(
"Input size is too large for xpu in warpctc kernel"));
loss->Resize(phi::make_ddim({num_sequences, 1}));
dev_ctx.template Alloc<T>(loss);
"The data type of Input(logits_length) should be "
"either %s or %s, "
"but received %s. ",
DataTypeToString(DataType::INT32),
DataTypeToString(DataType::INT64),
DataTypeToString(logits_length_dtype)));
PADDLE_ENFORCE_EQ(labels_length_dtype == DataType::INT32 ||
labels_length_dtype == DataType::INT64,
true,
phi::errors::InvalidArgument(
"The data type of Input(labels_length) should be "
"either %s or %s, "
"but received %s. ",
DataTypeToString(DataType::INT32),
DataTypeToString(DataType::INT64),
DataTypeToString(labels_length_dtype)));
warpctcgrad->Resize(
phi::make_ddim({max_sequence_length, num_sequences, sequence_width}));
dev_ctx.template Alloc<T>(warpctcgrad);
T* warpctcgrad_data = warpctcgrad->data<T>();
const T* logits_data = logits.data<T>();
const int* label_data = label.data<int>();
auto logits_length_data = logits_length.get_ptr()->data<int64_t>();
auto labels_length_data = labels_length.get_ptr()->data<int64_t>();
int sm_workspace, lm_workspace;
int64_t max_S = 2 * max_target_seq_length + 1;
if (warpctcgrad_data == nullptr) {
sm_workspace = sizeof(T) * sequence_width +
sizeof(int) * max_target_seq_length + sizeof(int);
lm_workspace = 2 * max_S * sizeof(T) + 2 * max_S * sizeof(int);
} else {
sm_workspace = sizeof(T) * sequence_width +
sizeof(int) * max_target_seq_length + sizeof(int);
lm_workspace = 4 * max_S * sizeof(T) + 2 * max_S * sizeof(int) +
sequence_width * sizeof(T);
}
PADDLE_ENFORCE_LE(
sm_workspace + lm_workspace,
256 * 1024,
phi::errors::InvalidArgument(
"Input size should be equal or less than %d for xpu warpctc kernel, "
"but size %d is received. ",
256 * 1024,
sm_workspace + lm_workspace));
loss->Resize(phi::make_ddim({num_sequences, 1}));
dev_ctx.template Alloc<T>(loss);
T* loss_data = loss->data<T>();
T* warpctcgrad_data = warpctcgrad->data<T>();
int r = xpu::ctc_loss<T, int64_t>(dev_ctx.x_context(),
logits_data,
label_data,
loss_data,
warpctcgrad_data,
logits_length_data,
labels_length_data,
max_sequence_length,
num_sequences,
sequence_width,
max_target_seq_length,
blank);
const T* logits_data = logits.data<T>();
const int* label_data = label.data<int>();
int r;
if (logits_length_dtype == DataType::INT32) {
auto logits_length_data = logits_length.get_ptr()->data<int>();
auto labels_length_data = labels_length.get_ptr()->data<int>();
r = xpu::ctc_loss<T, int>(dev_ctx.x_context(),
logits_data,
label_data,
loss_data,
warpctcgrad_data,
logits_length_data,
labels_length_data,
max_sequence_length,
num_sequences,
sequence_width,
max_target_seq_length,
blank);
} else {
auto logits_length_data = logits_length.get_ptr()->data<int64_t>();
auto labels_length_data = labels_length.get_ptr()->data<int64_t>();
r = xpu::ctc_loss<T, int64_t>(dev_ctx.x_context(),
logits_data,
label_data,
loss_data,
warpctcgrad_data,
logits_length_data,
labels_length_data,
max_sequence_length,
num_sequences,
sequence_width,
max_target_seq_length,
blank);
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "ctc_loss");
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册