diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index e77be832c0cc8975c3fc2ebb7fad577cdfe919f5..3901226216f4d2cd05a140e1fa2c841f9e396af8 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -111,10 +111,10 @@ op_library(load_combine_op DEPS string_array) if (WITH_GPU OR WITH_ROCM) if(WITH_ROCM) - op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu) + op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc) # warpctc_op needs cudnn 7 above elseif(${CUDNN_MAJOR_VERSION} VERSION_LESS 7) - op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu) + op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc) else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) endif() diff --git a/paddle/fluid/operators/math/sequence_padding.cc b/paddle/fluid/operators/math/sequence_padding.cc index c00eb54a2f8599bd8d492edf2f0c2f46a8ddc51c..35ba8c1d118a82ac63b9db91c7e289bb75c80722 100644 --- a/paddle/fluid/operators/math/sequence_padding.cc +++ b/paddle/fluid/operators/math/sequence_padding.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/sequence_padding.h" +#include "paddle/phi/backends/cpu/cpu_context.h" namespace phi { class DenseTensor; @@ -136,6 +137,52 @@ class PaddingLoDTensorFunctor { } }; +template +class PaddingLoDTensorFunctor { + public: + void operator()(const phi::CPUContext& context, + const framework::LoDTensor& seq_tensor, + framework::LoDTensor* pad_tensor, + const framework::LoDTensor& pad_value, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout layout = kBatchLengthWidth) { + auto seq_lod = seq_tensor.lod(); + const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level]; + const auto& seq_tensor_dims = seq_tensor.dims(); + const auto& pad_tensor_dims = pad_tensor->dims(); + if (pad_seq_len == -1) { + pad_seq_len = MaximumSequenceLength(seq_offsets); + } + int step_width = seq_tensor.numel() / seq_tensor_dims[0]; + + CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, + step_width, layout); + + PADDLE_ENFORCE_EQ( + pad_value.numel() == 1 || pad_value.numel() == step_width, true, + platform::errors::InvalidArgument( + "The numel of 'pad_value' can only be 1 or be equal to the " + "'step_width', but got %ld != 1 and %ld. Please check the input " + "value.", + pad_value.numel(), step_width)); + + // fill padding value + T* pad_data = pad_tensor->data(); + const T* pad_value_data = pad_value.data(); + if (pad_value.numel() == 1) { + fast_mem_init(pad_data, pad_tensor->numel(), pad_value_data, + sizeof(T)); + } else { + for (int i = 0; i < pad_tensor->numel(); i += step_width) { + memcpy(pad_data + i, pad_value_data, step_width * sizeof(T)); + } + } + + CopyValidData(pad_tensor, &seq_tensor, seq_offsets, pad_seq_len, + step_width, norm_by_times, kSeqToPad, layout); + } +}; + template class UnpaddingLoDTensorFunctor { public: @@ -160,6 +207,30 @@ class UnpaddingLoDTensorFunctor { } }; +template +class UnpaddingLoDTensorFunctor { + public: + void operator()(const phi::CPUContext& context, + const framework::LoDTensor& pad_tensor, + framework::LoDTensor* seq_tensor, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout layout = kBatchLengthWidth) { + auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level]; + const auto& seq_tensor_dims = seq_tensor->dims(); + const auto& pad_tensor_dims = pad_tensor.dims(); + if (pad_seq_len == -1) { + pad_seq_len = MaximumSequenceLength(seq_offsets); + } + int step_width = seq_tensor->numel() / seq_tensor_dims[0]; + + CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, + step_width, layout); + + CopyValidData(seq_tensor, &pad_tensor, seq_offsets, pad_seq_len, + step_width, norm_by_times, kPadToSeq, layout); + } +}; + template class PaddingLoDTensorFunctor; template class PaddingLoDTensorFunctor; template class PaddingLoDTensorFunctor; @@ -170,6 +241,16 @@ template class UnpaddingLoDTensorFunctor; template class UnpaddingLoDTensorFunctor; template class UnpaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; + +template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/sequence_padding.cu b/paddle/fluid/operators/math/sequence_padding.cu index 01fd2d403c4564ba022e3ab9633fa04d998dd662..9aca6ad0f5a2f4d7d2dec23736bf71b6b6667ac8 100644 --- a/paddle/fluid/operators/math/sequence_padding.cu +++ b/paddle/fluid/operators/math/sequence_padding.cu @@ -14,6 +14,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/sequence_padding.h" +#include "paddle/phi/backends/gpu/gpu_context.h" namespace paddle { namespace operators { @@ -112,6 +113,69 @@ class PaddingLoDTensorFunctor { } }; +template +class PaddingLoDTensorFunctor { + public: + void operator()(const phi::GPUContext& context, + const framework::LoDTensor& seq_tensor, + framework::LoDTensor* pad_tensor, + const framework::LoDTensor& pad_value, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout layout = kBatchLengthWidth) { + auto seq_lod = seq_tensor.lod(); + auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level]; + const auto& seq_tensor_dims = seq_tensor.dims(); + const auto& pad_tensor_dims = pad_tensor->dims(); + int max_seq_len = MaximumSequenceLength(seq_offsets); + if (pad_seq_len == -1) { + pad_seq_len = max_seq_len; + } + PADDLE_ENFORCE_GE( + pad_seq_len, max_seq_len, + platform::errors::InvalidArgument( + "The pad_seq_len must be equal to or greater than the " + "original max sequence length. Expected %ld >= %ld, but got %ld < " + "%ld. Please check the input value.", + pad_seq_len, max_seq_len, pad_seq_len, max_seq_len)); + int step_width = seq_tensor.numel() / seq_tensor_dims[0]; + int seq_num = seq_offsets.size() - 1; + + CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, + step_width, layout); + PADDLE_ENFORCE_EQ( + pad_value.numel() == 1 || pad_value.numel() == step_width, true, + platform::errors::InvalidArgument( + "The numel of 'pad_value' can only be 1 or be equal to " + "the 'step_width', but got %ld != 1 and %ld. Please check the " + "input value.", + pad_value.numel(), step_width)); + + const int kBlockSize = 512; + + /* At least use 32 threads to copy sequence_width elements, + * and at least 8 elements for each thread. + */ + size_t block_dim_x = + std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); + size_t block_dim_y = kBlockSize / block_dim_x; + dim3 threads(block_dim_x, block_dim_y); + + size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y; + size_t grid_dim_y = seq_num; + dim3 grid(grid_dim_x, grid_dim_y); + + const T* seq_data = seq_tensor.data(); + T* pad_data = pad_tensor->data(); + const T* pad_value_data = pad_value.data(); + + paddle::framework::MixVector mix_vector_seq_offsets(&seq_offsets); + SequencePaddingKernel<<>>( + pad_data, seq_data, pad_value_data, pad_value.numel() == 1, + mix_vector_seq_offsets.CUDAData(context.GetPlace()), seq_num, + pad_seq_len, step_width, norm_by_times, layout); + } +}; + template class UnpaddingLoDTensorFunctor { public: @@ -166,6 +230,60 @@ class UnpaddingLoDTensorFunctor { } }; +template +class UnpaddingLoDTensorFunctor { + public: + void operator()(const phi::GPUContext& context, + const framework::LoDTensor& pad_tensor, + framework::LoDTensor* seq_tensor, int pad_seq_len = -1, + int lod_level = 0, bool norm_by_times = false, + const PadLayout layout = kBatchLengthWidth) { + auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level]; + const auto& seq_tensor_dims = seq_tensor->dims(); + const auto& pad_tensor_dims = pad_tensor.dims(); + int max_seq_len = MaximumSequenceLength(seq_offsets); + if (pad_seq_len == -1) { + pad_seq_len = max_seq_len; + } + int step_width = seq_tensor->numel() / seq_tensor_dims[0]; + int seq_num = seq_offsets.size() - 1; + + CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, + step_width, layout); + /* + if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len) { + paddle::framework::TensorCopy(pad_tensor, context.GetPlace(), context, + seq_tensor); + seq_tensor->Resize(seq_tensor_dims); + return; + } + */ + + const int kBlockSize = 512; + + /* At least use 32 threads to copy sequence_width elements, + * and at least 8 elements for each thread. + */ + size_t block_dim_x = + std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize); + size_t block_dim_y = kBlockSize / block_dim_x; + dim3 threads(block_dim_x, block_dim_y); + + size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y; + size_t grid_dim_y = seq_num; + dim3 grid(grid_dim_x, grid_dim_y); + + const T* pad_data = pad_tensor.data(); + T* seq_data = seq_tensor->data(); + + paddle::framework::MixVector mixv_seq_offsets(&seq_offsets); + SequencePaddingKernel<<>>( + seq_data, pad_data, nullptr, false, + mixv_seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len, + step_width, norm_by_times, layout); + } +}; + template class PaddingLoDTensorFunctor; template class PaddingLoDTensorFunctor; template class PaddingLoDTensorFunctor; @@ -176,6 +294,16 @@ template class UnpaddingLoDTensorFunctor; template class UnpaddingLoDTensorFunctor; template class UnpaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; + +template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/sequence_scale.cc b/paddle/fluid/operators/math/sequence_scale.cc index a27792bb7802836dc0a30a75b102702857e6b602..bc8832a1bbc56b57e111496b8dc5364e1cedf218 100644 --- a/paddle/fluid/operators/math/sequence_scale.cc +++ b/paddle/fluid/operators/math/sequence_scale.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/sequence_scale.h" +#include "paddle/phi/backends/cpu/cpu_context.h" namespace phi { class DenseTensor; @@ -43,9 +44,33 @@ class ScaleLoDTensorFunctor { } }; +template +class ScaleLoDTensorFunctor { + public: + void operator()(const phi::CPUContext& context, const T* scales, + framework::LoDTensor* seq) { + const size_t level = 0; + auto lod = seq->lod(); + const size_t num_seq = lod[level].size() - 1; + size_t seq_width = seq->dims()[1]; + framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); + + T* seq_data = seq->mutable_data(context.GetPlace()); + for (size_t i = 0; i < num_seq; ++i) { + for (size_t j = lod[level][i] * seq_width; + j < lod[level][i + 1] * seq_width; ++j) { + seq_data[j] *= scales[i]; + } + } + } +}; + template class ScaleLoDTensorFunctor; template class ScaleLoDTensorFunctor; +template class ScaleLoDTensorFunctor; +template class ScaleLoDTensorFunctor; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/sequence_scale.cu b/paddle/fluid/operators/math/sequence_scale.cu index 8e02d1b70ff83b3641d498567a236ffcb41bb988..253a67c2c8cbe5788471f52e233bc9256f973353 100644 --- a/paddle/fluid/operators/math/sequence_scale.cu +++ b/paddle/fluid/operators/math/sequence_scale.cu @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/sequence_scale.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" namespace paddle { namespace operators { @@ -61,9 +62,41 @@ class ScaleLoDTensorFunctor { } }; +template +class ScaleLoDTensorFunctor { + public: + void operator()(const phi::GPUContext& context, const T* scales, + framework::LoDTensor* seq) { + const size_t level = 0; + auto lod = seq->lod(); + const size_t num_seq = lod[level].size() - 1; + const size_t seq_width = seq->numel() / seq->dims()[0]; + auto abs_offset_lod = framework::ToAbsOffset(lod); + T* seq_data = seq->mutable_data(context.GetPlace()); + paddle::framework::MixVector mix_vector(&(abs_offset_lod[level])); + +#ifdef PADDLE_WITH_HIP + hipLaunchKernelGGL( + HIP_KERNEL_NAME(SequenceScaleKernel), + dim3(num_seq), dim3(PADDLE_CUDA_NUM_THREADS), 0, context.stream(), + seq_data, mix_vector.CUDAMutableData(context.GetPlace()), scales, + seq_width); +#else + SequenceScaleKernel<<< + num_seq, PADDLE_CUDA_NUM_THREADS, 0, context.stream()>>>( + seq_data, mix_vector.CUDAMutableData(context.GetPlace()), scales, + seq_width); +#endif + mix_vector.CopyToCPU(); + } +}; + template class ScaleLoDTensorFunctor; template class ScaleLoDTensorFunctor; +template class ScaleLoDTensorFunctor; +template class ScaleLoDTensorFunctor; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/warpctc_op.cc b/paddle/fluid/operators/warpctc_op.cc index 9bb6ccd03df4bbfed372564917e1afc5836442f1..5cd9feee82895d032b91333d27e65401b6edf0ad 100644 --- a/paddle/fluid/operators/warpctc_op.cc +++ b/paddle/fluid/operators/warpctc_op.cc @@ -12,11 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/warpctc_op.h" - #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" namespace paddle { namespace operators { @@ -25,40 +27,6 @@ class WarpCTCOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Logits"), "Input", "Logits", "WarpCTC"); - OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "WarpCTC"); - OP_INOUT_CHECK(ctx->HasOutput("WarpCTCGrad"), "Output", "WarpCTCGrad", - "WarpCTC"); - OP_INOUT_CHECK(ctx->HasOutput("Loss"), "Output", "Loss", "WarpCTC"); - - auto logits_dims = ctx->GetInputDim("Logits"); - int blank = ctx->Attrs().Get("blank"); - int sequence_width = 0; - - if (ctx->HasInput("LogitsLength")) { - sequence_width = logits_dims[2]; - } else { - sequence_width = - static_cast(phi::product(logits_dims) / logits_dims[0]); - } - - PADDLE_ENFORCE_GE( - blank, 0, platform::errors::InvalidArgument( - "The value of Attr(blank) should be in interval [0, %d), " - "but received %d", - blank)); - PADDLE_ENFORCE_LT( - blank, sequence_width, - platform::errors::InvalidArgument( - "The value of Attr(blank) should be in interval [0, %d), " - "but received %d", - blank)); - - // TODO(liuyiqun): it is tricky to set the wrong dimension here. - ctx->SetOutputDim("Loss", {-1, 1}); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -189,15 +157,11 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(WarpCTCGradOpNoNeedBufferVarInferer, } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(warpctc, WarpctcInferShapeFunctor, + PD_INFER_META(phi::WarpctcInferMeta)); REGISTER_OPERATOR(warpctc, ops::WarpCTCOp, ops::WarpCTCOpMaker, ops::WarpCTCGradOpMaker, - ops::WarpCTCGradOpMaker); + ops::WarpCTCGradOpMaker, + WarpctcInferShapeFunctor); REGISTER_OPERATOR(warpctc_grad, ops::WarpCTCGradOp, ops::WarpCTCGradOpNoNeedBufferVarInferer); -REGISTER_OP_CPU_KERNEL( - warpctc, ops::WarpCTCKernel, - ops::WarpCTCKernel); -REGISTER_OP_CPU_KERNEL( - warpctc_grad, - ops::WarpCTCGradKernel, - ops::WarpCTCGradKernel); diff --git a/paddle/fluid/operators/warpctc_op.cu b/paddle/fluid/operators/warpctc_op.cu deleted file mode 100644 index fd820805e4d08ad34bf4fceaa2ac586a64fd677c..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/warpctc_op.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/warpctc_op.h" - -namespace ops = paddle::operators; -// register forward and backward of CUDA OP must in same *.cu file. -// Eigen can be used on GPU device, but must be in *.cu file not *.cu.cc file. -// *.cu.cc also using GCC compiler. *.cu using NVCC compiler -REGISTER_OP_CUDA_KERNEL( - warpctc, ops::WarpCTCKernel, - ops::WarpCTCKernel); -REGISTER_OP_CUDA_KERNEL( - warpctc_grad, - ops::WarpCTCGradKernel, - ops::WarpCTCGradKernel); \ No newline at end of file diff --git a/paddle/fluid/operators/warpctc_op.h b/paddle/fluid/operators/warpctc_op.h deleted file mode 100644 index a6199f090095755f2d1f96099ddffc684c83ad4b..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/warpctc_op.h +++ /dev/null @@ -1,448 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/sequence_padding.h" -#include "paddle/fluid/operators/math/sequence_scale.h" -#include "paddle/fluid/platform/dynload/warpctc.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; - -template -class ComputeCtcLossFunctor { - public: - ctcStatus_t operator()(const T* const activations, T* gradients, - const int* const flat_labels, - const int* const label_lengths, - const int* const input_lengths, int alphabet_size, - int minibatch, T* costs, void* workspace, - ctcOptions options) { - return CTC_STATUS_EXECUTION_FAILED; - } -}; - -template -class ComputeCtcLossFunctor { - public: - ctcStatus_t operator()(const float* const activations, float* gradients, - const int* const flat_labels, - const int* const label_lengths, - const int* const input_lengths, int alphabet_size, - int minibatch, float* costs, void* workspace, - ctcOptions options) { - return platform::dynload::compute_ctc_loss( - activations, gradients, flat_labels, label_lengths, input_lengths, - static_cast(alphabet_size), static_cast(minibatch), costs, - workspace, options); - } -}; - -template -class ComputeCtcLossFunctor { - public: - ctcStatus_t operator()(const double* const activations, double* gradients, - const int* const flat_labels, - const int* const label_lengths, - const int* const input_lengths, int alphabet_size, - int minibatch, double* costs, void* workspace, - ctcOptions options) { - return platform::dynload::compute_ctc_loss_double( - activations, gradients, flat_labels, label_lengths, input_lengths, - static_cast(alphabet_size), static_cast(minibatch), costs, - workspace, options); - } -}; - -template -class WarpCTCFunctor { - public: - /* - * \brief Compute the connectionist temporal classification loss, - * and optionally compute the gradient with respect to the inputs. - * - * If gradient is nullptr, it only computes the ctc loss, - * or computes both ctc loss and gradient. - * - * \param ctx execution context of this functor - * \param input batch matrix of input probabilities, in - * max_sequence_length x num_sequences x - * sequence_width, (row-major) format - * \param gradient batch matrix of gradient, with the same shape as - * input. - * \param cpu_labels labels always in CPU memory. - * \param cpu_label_lengths length of all labels in CPU memory. - * \param cpu_input_lengths length of all sequences in CPU memory. - * \param sequence_width number of possible output symbols. - * \param num_sequences number of sequence. - * \param blank blank label used in ctc loss function. - * \param cpu_losss cost of each sequence in CPU memory. - */ - void operator()(const framework::ExecutionContext& ctx, const T* input, - T* gradient, const int* cpu_labels, - const int* cpu_label_lengths, const int* cpu_input_lengths, - const size_t sequence_width, const size_t num_sequences, - const size_t blank, T* cpu_loss) { - // Init warp-ctc options - init(ctx, blank); - - // Compute the required workspace size. - // There is no memory allocated operations within warp-ctc. - size_t workspace_bytes = 0; - ctcStatus_t status = CTC_STATUS_UNKNOWN_ERROR; - if (sizeof(T) == 4) { - status = platform::dynload::get_workspace_size( - cpu_label_lengths, cpu_input_lengths, - static_cast(sequence_width), static_cast(num_sequences), - options_, &workspace_bytes); - } else { - status = platform::dynload::get_workspace_size_double( - cpu_label_lengths, cpu_input_lengths, - static_cast(sequence_width), static_cast(num_sequences), - options_, &workspace_bytes); - } - PADDLE_ENFORCE_EQ( - CTC_STATUS_SUCCESS, status, - platform::errors::PreconditionNotMet( - "warp-ctc [version %d] Error in get_workspace_size: %s", - warpctc_version_, platform::dynload::ctcGetStatusString(status))); - PADDLE_ENFORCE_GT( - workspace_bytes, 0UL, - platform::errors::InvalidArgument( - "Bytes of workspace got by warp-ctc function, " - "get_workspace_size() should be larger than 0, but received %d", - workspace_bytes)); - - auto& dev_ctx = ctx.template device_context(); - size_t workspace_elements = workspace_bytes / sizeof(T) + 1UL; - Tensor workspace = ctx.AllocateTmpTensor( - phi::make_ddim({static_cast(workspace_elements)}), dev_ctx); - T* workspace_data = workspace.data(); - phi::funcs::SetConstant()( - ctx.template device_context(), &workspace, - static_cast(0)); - - // compute loss and gradient - status = ComputeCtcLossFunctor()( - input, gradient, cpu_labels, cpu_label_lengths, cpu_input_lengths, - static_cast(sequence_width), static_cast(num_sequences), - cpu_loss, workspace_data, options_); - - PADDLE_ENFORCE_EQ( - CTC_STATUS_SUCCESS, status, - platform::errors::PreconditionNotMet( - "warp-ctc [version %d] Error in get_workspace_size: %s", - warpctc_version_, platform::dynload::ctcGetStatusString(status))); - } - - protected: - void init(const framework::ExecutionContext& ctx, const size_t blank) { - warpctc_version_ = platform::dynload::get_warpctc_version(); - - if (platform::is_gpu_place(ctx.GetPlace())) { -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - options_.loc = CTC_GPU; - options_.stream = reinterpret_cast( - ctx.device_context()) - .stream(); -#else - PADDLE_THROW(platform::errors::PreconditionNotMet( - "[warpctc init] GPU is not enabled.")); -#endif - } else { - options_.loc = CTC_CPU; - options_.num_threads = 1; - } - - options_.blank_label = blank; - } - - private: - int warpctc_version_; - ctcOptions options_; -}; - -template -class WarpCTCKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* logits = ctx.Input("Logits"); - auto* label = ctx.Input("Label"); - auto* warpctc_grad = ctx.Output("WarpCTCGrad"); - auto* loss = ctx.Output("Loss"); - - size_t num_sequences, sequence_width, max_sequence_length; - framework::Vector logits_lod; - framework::Vector label_lod; - - if (ctx.HasInput("LogitsLength") && ctx.HasInput("LabelLength")) { - num_sequences = logits->dims()[1]; - sequence_width = logits->dims()[2]; - max_sequence_length = logits->dims()[0]; - - PADDLE_ENFORCE_GT(max_sequence_length, 0, - platform::errors::InvalidArgument( - "The first dimension of Input(Logits) should be " - "greater than zero " - "but received %d. ", - max_sequence_length)); - - PADDLE_ENFORCE_GT(num_sequences, 0, - platform::errors::InvalidArgument( - "The second dimension of Input(Logits) should be " - "greater than zero " - "but received %d. ", - num_sequences)); - - PADDLE_ENFORCE_GT(sequence_width, 0, - platform::errors::InvalidArgument( - "The third dimension of Input(Logits) should be " - "greater than zero " - "but received %d. ", - sequence_width)); - - auto* logits_length = ctx.Input("LogitsLength"); - auto* labels_length = ctx.Input("LabelLength"); - framework::Tensor logits_length_cpu; - framework::Tensor labels_length_cpu; - framework::TensorCopy(*logits_length, platform::CPUPlace(), - &logits_length_cpu); - framework::TensorCopy(*labels_length, platform::CPUPlace(), - &labels_length_cpu); - - logits_lod.push_back(0); - label_lod.push_back(0); - for (size_t i = 0; i < num_sequences; i++) { - logits_lod.push_back(logits_lod[i] + - logits_length_cpu.data()[i]); - label_lod.push_back(label_lod[i] + - labels_length_cpu.data()[i]); - } - } else { - PADDLE_ENFORCE_GT(logits->NumLevels(), 0UL, - platform::errors::InvalidArgument( - "Input(Logits) Tensor of WarpCTC " - "does not contain LoD information.")); - PADDLE_ENFORCE_GT(label->NumLevels(), 0UL, - platform::errors::InvalidArgument( - "Input(Label) Tensor of WarpCTC " - "does not contain LoD information.")); - - logits_lod = framework::ToAbsOffset(logits->lod())[0]; - auto logits_dims = logits->dims(); - - PADDLE_ENFORCE_GT(logits_dims[0], 0, - platform::errors::InvalidArgument( - "The first dimension of Input(Logits) should be " - "greater than zero " - "but received %d. ", - logits_dims[0])); - - PADDLE_ENFORCE_EQ( - logits_dims[0], static_cast(logits_lod.back()), - platform::errors::InvalidArgument( - "The first dimension of Input(Logits) should be equal to " - "the sum of all sequences' lengths = %d., but received %d. ", - static_cast(logits_lod.back()), logits_dims[0])); - - label_lod = framework::ToAbsOffset(label->lod())[0]; - auto label_dims = label->dims(); - PADDLE_ENFORCE_EQ(label_dims[1], 1, - platform::errors::InvalidArgument( - "The last dimension of Input(Label) should be 1, " - "but received %d", - label_dims[1])); - - num_sequences = logits_lod.size() - 1; - PADDLE_ENFORCE_EQ( - num_sequences, label_lod.size() - 1, - platform::errors::InvalidArgument( - "The number of sequences of Input(Logits) should be " - "equal to that of Input(Label) = %d, but received %d", - label_lod.size() - 1, num_sequences)); - - sequence_width = logits->numel() / logits_dims[0]; - max_sequence_length = math::MaximumSequenceLength(logits_lod); - } - - auto loss_dims = phi::make_ddim({static_cast(num_sequences), 1}); - - // warpctc needs sequences data stored in transposed padding format - LoDTensor warpctc_logits; - auto warpctc_logits_dims = - phi::make_ddim({static_cast(max_sequence_length), - static_cast(num_sequences), - static_cast(sequence_width)}); - auto& dev_ctx = ctx.template device_context(); - Tensor warpctc_logits_tmp = - ctx.AllocateTmpTensor(warpctc_logits_dims, dev_ctx); - warpctc_logits.ShareDataWith(warpctc_logits_tmp); - if (ctx.HasInput("LogitsLength")) { - paddle::framework::TensorCopySync(*logits, ctx.GetPlace(), - &warpctc_logits); - } else { - LoDTensor cpu_pad_value; - T* pad_value_data = - cpu_pad_value.mutable_data({1}, platform::CPUPlace()); - *pad_value_data = static_cast(0); - LoDTensor pad_value; - if (platform::is_cpu_place(ctx.GetPlace())) { - pad_value = cpu_pad_value; - } else { - paddle::framework::TensorCopySync(cpu_pad_value, ctx.GetPlace(), - &pad_value); - } - - math::PaddingLoDTensorFunctor()( - ctx.template device_context(), *logits, - &warpctc_logits, pad_value, -1, 0, false /* norm_by_times */, - math::kLengthBatchWidth); - } - const T* warpctc_logits_data = warpctc_logits.data(); - - std::vector warpctc_label_lengths(num_sequences); - std::vector warpctc_logits_lengths(num_sequences); - - for (size_t i = 0; i < num_sequences; ++i) { - warpctc_label_lengths[i] = label_lod[i + 1] - label_lod[i]; - warpctc_logits_lengths[i] = logits_lod[i + 1] - logits_lod[i]; - } - - // warpctc computes loss and gradient in one call, gradient data also stored - // in batch format - T* warpctc_grad_data = - warpctc_grad->mutable_data(warpctc_logits.dims(), ctx.GetPlace()); - - phi::funcs::SetConstant()( - ctx.template device_context(), warpctc_grad, - static_cast(0)); - - // warpctc accesses labels in CPU memory - LoDTensor warpctc_label; - if (ctx.HasInput("LogitsLength")) { - warpctc_label.mutable_data( - {static_cast(math::TotalSequenceLength(label_lod)), 1}, - platform::CPUPlace()); - std::vector> lod; - lod.push_back(label_lod); - warpctc_label.set_lod(lod); - - if (platform::is_cpu_place(ctx.GetPlace())) { - math::UnpaddingLoDTensorFunctor()( - ctx.template device_context(), *label, - &warpctc_label, label->dims()[1] /*pad_seq_len*/, 0 /*lod_level*/, - false /*norm_by_times*/, math::kBatchLengthWidth); - } else { - LoDTensor gpu_label; - gpu_label.mutable_data( - {static_cast(math::TotalSequenceLength(label_lod)), 1}, - ctx.GetPlace()); - gpu_label.set_lod(lod); - math::UnpaddingLoDTensorFunctor()( - ctx.template device_context(), *label, &gpu_label, - label->dims()[1] /*pad_seq_len*/, 0 /*lod_level*/, - false /*norm_by_times*/, math::kBatchLengthWidth); - paddle::framework::TensorCopySync(gpu_label, platform::CPUPlace(), - &warpctc_label); - } - } else { - paddle::framework::TensorCopySync(*label, platform::CPUPlace(), - &warpctc_label); - } - - const int* warpctc_label_data = warpctc_label.data(); - // warpctc stores loss in CPU memory - Tensor warpctc_loss; - T* warpctc_loss_data = - warpctc_loss.mutable_data(loss_dims, platform::CPUPlace()); - - const size_t blank = static_cast(ctx.Attr("blank")); - - WarpCTCFunctor()( - ctx, warpctc_logits_data, warpctc_grad_data, warpctc_label_data, - warpctc_label_lengths.data(), warpctc_logits_lengths.data(), - sequence_width, num_sequences, blank, warpctc_loss_data); - - // Copy the loss back - paddle::framework::TensorCopy(warpctc_loss, ctx.GetPlace(), - ctx.device_context(), loss); - } -}; - -template -class WarpCTCGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* warpctc_grad = ctx.Input("WarpCTCGrad"); - auto* logits_grad = ctx.Output(framework::GradVarName("Logits")); - const Tensor* loss_grad = ctx.Input(framework::GradVarName("Loss")); - - logits_grad->mutable_data(ctx.GetPlace()); - bool norm_by_times = ctx.Attr("norm_by_times"); - - if (ctx.HasInput("LogitsLength")) { - int max_seq_length = warpctc_grad->dims()[0]; // Tmax - int num_sequences = warpctc_grad->dims()[1]; // B - int seq_width = warpctc_grad->dims()[2]; // D - - auto* logits_length = ctx.Input("LogitsLength"); - // B - auto logits_len_e = - framework::EigenTensor::From(*logits_length); - // (B, 1) - auto loss_grad_e = framework::EigenTensor::From(*loss_grad); - // (T, B, D) - auto warpctc_grad_e = framework::EigenTensor::From(*warpctc_grad); - - auto logits_grad_e = framework::EigenTensor::From(*logits_grad); - - Eigen::DSizes grad_shape(1, num_sequences, 1); - Eigen::DSizes bcast(max_seq_length, 1, seq_width); - auto logits_g = warpctc_grad_e * - loss_grad_e.reshape(grad_shape).broadcast(bcast).eval(); - - auto* place = ctx.template device_context().eigen_device(); - if (norm_by_times) { - auto scales = logits_len_e.cast() - .inverse() - .reshape(grad_shape) - .broadcast(bcast) - .eval(); - logits_grad_e.device(*place) = logits_g * scales; - } else { - logits_grad_e.device(*place) = logits_g; - } - } else { - math::UnpaddingLoDTensorFunctor()( - ctx.template device_context(), *warpctc_grad, - logits_grad, -1, 0, norm_by_times, math::kLengthBatchWidth); - - const T* loss_grad_data = loss_grad->data(); - math::ScaleLoDTensorFunctor()( - ctx.template device_context(), loss_grad_data, - logits_grad); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 483c7e2f70285a71480c3511545bf3dab50564f4..bbf278feb4de5731c9826ccb288cac93a9eb971b 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1082,6 +1082,43 @@ void PsroiPoolInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void WarpctcInferMeta(const MetaTensor& logits, + const MetaTensor& label, + const paddle::optional logits_length, + const paddle::optional labels_length, + int blank, + bool norm_by_times, + MetaTensor* warpctc_grad, + MetaTensor* loss) { + auto logits_dims = logits.dims(); + int sequence_width = 0; + + if (logits_length.is_initialized()) { + sequence_width = logits_dims[2]; + } else { + sequence_width = + static_cast(phi::product(logits_dims) / logits_dims[0]); + } + + PADDLE_ENFORCE_GE( + blank, + 0, + errors::InvalidArgument( + "The value of Attr(blank) should be in interval [0, %d), " + "but received %d", + blank)); + PADDLE_ENFORCE_LT( + blank, + sequence_width, + errors::InvalidArgument( + "The value of Attr(blank) should be in interval [0, %d), " + "but received %d", + blank)); + + loss->set_dims({-1, 1}); + loss->set_dtype(logits.dtype()); +} + void WhereInferMeta(const MetaTensor& condition, const MetaTensor& x, const MetaTensor& y, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index e0eed1e197756b68ec8b6db697debe1cfab0adba..53eefe8587487819fe0738b3eeb7350811e82cbb 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -214,6 +214,15 @@ void PsroiPoolInferMeta(const MetaTensor& x, float spatial_scale, MetaTensor* out); +void WarpctcInferMeta(const MetaTensor& logits, + const MetaTensor& label, + const paddle::optional logits_length, + const paddle::optional labels_length, + int blank, + bool norm_by_times, + MetaTensor* warpctc_grad, + MetaTensor* loss); + void WhereInferMeta(const MetaTensor& condition, const MetaTensor& x, const MetaTensor& y, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index fc3529f1a44f157edbc12c1b27c1f0e6a4f632af..b652b1c5fae94a6215040c765407d0c9ee8ab9b7 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -32,7 +32,7 @@ set(MANUAL_BUILD_KERNELS adam_kernel adamw_kernel deformable_conv_kernel deforma matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel - triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel) + triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel warpctc_kernel warpctc_grad_kernel) kernel_library(adam_kernel DEPS gflags glog flags ${COMMON_KERNEL_DEPS} selected_rows_functor threadpool jit_kernel_helper) kernel_library(adamw_kernel DEPS ${COMMON_KERNEL_DEPS} adam_kernel) kernel_library(deformable_conv_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor) @@ -58,6 +58,8 @@ kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_ kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel) kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_reduce) kernel_library(determinant_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse) +kernel_library(warpctc_kernel DEPS ${COMMON_KERNEL_DEPS} phi_dynload_warpctc sequence_padding sequence_scale) +kernel_library(warpctc_grad_kernel DEPS ${COMMON_KERNEL_DEPS} sequence_padding sequence_scale) # 4. auto parse and build kernel targets by cmake register_kernels(EXCLUDES ${COMMON_BAISC_KERNELS} ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS} ${COMMON_BAISC_KERNELS} ) diff --git a/paddle/phi/kernels/cpu/warpctc_grad_kernel.cc b/paddle/phi/kernels/cpu/warpctc_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..0b293363354818aefa327efa1c2358cb106b788e --- /dev/null +++ b/paddle/phi/kernels/cpu/warpctc_grad_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/warpctc_grad_kernel.h" +#include "paddle/phi/kernels/impl/warpctc_grad_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL( + warpctc_grad, CPU, ALL_LAYOUT, phi::WarpctcGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/warpctc_kernel.cc b/paddle/phi/kernels/cpu/warpctc_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..4b87202c11e926283ea6da8c3751dae228ee5952 --- /dev/null +++ b/paddle/phi/kernels/cpu/warpctc_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/warpctc_kernel.h" +#include "paddle/phi/kernels/impl/warpctc_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL( + warpctc, CPU, ALL_LAYOUT, phi::WarpctcKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/warpctc_grad_kernel.cu b/paddle/phi/kernels/gpu/warpctc_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..612b03555c6f16956bc350d19b99eca5f79bde32 --- /dev/null +++ b/paddle/phi/kernels/gpu/warpctc_grad_kernel.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/warpctc_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/warpctc_grad_kernel_impl.h" + +PD_REGISTER_KERNEL( + warpctc_grad, GPU, ALL_LAYOUT, phi::WarpctcGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/warpctc_kernel.cu b/paddle/phi/kernels/gpu/warpctc_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..3379322f3dfd844c470046dc7c04f3d0624fd01f --- /dev/null +++ b/paddle/phi/kernels/gpu/warpctc_kernel.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/warpctc_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/warpctc_kernel_impl.h" + +PD_REGISTER_KERNEL( + warpctc, GPU, ALL_LAYOUT, phi::WarpctcKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/warpctc_grad_kernel_impl.h b/paddle/phi/kernels/impl/warpctc_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..b788c966a1af12cec0a6f482a43b894f919792de --- /dev/null +++ b/paddle/phi/kernels/impl/warpctc_grad_kernel_impl.h @@ -0,0 +1,87 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/operators/math/sequence_padding.h" +#include "paddle/fluid/operators/math/sequence_scale.h" +#include "paddle/phi/backends/dynload/warpctc.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +void WarpctcGradKernel(const Context& dev_ctx, + const DenseTensor& warpctc_grad, + const DenseTensor& logits, + const DenseTensor& loss_grad, + const paddle::optional logits_length, + int blank, + bool norm_by_times, + DenseTensor* logits_grad) { + dev_ctx.template Alloc(logits_grad); + + if (logits_length.is_initialized()) { + int max_seq_length = warpctc_grad.dims()[0]; // Tmax + int num_sequences = warpctc_grad.dims()[1]; // B + int seq_width = warpctc_grad.dims()[2]; // D + + // B + auto logits_len_e = EigenTensor::From(*logits_length); + // (B, 1) + auto loss_grad_e = EigenTensor::From(loss_grad); + // (T, B, D) + auto warpctc_grad_e = EigenTensor::From(warpctc_grad); + + auto logits_grad_e = EigenTensor::From(*logits_grad); + + Eigen::DSizes grad_shape(1, num_sequences, 1); + Eigen::DSizes bcast(max_seq_length, 1, seq_width); + auto logits_g = warpctc_grad_e * + loss_grad_e.reshape(grad_shape).broadcast(bcast).eval(); + + auto* place = dev_ctx.eigen_device(); + if (norm_by_times) { + auto scales = logits_len_e.cast() + .inverse() + .reshape(grad_shape) + .broadcast(bcast) + .eval(); + logits_grad_e.device(*place) = logits_g * scales; + } else { + logits_grad_e.device(*place) = logits_g; + } + } else { + paddle::operators::math::UnpaddingLoDTensorFunctor()( + dev_ctx, + warpctc_grad, + logits_grad, + -1, + 0, + norm_by_times, + paddle::operators::math::kLengthBatchWidth); + + const T* loss_grad_data = loss_grad.data(); + paddle::operators::math::ScaleLoDTensorFunctor()( + dev_ctx, loss_grad_data, logits_grad); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/warpctc_kernel_impl.h b/paddle/phi/kernels/impl/warpctc_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..8a18f2500a512f19874bfb827aab63c3539b37b5 --- /dev/null +++ b/paddle/phi/kernels/impl/warpctc_kernel_impl.h @@ -0,0 +1,454 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/operators/math/sequence_padding.h" +#include "paddle/fluid/operators/math/sequence_scale.h" +#include "paddle/phi/backends/dynload/warpctc.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +class ComputeCtcLossFunctor { + public: + ctcStatus_t operator()(const T* const activations, + T* gradients, + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths, + int alphabet_size, + int minibatch, + T* costs, + void* workspace, + ctcOptions options) { + return CTC_STATUS_EXECUTION_FAILED; + } +}; + +template +class ComputeCtcLossFunctor { + public: + ctcStatus_t operator()(const float* const activations, + float* gradients, + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths, + int alphabet_size, + int minibatch, + float* costs, + void* workspace, + ctcOptions options) { + return phi::dynload::compute_ctc_loss(activations, + gradients, + flat_labels, + label_lengths, + input_lengths, + static_cast(alphabet_size), + static_cast(minibatch), + costs, + workspace, + options); + } +}; + +template +class ComputeCtcLossFunctor { + public: + ctcStatus_t operator()(const double* const activations, + double* gradients, + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths, + int alphabet_size, + int minibatch, + double* costs, + void* workspace, + ctcOptions options) { + return phi::dynload::compute_ctc_loss_double( + activations, + gradients, + flat_labels, + label_lengths, + input_lengths, + static_cast(alphabet_size), + static_cast(minibatch), + costs, + workspace, + options); + } +}; + +template +class WarpCTCFunctor { + public: + /* + * \brief Compute the connectionist temporal classification loss, + * and optionally compute the gradient with respect to the inputs. + * + * If gradient is nullptr, it only computes the ctc loss, + * or computes both ctc loss and gradient. + * + * \param ctx execution context of this functor + * \param input batch matrix of input probabilities, in + * max_sequence_length x num_sequences x + * sequence_width, (row-major) format + * \param gradient batch matrix of gradient, with the same shape as + * input. + * \param cpu_labels labels always in CPU memory. + * \param cpu_label_lengths length of all labels in CPU memory. + * \param cpu_input_lengths length of all sequences in CPU memory. + * \param sequence_width number of possible output symbols. + * \param num_sequences number of sequence. + * \param blank blank label used in ctc loss function. + * \param cpu_losss cost of each sequence in CPU memory. + */ + void operator()(const Context& dev_ctx, + const T* input, + T* gradient, + const int* cpu_labels, + const int* cpu_label_lengths, + const int* cpu_input_lengths, + const size_t sequence_width, + const size_t num_sequences, + const size_t blank, + T* cpu_loss) { + // Init warp-ctc options + init(dev_ctx, blank); + + // Compute the required workspace size. + // There is no memory allocated operations within warp-ctc. + size_t workspace_bytes = 0; + ctcStatus_t status = CTC_STATUS_UNKNOWN_ERROR; + if (sizeof(T) == 4) { + status = + phi::dynload::get_workspace_size(cpu_label_lengths, + cpu_input_lengths, + static_cast(sequence_width), + static_cast(num_sequences), + options_, + &workspace_bytes); + } else { + status = phi::dynload::get_workspace_size_double( + cpu_label_lengths, + cpu_input_lengths, + static_cast(sequence_width), + static_cast(num_sequences), + options_, + &workspace_bytes); + } + PADDLE_ENFORCE_EQ( + CTC_STATUS_SUCCESS, + status, + errors::PreconditionNotMet( + "warp-ctc [version %d] Error in get_workspace_size: %s", + warpctc_version_, + phi::dynload::ctcGetStatusString(status))); + PADDLE_ENFORCE_GT( + workspace_bytes, + 0UL, + errors::InvalidArgument( + "Bytes of workspace got by warp-ctc function, " + "get_workspace_size() should be larger than 0, but received %d", + workspace_bytes)); + + size_t workspace_elements = workspace_bytes / sizeof(T) + 1UL; + DenseTensor workspace = phi::Empty( + dev_ctx, {static_cast(workspace_elements)}); + T* workspace_data = workspace.data(); + phi::funcs::SetConstant()( + dev_ctx, &workspace, static_cast(0)); + + // compute loss and gradient + status = + ComputeCtcLossFunctor()(input, + gradient, + cpu_labels, + cpu_label_lengths, + cpu_input_lengths, + static_cast(sequence_width), + static_cast(num_sequences), + cpu_loss, + workspace_data, + options_); + + PADDLE_ENFORCE_EQ( + CTC_STATUS_SUCCESS, + status, + errors::PreconditionNotMet( + "warp-ctc [version %d] Error in get_workspace_size: %s", + warpctc_version_, + phi::dynload::ctcGetStatusString(status))); + } + + protected: + void init(const Context& dev_ctx, const size_t blank) { + warpctc_version_ = phi::dynload::get_warpctc_version(); + + if (dev_ctx.GetPlace() == phi::GPUPlace()) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + options_.loc = CTC_GPU; + options_.stream = + reinterpret_cast(dev_ctx).stream(); +#else + PADDLE_THROW( + errors::PreconditionNotMet("[warpctc init] GPU is not enabled.")); +#endif + } else { + options_.loc = CTC_CPU; + options_.num_threads = 1; + } + + options_.blank_label = blank; + } + + private: + int warpctc_version_; + ctcOptions options_; +}; + +template +void WarpctcKernel(const Context& dev_ctx, + const DenseTensor& logits, + const DenseTensor& label, + const paddle::optional logits_length, + const paddle::optional labels_length, + int blank, + bool norm_by_times, + DenseTensor* warpctc_grad, + DenseTensor* loss) { + size_t num_sequences, sequence_width, max_sequence_length; + paddle::framework::Vector logits_lod; + paddle::framework::Vector label_lod; + if (logits_length.is_initialized() && labels_length.is_initialized()) { + num_sequences = logits.dims()[1]; + sequence_width = logits.dims()[2]; + max_sequence_length = logits.dims()[0]; + + PADDLE_ENFORCE_GT(max_sequence_length, + 0, + phi::errors::InvalidArgument( + "The first dimension of Input(Logits) should be " + "greater than zero " + "but received %d. ", + max_sequence_length)); + + PADDLE_ENFORCE_GT(num_sequences, + 0, + phi::errors::InvalidArgument( + "The second dimension of Input(Logits) should be " + "greater than zero " + "but received %d. ", + num_sequences)); + + PADDLE_ENFORCE_GT(sequence_width, + 0, + phi::errors::InvalidArgument( + "The third dimension of Input(Logits) should be " + "greater than zero " + "but received %d. ", + sequence_width)); + + DenseTensor logits_length_cpu; + DenseTensor labels_length_cpu; + phi::Copy( + dev_ctx, *logits_length, phi::CPUPlace(), false, &logits_length_cpu); + phi::Copy( + dev_ctx, *labels_length, phi::CPUPlace(), false, &labels_length_cpu); + + logits_lod.push_back(0); + label_lod.push_back(0); + for (size_t i = 0; i < num_sequences; i++) { + logits_lod.push_back(logits_lod[i] + + logits_length_cpu.data()[i]); + label_lod.push_back(label_lod[i] + labels_length_cpu.data()[i]); + } + } else { + PADDLE_ENFORCE_GT( + logits.NumLevels(), + 0UL, + phi::errors::InvalidArgument("Input(Logits) Tensor of WarpCTC " + "does not contain LoD information.")); + PADDLE_ENFORCE_GT( + label.NumLevels(), + 0UL, + phi::errors::InvalidArgument("Input(Label) Tensor of WarpCTC " + "does not contain LoD information.")); + + logits_lod = paddle::framework::ToAbsOffset(logits.lod())[0]; + auto logits_dims = logits.dims(); + + PADDLE_ENFORCE_GT(logits_dims[0], + 0, + phi::errors::InvalidArgument( + "The first dimension of Input(Logits) should be " + "greater than zero " + "but received %d. ", + logits_dims[0])); + + PADDLE_ENFORCE_EQ( + logits_dims[0], + static_cast(logits_lod.back()), + phi::errors::InvalidArgument( + "The first dimension of Input(Logits) should be equal to " + "the sum of all sequences' lengths = %d., but received %d. ", + static_cast(logits_lod.back()), + logits_dims[0])); + + label_lod = paddle::framework::ToAbsOffset(label.lod())[0]; + auto label_dims = label.dims(); + PADDLE_ENFORCE_EQ(label_dims[1], + 1, + phi::errors::InvalidArgument( + "The last dimension of Input(Label) should be 1, " + "but received %d", + label_dims[1])); + + num_sequences = logits_lod.size() - 1; + PADDLE_ENFORCE_EQ(num_sequences, + label_lod.size() - 1, + phi::errors::InvalidArgument( + "The number of sequences of Input(Logits) should be " + "equal to that of Input(Label) = %d, but received %d", + label_lod.size() - 1, + num_sequences)); + + sequence_width = logits.numel() / logits_dims[0]; + max_sequence_length = + paddle::operators::math::MaximumSequenceLength(logits_lod); + } + + auto loss_dims = phi::make_ddim({static_cast(num_sequences), 1}); + + // warpctc needs sequences data stored in transposed padding format + DenseTensor warpctc_logits_tmp = + phi::Empty(dev_ctx, + {static_cast(max_sequence_length), + static_cast(num_sequences), + static_cast(sequence_width)}); + DenseTensor warpctc_logits(warpctc_logits_tmp); + + if (logits_length.is_initialized()) { + phi::Copy(dev_ctx, logits, dev_ctx.GetPlace(), true, &warpctc_logits); + } else { + DenseTensor cpu_pad_value; + cpu_pad_value.Resize({1}); + T* pad_value_data = dev_ctx.template HostAlloc(&cpu_pad_value); + *pad_value_data = static_cast(0); + DenseTensor pad_value; + if (dev_ctx.GetPlace() == phi::CPUPlace()) { + pad_value = cpu_pad_value; + } else { + phi::Copy(dev_ctx, cpu_pad_value, dev_ctx.GetPlace(), true, &pad_value); + } + + paddle::operators::math::PaddingLoDTensorFunctor()( + dev_ctx, + logits, + &warpctc_logits, + pad_value, + -1, + 0, + false /* norm_by_times */, + paddle::operators::math::kLengthBatchWidth); + } + + const T* warpctc_logits_data = warpctc_logits.data(); + + std::vector warpctc_label_lengths(num_sequences); + std::vector warpctc_logits_lengths(num_sequences); + + for (size_t i = 0; i < num_sequences; ++i) { + warpctc_label_lengths[i] = label_lod[i + 1] - label_lod[i]; + warpctc_logits_lengths[i] = logits_lod[i + 1] - logits_lod[i]; + } + + // warpctc computes loss and gradient in one call, gradient data also stored + // in batch format + warpctc_grad->Resize(warpctc_logits.dims()); + T* warpctc_grad_data = dev_ctx.template Alloc(warpctc_grad); + + phi::funcs::SetConstant()( + dev_ctx, warpctc_grad, static_cast(0)); + + // warpctc accesses labels in CPU memory + DenseTensor warpctc_label; + if (logits_length.is_initialized()) { + warpctc_label.Resize( + {static_cast( + paddle::operators::math::TotalSequenceLength(label_lod)), + 1}); + dev_ctx.template HostAlloc(&warpctc_label); + std::vector> lod; + lod.push_back(label_lod); + warpctc_label.set_lod(lod); + + if (dev_ctx.GetPlace() == phi::CPUPlace()) { + paddle::operators::math::UnpaddingLoDTensorFunctor()( + dev_ctx, + label, + &warpctc_label, + label.dims()[1] /*pad_seq_len*/, + 0 /*lod_level*/, + false /*norm_by_times*/, + paddle::operators::math::kBatchLengthWidth); + } else { + DenseTensor gpu_label; + gpu_label.Resize( + {static_cast( + paddle::operators::math::TotalSequenceLength(label_lod)), + 1}); + dev_ctx.template Alloc(&gpu_label); + gpu_label.set_lod(lod); + paddle::operators::math::UnpaddingLoDTensorFunctor()( + dev_ctx, + label, + &gpu_label, + label.dims()[1] /*pad_seq_len*/, + 0 /*lod_level*/, + false /*norm_by_times*/, + paddle::operators::math::kBatchLengthWidth); + phi::Copy(dev_ctx, gpu_label, phi::CPUPlace(), true, &warpctc_label); + } + } else { + phi::Copy(dev_ctx, label, phi::CPUPlace(), true, &warpctc_label); + } + + const int* warpctc_label_data = warpctc_label.data(); + // warpctc stores loss in CPU memory + DenseTensor warpctc_loss; + warpctc_loss.Resize(loss_dims); + T* warpctc_loss_data = dev_ctx.template HostAlloc(&warpctc_loss); + WarpCTCFunctor()(dev_ctx, + warpctc_logits_data, + warpctc_grad_data, + warpctc_label_data, + warpctc_label_lengths.data(), + warpctc_logits_lengths.data(), + sequence_width, + num_sequences, + blank, + warpctc_loss_data); + // Copy the loss back + phi::Copy(dev_ctx, warpctc_loss, dev_ctx.GetPlace(), false, loss); +} + +} // namespace phi diff --git a/paddle/phi/kernels/warpctc_grad_kernel.h b/paddle/phi/kernels/warpctc_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..8e1ab43324a50a575b2b292049d3f1ce029a5de9 --- /dev/null +++ b/paddle/phi/kernels/warpctc_grad_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +void WarpctcGradKernel(const Context& dev_ctx, + const DenseTensor& warpctc_grad, + const DenseTensor& logits, + const DenseTensor& loss_grad, + paddle::optional logits_length, + int blank, + bool norm_by_times, + DenseTensor* logits_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/warpctc_kernel.h b/paddle/phi/kernels/warpctc_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..4baa49064775ec15d5e104a830b3e7a82ff10819 --- /dev/null +++ b/paddle/phi/kernels/warpctc_kernel.h @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +void WarpctcKernel(const Context& dev_ctx, + const DenseTensor& logits, + const DenseTensor& label, + paddle::optional logits_length, + paddle::optional labels_length, + int blank, + bool norm_by_times, + DenseTensor* warpctc_grad, + DenseTensor* loss); + +} // namespace phi diff --git a/paddle/phi/ops/compat/warpctc_sig.cc b/paddle/phi/ops/compat/warpctc_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..75f440de7f2dbbc8bb07f940dd48d2d902950297 --- /dev/null +++ b/paddle/phi/ops/compat/warpctc_sig.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature WarpctcOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("warpctc", + {"Logits", "Label", "LogitsLength", "LabelLength"}, + {"blank", "norm_by_times"}, + {"WarpCTCGrad", "Loss"}); +} + +KernelSignature WarpctcGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "warpctc_grad", + {"WarpCTCGrad", "Logits", GradVarName("Loss"), "LogitsLength"}, + {"blank", "norm_by_times"}, + {GradVarName("Logits")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(warpctc, phi::WarpctcOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(warpctc_grad, phi::WarpctcGradOpArgumentMapping);