From c8e125872c2cffeb2018d8a98504ce15989a3b2d Mon Sep 17 00:00:00 2001 From: Li Fuchen Date: Fri, 27 Sep 2019 11:00:34 +0800 Subject: [PATCH] Fixed warpctc, test=develop (#20011) Use AllocateTmpTensor() for creating temporary tensors in warpctc. --- paddle/fluid/operators/warpctc_op.h | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/warpctc_op.h b/paddle/fluid/operators/warpctc_op.h index 1859c748d78..98c31e30dc5 100644 --- a/paddle/fluid/operators/warpctc_op.h +++ b/paddle/fluid/operators/warpctc_op.h @@ -73,11 +73,12 @@ class WarpCTCFunctor { "Bytes of workspace got by warp-ctc function, " "get_workspace_size(), should be larger than 0."); - Tensor workspace; + auto& dev_ctx = ctx.template device_context(); size_t workspace_elements = workspace_bytes / sizeof(float) + 1UL; - float* workspace_data = workspace.mutable_data( + Tensor workspace = ctx.AllocateTmpTensor( framework::make_ddim({static_cast(workspace_elements)}), - ctx.GetPlace()); + dev_ctx); + float* workspace_data = workspace.data(); math::SetConstant()( ctx.template device_context(), &workspace, static_cast(0)); @@ -186,8 +187,10 @@ class WarpCTCKernel : public framework::OpKernel { framework::make_ddim({static_cast(max_sequence_length), static_cast(num_sequences), static_cast(sequence_width)}); - warpctc_logits.mutable_data(warpctc_logits_dims, ctx.GetPlace()); - + auto& dev_ctx = ctx.template device_context(); + Tensor warpctc_logits_tmp = + ctx.AllocateTmpTensor(warpctc_logits_dims, dev_ctx); + warpctc_logits.ShareDataWith(warpctc_logits_tmp); if (ctx.HasInput("LogitsLength")) { TensorCopySync(*logits, ctx.GetPlace(), &warpctc_logits); } else { -- GitLab