// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/warpctc_kernel.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/core/kernel_registry.h" namespace phi { template void WarpctcKernel(const Context& dev_ctx, const DenseTensor& logits, const DenseTensor& label, const paddle::optional& logits_length, const paddle::optional& labels_length, int blank, bool norm_by_times, DenseTensor* loss, DenseTensor* warpctcgrad) { bool has_logits_length = logits_length.is_initialized(); if (!has_logits_length) { PADDLE_THROW( phi::errors::External("XPU only support logits_length is_initialized")); } bool has_labels_length = labels_length.is_initialized(); if (!has_labels_length) { PADDLE_THROW( phi::errors::External("XPU only support labels_length is_initialized")); } int max_sequence_length = logits.dims()[0]; int num_sequences = logits.dims()[1]; int sequence_width = logits.dims()[2]; int max_target_seq_length = label.dims()[1]; PADDLE_ENFORCE_GT(max_sequence_length, 0, phi::errors::InvalidArgument( "The first dimension of Input(Logits) should be " "greater than zero " "but received %d. ", max_sequence_length)); PADDLE_ENFORCE_GT(num_sequences, 0, phi::errors::InvalidArgument( "The second dimension of Input(Logits) should be " "greater than zero " "but received %d. ", num_sequences)); PADDLE_ENFORCE_GT(sequence_width, 0, phi::errors::InvalidArgument( "The third dimension of Input(Logits) should be " "greater than zero " "but received %d. ", sequence_width)); int lm_workspace = (max_sequence_length + 1) * (2 * max_target_seq_length + sequence_width + 1) * sizeof(T) + (7 * max_target_seq_length + 3) * sizeof(int); PADDLE_ENFORCE_LE(lm_workspace, 6144, phi::errors::InvalidArgument( "Input size is too large for xpu in warpctc kernel")); loss->Resize(phi::make_ddim({num_sequences, 1})); dev_ctx.template Alloc(loss); warpctcgrad->Resize( phi::make_ddim({max_sequence_length, num_sequences, sequence_width})); dev_ctx.template Alloc(warpctcgrad); const T* logits_data = logits.data(); const int* label_data = label.data(); auto logits_length_data = logits_length.get_ptr()->data(); auto labels_length_data = labels_length.get_ptr()->data(); T* loss_data = loss->data(); T* warpctcgrad_data = warpctcgrad->data(); int r = xpu::ctc_loss(dev_ctx.x_context(), logits_data, label_data, loss_data, warpctcgrad_data, logits_length_data, labels_length_data, max_sequence_length, num_sequences, sequence_width, max_target_seq_length, blank); PADDLE_ENFORCE_XDNN_SUCCESS(r, "ctc_loss"); } } // namespace phi PD_REGISTER_KERNEL(warpctc, XPU, ALL_LAYOUT, phi::WarpctcKernel, float) {}