diff --git a/mnist.py b/mnist.py index 7beac484dc1ff8bbc015d5ac56a8199e15e25ec5..f8e1883844108f03d2c360f5466012d44ffd4980 100644 --- a/mnist.py +++ b/mnist.py @@ -26,7 +26,7 @@ from paddle.fluid.optimizer import Momentum from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear from paddle.fluid.io import MNIST as MnistDataset -from model import Model, CrossEntropy, Input +from model import Model, CrossEntropy, Input, init_context from metrics import Accuracy @@ -106,10 +106,8 @@ class MNIST(Model): def main(): - place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \ - if fluid.dygraph.parallel.Env().nranks > 1 else fluid.CUDAPlace(0) - fluid.enable_dygraph(place) if FLAGS.dynamic else None - + init_context('dynamic' if FLAGS.dynamic else 'static') + train_dataset = MnistDataset(mode='train') val_dataset = MnistDataset(mode='test') diff --git a/model.py b/model.py index c38a6e7bcd368a6aebac6220df9814a0b2152cb7..8262a33c83a75619606219542fdd4e026f8858c8 100644 --- a/model.py +++ b/model.py @@ -85,6 +85,18 @@ def extract_args(func): return inspect.getargspec(func)[0] +def init_context(backend): + assert isinstance(backend, str) and backend.lower() in ['dynamic', 'static'], \ + "Expected backend in ['dynamic', 'static'], but got {}".format(backend) + + place = fluid.CUDAPlace(distributed.Env().dev_id) if \ + distributed.Env().nranks > 1 else fluid.CUDAPlace(0) + distributed.prepare_distributed_context() + backend = backend.lower() + if backend == 'dynamic': + fluid.enable_dygraph(place) + + class Input(fluid.dygraph.Layer): def __init__(self, shape=None, dtype=None, name=None): super(Input, self).__init__() @@ -357,7 +369,7 @@ class StaticGraphAdapter(object): # TODO: fixme if have better way to get batch size samples = state[0].shape[0] current_count = self._merge_count.get(self.mode + '_total', 0) - if current_count + samples > total_size: + if current_count + samples >= total_size: state = [s[:total_size - current_count, ...] for s in state] self._merge_count[self.mode + '_total'] = 0 self._merge_count[self.mode + '_batch'] = total_size - current_count @@ -423,12 +435,11 @@ class StaticGraphAdapter(object): strategy=dist_strategy) self.model._optimizer.minimize(self._loss_endpoint) - if self._nranks > 1 and mode != 'train' and self.model._test_dataloader is not None \ - and isinstance(self.model._test_dataloader, DataLoader): + if self._nranks > 1 and mode != 'train': outputs = [distributed._all_gather(o, self._nranks) for o in outputs] if mode != 'test': labels = [distributed._all_gather(l, self._nranks) for l in labels] - + if mode != 'test': for metric in self.model._metrics: metrics.append(to_list(metric.add_metric_op(outputs, labels))) @@ -566,11 +577,12 @@ class DynamicGraphAdapter(object): metrics = [] for metric in self.model._metrics: # cut off padding value. - if self.model._test_dataloader is not None and self._nranks > 1: + if self.model._test_dataloader is not None and self._nranks > 1 \ + and isinstance(self.model._test_dataloader, DataLoader): total_size = len(self.model._test_dataloader.dataset) samples = outputs[0].shape[0] current_count = self._merge_count.get(self.mode + '_total', 0) - if current_count + samples > total_size: + if current_count + samples >= total_size: outputs = [o[:total_size - metric.count[0]] for o in outputs] labels = [l[:total_size - metric.count[0]] for l in labels] self._merge_count[self.mode + '_total'] = 0 @@ -685,11 +697,9 @@ class Model(fluid.dygraph.Layer): # init multiple gpus context self._place = fluid.CUDAPlace(distributed.Env().dev_id) \ if distributed.Env().nranks > 1 else fluid.CUDAPlace(0) - if distributed.get_nranks() > 1: - distributed.prepare_distributed_context(self._place) # init backend - if fluid.in_dygraph_mode(): + if fluid.in_dygraph_mode(): self._adapter = DynamicGraphAdapter(self) else: self._adapter = StaticGraphAdapter(self) @@ -974,7 +984,7 @@ class Model(fluid.dygraph.Layer): logs[k] = v logs['step'] = step - if mode == 'train' or self._adapter._merge_count[mode + '_batch'] <= 0: + if mode == 'train' or self._adapter._merge_count.get(mode + '_batch', 0) <= 0: logs['batch_size'] = batch_size * distributed.Env().nranks else: logs['batch_size'] = self._adapter._merge_count[mode + '_batch'] diff --git a/tests/test_model.py b/tests/test_model.py index f3829becb40e34775ea02d4347a95ecf34d87cc1..e41e496b92d823e585c2bfaf9c9b163fa2ade74e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -28,7 +28,7 @@ import contextlib import paddle from paddle import fluid from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear -from model import Model, CrossEntropy, Input, Loss +from model import Model, CrossEntropy, Input, Loss, init_context from metrics import Accuracy from callbacks import ProgBarLogger from paddle.fluid.io import BatchSampler, DataLoader @@ -110,11 +110,6 @@ class MNIST(Model): return x -@contextlib.contextmanager -def null_guard(): - yield - - class MLP(Model): def __init__(self): super(MLP, self).__init__() @@ -146,12 +141,10 @@ class MyCrossEntropy(Loss): class TestModel(unittest.TestCase): def fit(self, dynamic, is_mlp=False): + init_context('dynamic' if FLAGS.dynamic else 'static') + im_shape = (-1, 784) batch_size = 128 - - place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \ - if fluid.dygraph.parallel.Env().nranks > 1 else fluid.CUDAPlace(0) - fluid.enable_dygraph(place) if dynamic else None inputs = [Input(im_shape, 'float32', name='image')] labels = [Input([None, 1], 'int64', name='label')]