提交 47cd178a 编写于 作者: L LielinJiang

fix eval samples

上级 abc1ecaa
...@@ -110,6 +110,10 @@ class DistributedBatchSampler(BatchSampler): ...@@ -110,6 +110,10 @@ class DistributedBatchSampler(BatchSampler):
return num_samples // self.batch_size return num_samples // self.batch_size
def _all_gather(x, nranks, ring_id=0, use_calc_stream=True):
return _c_allgather(x, nranks, ring_id=ring_id, use_calc_stream=use_calc_stream)
def get_local_rank(): def get_local_rank():
return Env().local_rank return Env().local_rank
...@@ -203,11 +207,7 @@ def prepare_distributed_context(place=None): ...@@ -203,11 +207,7 @@ def prepare_distributed_context(place=None):
exe.run(communicator_prog) exe.run(communicator_prog)
if fluid.in_dygraph_mode(): if fluid.in_dygraph_mode():
cnt = 0 fluid.disable_dygraph()
while fluid.in_dygraph_mode():
cnt += 1
print('debug', cnt)
fluid.disable_dygraph()
_init_context() _init_context()
fluid.enable_dygraph(place) fluid.enable_dygraph(place)
else: else:
......
...@@ -143,7 +143,9 @@ class StaticGraphAdapter(object): ...@@ -143,7 +143,9 @@ class StaticGraphAdapter(object):
self._progs = {} self._progs = {}
self._compiled_progs = {} self._compiled_progs = {}
self._merge_count = {'eval': 0, 'test': 0} self._merge_count = {'eval_total': 0, 'test_total': 0,
'eval_batch': 0, 'test_batch': 0}
self._nranks = distributed.Env().nranks self._nranks = distributed.Env().nranks
self._local_rank = distributed.Env().local_rank self._local_rank = distributed.Env().local_rank
...@@ -354,12 +356,14 @@ class StaticGraphAdapter(object): ...@@ -354,12 +356,14 @@ class StaticGraphAdapter(object):
total_size = len(self.model._test_dataloader.dataset) total_size = len(self.model._test_dataloader.dataset)
# TODO: fixme if have better way to get batch size # TODO: fixme if have better way to get batch size
samples = state[0].shape[0] samples = state[0].shape[0]
current_count = self._merge_count.get(self.mode, 0) current_count = self._merge_count.get(self.mode + '_total', 0)
if current_count + samples > total_size: if current_count + samples > total_size:
state = [s[:total_size - current_count, ...] for s in state] state = [s[:total_size - current_count, ...] for s in state]
self._merge_count[self.mode] = 0 self._merge_count[self.mode + '_total'] = 0
self._merge_count[self.mode + '_batch'] = total_size - current_count
else: else:
self._merge_count[self.mode] += samples self._merge_count[self.mode + '_total'] += samples
self._merge_count[self.mode + '_batch'] = samples
metrics.append(metric.update(*state)) metrics.append(metric.update(*state))
return (losses, metrics) if len(metrics) > 0 else losses return (losses, metrics) if len(metrics) > 0 else losses
...@@ -498,7 +502,8 @@ class DynamicGraphAdapter(object): ...@@ -498,7 +502,8 @@ class DynamicGraphAdapter(object):
self.model = model self.model = model
self._nranks = distributed.Env().nranks self._nranks = distributed.Env().nranks
self._local_rank = distributed.Env().local_rank self._local_rank = distributed.Env().local_rank
self._merge_count = {'eval': 0, 'test': 0} self._merge_count = {'eval_total': 0, 'test_total': 0,
'eval_batch': 0, 'test_batch': 0}
if self._nranks > 1: if self._nranks > 1:
self.ddp_model = distributed.DistributedDataParallel(self.model) self.ddp_model = distributed.DistributedDataParallel(self.model)
...@@ -564,13 +569,16 @@ class DynamicGraphAdapter(object): ...@@ -564,13 +569,16 @@ class DynamicGraphAdapter(object):
if self.model._test_dataloader is not None and self._nranks > 1: if self.model._test_dataloader is not None and self._nranks > 1:
total_size = len(self.model._test_dataloader.dataset) total_size = len(self.model._test_dataloader.dataset)
samples = outputs[0].shape[0] samples = outputs[0].shape[0]
current_count = self._merge_count.get(self.mode, 0) current_count = self._merge_count.get(self.mode + '_total', 0)
if current_count + samples > total_size: if current_count + samples > total_size:
outputs = [o[:total_size - metric.count[0]] for o in outputs] outputs = [o[:total_size - metric.count[0]] for o in outputs]
labels = [l[:total_size - metric.count[0]] for l in labels] labels = [l[:total_size - metric.count[0]] for l in labels]
self._merge_count[self.mode] = 0 self._merge_count[self.mode + '_total'] = 0
self._merge_count[self.mode + '_batch'] = total_size - current_count
else: else:
self._merge_count[self.mode] += samples self._merge_count[self.mode + '_total'] += samples
self._merge_count[self.mode + '_batch'] = samples
metric_outs = metric.add_metric_op(to_list(outputs), labels) metric_outs = metric.add_metric_op(to_list(outputs), labels)
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
...@@ -966,7 +974,10 @@ class Model(fluid.dygraph.Layer): ...@@ -966,7 +974,10 @@ class Model(fluid.dygraph.Layer):
logs[k] = v logs[k] = v
logs['step'] = step logs['step'] = step
logs['batch_size'] = batch_size if mode == 'train' or self._adapter._merge_count[mode + '_batch'] <= 0:
logs['batch_size'] = batch_size * distributed.Env().nranks
else:
logs['batch_size'] = self._adapter._merge_count[mode + '_batch']
cbks.on_batch_end(mode, step, logs) cbks.on_batch_end(mode, step, logs)
self._reset_metrics() self._reset_metrics()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册