From 8ede9e6f59027bc9a2e8a21fb42b5d3746dda27a Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 28 Oct 2021 01:42:29 -0500 Subject: [PATCH] [Cherry-pick] Enable CTC grad compute on GPU (#36780) * Revert "Align CTC grad scale same with ESPNet (#34729)" This reverts commit 10f9644cc4cb4eb23807007d678df880db4b0336. * ctc grad compute on gpu --- .../fluid/operators/math/sequence_padding.cc | 31 +-- .../fluid/operators/math/sequence_padding.cu | 24 +- .../fluid/operators/math/sequence_padding.h | 4 - .../operators/math/sequence_padding_test.cc | 4 +- .../operators/sequence_ops/sequence_pad_op.h | 4 +- .../sequence_ops/sequence_unpad_op.h | 5 +- paddle/fluid/operators/warpctc_op.cc | 29 --- paddle/fluid/operators/warpctc_op.cu | 180 +-------------- paddle/fluid/operators/warpctc_op.h | 42 +--- python/paddle/fluid/layers/loss.py | 25 +- .../fluid/tests/unittests/test_warpctc_op.py | 215 ------------------ python/paddle/nn/functional/loss.py | 16 +- python/paddle/nn/layer/loss.py | 8 +- 13 files changed, 40 insertions(+), 547 deletions(-) diff --git a/paddle/fluid/operators/math/sequence_padding.cc b/paddle/fluid/operators/math/sequence_padding.cc index dca58f796a7..e29313e9f74 100644 --- a/paddle/fluid/operators/math/sequence_padding.cc +++ b/paddle/fluid/operators/math/sequence_padding.cc @@ -33,8 +33,7 @@ void CopyValidData(framework::Tensor* dst_tensor, const framework::Tensor* src_tensor, const framework::Vector& seq_offsets, int pad_seq_len, int step_width, bool norm_by_len, - bool norm_by_batchsize, bool norm_by_total_logits_len, - int total_logits_len, CopyType type, PadLayout layout) { + CopyType type, PadLayout layout) { int seq_num = seq_offsets.size() - 1; const T* src_data = src_tensor->data(); T* dst_data = dst_tensor->data(); @@ -55,21 +54,7 @@ void CopyValidData(framework::Tensor* dst_tensor, int pad_data_offset = layout == kBatchLengthWidth ? seq_idx * pad_seq_len * step_width : seq_idx * step_width; - - float scale = 1.0f; - if (norm_by_total_logits_len) { - scale = 1.0f / static_cast(total_logits_len); - VLOG(3) << "[warpctc grad][norm_by_total_logits_len]: scale " << scale - << "total_logits_len " << total_logits_len; - } else if (norm_by_batchsize) { - scale = 1.0f / static_cast(seq_num); - VLOG(3) << "[warpctc grad][norm_by_batchsize]: scale " << scale << "B " - << seq_num; - } else if (norm_by_len) { - scale = 1.0f / static_cast(valid_seq_len); - VLOG(3) << "[warpctc grad][norm_by_len]: scale " << scale << "T " - << valid_seq_len; - } + float scale = 1.0f / static_cast(valid_seq_len); for (int step_idx = 0; step_idx < valid_seq_len; ++step_idx) { const T* src = @@ -112,8 +97,6 @@ class PaddingLoDTensorFunctor { framework::LoDTensor* pad_tensor, const framework::LoDTensor& pad_value, int pad_seq_len = -1, int lod_level = 0, bool norm_by_times = false, - bool norm_by_batchsize = false, - bool norm_by_total_logits_len = false, const PadLayout layout = kBatchLengthWidth) { auto seq_lod = seq_tensor.lod(); const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level]; @@ -148,8 +131,7 @@ class PaddingLoDTensorFunctor { } CopyValidData(pad_tensor, &seq_tensor, seq_offsets, pad_seq_len, - step_width, norm_by_times, false, false, 0, kSeqToPad, - layout); + step_width, norm_by_times, kSeqToPad, layout); } }; @@ -160,8 +142,6 @@ class UnpaddingLoDTensorFunctor { const framework::LoDTensor& pad_tensor, framework::LoDTensor* seq_tensor, int pad_seq_len = -1, int lod_level = 0, bool norm_by_times = false, - bool norm_by_batchsize = false, - bool norm_by_total_logits_len = false, const PadLayout layout = kBatchLengthWidth) { auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level]; const auto& seq_tensor_dims = seq_tensor->dims(); @@ -169,16 +149,13 @@ class UnpaddingLoDTensorFunctor { if (pad_seq_len == -1) { pad_seq_len = MaximumSequenceLength(seq_offsets); } - int total_logits_len = TotalSequenceLength(seq_offsets); int step_width = seq_tensor->numel() / seq_tensor_dims[0]; CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len, step_width, layout); CopyValidData(seq_tensor, &pad_tensor, seq_offsets, pad_seq_len, - step_width, norm_by_times, norm_by_batchsize, - norm_by_total_logits_len, total_logits_len, kPadToSeq, - layout); + step_width, norm_by_times, kPadToSeq, layout); } }; diff --git a/paddle/fluid/operators/math/sequence_padding.cu b/paddle/fluid/operators/math/sequence_padding.cu index 3578d7e91fd..19c3af03411 100644 --- a/paddle/fluid/operators/math/sequence_padding.cu +++ b/paddle/fluid/operators/math/sequence_padding.cu @@ -23,9 +23,7 @@ template __global__ void SequencePaddingKernel( T* dst, const T* src, const T* pad_value, bool is_constant_pad, const size_t* seq_offsets, const size_t seq_num, const size_t pad_seq_len, - const size_t step_width, bool norm_by_len, bool norm_by_batchsize, - bool norm_by_total_logits_len, int total_logits_len, - const PadLayout layout) { + const size_t step_width, bool norm_by_len, const PadLayout layout) { size_t seq_idx = blockIdx.y; size_t seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx]; @@ -40,15 +38,7 @@ __global__ void SequencePaddingKernel( src + (Type == kSeqToPad ? seq_data_offset : pad_data_offset); if (step_idx < seq_len) { - float scale = 1.0f; - if (norm_by_total_logits_len) { - scale = 1.0f / static_cast(total_logits_len); - } else if (norm_by_batchsize) { - scale = 1.0f / static_cast(seq_num); - } else if (norm_by_len) { - scale = norm_by_len ? (1.0f / static_cast(seq_len)) : 1.0f; - } - + float scale = norm_by_len ? (1.0f / static_cast(seq_len)) : 1.0f; for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) { dst_data[i] = scale * src_data[i]; } @@ -67,8 +57,6 @@ class PaddingLoDTensorFunctor { framework::LoDTensor* pad_tensor, const framework::LoDTensor& pad_value, int pad_seq_len = -1, int lod_level = 0, bool norm_by_times = false, - bool norm_by_batchsize = false, - bool norm_by_total_logits_len = false, const PadLayout layout = kBatchLengthWidth) { auto seq_lod = seq_tensor.lod(); const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level]; @@ -119,7 +107,7 @@ class PaddingLoDTensorFunctor { SequencePaddingKernel<<>>( pad_data, seq_data, pad_value_data, pad_value.numel() == 1, seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len, - step_width, norm_by_times, false, false, 0, layout); + step_width, norm_by_times, layout); } }; @@ -130,8 +118,6 @@ class UnpaddingLoDTensorFunctor { const framework::LoDTensor& pad_tensor, framework::LoDTensor* seq_tensor, int pad_seq_len = -1, int lod_level = 0, bool norm_by_times = false, - bool norm_by_batchsize = false, - bool norm_by_total_logits_len = false, const PadLayout layout = kBatchLengthWidth) { auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level]; const auto& seq_tensor_dims = seq_tensor->dims(); @@ -140,7 +126,6 @@ class UnpaddingLoDTensorFunctor { if (pad_seq_len == -1) { pad_seq_len = max_seq_len; } - int total_logits_len = TotalSequenceLength(seq_offsets); int step_width = seq_tensor->numel() / seq_tensor_dims[0]; int seq_num = seq_offsets.size() - 1; @@ -174,8 +159,7 @@ class UnpaddingLoDTensorFunctor { SequencePaddingKernel<<>>( seq_data, pad_data, nullptr, false, seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len, - step_width, norm_by_times, norm_by_batchsize, norm_by_total_logits_len, - total_logits_len, layout); + step_width, norm_by_times, layout); } }; diff --git a/paddle/fluid/operators/math/sequence_padding.h b/paddle/fluid/operators/math/sequence_padding.h index 308e1eedebd..956a4ff6a2d 100644 --- a/paddle/fluid/operators/math/sequence_padding.h +++ b/paddle/fluid/operators/math/sequence_padding.h @@ -107,8 +107,6 @@ class PaddingLoDTensorFunctor { framework::LoDTensor* pad_tensor, const framework::LoDTensor& pad_value, int pad_seq_len = -1, int lod_level = 0, bool norm_by_times = false, - bool norm_by_batchsize = false, - bool norm_by_total_logits_len = false, const PadLayout layout = kBatchLengthWidth); }; @@ -119,8 +117,6 @@ class UnpaddingLoDTensorFunctor { const framework::LoDTensor& pad_tensor, framework::LoDTensor* seq_tensor, int pad_seq_len = -1, int lod_level = 0, bool norm_by_times = false, - bool norm_by_batchsize = false, - bool norm_by_total_logits_len = false, const PadLayout layout = kBatchLengthWidth); }; diff --git a/paddle/fluid/operators/math/sequence_padding_test.cc b/paddle/fluid/operators/math/sequence_padding_test.cc index 590d1d6191d..ea31b10c555 100644 --- a/paddle/fluid/operators/math/sequence_padding_test.cc +++ b/paddle/fluid/operators/math/sequence_padding_test.cc @@ -66,13 +66,13 @@ void TestSequencePadding(const DeviceContext &context, } paddle::operators::math::PaddingLoDTensorFunctor()( - context, seq, &padding, pad_value, -1, 0, false, false, false, + context, seq, &padding, pad_value, -1, 0, false, paddle::operators::math::kLengthBatchWidth); seq_back.set_lod(lod); seq_back.mutable_data(seq_dims, place); paddle::operators::math::UnpaddingLoDTensorFunctor()( - context, padding, &seq_back, -1, 0, false, false, false, + context, padding, &seq_back, -1, 0, false, paddle::operators::math::kLengthBatchWidth); if (paddle::platform::is_cpu_place(place)) { diff --git a/paddle/fluid/operators/sequence_ops/sequence_pad_op.h b/paddle/fluid/operators/sequence_ops/sequence_pad_op.h index d8ae0b200df..a9660f05c3c 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pad_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_pad_op.h @@ -46,7 +46,7 @@ class SequencePadOpKernel : public framework::OpKernel { math::PaddingLoDTensorFunctor()( ctx.template device_context(), *x, out, *pad_value, - padded_length, 0, false, false, false, math::kBatchLengthWidth); + padded_length, 0, false, math::kBatchLengthWidth); LoDTensor seq_len; seq_len.Resize(len_t->dims()); @@ -72,7 +72,7 @@ class SequencePadGradOpKernel : public framework::OpKernel { math::UnpaddingLoDTensorFunctor()( ctx.template device_context(), *d_out, d_x, - padded_length, 0, false, false, false, math::kBatchLengthWidth); + padded_length, 0, false, math::kBatchLengthWidth); } } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_unpad_op.h b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.h index 398c3bba075..60ba4797db1 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_unpad_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.h @@ -69,8 +69,7 @@ class SequenceUnpadOpKernel : public framework::OpKernel { int64_t padded_length = x_t->dims()[1]; math::UnpaddingLoDTensorFunctor()( - dev_ctx, *x_t, out_t, padded_length, 0, false, false, false, - math::kBatchLengthWidth); + dev_ctx, *x_t, out_t, padded_length, 0, false, math::kBatchLengthWidth); } }; @@ -94,7 +93,7 @@ class SequenceUnpadGradOpKernel : public framework::OpKernel { math::PaddingLoDTensorFunctor()( ctx.template device_context(), *d_out, d_x, zero_pads, - padded_length, 0, false, false, false, math::kBatchLengthWidth); + padded_length, 0, false, math::kBatchLengthWidth); } } }; diff --git a/paddle/fluid/operators/warpctc_op.cc b/paddle/fluid/operators/warpctc_op.cc index 92862929159..f38f5d9f723 100644 --- a/paddle/fluid/operators/warpctc_op.cc +++ b/paddle/fluid/operators/warpctc_op.cc @@ -125,17 +125,6 @@ class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker { "normalize the gradients by the number of time-step, " "which is also the sequence's length.") .SetDefault(false); - AddAttr( - "norm_by_batchsize", - "(bool, default: false), normalize the loss by the batch size." - "If True, supersedes norm_by_times") - .SetDefault(false); - AddAttr( - "norm_by_total_logits_len", - "(bool, default: false), normalize the loss by the total number of " - "frames" - "in the batch. If True, supersedes norm_by_batchsize and norm_by_times") - .SetDefault(false); AddComment(R"DOC( An operator integrating the open-source [warp-ctc](https://github.com/baidu-research/warp-ctc) library, which is used in @@ -217,21 +206,3 @@ REGISTER_OP_CPU_KERNEL( warpctc_grad, ops::WarpCTCGradKernel, ops::WarpCTCGradKernel); - -REGISTER_OP_VERSION(warpctc) - .AddCheckpoint( - R"ROC( - Upgrade warpctc add a new attribute [norm_by_batchsize] and [norm_by_total_logits_len])ROC", - paddle::framework::compatible::OpVersionDesc() - .NewAttr( - "norm_by_batchsize", - "(bool, default: false), normalize the loss by the batch size." - "If True, supersedes norm_by_times", - false) - .NewAttr("norm_by_total_logits_len", - "(bool, default: false), normalize the loss by the total " - "number of " - "frames" - "in the batch. If True, supersedes norm_by_batchsize and " - "norm_by_times", - false)); \ No newline at end of file diff --git a/paddle/fluid/operators/warpctc_op.cu b/paddle/fluid/operators/warpctc_op.cu index 27c17eb6de8..fd820805e4d 100644 --- a/paddle/fluid/operators/warpctc_op.cu +++ b/paddle/fluid/operators/warpctc_op.cu @@ -12,185 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/warpctc_op.h" -#include "paddle/fluid/platform/cuda_primitives.h" -#include "paddle/fluid/platform/gpu_info.h" - -namespace paddle { -namespace operators { - -using platform::PADDLE_CUDA_NUM_THREADS; - -template -void PrintTensor(const framework::LoDTensor& src, - const framework::ExecutionContext& ctx) { - std::vector vec(src.numel()); - TensorToVector(src, ctx.device_context(), &vec); - for (int i = 0; i < static_cast(vec.size()); ++i) { - VLOG(3) << "vec[" << i << "] : " << vec[i]; - } -} - -template -__global__ void ReduceSumKernel(const T* d_in, T* d_out) { - // Allocate shared memory - extern __shared__ int partial_sum[]; - - // Calculate thread ID - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - // Load elements into shared memory - partial_sum[threadIdx.x] = d_in[tid]; - __syncthreads(); - - // Start at 1/2 block stride and divide by two each iteration - for (int s = blockDim.x / 2; s > 0; s >>= 1) { - // Each thread does work unless it is further than the stride - if (threadIdx.x < s) { - partial_sum[threadIdx.x] += partial_sum[threadIdx.x + s]; - } - __syncthreads(); - } - - // Let the thread 0 for this block write it's result to main memory - // Result is inexed by this block - if (threadIdx.x == 0) { - d_out[blockIdx.x] = partial_sum[0]; - } -} - -template -__global__ void CTCGradScaleKernel(T* d_out, const T* d_ctc, const T* d_loss, - int scale, int Tmax, int B, int D) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int n_elems = Tmax * B * D; - int b_idx = (tid / D) % B; - for (; tid < n_elems; tid += gridDim.x * blockDim.x) { - d_out[tid] = d_ctc[tid] * d_loss[b_idx] / static_cast(scale); - } -} - -template -__global__ void CTCGradScaleKernel(T* d_out, const T* d_ctc, const T* d_loss, - int64_t* scale, int Tmax, int B, int D) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int n_elems = Tmax * B * D; - int b_idx = (tid / D) % B; - for (; tid < n_elems; tid += gridDim.x * blockDim.x) { - d_out[tid] = d_ctc[tid] * d_loss[b_idx] / static_cast(scale[0]); - } -} - -template -__global__ void CTCGradBatchScaleKernel(T* d_out, const T* d_ctc, - const T* d_loss, const int64_t* scales, - int Tmax, int B, int D) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int n_elems = Tmax * B * D; - int b_idx = (tid / D) % B; - // scale is vector, (B) - for (; tid < n_elems; tid += gridDim.x * blockDim.x) { - d_out[tid] = d_ctc[tid] * d_loss[b_idx] / scales[b_idx]; - } -} - -template -class WarpCTCGradCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* warpctc_grad = ctx.Input("WarpCTCGrad"); - auto* logits_grad = ctx.Output(framework::GradVarName("Logits")); - const Tensor* loss_grad = ctx.Input(framework::GradVarName("Loss")); - - logits_grad->mutable_data(ctx.GetPlace()); - bool norm_by_times = ctx.Attr("norm_by_times"); - bool norm_by_batchsize = ctx.Attr("norm_by_batchsize"); - bool norm_by_total_logits_len = ctx.Attr("norm_by_total_logits_len"); - - if ((norm_by_times && norm_by_batchsize) || - (norm_by_times && norm_by_total_logits_len) || - (norm_by_batchsize && norm_by_total_logits_len)) { - PADDLE_THROW(platform::errors::InvalidArgument( - "[warpctc grad] norm_by_times, norm_by_batchsize and " - "norm_by_total_logits_len " - "should one be true.")); - } - - if (ctx.HasInput("LogitsLength")) { - auto& dev_ctx = ctx.template device_context(); - auto stream = dev_ctx.stream(); - int max_seq_length = warpctc_grad->dims()[0]; // Tmax - int num_sequences = warpctc_grad->dims()[1]; // B - int seq_width = warpctc_grad->dims()[2]; // D - - auto* logits_length = ctx.Input("LogitsLength"); - const int64_t* logits_length_ptr = logits_length->data(); - - int n_elems = max_seq_length * num_sequences * seq_width; - int num_blocks = - (n_elems + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS; - int shm_bytes = PADDLE_CUDA_NUM_THREADS * sizeof(T); - - auto logits_grad_ptr = - logits_grad->mutable_data(ctx.GetPlace()); // (Tmax, B, D) - auto warpctc_grad_ptr = warpctc_grad->data(); // (Tmax, B, D) - auto loss_grad_ptr = loss_grad->data(); // (B, 1) - - if (norm_by_total_logits_len) { - VLOG(3) << "norm_by_total_logits_len no impl "; - // total length - Tensor total_length; - int64_t* total_length_ptr = - total_length.mutable_data({1}, ctx.GetPlace()); - int bytes = num_sequences * sizeof(int64_t); - ReduceSumKernel<<<1, num_sequences, bytes, stream>>>( - logits_length_ptr, total_length_ptr); - - CTCGradScaleKernel< - T><<>>( - logits_grad_ptr, warpctc_grad_ptr, loss_grad_ptr, total_length_ptr, - max_seq_length, num_sequences, seq_width); - - } else if (norm_by_batchsize) { - VLOG(3) << "norm_by_batchsize "; - CTCGradScaleKernel< - T><<>>( - logits_grad_ptr, warpctc_grad_ptr, loss_grad_ptr, num_sequences, - max_seq_length, num_sequences, seq_width); - } else if (norm_by_times) { - VLOG(3) << "norm_by_times "; - CTCGradBatchScaleKernel< - T><<>>( - logits_grad_ptr, warpctc_grad_ptr, loss_grad_ptr, logits_length_ptr, - max_seq_length, num_sequences, seq_width); - } else { - VLOG(3) << "default "; - CTCGradScaleKernel< - T><<>>( - logits_grad_ptr, warpctc_grad_ptr, loss_grad_ptr, 1, max_seq_length, - num_sequences, seq_width); - } - } else { - math::UnpaddingLoDTensorFunctor()( - ctx.template device_context(), *warpctc_grad, - logits_grad, -1, 0, norm_by_times, norm_by_batchsize, - norm_by_total_logits_len, math::kLengthBatchWidth); - - const T* loss_grad_data = loss_grad->data(); - math::ScaleLoDTensorFunctor()( - ctx.template device_context(), loss_grad_data, - logits_grad); - } - } -}; - -} // operators -} // paddle namespace ops = paddle::operators; - // register forward and backward of CUDA OP must in same *.cu file. // Eigen can be used on GPU device, but must be in *.cu file not *.cu.cc file. // *.cu.cc also using GCC compiler. *.cu using NVCC compiler @@ -199,5 +23,5 @@ REGISTER_OP_CUDA_KERNEL( ops::WarpCTCKernel); REGISTER_OP_CUDA_KERNEL( warpctc_grad, - ops::WarpCTCGradCUDAKernel, - ops::WarpCTCGradCUDAKernel); + ops::WarpCTCGradKernel, + ops::WarpCTCGradKernel); \ No newline at end of file diff --git a/paddle/fluid/operators/warpctc_op.h b/paddle/fluid/operators/warpctc_op.h index b515adc43fd..4cce33c3f52 100644 --- a/paddle/fluid/operators/warpctc_op.h +++ b/paddle/fluid/operators/warpctc_op.h @@ -17,7 +17,6 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/sequence_padding.h" #include "paddle/fluid/operators/math/sequence_scale.h" @@ -152,7 +151,7 @@ class WarpCTCFunctor { PADDLE_ENFORCE_EQ( CTC_STATUS_SUCCESS, status, platform::errors::PreconditionNotMet( - "warp-ctc [version %d] Error in ComputeCtcLossFunctor: %s", + "warp-ctc [version %d] Error in get_workspace_size: %s", warpctc_version_, platform::dynload::ctcGetStatusString(status))); } @@ -315,8 +314,8 @@ class WarpCTCKernel : public framework::OpKernel { math::PaddingLoDTensorFunctor()( ctx.template device_context(), *logits, - &warpctc_logits, pad_value, -1, 0, false /* norm_by_times */, false, - false, math::kLengthBatchWidth); + &warpctc_logits, pad_value, -1, 0, false /* norm_by_times */, + math::kLengthBatchWidth); } const T* warpctc_logits_data = warpctc_logits.data(); @@ -351,7 +350,7 @@ class WarpCTCKernel : public framework::OpKernel { math::UnpaddingLoDTensorFunctor()( ctx.template device_context(), *label, &warpctc_label, label->dims()[1] /*pad_seq_len*/, 0 /*lod_level*/, - false /*norm_by_times*/, false, false, math::kBatchLengthWidth); + false /*norm_by_times*/, math::kBatchLengthWidth); } else { LoDTensor gpu_label; gpu_label.mutable_data( @@ -361,7 +360,7 @@ class WarpCTCKernel : public framework::OpKernel { math::UnpaddingLoDTensorFunctor()( ctx.template device_context(), *label, &gpu_label, label->dims()[1] /*pad_seq_len*/, 0 /*lod_level*/, - false /*norm_by_times*/, false, false, math::kBatchLengthWidth); + false /*norm_by_times*/, math::kBatchLengthWidth); TensorCopySync(gpu_label, platform::CPUPlace(), &warpctc_label); } } else { @@ -390,23 +389,12 @@ template class WarpCTCGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - const Tensor* loss_grad = ctx.Input(framework::GradVarName("Loss")); auto* warpctc_grad = ctx.Input("WarpCTCGrad"); auto* logits_grad = ctx.Output(framework::GradVarName("Logits")); + const Tensor* loss_grad = ctx.Input(framework::GradVarName("Loss")); logits_grad->mutable_data(ctx.GetPlace()); bool norm_by_times = ctx.Attr("norm_by_times"); - bool norm_by_batchsize = ctx.Attr("norm_by_batchsize"); - bool norm_by_total_logits_len = ctx.Attr("norm_by_total_logits_len"); - - if ((norm_by_times && norm_by_batchsize) || - (norm_by_times && norm_by_total_logits_len) || - (norm_by_batchsize && norm_by_total_logits_len)) { - PADDLE_THROW(platform::errors::InvalidArgument( - "[warpctc grad] norm_by_times, norm_by_batchsize and " - "norm_by_total_logits_len " - "should one be true.")); - } if (ctx.HasInput("LogitsLength")) { int max_seq_length = warpctc_grad->dims()[0]; // Tmax @@ -430,20 +418,7 @@ class WarpCTCGradKernel : public framework::OpKernel { loss_grad_e.reshape(grad_shape).broadcast(bcast).eval(); auto* place = ctx.template device_context().eigen_device(); - if (norm_by_total_logits_len) { - // Compute the avg. log-probability per batch sample and frame. - // Rank is 0 - auto inv_len = logits_len_e.sum().cast().inverse().eval(); - logits_grad_e.device(*place) = - logits_g * - inv_len.reshape(Eigen::DSizes{1, 1, 1}) - .broadcast(Eigen::DSizes{max_seq_length, num_sequences, - seq_width}); - } else if (norm_by_batchsize) { - // Compute the avg. log-probability per batch sample. - T scale = 1.0 / static_cast(num_sequences); - logits_grad_e.device(*place) = logits_g * scale; - } else if (norm_by_times) { + if (norm_by_times) { auto scales = logits_len_e.cast() .inverse() .reshape(grad_shape) @@ -456,8 +431,7 @@ class WarpCTCGradKernel : public framework::OpKernel { } else { math::UnpaddingLoDTensorFunctor()( ctx.template device_context(), *warpctc_grad, - logits_grad, -1, 0, norm_by_times, norm_by_batchsize, - norm_by_total_logits_len, math::kLengthBatchWidth); + logits_grad, -1, 0, norm_by_times, math::kLengthBatchWidth); const T* loss_grad_data = loss_grad->data(); math::ScaleLoDTensorFunctor()( diff --git a/python/paddle/fluid/layers/loss.py b/python/paddle/fluid/layers/loss.py index eaac99fc5b5..3db4a894d1a 100644 --- a/python/paddle/fluid/layers/loss.py +++ b/python/paddle/fluid/layers/loss.py @@ -479,9 +479,7 @@ def warpctc(input, blank=0, norm_by_times=False, input_length=None, - label_length=None, - norm_by_batchsize=False, - norm_by_total_logits_len=False): + label_length=None): """ An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc) @@ -518,12 +516,6 @@ def warpctc(input, of Tensor type, it should have shape `[batch_size]` and dtype int64. label_length(Variable): The length for each label sequence if it is of Tensor type, it should have shape `[batch_size]` and dtype int64. - norm_by_batchsize (bool): normalize the loss by the batch size. - If `True`, supersedes `norm_by_times` - (default: `False`) - norm_by_total_logits_len (bool): normalize the loss by the total number of frames - in the batch. If `True`, supersedes `norm_by_batchsize` and `norm_by_times` - (default: `False`) Returns: Variable: The Connectionist Temporal Classification (CTC) loss, @@ -611,12 +603,15 @@ def warpctc(input, "input_length and label_length must not be None in dygraph mode!" ) grad, loss_out = _C_ops.warpctc( - input, label, input_length, label_length, 'blank', blank, - 'norm_by_times', norm_by_times, 'norm_by_batchsize', - norm_by_batchsize, 'norm_by_total_logits_len', - norm_by_total_logits_len) + input, + label, + input_length, + label_length, + 'blank', + blank, + 'norm_by_times', + norm_by_times, ) return loss_out - helper = LayerHelper('warpctc', **locals()) check_variable_and_dtype(input, 'input', ['float32', 'float64'], "warpctc") check_variable_and_dtype(label, 'label', ['int32'], "warpctc") @@ -640,8 +635,6 @@ def warpctc(input, attrs={ 'blank': blank, 'norm_by_times': norm_by_times, - 'norm_by_batchsize': norm_by_batchsize, - 'norm_by_total_logits_len': norm_by_total_logits_len, }) return loss_out diff --git a/python/paddle/fluid/tests/unittests/test_warpctc_op.py b/python/paddle/fluid/tests/unittests/test_warpctc_op.py index 6358cbcf0bb..53f3b3cf53d 100644 --- a/python/paddle/fluid/tests/unittests/test_warpctc_op.py +++ b/python/paddle/fluid/tests/unittests/test_warpctc_op.py @@ -18,7 +18,6 @@ import sys import unittest import numpy as np from op_test import OpTest -from op_test import skip_check_grad_ci from test_softmax_op import stable_softmax import paddle.fluid as fluid import paddle.fluid.core as core @@ -457,220 +456,6 @@ class TestWarpCTCOpFp64(OpTest): self.check_grad(["Logits"], "Loss") -@skip_check_grad_ci(reason="For warpctc, not check grad.") -class TestWarpCTCOpAttr(OpTest): - def config(self): - self.batch_size = 4 - self.num_classes = 8 - self.logits_lod = [[4, 1, 5, 5]] - self.labels_lod = [[3, 1, 4, 2]] - self.logits_length = np.array([4, 1, 5, 5], dtype=np.int64) - self.labels_length = np.array([3, 1, 4, 2], dtype=np.int64) - self.blank = self.num_classes - 1 - self.norm_by_times = False - self.norm_by_batchsize = False - self.norm_by_total_logits_len = False - - def setUp(self): - self.op_type = "warpctc" - self.config() - - logits = np.random.uniform( - 0.1, 1.0, - [sum(self.logits_length), self.num_classes]).astype("float64") - softmax = np.apply_along_axis(stable_softmax, 1, logits) - # labels should not be blank - labels = np.random.randint( - 0, - self.num_classes - 1, [sum(self.labels_length), 1], - dtype="int32") - - ctc = CTCForward(softmax, self.logits_lod, labels, self.labels_lod, - self.num_classes, self.batch_size, self.blank, - self.norm_by_times) - loss = ctc.forward() - - max_sequence_length = 0 - for i in range(self.batch_size): - max_sequence_length = max(max_sequence_length, - self.logits_length[i]) - # reshape logits to T*N*S - new_logits = np.zeros( - [max_sequence_length, self.batch_size, self.num_classes], - dtype=logits.dtype) - - cur = 0 - for batch_id in range(self.batch_size): - for i in range(self.logits_length[batch_id]): - for j in range(self.num_classes): - new_logits[i, batch_id, j] = logits[cur + i, j] - cur = cur + self.logits_length[batch_id] - - # reshape labels to N*S - max_target_seq_length = 0 - for i in range(self.batch_size): - max_target_seq_length = max(max_target_seq_length, - self.labels_length[i]) - new_labels = np.zeros( - [self.batch_size, max_target_seq_length], dtype="int32") - - cur = 0 - for batch_id in range(self.batch_size): - for i in range(self.labels_length[batch_id]): - new_labels[batch_id, i] = labels[cur + i] - cur = cur + self.labels_length[batch_id] - - self.gradient = np.zeros( - [max_sequence_length, self.batch_size, self.num_classes], - dtype=logits.dtype) - - self.inputs = { - "Logits": new_logits, - "Label": new_labels, - "LogitsLength": self.logits_length, - "LabelLength": self.labels_length - } - self.outputs = {"Loss": loss} - self.attrs = { - "blank": self.blank, - "norm_by_times": self.norm_by_times, - "norm_by_batchsize": self.norm_by_batchsize, - "norm_by_total_logits_len": self.norm_by_total_logits_len, - } - - def test_check_output(self): - self.check_output() - - -@skip_check_grad_ci(reason="For warpctc, not check grad.") -class TestWarpCTCOpFp64NormByTimes(TestWarpCTCOpAttr): - def config(self): - self.batch_size = 4 - self.num_classes = 8 - self.logits_lod = [[4, 1, 5, 5]] - self.labels_lod = [[3, 1, 4, 2]] - self.logits_length = np.array([4, 1, 5, 5], dtype=np.int64) - self.labels_length = np.array([3, 1, 4, 2], dtype=np.int64) - self.blank = self.num_classes - 1 - self.norm_by_times = True - self.norm_by_batchsize = False - self.norm_by_total_logits_len = False - - -@skip_check_grad_ci(reason="For warpctc, not check grad.") -class TestWarpCTCOpFp64SizeAverage(TestWarpCTCOpAttr): - def config(self): - self.batch_size = 4 - self.num_classes = 8 - self.logits_lod = [[4, 1, 5, 5]] - self.labels_lod = [[3, 1, 4, 2]] - self.logits_length = np.array([4, 1, 5, 5], dtype=np.int64) - self.labels_length = np.array([3, 1, 4, 2], dtype=np.int64) - self.blank = self.num_classes - 1 - self.norm_by_times = False - self.norm_by_batchsize = True - self.norm_by_total_logits_len = False - - -@skip_check_grad_ci(reason="For warpctc, not check grad.") -class TestWarpCTCOpFp64LengthAverage(TestWarpCTCOpAttr): - def config(self): - self.batch_size = 4 - self.num_classes = 8 - self.logits_lod = [[4, 1, 5, 5]] - self.labels_lod = [[3, 1, 4, 2]] - self.logits_length = np.array([4, 1, 5, 5], dtype=np.int64) - self.labels_length = np.array([3, 1, 4, 2], dtype=np.int64) - self.blank = self.num_classes - 1 - self.norm_by_times = False - self.norm_by_batchsize = False - self.norm_by_total_logits_len = True - - -class TestWarpCTCOpDygraph(unittest.TestCase): - def test_dygraph(self): - places = ['cpu'] - if paddle.is_compiled_with_cuda(): - places += ['gpu:0'] - - for p in places: - paddle.set_device(p) - paddle.disable_static() - paddle.seed(1) - np.random.seed(1) - #(B=2) - log_probs = np.array( - [[[4.17021990e-01, 7.20324516e-01, 1.14374816e-04], - [3.02332580e-01, 1.46755889e-01, 9.23385918e-02]], [ - [1.86260208e-01, 3.45560730e-01, 3.96767467e-01], - [5.38816750e-01, 4.19194520e-01, 6.85219526e-01] - ], [[2.04452246e-01, 8.78117442e-01, 2.73875929e-02], - [6.70467496e-01, 4.17304814e-01, 5.58689833e-01]], - [[1.40386939e-01, 1.98101491e-01, 8.00744593e-01], - [9.68261600e-01, 3.13424170e-01, 6.92322612e-01]], - [[8.76389146e-01, 8.94606650e-01, 8.50442126e-02], - [3.90547849e-02, 1.69830427e-01, - 8.78142476e-01]]]).astype("float32") - labels = np.array([[1, 2, 2], [1, 2, 2]]).astype("int32") - input_lengths = np.array([5, 5]).astype("int64") - label_lengths = np.array([3, 3]).astype("int64") - - log_probs = paddle.to_tensor(log_probs, stop_gradient=False) - labels = paddle.to_tensor(labels) - input_lengths = paddle.to_tensor(input_lengths) - label_lengths = paddle.to_tensor(label_lengths) - - loss = paddle.nn.CTCLoss( - blank=0, reduction='sum')(log_probs, - labels, - input_lengths, - label_lengths, - norm_by_times=False, - norm_by_batchsize=False, - norm_by_total_logits_len=False) - self.assertTrue(np.allclose(loss, [6.82563686], atol=1)) - loss.backward() - log_probs.clear_gradient() - - loss = paddle.nn.CTCLoss( - blank=0, reduction='sum')(log_probs, - labels, - input_lengths, - label_lengths, - norm_by_times=True, - norm_by_batchsize=False, - norm_by_total_logits_len=False) - self.assertTrue(np.allclose(loss, [6.82563686], atol=1)) - loss.backward() - log_probs.clear_gradient() - - loss = paddle.nn.CTCLoss( - blank=0, reduction='sum')(log_probs, - labels, - input_lengths, - label_lengths, - norm_by_times=False, - norm_by_batchsize=True, - norm_by_total_logits_len=False) - self.assertTrue(np.allclose(loss, [6.82563686], atol=1)) - loss.backward() - log_probs.clear_gradient() - - loss = paddle.nn.CTCLoss( - blank=0, reduction='sum')(log_probs, - labels, - input_lengths, - label_lengths, - norm_by_times=False, - norm_by_batchsize=False, - norm_by_total_logits_len=True) - self.assertTrue(np.allclose(loss, [6.82563686], atol=1)) - loss.backward() - log_probs.clear_gradient() - - paddle.enable_static() - - class TestWarpCTCOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index b1db45ad506..c353451d0c8 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1001,9 +1001,7 @@ def ctc_loss(log_probs, label_lengths, blank=0, reduction='mean', - norm_by_times=False, - norm_by_batchsize=False, - norm_by_total_logits_len=False): + norm_by_times=False): """ An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc) @@ -1019,9 +1017,7 @@ def ctc_loss(log_probs, blank (int, optional): The blank label index of Connectionist Temporal Classification (CTC) loss, which is in the half-opened interval [0, num_classes + 1). The data type must be int32. Default is 0. reduction (string, optional): Indicate how to average the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output loss will be divided by the label_lengths, and then return the mean of quotient; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default is ``'mean'``. norm_by_times (bool, default False) – Whether to normalize the gradients by the number of time-step, which is also the sequence’s length. There is no need to normalize the gradients if reduction mode is 'mean'. - norm_by_batchsize (bool): normalize the loss by the batch size (default: `False`). If `True`, supersedes `norm_by_times` (default: `False`) - norm_by_total_logits_len (bool): normalize the loss by the total number of frames in the batch. If `True`, supersedes `norm_by_batchsize` and `norm_by_times` (default: `False`) - + Returns: Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``log_probs``. @@ -1029,7 +1025,6 @@ def ctc_loss(log_probs, .. code-block:: python - # required: skiptest # declarative mode import paddle.nn.functional as F import numpy as np @@ -1086,10 +1081,9 @@ def ctc_loss(log_probs, """ loss_out = fluid.layers.warpctc(log_probs, labels, blank, norm_by_times, - input_lengths, label_lengths, - norm_by_batchsize, norm_by_total_logits_len) + input_lengths, label_lengths) - loss_out = fluid.layers.squeeze(loss_out, [-1]) # (B) + loss_out = fluid.layers.squeeze(loss_out, [-1]) assert reduction in ['mean', 'sum', 'none'] if reduction == 'mean': loss_out = paddle.mean(loss_out / label_lengths) @@ -1544,7 +1538,7 @@ def cross_entropy(input, Indicate how to average the loss by batch_size, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; - If :attr:`norm_by_batchsize` is ``'sum'``, the reduced sum loss is returned. + If :attr:`size_average` is ``'sum'``, the reduced sum loss is returned. If :attr:`reduction` is ``'none'``, the unreduced loss is returned. Default is ``'mean'``. diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 781e13867f2..3ac0d675fb7 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1119,9 +1119,7 @@ class CTCLoss(Layer): labels, input_lengths, label_lengths, - norm_by_times=False, - norm_by_batchsize=False, - norm_by_total_logits_len=False): + norm_by_times=False): return paddle.nn.functional.ctc_loss( log_probs, labels, @@ -1129,9 +1127,7 @@ class CTCLoss(Layer): label_lengths, self.blank, self.reduction, - norm_by_times=norm_by_times, - norm_by_batchsize=norm_by_batchsize, - norm_by_total_logits_len=norm_by_total_logits_len) + norm_by_times=norm_by_times) class SmoothL1Loss(Layer): -- GitLab