提交 368d6302 编写于 作者: L LielinJiang

refine fit, distributedsampler

上级 ba723731
......@@ -220,7 +220,7 @@ class ProgBarLogger(Callback):
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
if self.verbose:
if self.verbose and get_local_rank() == 0:
self._updates(logs, 'train')
def on_eval_begin(self, logs=None):
......
......@@ -80,20 +80,18 @@ class DistributedBatchSampler(BatchSampler):
self.total_size = self.num_samples * self.nranks
def __iter__(self):
_sample_iter = self.sample_iter
if _sample_iter is None:
num_samples = len(self.dataset)
indices = np.arange(num_samples).tolist()
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
if self.shuffle:
np.random.RandomState(self.epoch).shuffle(indices)
self.epoch += 1
# subsample
indices = indices[self.local_rank * self.num_samples:
(self.local_rank + 1) * self.num_samples]
assert len(indices) == self.num_samples
_sample_iter = iter(indices)
num_samples = len(self.dataset)
indices = np.arange(num_samples).tolist()
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
if self.shuffle:
np.random.RandomState(self.epoch).shuffle(indices)
self.epoch += 1
# subsample
indices = indices[self.local_rank * self.num_samples:
(self.local_rank + 1) * self.num_samples]
assert len(indices) == self.num_samples
_sample_iter = iter(indices)
batch_indices = []
for idx in _sample_iter:
......
......@@ -91,7 +91,7 @@ def init_context(backend):
place = fluid.CUDAPlace(distributed.Env().dev_id) if \
distributed.Env().nranks > 1 else fluid.CUDAPlace(0)
distributed.prepare_distributed_context()
distributed.prepare_distributed_context(place)
backend = backend.lower()
if backend == 'dynamic':
fluid.enable_dygraph(place)
......@@ -419,22 +419,10 @@ class StaticGraphAdapter(object):
labels = [k.forward() for k in to_list(lbls)]
self._label_vars[mode] = labels
outputs = to_list(self.model.forward(*inputs))
if mode != 'test':
if self.model._loss_function:
if mode != 'test' and self.model._loss_function:
losses = self.model._loss_function(outputs, labels)
if mode == 'train' and self.model._optimizer:
self._loss_endpoint = fluid.layers.sum(losses)
if self._nranks > 1:
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
dist_strategy = DistributedStrategy()
dist_strategy.mode = "collective"
dist_strategy.collective_mode = "grad_allreduce"
self.model._optimizer = fleet.distributed_optimizer(self.model._optimizer,
strategy=dist_strategy)
self.model._optimizer.minimize(self._loss_endpoint)
if self._nranks > 1 and mode != 'train':
outputs = [distributed._all_gather(o, self._nranks) for o in outputs]
if mode != 'test':
......@@ -442,8 +430,21 @@ class StaticGraphAdapter(object):
if mode != 'test':
for metric in self.model._metrics:
metrics.append(to_list(metric.add_metric_op(outputs, labels)))
metrics.append(to_list(metric.add_metric_op(outputs, labels)))
if mode == 'train' and self.model._optimizer:
self._loss_endpoint = fluid.layers.sum(losses)
if self._nranks > 1:
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
dist_strategy = DistributedStrategy()
dist_strategy.mode = "collective"
dist_strategy.collective_mode = "grad_allreduce"
self.model._optimizer = fleet.distributed_optimizer(self.model._optimizer,
strategy=dist_strategy)
self.model._optimizer.minimize(self._loss_endpoint)
if mode != 'train': # clone again to put it in test mode
prog = prog.clone(for_test=True)
......@@ -870,6 +871,8 @@ class Model(fluid.dygraph.Layer):
log_freq=10,
save_freq=1,
verbose=2,
drop_last=False,
shuffle=True,
num_workers=0,
callbacks=None, ):
"""
......@@ -901,43 +904,27 @@ class Model(fluid.dygraph.Layer):
feed_list = [x.forward() for x in self._inputs + self._labels]
if train_loader is None:
if distributed.get_nranks() > 1:
train_sampler = DistributedBatchSampler(train_dataset,
batch_size=batch_size,
shuffle=True)
train_loader = DataLoader(train_dataset,
batch_sampler=train_sampler,
places=self._place,
feed_list=feed_list,
num_workers=num_workers,
return_list=True)
else:
train_loader = DataLoader(train_dataset,
batch_size=batch_size,
places=self._place,
feed_list=feed_list,
num_workers=4,
return_list=True)
train_sampler = DistributedBatchSampler(train_dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last)
train_loader = DataLoader(train_dataset,
batch_sampler=train_sampler,
places=self._place,
feed_list=feed_list,
num_workers=num_workers,
return_list=True)
if eval_loader is None and eval_dataset is not None:
if distributed.get_nranks() > 1:
eval_sampler = DistributedBatchSampler(eval_dataset,
batch_size=batch_size)
eval_loader = DataLoader(eval_dataset,
batch_sampler=eval_sampler,
places=self._place,
feed_list=feed_list,
num_workers=num_workers,
return_list=True)
else:
eval_loader = DataLoader(eval_dataset,
batch_size=batch_size,
places=self._place,
feed_list=feed_list,
num_workers=num_workers,
return_list=True)
eval_sampler = DistributedBatchSampler(eval_dataset,
batch_size=batch_size)
eval_loader = DataLoader(eval_dataset,
batch_sampler=eval_sampler,
places=self._place,
feed_list=feed_list,
num_workers=num_workers,
return_list=True)
do_eval = eval_loader is not None
self._test_dataloader = eval_loader
metrics_name = self._metrics_name()
......
......@@ -141,8 +141,8 @@ class MyCrossEntropy(Loss):
class TestModel(unittest.TestCase):
def fit(self, dynamic, is_mlp=False):
init_context('dynamic' if FLAGS.dynamic else 'static')
init_context('dynamic' if dynamic else 'static')
im_shape = (-1, 784)
batch_size = 128
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册