未验证 提交 689bcad5 编写于 作者: R RuohengMa 提交者: GitHub

[XPU kernel] fix warpctc issue (#55950)

* [XPU kernel] fix warpctc issue

* fix issue

* temporal hack to circumvent depthwise_conv2d precision issue

* reset test case
上级 986ccbca
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -67,42 +67,113 @@ void WarpctcKernel(const Context& dev_ctx, ...@@ -67,42 +67,113 @@ void WarpctcKernel(const Context& dev_ctx,
"greater than zero " "greater than zero "
"but received %d. ", "but received %d. ",
sequence_width)); sequence_width));
PADDLE_ENFORCE_GE(blank,
0,
phi::errors::InvalidArgument("Input(blank) should be "
"equal or greater than zero "
"but received %d. ",
blank));
PADDLE_ENFORCE_LT(blank,
sequence_width,
phi::errors::InvalidArgument("Input(blank) should be "
"less than %d "
"but received %d. ",
sequence_width,
blank));
int lm_workspace = (max_sequence_length + 1) * auto logits_length_dtype = logits_length.get_ptr()->dtype();
(2 * max_target_seq_length + sequence_width + 1) * auto labels_length_dtype = labels_length.get_ptr()->dtype();
sizeof(T) + PADDLE_ENFORCE_EQ(
(7 * max_target_seq_length + 3) * sizeof(int); logits_length_dtype == labels_length_dtype,
PADDLE_ENFORCE_LE(lm_workspace, true,
6144, phi::errors::InvalidArgument("The data type of Input(logits_length) and "
"Input(labels_length) should be equal. "));
PADDLE_ENFORCE_EQ(logits_length_dtype == DataType::INT32 ||
logits_length_dtype == DataType::INT64,
true,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Input size is too large for xpu in warpctc kernel")); "The data type of Input(logits_length) should be "
"either %s or %s, "
loss->Resize(phi::make_ddim({num_sequences, 1})); "but received %s. ",
dev_ctx.template Alloc<T>(loss); DataTypeToString(DataType::INT32),
DataTypeToString(DataType::INT64),
DataTypeToString(logits_length_dtype)));
PADDLE_ENFORCE_EQ(labels_length_dtype == DataType::INT32 ||
labels_length_dtype == DataType::INT64,
true,
phi::errors::InvalidArgument(
"The data type of Input(labels_length) should be "
"either %s or %s, "
"but received %s. ",
DataTypeToString(DataType::INT32),
DataTypeToString(DataType::INT64),
DataTypeToString(labels_length_dtype)));
warpctcgrad->Resize( warpctcgrad->Resize(
phi::make_ddim({max_sequence_length, num_sequences, sequence_width})); phi::make_ddim({max_sequence_length, num_sequences, sequence_width}));
dev_ctx.template Alloc<T>(warpctcgrad); dev_ctx.template Alloc<T>(warpctcgrad);
T* warpctcgrad_data = warpctcgrad->data<T>();
const T* logits_data = logits.data<T>(); int sm_workspace, lm_workspace;
const int* label_data = label.data<int>(); int64_t max_S = 2 * max_target_seq_length + 1;
auto logits_length_data = logits_length.get_ptr()->data<int64_t>(); if (warpctcgrad_data == nullptr) {
auto labels_length_data = labels_length.get_ptr()->data<int64_t>(); sm_workspace = sizeof(T) * sequence_width +
sizeof(int) * max_target_seq_length + sizeof(int);
lm_workspace = 2 * max_S * sizeof(T) + 2 * max_S * sizeof(int);
} else {
sm_workspace = sizeof(T) * sequence_width +
sizeof(int) * max_target_seq_length + sizeof(int);
lm_workspace = 4 * max_S * sizeof(T) + 2 * max_S * sizeof(int) +
sequence_width * sizeof(T);
}
PADDLE_ENFORCE_LE(
sm_workspace + lm_workspace,
256 * 1024,
phi::errors::InvalidArgument(
"Input size should be equal or less than %d for xpu warpctc kernel, "
"but size %d is received. ",
256 * 1024,
sm_workspace + lm_workspace));
loss->Resize(phi::make_ddim({num_sequences, 1}));
dev_ctx.template Alloc<T>(loss);
T* loss_data = loss->data<T>(); T* loss_data = loss->data<T>();
T* warpctcgrad_data = warpctcgrad->data<T>();
int r = xpu::ctc_loss<T, int64_t>(dev_ctx.x_context(), const T* logits_data = logits.data<T>();
logits_data, const int* label_data = label.data<int>();
label_data, int r;
loss_data, if (logits_length_dtype == DataType::INT32) {
warpctcgrad_data, auto logits_length_data = logits_length.get_ptr()->data<int>();
logits_length_data, auto labels_length_data = labels_length.get_ptr()->data<int>();
labels_length_data, r = xpu::ctc_loss<T, int>(dev_ctx.x_context(),
max_sequence_length, logits_data,
num_sequences, label_data,
sequence_width, loss_data,
max_target_seq_length, warpctcgrad_data,
blank); logits_length_data,
labels_length_data,
max_sequence_length,
num_sequences,
sequence_width,
max_target_seq_length,
blank);
} else {
auto logits_length_data = logits_length.get_ptr()->data<int64_t>();
auto labels_length_data = labels_length.get_ptr()->data<int64_t>();
r = xpu::ctc_loss<T, int64_t>(dev_ctx.x_context(),
logits_data,
label_data,
loss_data,
warpctcgrad_data,
logits_length_data,
labels_length_data,
max_sequence_length,
num_sequences,
sequence_width,
max_target_seq_length,
blank);
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "ctc_loss"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "ctc_loss");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册