diff --git a/paddle/fluid/operators/math/sequence_padding.cc b/paddle/fluid/operators/math/sequence_padding.cc index 2dd2cafa23babbc11271955b8cce84a629f8f738..5ceb26553c02c1be18358b7f110068303d2498fa 100644 --- a/paddle/fluid/operators/math/sequence_padding.cc +++ b/paddle/fluid/operators/math/sequence_padding.cc @@ -18,111 +18,114 @@ namespace paddle { namespace operators { namespace math { -template +template void CopyDataCPU(framework::LoDTensor* seq_tensor, - framework::Tensor* padding_tensor, - const framework::Vector& abs_offset, + framework::Tensor* pad_tensor, + const framework::Vector& seq_offset, const int64_t& max_seq_len, const int64_t& seq_width, - bool seq_to_padding, bool norm_by_len) { + bool seq_to_pad, bool norm_by_len, + OutputLayout output_layout) { T* seq_data = seq_tensor->data(); - T* padding_data = padding_tensor->data(); + T* pad_data = pad_tensor->data(); - int64_t seq_num = abs_offset.size() - 1; + int64_t seq_num = seq_offset.size() - 1; for (int64_t i = 0; i < seq_num; ++i) { - int64_t seq_start = abs_offset[i]; - int64_t seq_len = abs_offset[i + 1] - seq_start; - + int64_t seq_start = seq_offset[i]; + int64_t seq_len = seq_offset[i + 1] - seq_start; T scale = norm_by_len ? (1.0f / static_cast(seq_len)) : 1.0f; - for (int64_t j = 0; j < seq_len; ++j) { for (int64_t k = 0; k < seq_width; ++k) { - size_t padding_offset = 0; - if (padding_layout == BATCH_LENGTH_WIDTH) { - padding_offset = (i * max_seq_len * seq_width) + j * seq_width + k; + size_t pad_data_idx = 0; + size_t seq_data_idx = (seq_start + j) * seq_width + k; + if (output_layout == kBatchLengthWidth) { + pad_data_idx = (i * max_seq_len + j) * seq_width + k; } else { - padding_offset = (j * seq_num * seq_width) + i * seq_width + k; + pad_data_idx = (j * seq_num + i) * seq_width + k; } - if (seq_to_padding) { - padding_data[padding_offset] = - seq_data[(seq_start + j) * seq_width + k] * scale; + if (seq_to_pad) { + pad_data[pad_data_idx] = seq_data[seq_data_idx] * scale; } else { - seq_data[(seq_start + j) * seq_width + k] = - padding_data[padding_offset] * scale; + seq_data[seq_data_idx] = pad_data[pad_data_idx] * scale; } } } } } -template -class PaddingLoDTensorFunctor { +template +class PaddingLoDTensorFunctor { public: void operator()(const platform::CPUDeviceContext& context, const framework::LoDTensor& seq_tensor, - framework::Tensor* padding_tensor, - T padding_value = static_cast(0), - bool norm_by_times = false, size_t lod_level = 0) { - ValidateLoD(seq_tensor, lod_level); + framework::Tensor* pad_tensor, + T pad_value = static_cast(0), bool norm_by_times = false, + size_t lod_level = 0, + OutputLayout output_layout = kBatchLengthWidth) { + CheckLoD(seq_tensor, lod_level); auto& lod = seq_tensor.lod(); - auto& abs_offset = framework::ToAbsOffset(lod)[lod_level]; + auto& seq_offset = framework::ToAbsOffset(lod)[lod_level]; - auto seq_dims = seq_tensor.dims(); - auto padding_dims = padding_tensor->dims(); - int64_t max_seq_len = MaximumSequenceLength(lod, lod_level); - int64_t seq_num = abs_offset.size() - 1; - int64_t seq_width = seq_tensor.numel() / seq_dims[0]; - int64_t numel = max_seq_len * seq_num * seq_width; + auto seq_tensor_dims = seq_tensor.dims(); + auto pad_tensor_dims = pad_tensor->dims(); + int64_t max_seq_len = MaximumSequenceLength(seq_offset); + int64_t seq_num = seq_offset.size() - 1; + int64_t seq_width = seq_tensor.numel() / seq_tensor_dims[0]; - ValidateShape(seq_dims, abs_offset.back(), padding_dims, max_seq_len, - seq_num, seq_width, padding_layout); + CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len, + seq_num, seq_width, output_layout); - T* padding_data = padding_tensor->data(); + T* pad_data = pad_tensor->data(); - memset(padding_data, padding_value, numel * sizeof(T)); + memset(pad_data, pad_value, max_seq_len * seq_num * seq_width * sizeof(T)); - CopyDataCPU( - const_cast(&seq_tensor), padding_tensor, - abs_offset, max_seq_len, seq_width, true /* seq_to_padding */, - norm_by_times); + CopyDataCPU(const_cast(&seq_tensor), pad_tensor, + seq_offset, max_seq_len, seq_width, true /* seq_to_pad */, + norm_by_times, output_layout); } }; -template -class UnpaddingLoDTensorFunctor { +template +class UnpaddingLoDTensorFunctor { public: void operator()(const platform::CPUDeviceContext& context, framework::LoDTensor* seq_tensor, - const framework::Tensor& padding_tensor, - bool norm_by_times = false, size_t lod_level = 0) { - ValidateLoD(*seq_tensor, lod_level); + const framework::Tensor& pad_tensor, + bool norm_by_times = false, size_t lod_level = 0, + OutputLayout output_layout = kBatchLengthWidth) { + CheckLoD(*seq_tensor, lod_level); auto& lod = seq_tensor->lod(); - auto& abs_offset = framework::ToAbsOffset(lod)[lod_level]; + auto& seq_offset = framework::ToAbsOffset(lod)[lod_level]; - auto& seq_dims = seq_tensor->dims(); - auto& padding_dims = padding_tensor.dims(); - int64_t max_seq_len = MaximumSequenceLength(lod, lod_level); - int64_t seq_num = abs_offset.size() - 1; - int64_t seq_width = seq_tensor->numel() / seq_dims[0]; + auto& seq_tensor_dims = seq_tensor->dims(); + auto& pad_tensor_dims = pad_tensor.dims(); + int64_t max_seq_len = MaximumSequenceLength(seq_offset); + int64_t seq_num = seq_offset.size() - 1; + int64_t seq_width = seq_tensor->numel() / seq_tensor_dims[0]; - ValidateShape(seq_dims, abs_offset.back(), padding_dims, max_seq_len, - seq_num, seq_width, padding_layout); + CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len, + seq_num, seq_width, output_layout); T* seq_data = seq_tensor->data(); memset(seq_data, static_cast(0), seq_tensor->numel() * sizeof(T)); - CopyDataCPU( - seq_tensor, const_cast(&padding_tensor), abs_offset, - max_seq_len, seq_width, false /* seq_to_padding */, norm_by_times); + CopyDataCPU(seq_tensor, const_cast(&pad_tensor), + seq_offset, max_seq_len, seq_width, false /* seq_to_pad */, + norm_by_times, output_layout); } }; -template class PaddingLoDTensorFunctor; -template class UnpaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; + +template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/sequence_padding.cu b/paddle/fluid/operators/math/sequence_padding.cu index 2377bef0247fa88b26079fbda16c0ffc56318fa3..20e3e3de2aa681d7474501b296b0445aa2dceca7 100644 --- a/paddle/fluid/operators/math/sequence_padding.cu +++ b/paddle/fluid/operators/math/sequence_padding.cu @@ -21,74 +21,74 @@ namespace math { template __global__ void SequencePaddingKernel( - T* padding_data, T* seq_data, const size_t* abs_offset, - const size_t& seq_num, const size_t& max_seq_len, const size_t& seq_width, - const PaddingLayout& padding_layout, bool norm_by_times = false, - const T& padding_value = 0) { - size_t padding_idx = blockIdx.y; - size_t seq_start = abs_offset[padding_idx]; - size_t seq_len = abs_offset[padding_idx + 1] - seq_start; + T* pad_data, T* seq_data, const size_t* seq_offset, const size_t& seq_num, + const size_t& max_seq_len, const size_t& seq_width, bool norm_by_times, + const T& pad_value, const OutputLayout& output_layout) { + size_t seq_idx = blockIdx.y; + size_t seq_start = seq_offset[seq_idx]; + size_t seq_len = seq_offset[seq_idx + 1] - seq_start; - size_t seq_idx = blockIdx.x * blockDim.y + threadIdx.y; + size_t seq_step_idx = blockIdx.x * blockDim.y + threadIdx.y; - size_t seq_offset = (seq_start + seq_idx) * seq_width; + size_t seq_data_offset = (seq_start + seq_step_idx) * seq_width; - size_t padding_offset = 0; + size_t pad_data_offset = 0; - if (padding_layout == LENGTH_BATCH_WIDTH) { - padding_offset = (seq_idx * seq_num + padding_idx) * seq_width; + if (output_layout == kLengthBatchWidth) { + pad_data_offset = (seq_step_idx * seq_num + seq_idx) * seq_width; } else { - padding_offset = (padding_idx * max_seq_len + seq_idx) * seq_width; + pad_data_offset = (seq_idx * max_seq_len + seq_step_idx) * seq_width; } - if (seq_idx < seq_len) { + if (seq_step_idx < seq_len) { T scale = norm_by_times ? (1.0f / static_cast(seq_len)) : 1.0f; if (Padding) { - /* sequence -> padding */ + /* seq -> pad */ for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) { - padding_data[padding_offset + i] = scale * seq_data[seq_offset + i]; + pad_data[pad_data_offset + i] = scale * seq_data[seq_data_offset + i]; } } else { - /* padding -> sequence */ + /* pad -> seq */ for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) { - seq_data[seq_offset + i] = scale * padding_data[padding_offset + i]; + seq_data[seq_data_offset + i] = scale * pad_data[pad_data_offset + i]; } } - } else if (seq_idx < max_seq_len) { + } else if (seq_step_idx < max_seq_len) { if (Padding) { - /* sequence -> padding */ + /* seq -> pad */ for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) { - padding_data[padding_offset + i] = padding_value; + pad_data[pad_data_offset + i] = pad_value; } } } } -template -class PaddingLoDTensorFunctor { +template +class PaddingLoDTensorFunctor { public: void operator()(const platform::CUDADeviceContext& context, const framework::LoDTensor& seq_tensor, - framework::Tensor* padding_tensor, - T padding_value = static_cast(0), - bool norm_by_times = false, size_t lod_level = 0) { - ValidateLoD(seq_tensor, lod_level); + framework::Tensor* pad_tensor, + T pad_value = static_cast(0), bool norm_by_times = false, + size_t lod_level = 0, + OutputLayout output_layout = kBatchLengthWidth) { + CheckLoD(seq_tensor, lod_level); auto& lod = seq_tensor.lod(); - auto& abs_offset = framework::ToAbsOffset(lod)[lod_level]; + auto& seq_offset = framework::ToAbsOffset(lod)[lod_level]; - auto seq_dims = seq_tensor.dims(); - auto padding_dims = padding_tensor->dims(); - int64_t max_seq_len = MaximumSequenceLength(lod, lod_level); - const int64_t seq_num = abs_offset.size() - 1; - const int64_t seq_width = seq_tensor.numel() / seq_dims[0]; + auto seq_tensor_dims = seq_tensor.dims(); + auto pad_tensor_dims = pad_tensor->dims(); + int64_t max_seq_len = MaximumSequenceLength(seq_offset); + int64_t seq_num = seq_offset.size() - 1; + int64_t seq_width = seq_tensor.numel() / seq_tensor_dims[0]; - ValidateShape(seq_dims, abs_offset.back(), padding_dims, max_seq_len, - seq_num, seq_width, padding_layout); + CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len, + seq_num, seq_width, output_layout); if (!norm_by_times && seq_num == 1UL) { - TensorCopy(seq_tensor, context.GetPlace(), context, padding_tensor); - padding_tensor->Resize(padding_dims); + TensorCopy(seq_tensor, context.GetPlace(), context, pad_tensor); + pad_tensor->Resize(pad_tensor_dims); return; } @@ -107,37 +107,40 @@ class PaddingLoDTensorFunctor { dim3 grid(grid_dim_x, grid_dim_y); const T* seq_data = seq_tensor.data(); - T* padding_data = padding_tensor->data(); + T* pad_data = pad_tensor->data(); SequencePaddingKernel<<>>( - padding_data, const_cast(seq_data), - abs_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len, - seq_width, padding_layout, norm_by_times, padding_value); + pad_data, const_cast(seq_data), + seq_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len, + seq_width, norm_by_times, pad_value, output_layout); } }; -template -class UnpaddingLoDTensorFunctor { +template +class UnpaddingLoDTensorFunctor { public: void operator()(const platform::CUDADeviceContext& context, framework::LoDTensor* seq_tensor, - const framework::Tensor& padding_tensor, - bool norm_by_times = false, size_t lod_level = 0) { - ValidateLoD(*seq_tensor, lod_level); + const framework::Tensor& pad_tensor, + bool norm_by_times = false, size_t lod_level = 0, + OutputLayout output_layout = kBatchLengthWidth) { + CheckLoD(*seq_tensor, lod_level); auto& lod = seq_tensor->lod(); - auto& abs_offset = framework::ToAbsOffset(lod)[lod_level]; + auto& seq_offset = framework::ToAbsOffset(lod)[lod_level]; - auto seq_dims = seq_tensor->dims(); - auto padding_dims = padding_tensor.dims(); - int64_t max_seq_len = MaximumSequenceLength(lod, lod_level); - int64_t seq_num = abs_offset.size() - 1; - int64_t seq_width = seq_tensor->numel() / seq_dims[0]; + auto seq_tensor_dims = seq_tensor->dims(); + auto pad_tensor_dims = pad_tensor.dims(); + int64_t max_seq_len = MaximumSequenceLength(seq_offset); + int64_t seq_num = seq_offset.size() - 1; + int64_t seq_width = seq_tensor->numel() / seq_tensor_dims[0]; + + CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len, + seq_num, seq_width, output_layout); if (!norm_by_times && seq_num == 1UL) { - TensorCopy(padding_tensor, context.GetPlace(), context, seq_tensor); - seq_tensor->Resize(seq_dims); + TensorCopy(pad_tensor, context.GetPlace(), context, seq_tensor); + seq_tensor->Resize(seq_tensor_dims); return; } @@ -155,20 +158,25 @@ class UnpaddingLoDTensorFunctor(); + const T* pad_data = pad_tensor.data(); T* seq_data = seq_tensor->data(); - SequencePaddingKernel<<>>( - const_cast(padding_data), seq_data, - abs_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len, - seq_width, padding_layout, norm_by_times); + SequencePaddingKernel<<>>( + const_cast(pad_data), seq_data, + seq_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len, + seq_width, norm_by_times, static_cast(0), output_layout); } }; -template class PaddingLoDTensorFunctor; -template class UnpaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; +template class PaddingLoDTensorFunctor; + +template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; +template class UnpaddingLoDTensorFunctor; } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/sequence_padding.h b/paddle/fluid/operators/math/sequence_padding.h index 91d205641acd6e674d0f60c6d9e373b08d0fec19..44d6404335966bde70ef1f19e80fc860b26a9bde 100644 --- a/paddle/fluid/operators/math/sequence_padding.h +++ b/paddle/fluid/operators/math/sequence_padding.h @@ -22,49 +22,46 @@ namespace paddle { namespace operators { namespace math { -enum PaddingLayout { BATCH_LENGTH_WIDTH, LENGTH_BATCH_WIDTH }; +enum OutputLayout { kBatchLengthWidth = 0, kLengthBatchWidth }; -inline static size_t MaximumSequenceLength(const framework::LoD& lod, - const size_t level) { - const size_t seq_num = lod[level].size() - 1; +inline static size_t MaximumSequenceLength( + const framework::Vector& seq_offset) { + size_t seq_num = seq_offset.size() - 1; size_t max_seq_len = 0; - auto abs_offset = framework::ToAbsOffset(lod)[level]; for (size_t i = 0; i < seq_num; ++i) { - max_seq_len = std::max(max_seq_len, abs_offset[i + 1] - abs_offset[i]); + max_seq_len = std::max(max_seq_len, seq_offset[i + 1] - seq_offset[i]); } return max_seq_len; } -inline static void ValidateLoD(const framework::LoDTensor& seq_tensor, - const size_t& lod_level) { +inline static void CheckLoD(const framework::LoDTensor& seq_tensor, + const size_t& lod_level) { PADDLE_ENFORCE(lod_level < seq_tensor.lod().size(), - "Invalid `lod_level` which should be at least 0 and less " - "than maximum lod level of `seq_tensor`."); + "Invalid lod level which should be at least 0 and less " + "than maximum lod level of sequence tensor."); } -inline static void ValidateShape(const framework::DDim& seq_tensor_dims, - const size_t& abs_offset_back_value, - const framework::DDim& padding_tensor_dims, - const int64_t& max_seq_len, - const int64_t& seq_num, - const int64_t& seq_width, - const PaddingLayout& padding_layout) { - PADDLE_ENFORCE_EQ(static_cast(seq_tensor_dims[0]), - abs_offset_back_value, - "The 1st dimension of `seq_tensor` should be equal to " - "sum of lengths of all sequences."); +inline static void CheckDims(const framework::DDim& seq_tensor_dims, + const size_t& last_offset, + const framework::DDim& pad_tensor_dims, + const int64_t& max_seq_len, const int64_t& seq_num, + const int64_t& seq_width, + const OutputLayout& output_layout) { + PADDLE_ENFORCE_EQ(static_cast(seq_tensor_dims[0]), last_offset, + "Value of 1st dimension of the sequence tensor should be " + "equal to sum of lengths of all sequences."); - PADDLE_ENFORCE_EQ(padding_tensor_dims.size(), 3UL, - "`padding_tensor` should be a 3-D tensor."); + PADDLE_ENFORCE_EQ(pad_tensor_dims.size(), 3UL, + "Padded tensor should be a 3-D tensor."); - if (padding_layout == BATCH_LENGTH_WIDTH) { - PADDLE_ENFORCE_EQ(padding_tensor_dims, + if (output_layout == kBatchLengthWidth) { + PADDLE_ENFORCE_EQ(pad_tensor_dims, framework::make_ddim({seq_num, max_seq_len, seq_width})); - } else if (padding_layout == LENGTH_BATCH_WIDTH) { - PADDLE_ENFORCE_EQ(padding_tensor_dims, + } else if (output_layout == kLengthBatchWidth) { + PADDLE_ENFORCE_EQ(pad_tensor_dims, framework::make_ddim({max_seq_len, seq_num, seq_width})); } else { - PADDLE_THROW("Unsupported padding layout."); + PADDLE_THROW("Unsupported output layout."); } } @@ -94,23 +91,25 @@ inline static void ValidateShape(const framework::DDim& seq_tensor_dims, * * \note transposition is also done in this functor. */ -template +template class PaddingLoDTensorFunctor { public: void operator()(const DeviceContext& context, const framework::LoDTensor& seq_tensor, - framework::Tensor* padding_tensor, - T padding_value = static_cast(0), - bool norm_by_times = false, size_t lod_level = 0); + framework::Tensor* pad_tensor, + T pad_value = static_cast(0), bool norm_by_times = false, + size_t lod_level = 0, + OutputLayout output_layout = kBatchLengthWidth); }; -template +template class UnpaddingLoDTensorFunctor { public: void operator()(const DeviceContext& context, framework::LoDTensor* seq_tensor, - const framework::Tensor& padding_tensor, - bool norm_by_times = false, size_t lod_level = 0); + const framework::Tensor& pad_tensor, + bool norm_by_times = false, size_t lod_level = 0, + OutputLayout output_layout = kBatchLengthWidth); }; } // namespace math diff --git a/paddle/fluid/operators/math/sequence_padding_test.cc b/paddle/fluid/operators/math/sequence_padding_test.cc index b0c201db0ccbe81d8f57cd984d2cdfd2f6a48f25..82459274c46e1df3ce5a63b868107c88666c857f 100644 --- a/paddle/fluid/operators/math/sequence_padding_test.cc +++ b/paddle/fluid/operators/math/sequence_padding_test.cc @@ -46,20 +46,24 @@ void TestSequencePadding(const paddle::framework::LoD& lod, } const size_t max_sequence_length = - paddle::operators::math::MaximumSequenceLength(lod, level); + paddle::operators::math::MaximumSequenceLength(lod[level]); const size_t num_sequences = lod[level].size() - 1; auto padding_dims = paddle::framework::make_ddim({static_cast(max_sequence_length), static_cast(num_sequences), static_cast(sequence_width)}); + padding.mutable_data(padding_dims, *place); + paddle::operators::math::PaddingLoDTensorFunctor()( - *context, seq, &padding, false); + *context, seq, &padding, 0, false, 0, + paddle::operators::math::kLengthBatchWidth); seq_back.set_lod(lod); seq_back.mutable_data(seq_dims, *place); paddle::operators::math::UnpaddingLoDTensorFunctor()( - *context, &seq_back, padding, false); + *context, &seq_back, padding, false, 0, + paddle::operators::math::kLengthBatchWidth); if (paddle::platform::is_cpu_place(*place)) { cpu_seq_back = seq_back; diff --git a/paddle/fluid/operators/sequence_pad_op.cc b/paddle/fluid/operators/sequence_pad_op.cc index f3a6fff0e10a745f00c9582d7d43115d5bd4d7a6..dc79b252c7d36bce846e29a6f8b84f24612b8bd9 100644 --- a/paddle/fluid/operators/sequence_pad_op.cc +++ b/paddle/fluid/operators/sequence_pad_op.cc @@ -54,7 +54,7 @@ class SequencePadOp : public framework::OperatorWithKernel { seq_num = x_abs_offset.size() - 1; - for (size_t i = 1; i <= seq_num; ++i) { + for (int64_t i = 1; i <= seq_num; ++i) { int64_t seq_len = x_abs_offset[i] - x_abs_offset[i - 1]; max_len = max_len < seq_len ? seq_len : max_len; } diff --git a/paddle/fluid/operators/warpctc_op.h b/paddle/fluid/operators/warpctc_op.h index cb9db1682032c3c181e0d93e38268893609b9fae..cb56f42a8db3d4765f543e42ddb38d8382968e24 100644 --- a/paddle/fluid/operators/warpctc_op.h +++ b/paddle/fluid/operators/warpctc_op.h @@ -155,15 +155,16 @@ class WarpCTCKernel : public framework::OpKernel { // warpctc needs sequences data stored in transposed padding format Tensor warpctc_logits; const size_t max_sequence_length = - math::MaximumSequenceLength(logits_lod, level); + math::MaximumSequenceLength(logits_lod[level]); auto warpctc_logits_dims = framework::make_ddim({static_cast(max_sequence_length), static_cast(num_sequences), static_cast(sequence_width)}); warpctc_logits.mutable_data(warpctc_logits_dims, ctx.GetPlace()); - math::PaddingLoDTensorFunctor()( + math::PaddingLoDTensorFunctor()( ctx.template device_context(), *logits, &warpctc_logits, - false); + static_cast(0), false /* norm_by_times */, 0, + math::kLengthBatchWidth); const T* warpctc_logits_data = warpctc_logits.data(); std::vector warpctc_label_lengths(num_sequences); @@ -215,10 +216,9 @@ class WarpCTCGradKernel : public framework::OpKernel { logits_grad->mutable_data(ctx.GetPlace()); bool norm_by_times = ctx.Attr("norm_by_times"); - math::UnpaddingLoDTensorFunctor()( + math::UnpaddingLoDTensorFunctor()( ctx.template device_context(), logits_grad, - *warpctc_grad, norm_by_times); + *warpctc_grad, norm_by_times, 0, math::kLengthBatchWidth); const T* loss_grad_data = loss_grad->data(); math::ScaleLoDTensorFunctor()(