未验证 提交 3fa302b8 编写于 作者: G Guoxia Wang 提交者: GitHub

fix seed for class_center_sample using paddle.seed (#38248) (#38498)

上级 8ef7102c
......@@ -390,18 +390,26 @@ class ClassCenterSampleCUDAKernel : public framework::OpKernel<T> {
ctx.cuda_device_context().stream())));
// step 5: random sample negative class center
uint64_t seed_data;
uint64_t increment;
int vec_size = VectorizedSize<T>(cub_sort_keys_ptr);
int increment = ((num_classes - 1) /
(NumBlocks(num_classes) * kNumCUDAThreads * vec_size) +
1) *
vec_size;
if (!fix_seed) {
auto offset = ((num_classes - 1) /
(NumBlocks(num_classes) * kNumCUDAThreads * vec_size) +
1) *
vec_size;
auto gen_cuda = framework::GetDefaultCUDAGenerator(rank);
if (gen_cuda->GetIsInitPy() && (!fix_seed)) {
auto seed_offset = gen_cuda->IncrementOffset(offset);
seed_data = seed_offset.first;
increment = seed_offset.second;
} else {
std::random_device rnd;
seed = rnd();
seed_data = fix_seed ? seed + rank : rnd();
increment = offset;
}
RandomSampleClassCenter<T><<<NumBlocks(num_classes), kNumCUDAThreads, 0,
ctx.cuda_device_context().stream()>>>(
num_classes, seed + rank, increment, num_classes, cub_sort_keys_ptr);
num_classes, seed_data, increment, num_classes, cub_sort_keys_ptr);
// step 6: mark positive class center as negative value
// fill the sort values to index 0, 1, ..., batch_size-1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册