未验证 提交 6dd64b0a 编写于 作者: Z zhupengyang 提交者: GitHub

randperm run error in multi-gpus (#27942)

上级 74fadeb4
......@@ -57,7 +57,7 @@ class RandpermKernel : public framework::OpKernel<T> {
tmp_tensor.Resize(framework::make_ddim({n}));
T* tmp_data = tmp_tensor.mutable_data<T>(platform::CPUPlace());
random_permate<T>(tmp_data, n, seed);
framework::TensorCopy(tmp_tensor, platform::CUDAPlace(), out_tensor);
framework::TensorCopy(tmp_tensor, ctx.GetPlace(), out_tensor);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册