diff --git a/callbacks.py b/callbacks.py index a055940776791bebefb4ae4c6eadf120f3d3450f..66690cf288efe8ba0d8dcc9eec64031674c8a18b 100644 --- a/callbacks.py +++ b/callbacks.py @@ -242,6 +242,12 @@ class ProgBarLogger(Callback): samples = logs.get('batch_size', 1) self.evaled_samples += samples + if self.eval_step % self.log_freq == 0 and self.verbose and ParallelEnv( + ).local_rank == 0: + # if steps is not None, last step will update in on_epoch_end + if self.eval_steps and self.eval_step < self.eval_steps: + self._updates(logs, 'eval') + def on_eval_end(self, logs=None): logs = logs or {} if self.verbose and ParallelEnv().local_rank == 0: diff --git a/distributed.py b/distributed.py index d4302254738f8354e4e1a41cd58f158f34069249..87818545671c45cf4faba234406e87762e897784 100644 --- a/distributed.py +++ b/distributed.py @@ -25,7 +25,6 @@ from paddle.fluid.layers import collective from paddle.fluid.dygraph.parallel import ParallelEnv, ParallelStrategy from paddle.fluid.io import BatchSampler - _parallel_context_initialized = False @@ -67,7 +66,8 @@ class DistributedBatchSampler(BatchSampler): self.nranks = ParallelEnv().nranks self.local_rank = ParallelEnv().local_rank self.epoch = 0 - self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks)) + self.num_samples = int( + math.ceil(len(self.dataset) * 1.0 / self.nranks)) self.total_size = self.num_samples * self.nranks def __iter__(self): @@ -78,9 +78,28 @@ class DistributedBatchSampler(BatchSampler): if self.shuffle: np.random.RandomState(self.epoch).shuffle(indices) self.epoch += 1 + # subsample - indices = indices[self.local_rank * self.num_samples: - (self.local_rank + 1) * self.num_samples] + def _get_indices_by_batch_size(indices): + subsampled_indices = [] + last_batch_size = self.total_size % (self.batch_size * self.nranks) + assert last_batch_size % self.nranks == 0 + last_local_batch_size = last_batch_size // self.nranks + + for i in range(self.local_rank * self.batch_size, + len(indices) - last_batch_size, + self.batch_size * self.nranks): + subsampled_indices.extend(indices[i:i + self.batch_size]) + + indices = indices[len(indices) - last_batch_size:] + subsampled_indices.extend(indices[ + self.local_rank * last_local_batch_size:( + self.local_rank + 1) * last_local_batch_size]) + return subsampled_indices + + if self.nranks > 1: + indices = _get_indices_by_batch_size(indices) + assert len(indices) == self.num_samples _sample_iter = iter(indices) @@ -103,7 +122,8 @@ class DistributedBatchSampler(BatchSampler): def _all_gather(x, nranks, ring_id=0, use_calc_stream=True): - return collective._c_allgather(x, nranks, ring_id=ring_id, use_calc_stream=use_calc_stream) + return collective._c_allgather( + x, nranks, ring_id=ring_id, use_calc_stream=use_calc_stream) def wait_server_ready(endpoints): @@ -114,8 +134,7 @@ def wait_server_ready(endpoints): for ep in endpoints: ip_port = ep.split(":") with contextlib.closing( - socket.socket(socket.AF_INET, - socket.SOCK_STREAM)) as sock: + socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: sock.settimeout(2) result = sock.connect_ex((ip_port[0], int(ip_port[1]))) if result != 0: @@ -127,8 +146,8 @@ def wait_server_ready(endpoints): break -def init_communicator(program, rank, nranks, wait_port, - current_endpoint, endpoints): +def init_communicator(program, rank, nranks, wait_port, current_endpoint, + endpoints): if nranks < 2: return other_endpoints = endpoints[:] @@ -166,7 +185,7 @@ def prepare_distributed_context(place=None): if place is None: place = fluid.CUDAPlace(ParallelEnv().dev_id) if ParallelEnv().nranks > 1 \ else fluid.CUDAPlace(0) - + strategy = ParallelStrategy() strategy.nranks = ParallelEnv().nranks strategy.local_rank = ParallelEnv().local_rank @@ -178,11 +197,14 @@ def prepare_distributed_context(place=None): global _parallel_context_initialized - if not _parallel_context_initialized and isinstance(place, fluid.CUDAPlace): + if not _parallel_context_initialized and isinstance(place, + fluid.CUDAPlace): + def _init_context(): communicator_prog = fluid.Program() - init_communicator(communicator_prog, strategy.local_rank, strategy.nranks, - True, strategy.current_endpoint, strategy.trainer_endpoints) + init_communicator(communicator_prog, strategy.local_rank, + strategy.nranks, True, strategy.current_endpoint, + strategy.trainer_endpoints) exe = fluid.Executor(place) exe.run(communicator_prog) @@ -197,4 +219,4 @@ def prepare_distributed_context(place=None): assert ("Only support CUDAPlace for now.") _parallel_context_initialized = True - return strategy \ No newline at end of file + return strategy diff --git a/model.py b/model.py index 91340488ff60348f881e895d7b9587b452b5fb97..ba80bea0c137158bebbee2537bae3788d0229800 100644 --- a/model.py +++ b/model.py @@ -20,6 +20,7 @@ import pickle import numpy as np import six import warnings +import tqdm from collections import Iterable from paddle import fluid @@ -587,10 +588,8 @@ class DynamicGraphAdapter(object): samples = outputs[0].shape[0] current_count = self._merge_count.get(self.mode + '_total', 0) if current_count + samples >= total_size: - outputs = [ - o[:total_size - metric.count[0]] for o in outputs - ] - labels = [l[:total_size - metric.count[0]] for l in labels] + outputs = [o[:total_size - current_count] for o in outputs] + labels = [l[:total_size - current_count] for l in labels] self._merge_count[self.mode + '_total'] = 0 self._merge_count[self.mode + '_batch'] = total_size - current_count @@ -612,8 +611,9 @@ class DynamicGraphAdapter(object): self.mode = 'test' inputs = [to_variable(x) for x in to_list(inputs)] outputs = self.model.forward(*inputs) - if self._nranks > 2: + if self._nranks > 1 and isinstance(self.model._place, fluid.CUDAPlace): outputs = [_all_gather(o, self._nranks) for o in to_list(outputs)] + return [to_numpy(o) for o in to_list(outputs)] def parameters(self, *args, **kwargs): @@ -696,7 +696,6 @@ class Model(fluid.dygraph.Layer): self._loss_weights = None self._optimizer = None self._device = None - self._device_ids = None self._optimizer = None self._test_dataloader = None @@ -794,8 +793,7 @@ class Model(fluid.dygraph.Layer): metrics=None, inputs=None, labels=None, - device=None, - device_ids=None): + device=None): """ FIXME: add comments Args: @@ -818,17 +816,6 @@ class Model(fluid.dygraph.Layer): device (str|None): specify device type, 'CPU' or 'GPU'. If None, automatically select device according to installation package version. - device_ids (list[int]|None): specify device index. If None, - the available device will be obtained from the environment - variable when the model is executed: If the GPU is used, the - currently available device ID is obtained from the environment - variable FLAGS_selected_gpus or CUDA_VISIBLE_DEVICES when the - model is executed; CPU, when the model is executed, - the currently available CPU number is obtained from the - environment variable CPU_NUM. For example, export CPU_NUM=4, - if the environment variable is not set, the executor will add - the variable to the environment variable and set its value to 1. - The default is None. """ if isinstance(device, fluid.CUDAPlace) or \ @@ -918,7 +905,7 @@ class Model(fluid.dygraph.Layer): eval_freq (int): The frequency, in number of epochs, an evalutation is performed. log_freq (int): The frequency, in number of steps, the training logs - is printed. + are printed. save_dir(str|None): The directory to save checkpoint during training. If None, will not save checkpoint. save_freq (int): The frequency, in number of epochs, to save checkpoint. @@ -991,71 +978,243 @@ class Model(fluid.dygraph.Layer): verbose=verbose, metrics=self._metrics_name(), ) - def _run_one_epoch(data_loader, callbacks, mode): - size = len(data_loader) if hasattr(data_loader, - '__len__') else None - logs = { - 'steps': size, - 'metrics_name': metrics_name, - } - for step, data in enumerate(data_loader): - if not fluid.in_dygraph_mode(): - data = data[0] - batch_size = data[0].shape()[0] - else: - batch_size = data[0].shape[0] - - cbks.on_batch_begin(mode, step, logs) - if mode == 'train': - outs = self.train(*data) - else: - outs = self.eval(*data) - - # losses - loss = outs[0] if self._metrics else outs - metrics = [[l[0] for l in loss]] - - # metrics - for metric in self._metrics: - res = metric.accumulate() - metrics.extend(to_list(res)) - - assert len(metrics_name) == len(metrics) - for k, v in zip(metrics_name, metrics): - logs[k] = v - - logs['step'] = step - if mode == 'train' or self._adapter._merge_count.get( - mode + '_batch', 0) <= 0: - logs['batch_size'] = batch_size * ParallelEnv().nranks - else: - logs['batch_size'] = self._adapter._merge_count[mode + - '_batch'] - - cbks.on_batch_end(mode, step, logs) - self._reset_metrics() - return logs - cbks.on_begin('train') for epoch in range(epochs): - cbks.on_epoch_begin(epoch) + # FIXME: adapt to DataLoader loader = train_loader if not isinstance(train_loader, Iterable): loader = train_loader() - logs = _run_one_epoch(loader, cbks, 'train') - cbks.on_epoch_end(epoch, logs) + logs = self._run_one_epoch( + loader, cbks, 'train', metrics_name, epoch=epoch) if do_eval and epoch % eval_freq == 0: - cbks.on_begin('eval', logs) # FIXME: adapt to DataLoader loader = eval_loader if not isinstance(eval_loader, Iterable): loader = eval_loader() - logs = _run_one_epoch(loader, cbks, 'eval') + + eval_steps = len(loader) if hasattr(loader, + '__len__') else None + cbks.on_begin('eval', { + 'steps': eval_steps, + 'metrics_name': metrics_name + }) + + logs = self._run_one_epoch(loader, cbks, 'eval', metrics_name) + cbks.on_end('eval', logs) cbks.on_end('train', logs) + self._test_dataloader = None + + def evaluate( + self, + eval_data, + batch_size=1, + log_freq=10, + verbose=2, + num_workers=0, + callbacks=None, ): + """ + FIXME: add more comments and usage + Args: + eval_data (Dataset|DataLoader): An iterable data loader is used for + evaluation. An instance of paddle.fluid.io.Dataset or + paddle.fluid.io.Dataloader is recomended. + batch_size (int): Integer number. The batch size of train_data and eval_data. + When train_data and eval_data are both the instance of Dataloader, this + parameter will be ignored. + log_freq (int): The frequency, in number of steps, the eval logs + are printed. + verbose (int): The verbosity mode, should be 0, 1, or 2. + 0 = silent, 1 = progress bar, 2 = one line per epoch. + num_workers (int): The number of subprocess to load data, 0 for no subprocess + used and loading data in main process. When train_data and eval_data are + both the instance of Dataloader, this parameter will be ignored. + callbacks (Callback|None): A list of `Callback` instances to apply + during training. If None, `ProgBarLogger` and `ModelCheckpoint` + are automatically inserted. + """ + + if fluid.in_dygraph_mode(): + feed_list = None + else: + feed_list = [x.forward() for x in self._inputs + self._labels] + + if eval_data is not None and isinstance(eval_data, Dataset): + eval_sampler = DistributedBatchSampler( + eval_data, batch_size=batch_size) + eval_loader = DataLoader( + eval_data, + batch_sampler=eval_sampler, + places=self._place, + feed_list=feed_list, + num_workers=num_workers, + return_list=True) + else: + eval_loader = eval_data + + self._test_dataloader = eval_loader + metrics_name = self._metrics_name() + + cbks = config_callbacks( + callbacks, + model=self, + log_freq=log_freq, + verbose=verbose, + metrics=self._metrics_name(), ) + + loader = eval_loader + if not isinstance(eval_loader, Iterable): + loader = eval_loader() + + eval_steps = len(loader) if hasattr(loader, '__len__') else None + cbks.on_begin('eval', + {'steps': eval_steps, + 'metrics_name': metrics_name}) + + logs = self._run_one_epoch(loader, cbks, 'eval', metrics_name) + + cbks.on_end('eval', logs) + + self._test_dataloader = None + + eval_result = {} + for k in self._metrics_name(): + eval_result[k] = logs[k] + + return eval_result + + def predict(self, test_data, batch_size=1, num_workers=0): + """ + FIXME: add more comments and usage + Args: + test_data (Dataset|DataLoader): An iterable data loader is used for + predict. An instance of paddle.fluid.io.Dataset or paddle.fluid.io.Dataloader + is recomended. + batch_size (int): Integer number. The batch size of train_data and eval_data. + When train_data and eval_data are both the instance of Dataloader, this + parameter will be ignored. + num_workers (int): the number of subprocess to load data, 0 for no subprocess + used and loading data in main process. When train_data and eval_data are + both the instance of Dataloader, this parameter will be ignored. + """ + + if fluid.in_dygraph_mode(): + feed_list = None + else: + feed_list = [x.forward() for x in self._inputs + self._labels] + + if test_data is not None and isinstance(test_data, Dataset): + test_sampler = DistributedBatchSampler( + test_data, batch_size=batch_size) + test_loader = DataLoader( + test_data, + batch_sampler=test_sampler, + places=self._place, + feed_list=feed_list, + num_workers=num_workers, + return_list=True) + else: + test_loader = test_data + + self._test_dataloader = test_loader + + loader = test_loader + if not isinstance(test_loader, Iterable): + loader = test_loader() + + outputs = None + for data in tqdm.tqdm(loader): + if not fluid.in_dygraph_mode(): + data = data[0] + + outs = self.test(*data) + + if outputs is None: + outputs = outs + else: + outputs = [ + np.vstack([x, outs[i]]) for i, x in enumerate(outputs) + ] + + self._test_dataloader = None + if test_loader is not None and self._adapter._nranks > 1 \ + and isinstance(test_loader, DataLoader): + outputs = [o[:len(test_loader.dataset)] for o in outputs] + return outputs + + def set_eval_data(self, eval_data): + """ + Args: + eval_data (Dataset|DataLoader|None): An iterable data loader is used for + eval. An instance of paddle.fluid.io.Dataset or + paddle.fluid.io.Dataloader is recomended. + """ + assert isinstance( + eval_data, + DataLoader), "eval_data must be a instance of Dataloader!" + self._test_dataloader = eval_data + + def _run_one_epoch(self, + data_loader, + callbacks, + mode, + metrics_name, + epoch=None): + size = len(data_loader) if hasattr(data_loader, '__len__') else None + logs = { + 'steps': size, + 'metrics_name': metrics_name, + } + + if mode == 'train': + assert epoch is not None, 'when mode is train, epoch must be given' + callbacks.on_epoch_begin(epoch) + + for step, data in enumerate(data_loader): + if not fluid.in_dygraph_mode(): + data = data[0] + batch_size = data[0].shape()[0] + else: + batch_size = data[0].shape[0] + + callbacks.on_batch_begin(mode, step, logs) + if mode == 'train': + outs = self.train(*data) + else: + outs = self.eval(*data) + + # losses + loss = outs[0] if self._metrics else outs + metrics = [[l[0] for l in loss]] + + # metrics + for metric in self._metrics: + res = metric.accumulate() + metrics.extend(to_list(res)) + + assert len(metrics_name) == len(metrics) + for k, v in zip(metrics_name, metrics): + logs[k] = v + + logs['step'] = step + if mode == 'train' or self._adapter._merge_count.get( + mode + '_batch', 0) <= 0: + logs['batch_size'] = batch_size * ParallelEnv().nranks + else: + logs['batch_size'] = self._adapter._merge_count[mode + + '_batch'] + + callbacks.on_batch_end(mode, step, logs) + self._reset_metrics() + + if mode == 'train': + assert epoch is not None, 'when mode is train, epoch must be given' + callbacks.on_epoch_end(epoch) + + return logs def _reset_metrics(self): for metric in self._metrics: diff --git a/tests/test_model.py b/tests/test_model.py index 87c9e5731f4f73d93b0a7dda70b771f76fdabae3..9e8c880e461c684bc46e392c362ace3d00e67f53 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -139,6 +139,26 @@ class MyCrossEntropy(Loss): return [loss1, loss2] +class TestMnistDataset(MnistDataset): + def __init__(self): + super(TestMnistDataset, self).__init__(mode='test') + + def __getitem__(self, idx): + return self.images[idx], + + def __len__(self): + return len(self.images) + + +def get_predict_accuracy(pred, gt): + pred = np.argmax(pred, -1) + gt = np.array(gt) + + correct = pred[:, np.newaxis] == gt + + return np.sum(correct) / correct.shape[0] + + class TestModel(unittest.TestCase): def fit(self, dynamic, is_mlp=False): device = set_device('gpu') @@ -152,6 +172,7 @@ class TestModel(unittest.TestCase): train_dataset = MnistDataset(mode='train') val_dataset = MnistDataset(mode='test') + test_dataset = TestMnistDataset() model = MNIST() if not is_mlp else MLP() optim = fluid.optimizer.Momentum( @@ -159,12 +180,23 @@ class TestModel(unittest.TestCase): loss = CrossEntropy() if not is_mlp else MyCrossEntropy() model.prepare(optim, loss, Accuracy(), inputs, labels, device=device) cbk = ProgBarLogger(50) + model.fit(train_dataset, val_dataset, epochs=2, batch_size=batch_size, callbacks=cbk) + eval_result = model.evaluate(val_dataset, batch_size=batch_size) + + output = model.predict(test_dataset, batch_size=batch_size) + + np.testing.assert_equal(output[0].shape[0], len(test_dataset)) + + acc = get_predict_accuracy(output[0], val_dataset.labels) + + np.testing.assert_allclose(acc, eval_result['acc']) + def test_fit_static(self): self.fit(False)