From 89de5d5e6628ffaebe9e1fb7070d5308e5b8ab9a Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Thu, 11 Jan 2018 10:10:03 +0800 Subject: [PATCH] Fix cuda kernel of sequence scale functor --- paddle/operators/math/sequence_scale.cu | 6 +++--- paddle/operators/math/sequence_scale.h | 16 +++++++++------- paddle/operators/warpctc_op.h | 19 ------------------- 3 files changed, 12 insertions(+), 29 deletions(-) diff --git a/paddle/operators/math/sequence_scale.cu b/paddle/operators/math/sequence_scale.cu index 23b0cce13f..fd1370c118 100644 --- a/paddle/operators/math/sequence_scale.cu +++ b/paddle/operators/math/sequence_scale.cu @@ -22,16 +22,16 @@ template __global__ void SequenceScaleKernel(T* seq, size_t* lod, const T* scales, const size_t num_seq, const size_t seq_width) { - size_t idx = blockIdx.x * blockDim.y + threadIdx.y; + const int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < lod[num_seq]) { + if (idx < lod[num_seq] * seq_width) { size_t i = 0; for (i = 0; i < num_seq; ++i) { if (idx < lod[i + 1] * seq_width) { break; } } - seq[i] *= scales[i]; + seq[idx] *= scales[i]; } } diff --git a/paddle/operators/math/sequence_scale.h b/paddle/operators/math/sequence_scale.h index a42fc6d0db..8c47179b55 100644 --- a/paddle/operators/math/sequence_scale.h +++ b/paddle/operators/math/sequence_scale.h @@ -27,19 +27,21 @@ namespace math { * All sequences will be padded to the same length and stored in a transposed * shape. * Example: - * seq (s0, s0, s0, s0; s1, s1; s2, s2, s2; s3) - * padding (s0, s1, s2, s3; s0, s1, s2, 0; s0, 0, s2, 0; s0, 0, 0, 0) + * Given: + * seq = (s0, s0, s0, s0; s1, s1; s2, s2, s2; s3) + * scales = (2, 3, 4, 5) + * then: + * result = (2*s0, 2*s0, 2*s0, 2*s0; 3*s1, 3*s1; 4*s2, 4*s2, 4*s2; 5*s3) + * - * \param context device context of this functor. + * \param context Device context of this functor. * \param seq LoDTensor which is stored in sequence format, the shape * is [total_sequence_length, sequence_width] where * total_sequence_length is the sum of all sequences' * length. - * \param padding Tensor which is padded to the same length, the shape is - * [max_sequence_length, num_sequences, sequence_width]. - * \param norm_by_times whether dividing sequence's length. + * \param scales Array. The i-th sequence will be scaled by scales[i]. + * \param num_seq Number of sequence * - * \note transposition is also done in this functor. */ template class ScaleLoDTensorFunctor { diff --git a/paddle/operators/warpctc_op.h b/paddle/operators/warpctc_op.h index c2bbceb6d1..d41752e733 100644 --- a/paddle/operators/warpctc_op.h +++ b/paddle/operators/warpctc_op.h @@ -14,7 +14,6 @@ limitations under the License. */ #pragma once -#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/sequence_padding.h" @@ -209,12 +208,6 @@ class WarpCTCGradKernel : public framework::OpKernel { auto* logits_grad = ctx.Output(framework::GradVarName("Logits")); const Tensor* loss_grad = ctx.Input(framework::GradVarName("Loss")); - // LOG(ERROR) << "loss_grad_dims: " << loss_grad_dims; - // for (int i=0; inumel();i++) { - // LOG(ERROR) << "loss_grad: " << loss_grad_data[i]; - //} - - // T* logits_grad_data = logits_grad->mutable_data(ctx.GetPlace()); bool norm_by_times = ctx.Attr("norm_by_times"); math::UnpaddingLoDTensorFunctor()( @@ -226,18 +219,6 @@ class WarpCTCGradKernel : public framework::OpKernel { math::ScaleLoDTensorFunctor()( ctx.template device_context(), *logits_grad, loss_grad_data, num_seq); - /* - int level = 0; - auto logits_grad_lod = framework::ToAbsOffset(logits_grad->lod()); - const size_t num_sequences = logits_grad_lod[level].size() - 1; - for (int seq_index = 0; seq_index < num_sequences; ++seq_index) { - for (int token_index = logits_grad_lod[level][seq_index]; - token_index < logits_grad_lod[level][seq_index + 1]; - ++token_index) { - logits_grad_data[token_index] *= loss_grad_data[seq_index]; - } - } - */ } }; -- GitLab