提交 bf99396a 编写于 作者: F fengjiayi

fix errors in sequence_slice_op

上级 baa9f50d
...@@ -66,13 +66,11 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> { ...@@ -66,13 +66,11 @@ class SequenceSliceOpKernel : public framework::OpKernel<T> {
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace()); offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace());
framework::TensorCopy(*offset, platform::CPUPlace(), ctx.device_context(), framework::TensorCopySync(*offset, platform::CPUPlace(), &offset_cpu);
&offset_cpu);
offset_data = offset_cpu.data<int64_t>(); offset_data = offset_cpu.data<int64_t>();
length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace()); length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace());
framework::TensorCopy(*length, platform::CPUPlace(), ctx.device_context(), framework::TensorCopySync(*length, platform::CPUPlace(), &length_cpu);
&length_cpu);
length_data = length_cpu.data<int64_t>(); length_data = length_cpu.data<int64_t>();
} }
...@@ -127,13 +125,11 @@ class SequenceSliceGradOpKernel : public framework::OpKernel<T> { ...@@ -127,13 +125,11 @@ class SequenceSliceGradOpKernel : public framework::OpKernel<T> {
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace()); offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace());
framework::TensorCopy(*offset, platform::CPUPlace(), ctx.device_context(), framework::TensorCopySync(*offset, platform::CPUPlace(), &offset_cpu);
&offset_cpu);
offset_data = offset_cpu.data<int64_t>(); offset_data = offset_cpu.data<int64_t>();
length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace()); length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace());
framework::TensorCopy(*length, platform::CPUPlace(), ctx.device_context(), framework::TensorCopySync(*length, platform::CPUPlace(), &length_cpu);
&length_cpu);
length_data = length_cpu.data<int64_t>(); length_data = length_cpu.data<int64_t>();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册