/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/sequence_padding.h" #include "paddle/operators/math/sequence_scale.h" #include "paddle/platform/dynload/warpctc.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; template class WarpCTCFunctor { public: /* * \brief Compute the connectionist temporal classification loss, * and optionally compute the gradient with respect to the inputs. * * If gradient is nullptr, it only computes the ctc loss, * or computes both ctc loss and gradient. * * \param ctx execution context of this functor * \param input batch matrix of input probabilities, in * max_sequence_length x num_sequences x * sequence_width, (row-major) format * \param gradient batch matrix of gradient, with the same shape as * input. * \param cpu_labels labels always in CPU memory. * \param cpu_label_lengths length of all labels in CPU memory. * \param cpu_input_lengths length of all sequences in CPU memory. * \param sequence_width number of possible output symbols. * \param num_sequences number of sequence. * \param blank blank label used in ctc loss function. * \param cpu_losss cost of each sequence in CPU memory. */ void operator()(const framework::ExecutionContext& ctx, const float* input, float* gradient, const int* cpu_labels, const int* cpu_label_lengths, const int* cpu_input_lengths, const size_t sequence_width, const size_t num_sequences, const size_t blank, float* cpu_loss) { // Init warp-ctc options init(ctx, blank); // Compute the required workspace size. // There is no memory allocated operations within warp-ctc. size_t workspace_bytes = 0; ctcStatus_t status = platform::dynload::get_workspace_size( cpu_label_lengths, cpu_input_lengths, static_cast(sequence_width), static_cast(num_sequences), options_, &workspace_bytes); PADDLE_ENFORCE_EQ(CTC_STATUS_SUCCESS, status, "warp-ctc [version %d] Error in get_workspace_size: ", warpctc_version_, platform::dynload::ctcGetStatusString(status)); PADDLE_ENFORCE_GT(workspace_bytes, 0UL, "Bytes of workspace got by warp-ctc function, " "get_workspace_size(), should be larger than 0."); Tensor workspace; size_t workspace_elements = workspace_bytes / sizeof(float) + 1UL; float* workspace_data = workspace.mutable_data( framework::make_ddim({static_cast(workspace_elements)}), ctx.GetPlace()); math::SetConstant()( ctx.template device_context(), &workspace, static_cast(0)); // compute loss and gradient status = platform::dynload::compute_ctc_loss( input, gradient, cpu_labels, cpu_label_lengths, cpu_input_lengths, static_cast(sequence_width), static_cast(num_sequences), cpu_loss, workspace_data, options_); PADDLE_ENFORCE_EQ(CTC_STATUS_SUCCESS, status, "warp-ctc [version %d] Error in compute_ctc_loss: ", warpctc_version_, platform::dynload::ctcGetStatusString(status)); } protected: void init(const framework::ExecutionContext& ctx, const size_t blank) { warpctc_version_ = platform::dynload::get_warpctc_version(); if (platform::is_gpu_place(ctx.GetPlace())) { #ifdef PADDLE_WITH_CUDA options_.loc = CTC_GPU; options_.stream = reinterpret_cast( ctx.device_context()) .stream(); #else PADDLE_THROW("[warpctc init] GPU is not enabled."); #endif } else { options_.loc = CTC_CPU; options_.num_threads = 1; } options_.blank_label = blank; } private: int warpctc_version_; ctcOptions options_; }; template class WarpCTCKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* logits = ctx.Input("Logits"); auto* label = ctx.Input("Label"); auto* warpctc_grad = ctx.Output("WarpCTCGrad"); auto* loss = ctx.Output("Loss"); const size_t level = 0; auto logits_lod = framework::ToAbsOffset(logits->lod()); auto logits_dims = logits->dims(); PADDLE_ENFORCE_EQ(logits_dims[0], static_cast(logits_lod[level].back()), "The first dimension of Input(Logits) should be equal to " "the sum of all sequences' lengths."); auto label_lod = framework::ToAbsOffset(label->lod()); auto label_dims = label->dims(); PADDLE_ENFORCE_EQ( label_dims[0], label->numel(), "The width of each timestep in Input(Label) should be 1."); const size_t num_sequences = logits_lod[level].size() - 1; PADDLE_ENFORCE_EQ(num_sequences, label_lod[level].size() - 1, "The number of sequences of Input(Logits) should be " "equal to that of Input(Label)."); const size_t sequence_width = logits->numel() / logits_dims[0]; auto loss_dims = framework::make_ddim({static_cast(num_sequences), 1}); // warpctc needs sequences data stored in transposed padding format Tensor warpctc_logits; const size_t max_sequence_length = math::MaximumSequenceLength(logits_lod, level); auto warpctc_logits_dims = framework::make_ddim({static_cast(max_sequence_length), static_cast(num_sequences), static_cast(sequence_width)}); warpctc_logits.mutable_data(warpctc_logits_dims, ctx.GetPlace()); math::PaddingLoDTensorFunctor()( ctx.template device_context(), *logits, warpctc_logits, false); const T* warpctc_logits_data = warpctc_logits.data(); std::vector warpctc_label_lengths(num_sequences); std::vector warpctc_logits_lengths(num_sequences); for (size_t i = 0; i < num_sequences; ++i) { warpctc_label_lengths[i] = label_lod[level][i + 1] - label_lod[level][i]; warpctc_logits_lengths[i] = logits_lod[level][i + 1] - logits_lod[level][i]; } // warpctc computes loss and gradient in one call, gradient data also stored // in batch format T* warpctc_grad_data = warpctc_grad->mutable_data(warpctc_logits.dims(), ctx.GetPlace()); // warpctc accesses labels in CPU memory Tensor warpctc_label; Copy(*label, platform::CPUPlace(), ctx.device_context(), &warpctc_label); const int* warpctc_label_data = warpctc_label.data(); // warpctc stores loss in CPU memory Tensor warpctc_loss; T* warpctc_loss_data = warpctc_loss.mutable_data(loss_dims, platform::CPUPlace()); const size_t blank = static_cast(ctx.Attr("blank")); WarpCTCFunctor()( ctx, warpctc_logits_data, warpctc_grad_data, warpctc_label_data, warpctc_label_lengths.data(), warpctc_logits_lengths.data(), sequence_width, num_sequences, blank, warpctc_loss_data); // Copy the loss back Copy(warpctc_loss, ctx.GetPlace(), ctx.device_context(), loss); } }; template class WarpCTCGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* warpctc_grad = ctx.Input("WarpCTCGrad"); auto* logits_grad = ctx.Output(framework::GradVarName("Logits")); const Tensor* loss_grad = ctx.Input(framework::GradVarName("Loss")); // LOG(ERROR) << "loss_grad_dims: " << loss_grad_dims; // for (int i=0; inumel();i++) { // LOG(ERROR) << "loss_grad: " << loss_grad_data[i]; //} // T* logits_grad_data = logits_grad->mutable_data(ctx.GetPlace()); bool norm_by_times = ctx.Attr("norm_by_times"); math::UnpaddingLoDTensorFunctor()( ctx.template device_context(), *logits_grad, *warpctc_grad, norm_by_times); const T* loss_grad_data = loss_grad->data(); const size_t num_seq = loss_grad->dims()[0]; math::ScaleLoDTensorFunctor()( ctx.template device_context(), *logits_grad, loss_grad_data, num_seq); /* int level = 0; auto logits_grad_lod = framework::ToAbsOffset(logits_grad->lod()); const size_t num_sequences = logits_grad_lod[level].size() - 1; for (int seq_index = 0; seq_index < num_sequences; ++seq_index) { for (int token_index = logits_grad_lod[level][seq_index]; token_index < logits_grad_lod[level][seq_index + 1]; ++token_index) { logits_grad_data[token_index] *= loss_grad_data[seq_index]; } } */ } }; } // namespace operators } // namespace paddle