未验证 提交 624a2b9c 编写于 作者: G Guoxia Wang 提交者: GitHub

fix cuda seed bug of class_center_sample traning on multi gpu (#38817)

上级 ad92fa61
......@@ -397,7 +397,9 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
(NumBlocks(num_classes) * kNumCUDAThreads * vec_size) +
1) *
vec_size;
auto gen_cuda = framework::GetDefaultCUDAGenerator(rank);
int device_id =
BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && (!fix_seed)) {
auto seed_offset = gen_cuda->IncrementOffset(offset);
seed_data = seed_offset.first;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册