未验证 提交 cb183762 编写于 作者: 0 0x45f 提交者: GitHub

[Phi] Move warpctc OP to phi (#40023)

* moving OP

* move forward

* move grad and infershape

* code format

* format code

* fix code

* fix code

* fix CMakerLists.txt

* fix comments

* Refine CMakeLists for rocm ci
上级 e73857a3
......@@ -111,10 +111,10 @@ op_library(load_combine_op DEPS string_array)
if (WITH_GPU OR WITH_ROCM)
if(WITH_ROCM)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc)
# warpctc_op needs cudnn 7 above
elseif(${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc)
else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/sequence_padding.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
namespace phi {
class DenseTensor;
......@@ -136,6 +137,52 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
}
};
template <typename T>
class PaddingLoDTensorFunctor<phi::CPUContext, T> {
public:
void operator()(const phi::CPUContext& context,
const framework::LoDTensor& seq_tensor,
framework::LoDTensor* pad_tensor,
const framework::LoDTensor& pad_value, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
const PadLayout layout = kBatchLengthWidth) {
auto seq_lod = seq_tensor.lod();
const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level];
const auto& seq_tensor_dims = seq_tensor.dims();
const auto& pad_tensor_dims = pad_tensor->dims();
if (pad_seq_len == -1) {
pad_seq_len = MaximumSequenceLength(seq_offsets);
}
int step_width = seq_tensor.numel() / seq_tensor_dims[0];
CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
step_width, layout);
PADDLE_ENFORCE_EQ(
pad_value.numel() == 1 || pad_value.numel() == step_width, true,
platform::errors::InvalidArgument(
"The numel of 'pad_value' can only be 1 or be equal to the "
"'step_width', but got %ld != 1 and %ld. Please check the input "
"value.",
pad_value.numel(), step_width));
// fill padding value
T* pad_data = pad_tensor->data<T>();
const T* pad_value_data = pad_value.data<T>();
if (pad_value.numel() == 1) {
fast_mem_init<T>(pad_data, pad_tensor->numel(), pad_value_data,
sizeof(T));
} else {
for (int i = 0; i < pad_tensor->numel(); i += step_width) {
memcpy(pad_data + i, pad_value_data, step_width * sizeof(T));
}
}
CopyValidData<T>(pad_tensor, &seq_tensor, seq_offsets, pad_seq_len,
step_width, norm_by_times, kSeqToPad, layout);
}
};
template <typename T>
class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
public:
......@@ -160,6 +207,30 @@ class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
}
};
template <typename T>
class UnpaddingLoDTensorFunctor<phi::CPUContext, T> {
public:
void operator()(const phi::CPUContext& context,
const framework::LoDTensor& pad_tensor,
framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
const PadLayout layout = kBatchLengthWidth) {
auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level];
const auto& seq_tensor_dims = seq_tensor->dims();
const auto& pad_tensor_dims = pad_tensor.dims();
if (pad_seq_len == -1) {
pad_seq_len = MaximumSequenceLength(seq_offsets);
}
int step_width = seq_tensor->numel() / seq_tensor_dims[0];
CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
step_width, layout);
CopyValidData<T>(seq_tensor, &pad_tensor, seq_offsets, pad_seq_len,
step_width, norm_by_times, kPadToSeq, layout);
}
};
template class PaddingLoDTensorFunctor<platform::CPUDeviceContext, int>;
template class PaddingLoDTensorFunctor<platform::CPUDeviceContext, int64_t>;
template class PaddingLoDTensorFunctor<platform::CPUDeviceContext, float>;
......@@ -170,6 +241,16 @@ template class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, int64_t>;
template class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, float>;
template class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, double>;
template class PaddingLoDTensorFunctor<phi::CPUContext, int>;
template class PaddingLoDTensorFunctor<phi::CPUContext, int64_t>;
template class PaddingLoDTensorFunctor<phi::CPUContext, float>;
template class PaddingLoDTensorFunctor<phi::CPUContext, double>;
template class UnpaddingLoDTensorFunctor<phi::CPUContext, int>;
template class UnpaddingLoDTensorFunctor<phi::CPUContext, int64_t>;
template class UnpaddingLoDTensorFunctor<phi::CPUContext, float>;
template class UnpaddingLoDTensorFunctor<phi::CPUContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -14,6 +14,7 @@ limitations under the License. */
#include <algorithm>
#include "paddle/fluid/operators/math/sequence_padding.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace paddle {
namespace operators {
......@@ -112,6 +113,69 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
}
};
template <typename T>
class PaddingLoDTensorFunctor<phi::GPUContext, T> {
public:
void operator()(const phi::GPUContext& context,
const framework::LoDTensor& seq_tensor,
framework::LoDTensor* pad_tensor,
const framework::LoDTensor& pad_value, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
const PadLayout layout = kBatchLengthWidth) {
auto seq_lod = seq_tensor.lod();
auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level];
const auto& seq_tensor_dims = seq_tensor.dims();
const auto& pad_tensor_dims = pad_tensor->dims();
int max_seq_len = MaximumSequenceLength(seq_offsets);
if (pad_seq_len == -1) {
pad_seq_len = max_seq_len;
}
PADDLE_ENFORCE_GE(
pad_seq_len, max_seq_len,
platform::errors::InvalidArgument(
"The pad_seq_len must be equal to or greater than the "
"original max sequence length. Expected %ld >= %ld, but got %ld < "
"%ld. Please check the input value.",
pad_seq_len, max_seq_len, pad_seq_len, max_seq_len));
int step_width = seq_tensor.numel() / seq_tensor_dims[0];
int seq_num = seq_offsets.size() - 1;
CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
step_width, layout);
PADDLE_ENFORCE_EQ(
pad_value.numel() == 1 || pad_value.numel() == step_width, true,
platform::errors::InvalidArgument(
"The numel of 'pad_value' can only be 1 or be equal to "
"the 'step_width', but got %ld != 1 and %ld. Please check the "
"input value.",
pad_value.numel(), step_width));
const int kBlockSize = 512;
/* At least use 32 threads to copy sequence_width elements,
* and at least 8 elements for each thread.
*/
size_t block_dim_x =
std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
size_t block_dim_y = kBlockSize / block_dim_x;
dim3 threads(block_dim_x, block_dim_y);
size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y;
size_t grid_dim_y = seq_num;
dim3 grid(grid_dim_x, grid_dim_y);
const T* seq_data = seq_tensor.data<T>();
T* pad_data = pad_tensor->data<T>();
const T* pad_value_data = pad_value.data<T>();
paddle::framework::MixVector<size_t> mix_vector_seq_offsets(&seq_offsets);
SequencePaddingKernel<T, kSeqToPad><<<grid, threads, 0, context.stream()>>>(
pad_data, seq_data, pad_value_data, pad_value.numel() == 1,
mix_vector_seq_offsets.CUDAData(context.GetPlace()), seq_num,
pad_seq_len, step_width, norm_by_times, layout);
}
};
template <typename T>
class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
public:
......@@ -166,6 +230,60 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
}
};
template <typename T>
class UnpaddingLoDTensorFunctor<phi::GPUContext, T> {
public:
void operator()(const phi::GPUContext& context,
const framework::LoDTensor& pad_tensor,
framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
int lod_level = 0, bool norm_by_times = false,
const PadLayout layout = kBatchLengthWidth) {
auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level];
const auto& seq_tensor_dims = seq_tensor->dims();
const auto& pad_tensor_dims = pad_tensor.dims();
int max_seq_len = MaximumSequenceLength(seq_offsets);
if (pad_seq_len == -1) {
pad_seq_len = max_seq_len;
}
int step_width = seq_tensor->numel() / seq_tensor_dims[0];
int seq_num = seq_offsets.size() - 1;
CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
step_width, layout);
/*
if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len) {
paddle::framework::TensorCopy(pad_tensor, context.GetPlace(), context,
seq_tensor);
seq_tensor->Resize(seq_tensor_dims);
return;
}
*/
const int kBlockSize = 512;
/* At least use 32 threads to copy sequence_width elements,
* and at least 8 elements for each thread.
*/
size_t block_dim_x =
std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
size_t block_dim_y = kBlockSize / block_dim_x;
dim3 threads(block_dim_x, block_dim_y);
size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y;
size_t grid_dim_y = seq_num;
dim3 grid(grid_dim_x, grid_dim_y);
const T* pad_data = pad_tensor.data<T>();
T* seq_data = seq_tensor->data<T>();
paddle::framework::MixVector<size_t> mixv_seq_offsets(&seq_offsets);
SequencePaddingKernel<T, kPadToSeq><<<grid, threads, 0, context.stream()>>>(
seq_data, pad_data, nullptr, false,
mixv_seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len,
step_width, norm_by_times, layout);
}
};
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, int>;
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, int64_t>;
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
......@@ -176,6 +294,16 @@ template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, int64_t>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, double>;
template class PaddingLoDTensorFunctor<phi::GPUContext, int>;
template class PaddingLoDTensorFunctor<phi::GPUContext, int64_t>;
template class PaddingLoDTensorFunctor<phi::GPUContext, float>;
template class PaddingLoDTensorFunctor<phi::GPUContext, double>;
template class UnpaddingLoDTensorFunctor<phi::GPUContext, int>;
template class UnpaddingLoDTensorFunctor<phi::GPUContext, int64_t>;
template class UnpaddingLoDTensorFunctor<phi::GPUContext, float>;
template class UnpaddingLoDTensorFunctor<phi::GPUContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/sequence_scale.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
namespace phi {
class DenseTensor;
......@@ -43,9 +44,33 @@ class ScaleLoDTensorFunctor<platform::CPUDeviceContext, T> {
}
};
template <typename T>
class ScaleLoDTensorFunctor<phi::CPUContext, T> {
public:
void operator()(const phi::CPUContext& context, const T* scales,
framework::LoDTensor* seq) {
const size_t level = 0;
auto lod = seq->lod();
const size_t num_seq = lod[level].size() - 1;
size_t seq_width = seq->dims()[1];
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
T* seq_data = seq->mutable_data<T>(context.GetPlace());
for (size_t i = 0; i < num_seq; ++i) {
for (size_t j = lod[level][i] * seq_width;
j < lod[level][i + 1] * seq_width; ++j) {
seq_data[j] *= scales[i];
}
}
}
};
template class ScaleLoDTensorFunctor<platform::CPUDeviceContext, float>;
template class ScaleLoDTensorFunctor<platform::CPUDeviceContext, double>;
template class ScaleLoDTensorFunctor<phi::CPUContext, float>;
template class ScaleLoDTensorFunctor<phi::CPUContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/sequence_scale.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace paddle {
namespace operators {
......@@ -61,9 +62,41 @@ class ScaleLoDTensorFunctor<platform::CUDADeviceContext, T> {
}
};
template <typename T>
class ScaleLoDTensorFunctor<phi::GPUContext, T> {
public:
void operator()(const phi::GPUContext& context, const T* scales,
framework::LoDTensor* seq) {
const size_t level = 0;
auto lod = seq->lod();
const size_t num_seq = lod[level].size() - 1;
const size_t seq_width = seq->numel() / seq->dims()[0];
auto abs_offset_lod = framework::ToAbsOffset(lod);
T* seq_data = seq->mutable_data<T>(context.GetPlace());
paddle::framework::MixVector<size_t> mix_vector(&(abs_offset_lod[level]));
#ifdef PADDLE_WITH_HIP
hipLaunchKernelGGL(
HIP_KERNEL_NAME(SequenceScaleKernel<T, PADDLE_CUDA_NUM_THREADS>),
dim3(num_seq), dim3(PADDLE_CUDA_NUM_THREADS), 0, context.stream(),
seq_data, mix_vector.CUDAMutableData(context.GetPlace()), scales,
seq_width);
#else
SequenceScaleKernel<T, PADDLE_CUDA_NUM_THREADS><<<
num_seq, PADDLE_CUDA_NUM_THREADS, 0, context.stream()>>>(
seq_data, mix_vector.CUDAMutableData(context.GetPlace()), scales,
seq_width);
#endif
mix_vector.CopyToCPU();
}
};
template class ScaleLoDTensorFunctor<platform::CUDADeviceContext, float>;
template class ScaleLoDTensorFunctor<platform::CUDADeviceContext, double>;
template class ScaleLoDTensorFunctor<phi::GPUContext, float>;
template class ScaleLoDTensorFunctor<phi::GPUContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -12,11 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/warpctc_op.h"
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
......@@ -25,40 +27,6 @@ class WarpCTCOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Logits"), "Input", "Logits", "WarpCTC");
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "WarpCTC");
OP_INOUT_CHECK(ctx->HasOutput("WarpCTCGrad"), "Output", "WarpCTCGrad",
"WarpCTC");
OP_INOUT_CHECK(ctx->HasOutput("Loss"), "Output", "Loss", "WarpCTC");
auto logits_dims = ctx->GetInputDim("Logits");
int blank = ctx->Attrs().Get<int>("blank");
int sequence_width = 0;
if (ctx->HasInput("LogitsLength")) {
sequence_width = logits_dims[2];
} else {
sequence_width =
static_cast<int>(phi::product(logits_dims) / logits_dims[0]);
}
PADDLE_ENFORCE_GE(
blank, 0, platform::errors::InvalidArgument(
"The value of Attr(blank) should be in interval [0, %d), "
"but received %d",
blank));
PADDLE_ENFORCE_LT(
blank, sequence_width,
platform::errors::InvalidArgument(
"The value of Attr(blank) should be in interval [0, %d), "
"but received %d",
blank));
// TODO(liuyiqun): it is tricky to set the wrong dimension here.
ctx->SetOutputDim("Loss", {-1, 1});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -189,15 +157,11 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(WarpCTCGradOpNoNeedBufferVarInferer,
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(warpctc, WarpctcInferShapeFunctor,
PD_INFER_META(phi::WarpctcInferMeta));
REGISTER_OPERATOR(warpctc, ops::WarpCTCOp, ops::WarpCTCOpMaker,
ops::WarpCTCGradOpMaker<paddle::framework::OpDesc>,
ops::WarpCTCGradOpMaker<paddle::imperative::OpBase>);
ops::WarpCTCGradOpMaker<paddle::imperative::OpBase>,
WarpctcInferShapeFunctor);
REGISTER_OPERATOR(warpctc_grad, ops::WarpCTCGradOp,
ops::WarpCTCGradOpNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
warpctc, ops::WarpCTCKernel<paddle::platform::CPUDeviceContext, float>,
ops::WarpCTCKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
warpctc_grad,
ops::WarpCTCGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::WarpCTCGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/warpctc_op.h"
namespace ops = paddle::operators;
// register forward and backward of CUDA OP must in same *.cu file.
// Eigen can be used on GPU device, but must be in *.cu file not *.cu.cc file.
// *.cu.cc also using GCC compiler. *.cu using NVCC compiler
REGISTER_OP_CUDA_KERNEL(
warpctc, ops::WarpCTCKernel<paddle::platform::CUDADeviceContext, float>,
ops::WarpCTCKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
warpctc_grad,
ops::WarpCTCGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::WarpCTCGradKernel<paddle::platform::CUDADeviceContext, double>);
\ No newline at end of file
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/sequence_padding.h"
#include "paddle/fluid/operators/math/sequence_scale.h"
#include "paddle/fluid/platform/dynload/warpctc.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T>
class ComputeCtcLossFunctor {
public:
ctcStatus_t operator()(const T* const activations, T* gradients,
const int* const flat_labels,
const int* const label_lengths,
const int* const input_lengths, int alphabet_size,
int minibatch, T* costs, void* workspace,
ctcOptions options) {
return CTC_STATUS_EXECUTION_FAILED;
}
};
template <typename DeviceContext>
class ComputeCtcLossFunctor<DeviceContext, float> {
public:
ctcStatus_t operator()(const float* const activations, float* gradients,
const int* const flat_labels,
const int* const label_lengths,
const int* const input_lengths, int alphabet_size,
int minibatch, float* costs, void* workspace,
ctcOptions options) {
return platform::dynload::compute_ctc_loss(
activations, gradients, flat_labels, label_lengths, input_lengths,
static_cast<int>(alphabet_size), static_cast<int>(minibatch), costs,
workspace, options);
}
};
template <typename DeviceContext>
class ComputeCtcLossFunctor<DeviceContext, double> {
public:
ctcStatus_t operator()(const double* const activations, double* gradients,
const int* const flat_labels,
const int* const label_lengths,
const int* const input_lengths, int alphabet_size,
int minibatch, double* costs, void* workspace,
ctcOptions options) {
return platform::dynload::compute_ctc_loss_double(
activations, gradients, flat_labels, label_lengths, input_lengths,
static_cast<int>(alphabet_size), static_cast<int>(minibatch), costs,
workspace, options);
}
};
template <typename DeviceContext, typename T>
class WarpCTCFunctor {
public:
/*
* \brief Compute the connectionist temporal classification loss,
* and optionally compute the gradient with respect to the inputs.
*
* If gradient is nullptr, it only computes the ctc loss,
* or computes both ctc loss and gradient.
*
* \param ctx execution context of this functor
* \param input batch matrix of input probabilities, in
* max_sequence_length x num_sequences x
* sequence_width, (row-major) format
* \param gradient batch matrix of gradient, with the same shape as
* input.
* \param cpu_labels labels always in CPU memory.
* \param cpu_label_lengths length of all labels in CPU memory.
* \param cpu_input_lengths length of all sequences in CPU memory.
* \param sequence_width number of possible output symbols.
* \param num_sequences number of sequence.
* \param blank blank label used in ctc loss function.
* \param cpu_losss cost of each sequence in CPU memory.
*/
void operator()(const framework::ExecutionContext& ctx, const T* input,
T* gradient, const int* cpu_labels,
const int* cpu_label_lengths, const int* cpu_input_lengths,
const size_t sequence_width, const size_t num_sequences,
const size_t blank, T* cpu_loss) {
// Init warp-ctc options
init(ctx, blank);
// Compute the required workspace size.
// There is no memory allocated operations within warp-ctc.
size_t workspace_bytes = 0;
ctcStatus_t status = CTC_STATUS_UNKNOWN_ERROR;
if (sizeof(T) == 4) {
status = platform::dynload::get_workspace_size(
cpu_label_lengths, cpu_input_lengths,
static_cast<int>(sequence_width), static_cast<int>(num_sequences),
options_, &workspace_bytes);
} else {
status = platform::dynload::get_workspace_size_double(
cpu_label_lengths, cpu_input_lengths,
static_cast<int>(sequence_width), static_cast<int>(num_sequences),
options_, &workspace_bytes);
}
PADDLE_ENFORCE_EQ(
CTC_STATUS_SUCCESS, status,
platform::errors::PreconditionNotMet(
"warp-ctc [version %d] Error in get_workspace_size: %s",
warpctc_version_, platform::dynload::ctcGetStatusString(status)));
PADDLE_ENFORCE_GT(
workspace_bytes, 0UL,
platform::errors::InvalidArgument(
"Bytes of workspace got by warp-ctc function, "
"get_workspace_size() should be larger than 0, but received %d",
workspace_bytes));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
size_t workspace_elements = workspace_bytes / sizeof(T) + 1UL;
Tensor workspace = ctx.AllocateTmpTensor<T, DeviceContext>(
phi::make_ddim({static_cast<int64_t>(workspace_elements)}), dev_ctx);
T* workspace_data = workspace.data<T>();
phi::funcs::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), &workspace,
static_cast<T>(0));
// compute loss and gradient
status = ComputeCtcLossFunctor<DeviceContext, T>()(
input, gradient, cpu_labels, cpu_label_lengths, cpu_input_lengths,
static_cast<int>(sequence_width), static_cast<int>(num_sequences),
cpu_loss, workspace_data, options_);
PADDLE_ENFORCE_EQ(
CTC_STATUS_SUCCESS, status,
platform::errors::PreconditionNotMet(
"warp-ctc [version %d] Error in get_workspace_size: %s",
warpctc_version_, platform::dynload::ctcGetStatusString(status)));
}
protected:
void init(const framework::ExecutionContext& ctx, const size_t blank) {
warpctc_version_ = platform::dynload::get_warpctc_version();
if (platform::is_gpu_place(ctx.GetPlace())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
options_.loc = CTC_GPU;
options_.stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"[warpctc init] GPU is not enabled."));
#endif
} else {
options_.loc = CTC_CPU;
options_.num_threads = 1;
}
options_.blank_label = blank;
}
private:
int warpctc_version_;
ctcOptions options_;
};
template <typename DeviceContext, typename T>
class WarpCTCKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* logits = ctx.Input<LoDTensor>("Logits");
auto* label = ctx.Input<LoDTensor>("Label");
auto* warpctc_grad = ctx.Output<Tensor>("WarpCTCGrad");
auto* loss = ctx.Output<Tensor>("Loss");
size_t num_sequences, sequence_width, max_sequence_length;
framework::Vector<size_t> logits_lod;
framework::Vector<size_t> label_lod;
if (ctx.HasInput("LogitsLength") && ctx.HasInput("LabelLength")) {
num_sequences = logits->dims()[1];
sequence_width = logits->dims()[2];
max_sequence_length = logits->dims()[0];
PADDLE_ENFORCE_GT(max_sequence_length, 0,
platform::errors::InvalidArgument(
"The first dimension of Input(Logits) should be "
"greater than zero "
"but received %d. ",
max_sequence_length));
PADDLE_ENFORCE_GT(num_sequences, 0,
platform::errors::InvalidArgument(
"The second dimension of Input(Logits) should be "
"greater than zero "
"but received %d. ",
num_sequences));
PADDLE_ENFORCE_GT(sequence_width, 0,
platform::errors::InvalidArgument(
"The third dimension of Input(Logits) should be "
"greater than zero "
"but received %d. ",
sequence_width));
auto* logits_length = ctx.Input<framework::Tensor>("LogitsLength");
auto* labels_length = ctx.Input<framework::Tensor>("LabelLength");
framework::Tensor logits_length_cpu;
framework::Tensor labels_length_cpu;
framework::TensorCopy(*logits_length, platform::CPUPlace(),
&logits_length_cpu);
framework::TensorCopy(*labels_length, platform::CPUPlace(),
&labels_length_cpu);
logits_lod.push_back(0);
label_lod.push_back(0);
for (size_t i = 0; i < num_sequences; i++) {
logits_lod.push_back(logits_lod[i] +
logits_length_cpu.data<int64_t>()[i]);
label_lod.push_back(label_lod[i] +
labels_length_cpu.data<int64_t>()[i]);
}
} else {
PADDLE_ENFORCE_GT(logits->NumLevels(), 0UL,
platform::errors::InvalidArgument(
"Input(Logits) Tensor of WarpCTC "
"does not contain LoD information."));
PADDLE_ENFORCE_GT(label->NumLevels(), 0UL,
platform::errors::InvalidArgument(
"Input(Label) Tensor of WarpCTC "
"does not contain LoD information."));
logits_lod = framework::ToAbsOffset(logits->lod())[0];
auto logits_dims = logits->dims();
PADDLE_ENFORCE_GT(logits_dims[0], 0,
platform::errors::InvalidArgument(
"The first dimension of Input(Logits) should be "
"greater than zero "
"but received %d. ",
logits_dims[0]));
PADDLE_ENFORCE_EQ(
logits_dims[0], static_cast<int64_t>(logits_lod.back()),
platform::errors::InvalidArgument(
"The first dimension of Input(Logits) should be equal to "
"the sum of all sequences' lengths = %d., but received %d. ",
static_cast<int64_t>(logits_lod.back()), logits_dims[0]));
label_lod = framework::ToAbsOffset(label->lod())[0];
auto label_dims = label->dims();
PADDLE_ENFORCE_EQ(label_dims[1], 1,
platform::errors::InvalidArgument(
"The last dimension of Input(Label) should be 1, "
"but received %d",
label_dims[1]));
num_sequences = logits_lod.size() - 1;
PADDLE_ENFORCE_EQ(
num_sequences, label_lod.size() - 1,
platform::errors::InvalidArgument(
"The number of sequences of Input(Logits) should be "
"equal to that of Input(Label) = %d, but received %d",
label_lod.size() - 1, num_sequences));
sequence_width = logits->numel() / logits_dims[0];
max_sequence_length = math::MaximumSequenceLength(logits_lod);
}
auto loss_dims = phi::make_ddim({static_cast<int64_t>(num_sequences), 1});
// warpctc needs sequences data stored in transposed padding format
LoDTensor warpctc_logits;
auto warpctc_logits_dims =
phi::make_ddim({static_cast<int64_t>(max_sequence_length),
static_cast<int64_t>(num_sequences),
static_cast<int64_t>(sequence_width)});
auto& dev_ctx = ctx.template device_context<DeviceContext>();
Tensor warpctc_logits_tmp =
ctx.AllocateTmpTensor<T, DeviceContext>(warpctc_logits_dims, dev_ctx);
warpctc_logits.ShareDataWith(warpctc_logits_tmp);
if (ctx.HasInput("LogitsLength")) {
paddle::framework::TensorCopySync(*logits, ctx.GetPlace(),
&warpctc_logits);
} else {
LoDTensor cpu_pad_value;
T* pad_value_data =
cpu_pad_value.mutable_data<T>({1}, platform::CPUPlace());
*pad_value_data = static_cast<T>(0);
LoDTensor pad_value;
if (platform::is_cpu_place(ctx.GetPlace())) {
pad_value = cpu_pad_value;
} else {
paddle::framework::TensorCopySync(cpu_pad_value, ctx.GetPlace(),
&pad_value);
}
math::PaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *logits,
&warpctc_logits, pad_value, -1, 0, false /* norm_by_times */,
math::kLengthBatchWidth);
}
const T* warpctc_logits_data = warpctc_logits.data<T>();
std::vector<int> warpctc_label_lengths(num_sequences);
std::vector<int> warpctc_logits_lengths(num_sequences);
for (size_t i = 0; i < num_sequences; ++i) {
warpctc_label_lengths[i] = label_lod[i + 1] - label_lod[i];
warpctc_logits_lengths[i] = logits_lod[i + 1] - logits_lod[i];
}
// warpctc computes loss and gradient in one call, gradient data also stored
// in batch format
T* warpctc_grad_data =
warpctc_grad->mutable_data<T>(warpctc_logits.dims(), ctx.GetPlace());
phi::funcs::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), warpctc_grad,
static_cast<T>(0));
// warpctc accesses labels in CPU memory
LoDTensor warpctc_label;
if (ctx.HasInput("LogitsLength")) {
warpctc_label.mutable_data<int>(
{static_cast<int64_t>(math::TotalSequenceLength(label_lod)), 1},
platform::CPUPlace());
std::vector<framework::Vector<size_t>> lod;
lod.push_back(label_lod);
warpctc_label.set_lod(lod);
if (platform::is_cpu_place(ctx.GetPlace())) {
math::UnpaddingLoDTensorFunctor<DeviceContext, int>()(
ctx.template device_context<DeviceContext>(), *label,
&warpctc_label, label->dims()[1] /*pad_seq_len*/, 0 /*lod_level*/,
false /*norm_by_times*/, math::kBatchLengthWidth);
} else {
LoDTensor gpu_label;
gpu_label.mutable_data<int>(
{static_cast<int64_t>(math::TotalSequenceLength(label_lod)), 1},
ctx.GetPlace());
gpu_label.set_lod(lod);
math::UnpaddingLoDTensorFunctor<DeviceContext, int>()(
ctx.template device_context<DeviceContext>(), *label, &gpu_label,
label->dims()[1] /*pad_seq_len*/, 0 /*lod_level*/,
false /*norm_by_times*/, math::kBatchLengthWidth);
paddle::framework::TensorCopySync(gpu_label, platform::CPUPlace(),
&warpctc_label);
}
} else {
paddle::framework::TensorCopySync(*label, platform::CPUPlace(),
&warpctc_label);
}
const int* warpctc_label_data = warpctc_label.data<int>();
// warpctc stores loss in CPU memory
Tensor warpctc_loss;
T* warpctc_loss_data =
warpctc_loss.mutable_data<T>(loss_dims, platform::CPUPlace());
const size_t blank = static_cast<size_t>(ctx.Attr<int>("blank"));
WarpCTCFunctor<DeviceContext, T>()(
ctx, warpctc_logits_data, warpctc_grad_data, warpctc_label_data,
warpctc_label_lengths.data(), warpctc_logits_lengths.data(),
sequence_width, num_sequences, blank, warpctc_loss_data);
// Copy the loss back
paddle::framework::TensorCopy(warpctc_loss, ctx.GetPlace(),
ctx.device_context(), loss);
}
};
template <typename DeviceContext, typename T>
class WarpCTCGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* warpctc_grad = ctx.Input<LoDTensor>("WarpCTCGrad");
auto* logits_grad = ctx.Output<LoDTensor>(framework::GradVarName("Logits"));
const Tensor* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
logits_grad->mutable_data<T>(ctx.GetPlace());
bool norm_by_times = ctx.Attr<bool>("norm_by_times");
if (ctx.HasInput("LogitsLength")) {
int max_seq_length = warpctc_grad->dims()[0]; // Tmax
int num_sequences = warpctc_grad->dims()[1]; // B
int seq_width = warpctc_grad->dims()[2]; // D
auto* logits_length = ctx.Input<framework::Tensor>("LogitsLength");
// B
auto logits_len_e =
framework::EigenTensor<int64_t, 1>::From(*logits_length);
// (B, 1)
auto loss_grad_e = framework::EigenTensor<T, 2>::From(*loss_grad);
// (T, B, D)
auto warpctc_grad_e = framework::EigenTensor<T, 3>::From(*warpctc_grad);
auto logits_grad_e = framework::EigenTensor<T, 3>::From(*logits_grad);
Eigen::DSizes<int, 3> grad_shape(1, num_sequences, 1);
Eigen::DSizes<int, 3> bcast(max_seq_length, 1, seq_width);
auto logits_g = warpctc_grad_e *
loss_grad_e.reshape(grad_shape).broadcast(bcast).eval();
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
if (norm_by_times) {
auto scales = logits_len_e.cast<T>()
.inverse()
.reshape(grad_shape)
.broadcast(bcast)
.eval();
logits_grad_e.device(*place) = logits_g * scales;
} else {
logits_grad_e.device(*place) = logits_g;
}
} else {
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *warpctc_grad,
logits_grad, -1, 0, norm_by_times, math::kLengthBatchWidth);
const T* loss_grad_data = loss_grad->data<T>();
math::ScaleLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), loss_grad_data,
logits_grad);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -1082,6 +1082,43 @@ void PsroiPoolInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}
void WarpctcInferMeta(const MetaTensor& logits,
const MetaTensor& label,
const paddle::optional<const MetaTensor&> logits_length,
const paddle::optional<const MetaTensor&> labels_length,
int blank,
bool norm_by_times,
MetaTensor* warpctc_grad,
MetaTensor* loss) {
auto logits_dims = logits.dims();
int sequence_width = 0;
if (logits_length.is_initialized()) {
sequence_width = logits_dims[2];
} else {
sequence_width =
static_cast<int>(phi::product(logits_dims) / logits_dims[0]);
}
PADDLE_ENFORCE_GE(
blank,
0,
errors::InvalidArgument(
"The value of Attr(blank) should be in interval [0, %d), "
"but received %d",
blank));
PADDLE_ENFORCE_LT(
blank,
sequence_width,
errors::InvalidArgument(
"The value of Attr(blank) should be in interval [0, %d), "
"but received %d",
blank));
loss->set_dims({-1, 1});
loss->set_dtype(logits.dtype());
}
void WhereInferMeta(const MetaTensor& condition,
const MetaTensor& x,
const MetaTensor& y,
......
......@@ -214,6 +214,15 @@ void PsroiPoolInferMeta(const MetaTensor& x,
float spatial_scale,
MetaTensor* out);
void WarpctcInferMeta(const MetaTensor& logits,
const MetaTensor& label,
const paddle::optional<const MetaTensor&> logits_length,
const paddle::optional<const MetaTensor&> labels_length,
int blank,
bool norm_by_times,
MetaTensor* warpctc_grad,
MetaTensor* loss);
void WhereInferMeta(const MetaTensor& condition,
const MetaTensor& x,
const MetaTensor& y,
......
......@@ -32,7 +32,7 @@ set(MANUAL_BUILD_KERNELS adam_kernel adamw_kernel deformable_conv_kernel deforma
matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel
put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel
softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel
triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel)
triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel warpctc_kernel warpctc_grad_kernel)
kernel_library(adam_kernel DEPS gflags glog flags ${COMMON_KERNEL_DEPS} selected_rows_functor threadpool jit_kernel_helper)
kernel_library(adamw_kernel DEPS ${COMMON_KERNEL_DEPS} adam_kernel)
kernel_library(deformable_conv_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor)
......@@ -58,6 +58,8 @@ kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_
kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_reduce)
kernel_library(determinant_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse)
kernel_library(warpctc_kernel DEPS ${COMMON_KERNEL_DEPS} phi_dynload_warpctc sequence_padding sequence_scale)
kernel_library(warpctc_grad_kernel DEPS ${COMMON_KERNEL_DEPS} sequence_padding sequence_scale)
# 4. auto parse and build kernel targets by cmake
register_kernels(EXCLUDES ${COMMON_BAISC_KERNELS} ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS} ${COMMON_BAISC_KERNELS} )
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/warpctc_grad_kernel.h"
#include "paddle/phi/kernels/impl/warpctc_grad_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(
warpctc_grad, CPU, ALL_LAYOUT, phi::WarpctcGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/warpctc_kernel.h"
#include "paddle/phi/kernels/impl/warpctc_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(
warpctc, CPU, ALL_LAYOUT, phi::WarpctcKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/warpctc_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/warpctc_grad_kernel_impl.h"
PD_REGISTER_KERNEL(
warpctc_grad, GPU, ALL_LAYOUT, phi::WarpctcGradKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/warpctc_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/warpctc_kernel_impl.h"
PD_REGISTER_KERNEL(
warpctc, GPU, ALL_LAYOUT, phi::WarpctcKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/operators/math/sequence_padding.h"
#include "paddle/fluid/operators/math/sequence_scale.h"
#include "paddle/phi/backends/dynload/warpctc.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename T, typename Context>
void WarpctcGradKernel(const Context& dev_ctx,
const DenseTensor& warpctc_grad,
const DenseTensor& logits,
const DenseTensor& loss_grad,
const paddle::optional<const DenseTensor&> logits_length,
int blank,
bool norm_by_times,
DenseTensor* logits_grad) {
dev_ctx.template Alloc<T>(logits_grad);
if (logits_length.is_initialized()) {
int max_seq_length = warpctc_grad.dims()[0]; // Tmax
int num_sequences = warpctc_grad.dims()[1]; // B
int seq_width = warpctc_grad.dims()[2]; // D
// B
auto logits_len_e = EigenTensor<int64_t, 1>::From(*logits_length);
// (B, 1)
auto loss_grad_e = EigenTensor<T, 2>::From(loss_grad);
// (T, B, D)
auto warpctc_grad_e = EigenTensor<T, 3>::From(warpctc_grad);
auto logits_grad_e = EigenTensor<T, 3>::From(*logits_grad);
Eigen::DSizes<int, 3> grad_shape(1, num_sequences, 1);
Eigen::DSizes<int, 3> bcast(max_seq_length, 1, seq_width);
auto logits_g = warpctc_grad_e *
loss_grad_e.reshape(grad_shape).broadcast(bcast).eval();
auto* place = dev_ctx.eigen_device();
if (norm_by_times) {
auto scales = logits_len_e.cast<T>()
.inverse()
.reshape(grad_shape)
.broadcast(bcast)
.eval();
logits_grad_e.device(*place) = logits_g * scales;
} else {
logits_grad_e.device(*place) = logits_g;
}
} else {
paddle::operators::math::UnpaddingLoDTensorFunctor<Context, T>()(
dev_ctx,
warpctc_grad,
logits_grad,
-1,
0,
norm_by_times,
paddle::operators::math::kLengthBatchWidth);
const T* loss_grad_data = loss_grad.data<T>();
paddle::operators::math::ScaleLoDTensorFunctor<Context, T>()(
dev_ctx, loss_grad_data, logits_grad);
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/operators/math/sequence_padding.h"
#include "paddle/fluid/operators/math/sequence_scale.h"
#include "paddle/phi/backends/dynload/warpctc.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename Context, typename T>
class ComputeCtcLossFunctor {
public:
ctcStatus_t operator()(const T* const activations,
T* gradients,
const int* const flat_labels,
const int* const label_lengths,
const int* const input_lengths,
int alphabet_size,
int minibatch,
T* costs,
void* workspace,
ctcOptions options) {
return CTC_STATUS_EXECUTION_FAILED;
}
};
template <typename Context>
class ComputeCtcLossFunctor<Context, float> {
public:
ctcStatus_t operator()(const float* const activations,
float* gradients,
const int* const flat_labels,
const int* const label_lengths,
const int* const input_lengths,
int alphabet_size,
int minibatch,
float* costs,
void* workspace,
ctcOptions options) {
return phi::dynload::compute_ctc_loss(activations,
gradients,
flat_labels,
label_lengths,
input_lengths,
static_cast<int>(alphabet_size),
static_cast<int>(minibatch),
costs,
workspace,
options);
}
};
template <typename Context>
class ComputeCtcLossFunctor<Context, double> {
public:
ctcStatus_t operator()(const double* const activations,
double* gradients,
const int* const flat_labels,
const int* const label_lengths,
const int* const input_lengths,
int alphabet_size,
int minibatch,
double* costs,
void* workspace,
ctcOptions options) {
return phi::dynload::compute_ctc_loss_double(
activations,
gradients,
flat_labels,
label_lengths,
input_lengths,
static_cast<int>(alphabet_size),
static_cast<int>(minibatch),
costs,
workspace,
options);
}
};
template <typename Context, typename T>
class WarpCTCFunctor {
public:
/*
* \brief Compute the connectionist temporal classification loss,
* and optionally compute the gradient with respect to the inputs.
*
* If gradient is nullptr, it only computes the ctc loss,
* or computes both ctc loss and gradient.
*
* \param ctx execution context of this functor
* \param input batch matrix of input probabilities, in
* max_sequence_length x num_sequences x
* sequence_width, (row-major) format
* \param gradient batch matrix of gradient, with the same shape as
* input.
* \param cpu_labels labels always in CPU memory.
* \param cpu_label_lengths length of all labels in CPU memory.
* \param cpu_input_lengths length of all sequences in CPU memory.
* \param sequence_width number of possible output symbols.
* \param num_sequences number of sequence.
* \param blank blank label used in ctc loss function.
* \param cpu_losss cost of each sequence in CPU memory.
*/
void operator()(const Context& dev_ctx,
const T* input,
T* gradient,
const int* cpu_labels,
const int* cpu_label_lengths,
const int* cpu_input_lengths,
const size_t sequence_width,
const size_t num_sequences,
const size_t blank,
T* cpu_loss) {
// Init warp-ctc options
init(dev_ctx, blank);
// Compute the required workspace size.
// There is no memory allocated operations within warp-ctc.
size_t workspace_bytes = 0;
ctcStatus_t status = CTC_STATUS_UNKNOWN_ERROR;
if (sizeof(T) == 4) {
status =
phi::dynload::get_workspace_size(cpu_label_lengths,
cpu_input_lengths,
static_cast<int>(sequence_width),
static_cast<int>(num_sequences),
options_,
&workspace_bytes);
} else {
status = phi::dynload::get_workspace_size_double(
cpu_label_lengths,
cpu_input_lengths,
static_cast<int>(sequence_width),
static_cast<int>(num_sequences),
options_,
&workspace_bytes);
}
PADDLE_ENFORCE_EQ(
CTC_STATUS_SUCCESS,
status,
errors::PreconditionNotMet(
"warp-ctc [version %d] Error in get_workspace_size: %s",
warpctc_version_,
phi::dynload::ctcGetStatusString(status)));
PADDLE_ENFORCE_GT(
workspace_bytes,
0UL,
errors::InvalidArgument(
"Bytes of workspace got by warp-ctc function, "
"get_workspace_size() should be larger than 0, but received %d",
workspace_bytes));
size_t workspace_elements = workspace_bytes / sizeof(T) + 1UL;
DenseTensor workspace = phi::Empty<T, Context>(
dev_ctx, {static_cast<int64_t>(workspace_elements)});
T* workspace_data = workspace.data<T>();
phi::funcs::SetConstant<Context, T>()(
dev_ctx, &workspace, static_cast<T>(0));
// compute loss and gradient
status =
ComputeCtcLossFunctor<Context, T>()(input,
gradient,
cpu_labels,
cpu_label_lengths,
cpu_input_lengths,
static_cast<int>(sequence_width),
static_cast<int>(num_sequences),
cpu_loss,
workspace_data,
options_);
PADDLE_ENFORCE_EQ(
CTC_STATUS_SUCCESS,
status,
errors::PreconditionNotMet(
"warp-ctc [version %d] Error in get_workspace_size: %s",
warpctc_version_,
phi::dynload::ctcGetStatusString(status)));
}
protected:
void init(const Context& dev_ctx, const size_t blank) {
warpctc_version_ = phi::dynload::get_warpctc_version();
if (dev_ctx.GetPlace() == phi::GPUPlace()) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
options_.loc = CTC_GPU;
options_.stream =
reinterpret_cast<const phi::GPUContext&>(dev_ctx).stream();
#else
PADDLE_THROW(
errors::PreconditionNotMet("[warpctc init] GPU is not enabled."));
#endif
} else {
options_.loc = CTC_CPU;
options_.num_threads = 1;
}
options_.blank_label = blank;
}
private:
int warpctc_version_;
ctcOptions options_;
};
template <typename T, typename Context>
void WarpctcKernel(const Context& dev_ctx,
const DenseTensor& logits,
const DenseTensor& label,
const paddle::optional<const DenseTensor&> logits_length,
const paddle::optional<const DenseTensor&> labels_length,
int blank,
bool norm_by_times,
DenseTensor* warpctc_grad,
DenseTensor* loss) {
size_t num_sequences, sequence_width, max_sequence_length;
paddle::framework::Vector<size_t> logits_lod;
paddle::framework::Vector<size_t> label_lod;
if (logits_length.is_initialized() && labels_length.is_initialized()) {
num_sequences = logits.dims()[1];
sequence_width = logits.dims()[2];
max_sequence_length = logits.dims()[0];
PADDLE_ENFORCE_GT(max_sequence_length,
0,
phi::errors::InvalidArgument(
"The first dimension of Input(Logits) should be "
"greater than zero "
"but received %d. ",
max_sequence_length));
PADDLE_ENFORCE_GT(num_sequences,
0,
phi::errors::InvalidArgument(
"The second dimension of Input(Logits) should be "
"greater than zero "
"but received %d. ",
num_sequences));
PADDLE_ENFORCE_GT(sequence_width,
0,
phi::errors::InvalidArgument(
"The third dimension of Input(Logits) should be "
"greater than zero "
"but received %d. ",
sequence_width));
DenseTensor logits_length_cpu;
DenseTensor labels_length_cpu;
phi::Copy(
dev_ctx, *logits_length, phi::CPUPlace(), false, &logits_length_cpu);
phi::Copy(
dev_ctx, *labels_length, phi::CPUPlace(), false, &labels_length_cpu);
logits_lod.push_back(0);
label_lod.push_back(0);
for (size_t i = 0; i < num_sequences; i++) {
logits_lod.push_back(logits_lod[i] +
logits_length_cpu.data<int64_t>()[i]);
label_lod.push_back(label_lod[i] + labels_length_cpu.data<int64_t>()[i]);
}
} else {
PADDLE_ENFORCE_GT(
logits.NumLevels(),
0UL,
phi::errors::InvalidArgument("Input(Logits) Tensor of WarpCTC "
"does not contain LoD information."));
PADDLE_ENFORCE_GT(
label.NumLevels(),
0UL,
phi::errors::InvalidArgument("Input(Label) Tensor of WarpCTC "
"does not contain LoD information."));
logits_lod = paddle::framework::ToAbsOffset(logits.lod())[0];
auto logits_dims = logits.dims();
PADDLE_ENFORCE_GT(logits_dims[0],
0,
phi::errors::InvalidArgument(
"The first dimension of Input(Logits) should be "
"greater than zero "
"but received %d. ",
logits_dims[0]));
PADDLE_ENFORCE_EQ(
logits_dims[0],
static_cast<int64_t>(logits_lod.back()),
phi::errors::InvalidArgument(
"The first dimension of Input(Logits) should be equal to "
"the sum of all sequences' lengths = %d., but received %d. ",
static_cast<int64_t>(logits_lod.back()),
logits_dims[0]));
label_lod = paddle::framework::ToAbsOffset(label.lod())[0];
auto label_dims = label.dims();
PADDLE_ENFORCE_EQ(label_dims[1],
1,
phi::errors::InvalidArgument(
"The last dimension of Input(Label) should be 1, "
"but received %d",
label_dims[1]));
num_sequences = logits_lod.size() - 1;
PADDLE_ENFORCE_EQ(num_sequences,
label_lod.size() - 1,
phi::errors::InvalidArgument(
"The number of sequences of Input(Logits) should be "
"equal to that of Input(Label) = %d, but received %d",
label_lod.size() - 1,
num_sequences));
sequence_width = logits.numel() / logits_dims[0];
max_sequence_length =
paddle::operators::math::MaximumSequenceLength(logits_lod);
}
auto loss_dims = phi::make_ddim({static_cast<int64_t>(num_sequences), 1});
// warpctc needs sequences data stored in transposed padding format
DenseTensor warpctc_logits_tmp =
phi::Empty<T, Context>(dev_ctx,
{static_cast<int64_t>(max_sequence_length),
static_cast<int64_t>(num_sequences),
static_cast<int64_t>(sequence_width)});
DenseTensor warpctc_logits(warpctc_logits_tmp);
if (logits_length.is_initialized()) {
phi::Copy(dev_ctx, logits, dev_ctx.GetPlace(), true, &warpctc_logits);
} else {
DenseTensor cpu_pad_value;
cpu_pad_value.Resize({1});
T* pad_value_data = dev_ctx.template HostAlloc<T>(&cpu_pad_value);
*pad_value_data = static_cast<T>(0);
DenseTensor pad_value;
if (dev_ctx.GetPlace() == phi::CPUPlace()) {
pad_value = cpu_pad_value;
} else {
phi::Copy(dev_ctx, cpu_pad_value, dev_ctx.GetPlace(), true, &pad_value);
}
paddle::operators::math::PaddingLoDTensorFunctor<Context, T>()(
dev_ctx,
logits,
&warpctc_logits,
pad_value,
-1,
0,
false /* norm_by_times */,
paddle::operators::math::kLengthBatchWidth);
}
const T* warpctc_logits_data = warpctc_logits.data<T>();
std::vector<int> warpctc_label_lengths(num_sequences);
std::vector<int> warpctc_logits_lengths(num_sequences);
for (size_t i = 0; i < num_sequences; ++i) {
warpctc_label_lengths[i] = label_lod[i + 1] - label_lod[i];
warpctc_logits_lengths[i] = logits_lod[i + 1] - logits_lod[i];
}
// warpctc computes loss and gradient in one call, gradient data also stored
// in batch format
warpctc_grad->Resize(warpctc_logits.dims());
T* warpctc_grad_data = dev_ctx.template Alloc<T>(warpctc_grad);
phi::funcs::SetConstant<Context, T>()(
dev_ctx, warpctc_grad, static_cast<T>(0));
// warpctc accesses labels in CPU memory
DenseTensor warpctc_label;
if (logits_length.is_initialized()) {
warpctc_label.Resize(
{static_cast<int64_t>(
paddle::operators::math::TotalSequenceLength(label_lod)),
1});
dev_ctx.template HostAlloc<int>(&warpctc_label);
std::vector<paddle::framework::Vector<size_t>> lod;
lod.push_back(label_lod);
warpctc_label.set_lod(lod);
if (dev_ctx.GetPlace() == phi::CPUPlace()) {
paddle::operators::math::UnpaddingLoDTensorFunctor<Context, int>()(
dev_ctx,
label,
&warpctc_label,
label.dims()[1] /*pad_seq_len*/,
0 /*lod_level*/,
false /*norm_by_times*/,
paddle::operators::math::kBatchLengthWidth);
} else {
DenseTensor gpu_label;
gpu_label.Resize(
{static_cast<int64_t>(
paddle::operators::math::TotalSequenceLength(label_lod)),
1});
dev_ctx.template Alloc<int>(&gpu_label);
gpu_label.set_lod(lod);
paddle::operators::math::UnpaddingLoDTensorFunctor<Context, int>()(
dev_ctx,
label,
&gpu_label,
label.dims()[1] /*pad_seq_len*/,
0 /*lod_level*/,
false /*norm_by_times*/,
paddle::operators::math::kBatchLengthWidth);
phi::Copy(dev_ctx, gpu_label, phi::CPUPlace(), true, &warpctc_label);
}
} else {
phi::Copy(dev_ctx, label, phi::CPUPlace(), true, &warpctc_label);
}
const int* warpctc_label_data = warpctc_label.data<int>();
// warpctc stores loss in CPU memory
DenseTensor warpctc_loss;
warpctc_loss.Resize(loss_dims);
T* warpctc_loss_data = dev_ctx.template HostAlloc<T>(&warpctc_loss);
WarpCTCFunctor<Context, T>()(dev_ctx,
warpctc_logits_data,
warpctc_grad_data,
warpctc_label_data,
warpctc_label_lengths.data(),
warpctc_logits_lengths.data(),
sequence_width,
num_sequences,
blank,
warpctc_loss_data);
// Copy the loss back
phi::Copy(dev_ctx, warpctc_loss, dev_ctx.GetPlace(), false, loss);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename T, typename Context>
void WarpctcGradKernel(const Context& dev_ctx,
const DenseTensor& warpctc_grad,
const DenseTensor& logits,
const DenseTensor& loss_grad,
paddle::optional<const DenseTensor&> logits_length,
int blank,
bool norm_by_times,
DenseTensor* logits_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename T, typename Context>
void WarpctcKernel(const Context& dev_ctx,
const DenseTensor& logits,
const DenseTensor& label,
paddle::optional<const DenseTensor&> logits_length,
paddle::optional<const DenseTensor&> labels_length,
int blank,
bool norm_by_times,
DenseTensor* warpctc_grad,
DenseTensor* loss);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature WarpctcOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("warpctc",
{"Logits", "Label", "LogitsLength", "LabelLength"},
{"blank", "norm_by_times"},
{"WarpCTCGrad", "Loss"});
}
KernelSignature WarpctcGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"warpctc_grad",
{"WarpCTCGrad", "Logits", GradVarName("Loss"), "LogitsLength"},
{"blank", "norm_by_times"},
{GradVarName("Logits")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(warpctc, phi::WarpctcOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(warpctc_grad, phi::WarpctcGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册