diff --git a/distributed.py b/distributed.py index c83f8093020b3ac4397c180e57577f02e4963447..50590fc313345bb4bc2b7e3a75bbd4963cf9c1ea 100644 --- a/distributed.py +++ b/distributed.py @@ -110,6 +110,10 @@ class DistributedBatchSampler(BatchSampler): return num_samples // self.batch_size +def _all_gather(x, nranks, ring_id=0, use_calc_stream=True): + return _c_allgather(x, nranks, ring_id=ring_id, use_calc_stream=use_calc_stream) + + def get_local_rank(): return Env().local_rank @@ -203,11 +207,7 @@ def prepare_distributed_context(place=None): exe.run(communicator_prog) if fluid.in_dygraph_mode(): - cnt = 0 - while fluid.in_dygraph_mode(): - cnt += 1 - print('debug', cnt) - fluid.disable_dygraph() + fluid.disable_dygraph() _init_context() fluid.enable_dygraph(place) else: diff --git a/model.py b/model.py index 01f452fd4d213999238ac90d40519102196f160b..c38a6e7bcd368a6aebac6220df9814a0b2152cb7 100644 --- a/model.py +++ b/model.py @@ -143,7 +143,9 @@ class StaticGraphAdapter(object): self._progs = {} self._compiled_progs = {} - self._merge_count = {'eval': 0, 'test': 0} + self._merge_count = {'eval_total': 0, 'test_total': 0, + 'eval_batch': 0, 'test_batch': 0} + self._nranks = distributed.Env().nranks self._local_rank = distributed.Env().local_rank @@ -354,12 +356,14 @@ class StaticGraphAdapter(object): total_size = len(self.model._test_dataloader.dataset) # TODO: fixme if have better way to get batch size samples = state[0].shape[0] - current_count = self._merge_count.get(self.mode, 0) + current_count = self._merge_count.get(self.mode + '_total', 0) if current_count + samples > total_size: state = [s[:total_size - current_count, ...] for s in state] - self._merge_count[self.mode] = 0 + self._merge_count[self.mode + '_total'] = 0 + self._merge_count[self.mode + '_batch'] = total_size - current_count else: - self._merge_count[self.mode] += samples + self._merge_count[self.mode + '_total'] += samples + self._merge_count[self.mode + '_batch'] = samples metrics.append(metric.update(*state)) return (losses, metrics) if len(metrics) > 0 else losses @@ -498,7 +502,8 @@ class DynamicGraphAdapter(object): self.model = model self._nranks = distributed.Env().nranks self._local_rank = distributed.Env().local_rank - self._merge_count = {'eval': 0, 'test': 0} + self._merge_count = {'eval_total': 0, 'test_total': 0, + 'eval_batch': 0, 'test_batch': 0} if self._nranks > 1: self.ddp_model = distributed.DistributedDataParallel(self.model) @@ -564,13 +569,16 @@ class DynamicGraphAdapter(object): if self.model._test_dataloader is not None and self._nranks > 1: total_size = len(self.model._test_dataloader.dataset) samples = outputs[0].shape[0] - current_count = self._merge_count.get(self.mode, 0) + current_count = self._merge_count.get(self.mode + '_total', 0) if current_count + samples > total_size: outputs = [o[:total_size - metric.count[0]] for o in outputs] labels = [l[:total_size - metric.count[0]] for l in labels] - self._merge_count[self.mode] = 0 + self._merge_count[self.mode + '_total'] = 0 + self._merge_count[self.mode + '_batch'] = total_size - current_count else: - self._merge_count[self.mode] += samples + self._merge_count[self.mode + '_total'] += samples + self._merge_count[self.mode + '_batch'] = samples + metric_outs = metric.add_metric_op(to_list(outputs), labels) m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) @@ -966,7 +974,10 @@ class Model(fluid.dygraph.Layer): logs[k] = v logs['step'] = step - logs['batch_size'] = batch_size + if mode == 'train' or self._adapter._merge_count[mode + '_batch'] <= 0: + logs['batch_size'] = batch_size * distributed.Env().nranks + else: + logs['batch_size'] = self._adapter._merge_count[mode + '_batch'] cbks.on_batch_end(mode, step, logs) self._reset_metrics()