提交 cf3b3d60 编写于 作者: F fengjiayi

fix warpctc

上级 0285a2b9
...@@ -186,8 +186,7 @@ class WarpCTCKernel : public framework::OpKernel<T> { ...@@ -186,8 +186,7 @@ class WarpCTCKernel : public framework::OpKernel<T> {
// warpctc accesses labels in CPU memory // warpctc accesses labels in CPU memory
Tensor warpctc_label; Tensor warpctc_label;
TensorCopy(*label, platform::CPUPlace(), ctx.device_context(), TensorCopySync(*label, platform::CPUPlace(), &warpctc_label);
&warpctc_label);
const int* warpctc_label_data = warpctc_label.data<int>(); const int* warpctc_label_data = warpctc_label.data<int>();
// warpctc stores loss in CPU memory // warpctc stores loss in CPU memory
Tensor warpctc_loss; Tensor warpctc_loss;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册