未验证 提交 c8e12587 编写于 作者: L Li Fuchen 提交者: GitHub

Fixed warpctc, test=develop (#20011)

Use AllocateTmpTensor() for creating temporary tensors in warpctc.
上级 63dd3183
...@@ -73,11 +73,12 @@ class WarpCTCFunctor { ...@@ -73,11 +73,12 @@ class WarpCTCFunctor {
"Bytes of workspace got by warp-ctc function, " "Bytes of workspace got by warp-ctc function, "
"get_workspace_size(), should be larger than 0."); "get_workspace_size(), should be larger than 0.");
Tensor workspace; auto& dev_ctx = ctx.template device_context<DeviceContext>();
size_t workspace_elements = workspace_bytes / sizeof(float) + 1UL; size_t workspace_elements = workspace_bytes / sizeof(float) + 1UL;
float* workspace_data = workspace.mutable_data<float>( Tensor workspace = ctx.AllocateTmpTensor<float, DeviceContext>(
framework::make_ddim({static_cast<int64_t>(workspace_elements)}), framework::make_ddim({static_cast<int64_t>(workspace_elements)}),
ctx.GetPlace()); dev_ctx);
float* workspace_data = workspace.data<float>();
math::SetConstant<DeviceContext, float>()( math::SetConstant<DeviceContext, float>()(
ctx.template device_context<DeviceContext>(), &workspace, ctx.template device_context<DeviceContext>(), &workspace,
static_cast<float>(0)); static_cast<float>(0));
...@@ -186,8 +187,10 @@ class WarpCTCKernel : public framework::OpKernel<T> { ...@@ -186,8 +187,10 @@ class WarpCTCKernel : public framework::OpKernel<T> {
framework::make_ddim({static_cast<int64_t>(max_sequence_length), framework::make_ddim({static_cast<int64_t>(max_sequence_length),
static_cast<int64_t>(num_sequences), static_cast<int64_t>(num_sequences),
static_cast<int64_t>(sequence_width)}); static_cast<int64_t>(sequence_width)});
warpctc_logits.mutable_data<T>(warpctc_logits_dims, ctx.GetPlace()); auto& dev_ctx = ctx.template device_context<DeviceContext>();
Tensor warpctc_logits_tmp =
ctx.AllocateTmpTensor<T, DeviceContext>(warpctc_logits_dims, dev_ctx);
warpctc_logits.ShareDataWith(warpctc_logits_tmp);
if (ctx.HasInput("LogitsLength")) { if (ctx.HasInput("LogitsLength")) {
TensorCopySync(*logits, ctx.GetPlace(), &warpctc_logits); TensorCopySync(*logits, ctx.GetPlace(), &warpctc_logits);
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册