未验证 提交 6dd64b0a 编写于 作者: Z zhupengyang 提交者: GitHub

randperm run error in multi-gpus (#27942)

上级 74fadeb4
...@@ -57,7 +57,7 @@ class RandpermKernel : public framework::OpKernel<T> { ...@@ -57,7 +57,7 @@ class RandpermKernel : public framework::OpKernel<T> {
tmp_tensor.Resize(framework::make_ddim({n})); tmp_tensor.Resize(framework::make_ddim({n}));
T* tmp_data = tmp_tensor.mutable_data<T>(platform::CPUPlace()); T* tmp_data = tmp_tensor.mutable_data<T>(platform::CPUPlace());
random_permate<T>(tmp_data, n, seed); random_permate<T>(tmp_data, n, seed);
framework::TensorCopy(tmp_tensor, platform::CUDAPlace(), out_tensor); framework::TensorCopy(tmp_tensor, ctx.GetPlace(), out_tensor);
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册