提交 b4182125 编写于 作者: J Jiayuan Gu 提交者: Francisco Massa

Update balanced_positive_negative_sampler.py (#93)

use cuda version torch.randperm to avoid copy from gpu to cpu and a fatal bug in multi-thread cpu version
上级 c5475085
......@@ -46,8 +46,8 @@ class BalancedPositiveNegativeSampler(object):
num_neg = min(negative.numel(), num_neg)
# randomly select positive and negative examples
perm1 = torch.randperm(positive.numel())[:num_pos]
perm2 = torch.randperm(negative.numel())[:num_neg]
perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
pos_idx_per_image = positive[perm1]
neg_idx_per_image = negative[perm2]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册