提交 368d6302 编写于 作者: L LielinJiang

refine fit, distributedsampler

上级 ba723731
...@@ -220,7 +220,7 @@ class ProgBarLogger(Callback): ...@@ -220,7 +220,7 @@ class ProgBarLogger(Callback):
def on_epoch_end(self, epoch, logs=None): def on_epoch_end(self, epoch, logs=None):
logs = logs or {} logs = logs or {}
if self.verbose: if self.verbose and get_local_rank() == 0:
self._updates(logs, 'train') self._updates(logs, 'train')
def on_eval_begin(self, logs=None): def on_eval_begin(self, logs=None):
......
...@@ -80,8 +80,6 @@ class DistributedBatchSampler(BatchSampler): ...@@ -80,8 +80,6 @@ class DistributedBatchSampler(BatchSampler):
self.total_size = self.num_samples * self.nranks self.total_size = self.num_samples * self.nranks
def __iter__(self): def __iter__(self):
_sample_iter = self.sample_iter
if _sample_iter is None:
num_samples = len(self.dataset) num_samples = len(self.dataset)
indices = np.arange(num_samples).tolist() indices = np.arange(num_samples).tolist()
indices += indices[:(self.total_size - len(indices))] indices += indices[:(self.total_size - len(indices))]
......
...@@ -91,7 +91,7 @@ def init_context(backend): ...@@ -91,7 +91,7 @@ def init_context(backend):
place = fluid.CUDAPlace(distributed.Env().dev_id) if \ place = fluid.CUDAPlace(distributed.Env().dev_id) if \
distributed.Env().nranks > 1 else fluid.CUDAPlace(0) distributed.Env().nranks > 1 else fluid.CUDAPlace(0)
distributed.prepare_distributed_context() distributed.prepare_distributed_context(place)
backend = backend.lower() backend = backend.lower()
if backend == 'dynamic': if backend == 'dynamic':
fluid.enable_dygraph(place) fluid.enable_dygraph(place)
...@@ -419,10 +419,19 @@ class StaticGraphAdapter(object): ...@@ -419,10 +419,19 @@ class StaticGraphAdapter(object):
labels = [k.forward() for k in to_list(lbls)] labels = [k.forward() for k in to_list(lbls)]
self._label_vars[mode] = labels self._label_vars[mode] = labels
outputs = to_list(self.model.forward(*inputs)) outputs = to_list(self.model.forward(*inputs))
if mode != 'test':
if self.model._loss_function: if mode != 'test' and self.model._loss_function:
losses = self.model._loss_function(outputs, labels) losses = self.model._loss_function(outputs, labels)
if self._nranks > 1 and mode != 'train':
outputs = [distributed._all_gather(o, self._nranks) for o in outputs]
if mode != 'test':
labels = [distributed._all_gather(l, self._nranks) for l in labels]
if mode != 'test':
for metric in self.model._metrics:
metrics.append(to_list(metric.add_metric_op(outputs, labels)))
if mode == 'train' and self.model._optimizer: if mode == 'train' and self.model._optimizer:
self._loss_endpoint = fluid.layers.sum(losses) self._loss_endpoint = fluid.layers.sum(losses)
if self._nranks > 1: if self._nranks > 1:
...@@ -435,14 +444,6 @@ class StaticGraphAdapter(object): ...@@ -435,14 +444,6 @@ class StaticGraphAdapter(object):
strategy=dist_strategy) strategy=dist_strategy)
self.model._optimizer.minimize(self._loss_endpoint) self.model._optimizer.minimize(self._loss_endpoint)
if self._nranks > 1 and mode != 'train':
outputs = [distributed._all_gather(o, self._nranks) for o in outputs]
if mode != 'test':
labels = [distributed._all_gather(l, self._nranks) for l in labels]
if mode != 'test':
for metric in self.model._metrics:
metrics.append(to_list(metric.add_metric_op(outputs, labels)))
if mode != 'train': # clone again to put it in test mode if mode != 'train': # clone again to put it in test mode
prog = prog.clone(for_test=True) prog = prog.clone(for_test=True)
...@@ -870,6 +871,8 @@ class Model(fluid.dygraph.Layer): ...@@ -870,6 +871,8 @@ class Model(fluid.dygraph.Layer):
log_freq=10, log_freq=10,
save_freq=1, save_freq=1,
verbose=2, verbose=2,
drop_last=False,
shuffle=True,
num_workers=0, num_workers=0,
callbacks=None, ): callbacks=None, ):
""" """
...@@ -901,10 +904,10 @@ class Model(fluid.dygraph.Layer): ...@@ -901,10 +904,10 @@ class Model(fluid.dygraph.Layer):
feed_list = [x.forward() for x in self._inputs + self._labels] feed_list = [x.forward() for x in self._inputs + self._labels]
if train_loader is None: if train_loader is None:
if distributed.get_nranks() > 1:
train_sampler = DistributedBatchSampler(train_dataset, train_sampler = DistributedBatchSampler(train_dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=True) shuffle=shuffle,
drop_last=drop_last)
train_loader = DataLoader(train_dataset, train_loader = DataLoader(train_dataset,
batch_sampler=train_sampler, batch_sampler=train_sampler,
places=self._place, places=self._place,
...@@ -912,16 +915,7 @@ class Model(fluid.dygraph.Layer): ...@@ -912,16 +915,7 @@ class Model(fluid.dygraph.Layer):
num_workers=num_workers, num_workers=num_workers,
return_list=True) return_list=True)
else:
train_loader = DataLoader(train_dataset,
batch_size=batch_size,
places=self._place,
feed_list=feed_list,
num_workers=4,
return_list=True)
if eval_loader is None and eval_dataset is not None: if eval_loader is None and eval_dataset is not None:
if distributed.get_nranks() > 1:
eval_sampler = DistributedBatchSampler(eval_dataset, eval_sampler = DistributedBatchSampler(eval_dataset,
batch_size=batch_size) batch_size=batch_size)
eval_loader = DataLoader(eval_dataset, eval_loader = DataLoader(eval_dataset,
...@@ -930,13 +924,6 @@ class Model(fluid.dygraph.Layer): ...@@ -930,13 +924,6 @@ class Model(fluid.dygraph.Layer):
feed_list=feed_list, feed_list=feed_list,
num_workers=num_workers, num_workers=num_workers,
return_list=True) return_list=True)
else:
eval_loader = DataLoader(eval_dataset,
batch_size=batch_size,
places=self._place,
feed_list=feed_list,
num_workers=num_workers,
return_list=True)
do_eval = eval_loader is not None do_eval = eval_loader is not None
self._test_dataloader = eval_loader self._test_dataloader = eval_loader
......
...@@ -141,7 +141,7 @@ class MyCrossEntropy(Loss): ...@@ -141,7 +141,7 @@ class MyCrossEntropy(Loss):
class TestModel(unittest.TestCase): class TestModel(unittest.TestCase):
def fit(self, dynamic, is_mlp=False): def fit(self, dynamic, is_mlp=False):
init_context('dynamic' if FLAGS.dynamic else 'static') init_context('dynamic' if dynamic else 'static')
im_shape = (-1, 784) im_shape = (-1, 784)
batch_size = 128 batch_size = 128
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册