提交 352c5a96 编写于 作者: W wanghaox

update some code

...@@ -124,14 +124,14 @@ class SequenceSliceGradOpKernel : public framework::OpKernel<T> { ...@@ -124,14 +124,14 @@ class SequenceSliceGradOpKernel : public framework::OpKernel<T> {
const int64_t* offset_data = offset->data<int64_t>(); const int64_t* offset_data = offset->data<int64_t>();
const int64_t* length_data = length->data<int64_t>(); const int64_t* length_data = length->data<int64_t>();
framework::Tensor offset_cpu;
framework::Tensor length_cpu;
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
framework::Tensor offset_cpu;
offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace()); offset_cpu.mutable_data<T>(offset->dims(), platform::CPUPlace());
offset_cpu.CopyFrom(*offset, platform::CPUPlace(), ctx.device_context()); offset_cpu.CopyFrom(*offset, platform::CPUPlace(), ctx.device_context());
offset_data = offset_cpu.data<int64_t>(); offset_data = offset_cpu.data<int64_t>();
framework::Tensor length_cpu;
length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace()); length_cpu.mutable_data<T>(length->dims(), platform::CPUPlace());
length_cpu.CopyFrom(*length, platform::CPUPlace(), ctx.device_context()); length_cpu.CopyFrom(*length, platform::CPUPlace(), ctx.device_context());
length_data = length_cpu.data<int64_t>(); length_data = length_cpu.data<int64_t>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册