未验证 提交 2c0664b2 编写于 作者: T topduke 提交者: GitHub

fix bug when multi gpus training(sampler) (#9963)

* fix gris_sample data type bug when use fp16

* fix gris_sample data type bug when use fp16

* fix v4rec batchsize

* fix bug of hang when multi gpus training(sampler)
上级 abc4be00
...@@ -52,8 +52,7 @@ class MultiScaleSampler(Sampler): ...@@ -52,8 +52,7 @@ class MultiScaleSampler(Sampler):
num_replicas = dist.get_world_size() num_replicas = dist.get_world_size()
rank = dist.get_rank() rank = dist.get_rank()
# adjust the total samples to avoid batch dropping # adjust the total samples to avoid batch dropping
num_samples_per_replica = int( num_samples_per_replica = int(self.n_data_samples * 1.0 / num_replicas)
math.ceil(self.n_data_samples * 1.0 / num_replicas))
img_indices = [idx for idx in range(self.n_data_samples)] img_indices = [idx for idx in range(self.n_data_samples)]
...@@ -92,22 +91,23 @@ class MultiScaleSampler(Sampler): ...@@ -92,22 +91,23 @@ class MultiScaleSampler(Sampler):
self.batch_list = [] self.batch_list = []
self.current = 0 self.current = 0
indices_rank_i = self.img_indices[self.rank:len(self.img_indices): last_index = num_samples_per_replica * num_replicas
indices_rank_i = self.img_indices[self.rank:last_index:
self.num_replicas] self.num_replicas]
while self.current < self.n_samples_per_replica: while self.current < self.n_samples_per_replica:
curr_w, curr_h, curr_bsz = random.choice(self.img_batch_pairs) for curr_w, curr_h, curr_bsz in self.img_batch_pairs:
end_index = min(self.current + curr_bsz,
end_index = min(self.current + curr_bsz, self.n_samples_per_replica) self.n_samples_per_replica)
batch_ids = indices_rank_i[self.current:end_index]
batch_ids = indices_rank_i[self.current:end_index] n_batch_samples = len(batch_ids)
n_batch_samples = len(batch_ids) if n_batch_samples != curr_bsz:
if n_batch_samples != curr_bsz: batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)]
batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)] self.current += curr_bsz
self.current += curr_bsz
if len(batch_ids) > 0:
if len(batch_ids) > 0: batch = [curr_w, curr_h, len(batch_ids)]
batch = [curr_w, curr_h, len(batch_ids)] self.batch_list.append(batch)
self.batch_list.append(batch) random.shuffle(self.batch_list)
self.length = len(self.batch_list) self.length = len(self.batch_list)
self.batchs_in_one_epoch = self.iter() self.batchs_in_one_epoch = self.iter()
self.batchs_in_one_epoch_id = [ self.batchs_in_one_epoch_id = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册