diff --git a/paddle/operators/math/sequence_scale.cc b/paddle/operators/math/sequence_scale.cc index 0f66e43a1a650a53b0753a481535c2951ccec9f6..7e439e9a2cebaa5d494b185fd878e293a6895e45 100644 --- a/paddle/operators/math/sequence_scale.cc +++ b/paddle/operators/math/sequence_scale.cc @@ -22,10 +22,10 @@ template class ScaleLoDTensorFunctor { public: void operator()(const platform::CPUDeviceContext& context, - framework::LoDTensor& seq, const T* scales, - const size_t num_seq) { + framework::LoDTensor& seq, const T* scales) { const size_t level = 0; auto lod = seq.lod(); + const size_t num_seq = lod[level].size() - 1; size_t seq_width = seq.dims()[1]; framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); diff --git a/paddle/operators/math/sequence_scale.cu b/paddle/operators/math/sequence_scale.cu index fd1370c118407386d9fb5a93b5dba20349410cba..bc89711fcb957e7ec899085c00e962354220d5d8 100644 --- a/paddle/operators/math/sequence_scale.cu +++ b/paddle/operators/math/sequence_scale.cu @@ -20,18 +20,10 @@ namespace math { template __global__ void SequenceScaleKernel(T* seq, size_t* lod, const T* scales, - const size_t num_seq, const size_t seq_width) { - const int idx = threadIdx.x + blockIdx.x * blockDim.x; - - if (idx < lod[num_seq] * seq_width) { - size_t i = 0; - for (i = 0; i < num_seq; ++i) { - if (idx < lod[i + 1] * seq_width) { - break; - } - } - seq[idx] *= scales[i]; + if (threadIdx.x < (lod[blockIdx.x + 1] - lod[blockIdx.x]) * seq_width) { + int idx = lod[blockIdx.x] * seq_width + threadIdx.x; + seq[idx] *= scales[blockIdx.x]; } } @@ -39,18 +31,17 @@ template class ScaleLoDTensorFunctor { public: void operator()(const platform::CUDADeviceContext& context, - framework::LoDTensor& seq, const T* scales, - const size_t num_seq) { - auto lod = seq.lod(); - const size_t seq_width = seq.dims()[1]; + framework::LoDTensor& seq, const T* scales) { const size_t level = 0; + auto lod = seq.lod(); + const size_t num_seq = lod[level].size() - 1; + const size_t seq_width = seq.numel() / seq.dims()[0]; framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); T* seq_data = seq.mutable_data(context.GetPlace()); int threads = 1024; - int grid = (seq.numel() * seq_width + threads - 1) / threads; - SequenceScaleKernel<<>>( - seq_data, abs_offset_lod[level].data(), scales, num_seq, seq_width); + SequenceScaleKernel<<>>( + seq_data, abs_offset_lod[level].data(), scales, seq_width); } }; diff --git a/paddle/operators/math/sequence_scale.h b/paddle/operators/math/sequence_scale.h index 8c47179b55d08bd0794d9546254493a97a8e78f1..ecd9a57c3f4d8d91bfb8933a0fd38355c227744d 100644 --- a/paddle/operators/math/sequence_scale.h +++ b/paddle/operators/math/sequence_scale.h @@ -47,7 +47,7 @@ template class ScaleLoDTensorFunctor { public: void operator()(const DeviceContext& context, framework::LoDTensor& seq, - const T* scales, const size_t num_seq); + const T* scales); }; } // namespace math diff --git a/paddle/operators/warpctc_op.h b/paddle/operators/warpctc_op.h index d41752e7333539dffc5ae4be2f183c45de9eb6cc..8aea061c00cc9614db37ed408b6d330ef707d1cf 100644 --- a/paddle/operators/warpctc_op.h +++ b/paddle/operators/warpctc_op.h @@ -179,6 +179,10 @@ class WarpCTCKernel : public framework::OpKernel { T* warpctc_grad_data = warpctc_grad->mutable_data(warpctc_logits.dims(), ctx.GetPlace()); + math::SetConstant()( + ctx.template device_context(), warpctc_grad, + static_cast(0)); + // warpctc accesses labels in CPU memory Tensor warpctc_label; Copy(*label, platform::CPUPlace(), ctx.device_context(), &warpctc_label); @@ -215,10 +219,9 @@ class WarpCTCGradKernel : public framework::OpKernel { *warpctc_grad, norm_by_times); const T* loss_grad_data = loss_grad->data(); - const size_t num_seq = loss_grad->dims()[0]; math::ScaleLoDTensorFunctor()( ctx.template device_context(), *logits_grad, - loss_grad_data, num_seq); + loss_grad_data); } }; diff --git a/python/paddle/v2/fluid/tests/test_warpctc_op.py b/python/paddle/v2/fluid/tests/test_warpctc_op.py index a1c4e40111ffcd7f2c1104c3f588239c0e400335..07be05d2b03524a283925b02337e9957eae421cf 100644 --- a/python/paddle/v2/fluid/tests/test_warpctc_op.py +++ b/python/paddle/v2/fluid/tests/test_warpctc_op.py @@ -193,7 +193,7 @@ class TestWarpCTCOp(OpTest): def test_check_grad(self): self.outputs['WarpCTCGrad'] = self.gradient - self.check_grad(["Logits"], "Loss", max_relative_error=0.01) + self.check_grad(["Logits"], "Loss", max_relative_error=0.007) if __name__ == "__main__":