diff --git a/paddle/fluid/operators/warpctc_op.h b/paddle/fluid/operators/warpctc_op.h index 1859c748d783519971ffb43cd695a9d22d09dbb6..98c31e30dc53169422ddacc72b0b279bbcf9ed9a 100644 --- a/paddle/fluid/operators/warpctc_op.h +++ b/paddle/fluid/operators/warpctc_op.h @@ -73,11 +73,12 @@ class WarpCTCFunctor { "Bytes of workspace got by warp-ctc function, " "get_workspace_size(), should be larger than 0."); - Tensor workspace; + auto& dev_ctx = ctx.template device_context(); size_t workspace_elements = workspace_bytes / sizeof(float) + 1UL; - float* workspace_data = workspace.mutable_data( + Tensor workspace = ctx.AllocateTmpTensor( framework::make_ddim({static_cast(workspace_elements)}), - ctx.GetPlace()); + dev_ctx); + float* workspace_data = workspace.data(); math::SetConstant()( ctx.template device_context(), &workspace, static_cast(0)); @@ -186,8 +187,10 @@ class WarpCTCKernel : public framework::OpKernel { framework::make_ddim({static_cast(max_sequence_length), static_cast(num_sequences), static_cast(sequence_width)}); - warpctc_logits.mutable_data(warpctc_logits_dims, ctx.GetPlace()); - + auto& dev_ctx = ctx.template device_context(); + Tensor warpctc_logits_tmp = + ctx.AllocateTmpTensor(warpctc_logits_dims, dev_ctx); + warpctc_logits.ShareDataWith(warpctc_logits_tmp); if (ctx.HasInput("LogitsLength")) { TensorCopySync(*logits, ctx.GetPlace(), &warpctc_logits); } else {