未验证 提交 c8e12587 编写于 作者: L Li Fuchen 提交者: GitHub

Fixed warpctc, test=develop (#20011)

Use AllocateTmpTensor() for creating temporary tensors in warpctc.
上级 63dd3183
......@@ -73,11 +73,12 @@ class WarpCTCFunctor {
"Bytes of workspace got by warp-ctc function, "
"get_workspace_size(), should be larger than 0.");
Tensor workspace;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
size_t workspace_elements = workspace_bytes / sizeof(float) + 1UL;
float* workspace_data = workspace.mutable_data<float>(
Tensor workspace = ctx.AllocateTmpTensor<float, DeviceContext>(
framework::make_ddim({static_cast<int64_t>(workspace_elements)}),
ctx.GetPlace());
dev_ctx);
float* workspace_data = workspace.data<float>();
math::SetConstant<DeviceContext, float>()(
ctx.template device_context<DeviceContext>(), &workspace,
static_cast<float>(0));
......@@ -186,8 +187,10 @@ class WarpCTCKernel : public framework::OpKernel<T> {
framework::make_ddim({static_cast<int64_t>(max_sequence_length),
static_cast<int64_t>(num_sequences),
static_cast<int64_t>(sequence_width)});
warpctc_logits.mutable_data<T>(warpctc_logits_dims, ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
Tensor warpctc_logits_tmp =
ctx.AllocateTmpTensor<T, DeviceContext>(warpctc_logits_dims, dev_ctx);
warpctc_logits.ShareDataWith(warpctc_logits_tmp);
if (ctx.HasInput("LogitsLength")) {
TensorCopySync(*logits, ctx.GetPlace(), &warpctc_logits);
} else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册