提交 47cd178a 编写于 作者: L LielinJiang

fix eval samples

上级 abc1ecaa
......@@ -110,6 +110,10 @@ class DistributedBatchSampler(BatchSampler):
return num_samples // self.batch_size
def _all_gather(x, nranks, ring_id=0, use_calc_stream=True):
return _c_allgather(x, nranks, ring_id=ring_id, use_calc_stream=use_calc_stream)
def get_local_rank():
return Env().local_rank
......@@ -203,11 +207,7 @@ def prepare_distributed_context(place=None):
exe.run(communicator_prog)
if fluid.in_dygraph_mode():
cnt = 0
while fluid.in_dygraph_mode():
cnt += 1
print('debug', cnt)
fluid.disable_dygraph()
fluid.disable_dygraph()
_init_context()
fluid.enable_dygraph(place)
else:
......
......@@ -143,7 +143,9 @@ class StaticGraphAdapter(object):
self._progs = {}
self._compiled_progs = {}
self._merge_count = {'eval': 0, 'test': 0}
self._merge_count = {'eval_total': 0, 'test_total': 0,
'eval_batch': 0, 'test_batch': 0}
self._nranks = distributed.Env().nranks
self._local_rank = distributed.Env().local_rank
......@@ -354,12 +356,14 @@ class StaticGraphAdapter(object):
total_size = len(self.model._test_dataloader.dataset)
# TODO: fixme if have better way to get batch size
samples = state[0].shape[0]
current_count = self._merge_count.get(self.mode, 0)
current_count = self._merge_count.get(self.mode + '_total', 0)
if current_count + samples > total_size:
state = [s[:total_size - current_count, ...] for s in state]
self._merge_count[self.mode] = 0
self._merge_count[self.mode + '_total'] = 0
self._merge_count[self.mode + '_batch'] = total_size - current_count
else:
self._merge_count[self.mode] += samples
self._merge_count[self.mode + '_total'] += samples
self._merge_count[self.mode + '_batch'] = samples
metrics.append(metric.update(*state))
return (losses, metrics) if len(metrics) > 0 else losses
......@@ -498,7 +502,8 @@ class DynamicGraphAdapter(object):
self.model = model
self._nranks = distributed.Env().nranks
self._local_rank = distributed.Env().local_rank
self._merge_count = {'eval': 0, 'test': 0}
self._merge_count = {'eval_total': 0, 'test_total': 0,
'eval_batch': 0, 'test_batch': 0}
if self._nranks > 1:
self.ddp_model = distributed.DistributedDataParallel(self.model)
......@@ -564,13 +569,16 @@ class DynamicGraphAdapter(object):
if self.model._test_dataloader is not None and self._nranks > 1:
total_size = len(self.model._test_dataloader.dataset)
samples = outputs[0].shape[0]
current_count = self._merge_count.get(self.mode, 0)
current_count = self._merge_count.get(self.mode + '_total', 0)
if current_count + samples > total_size:
outputs = [o[:total_size - metric.count[0]] for o in outputs]
labels = [l[:total_size - metric.count[0]] for l in labels]
self._merge_count[self.mode] = 0
self._merge_count[self.mode + '_total'] = 0
self._merge_count[self.mode + '_batch'] = total_size - current_count
else:
self._merge_count[self.mode] += samples
self._merge_count[self.mode + '_total'] += samples
self._merge_count[self.mode + '_batch'] = samples
metric_outs = metric.add_metric_op(to_list(outputs), labels)
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
......@@ -966,7 +974,10 @@ class Model(fluid.dygraph.Layer):
logs[k] = v
logs['step'] = step
logs['batch_size'] = batch_size
if mode == 'train' or self._adapter._merge_count[mode + '_batch'] <= 0:
logs['batch_size'] = batch_size * distributed.Env().nranks
else:
logs['batch_size'] = self._adapter._merge_count[mode + '_batch']
cbks.on_batch_end(mode, step, logs)
self._reset_metrics()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册