提交 ba723731 编写于 作者: L LielinJiang

add init context

上级 47cd178a
...@@ -26,7 +26,7 @@ from paddle.fluid.optimizer import Momentum ...@@ -26,7 +26,7 @@ from paddle.fluid.optimizer import Momentum
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.io import MNIST as MnistDataset from paddle.fluid.io import MNIST as MnistDataset
from model import Model, CrossEntropy, Input from model import Model, CrossEntropy, Input, init_context
from metrics import Accuracy from metrics import Accuracy
...@@ -106,10 +106,8 @@ class MNIST(Model): ...@@ -106,10 +106,8 @@ class MNIST(Model):
def main(): def main():
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \ init_context('dynamic' if FLAGS.dynamic else 'static')
if fluid.dygraph.parallel.Env().nranks > 1 else fluid.CUDAPlace(0)
fluid.enable_dygraph(place) if FLAGS.dynamic else None
train_dataset = MnistDataset(mode='train') train_dataset = MnistDataset(mode='train')
val_dataset = MnistDataset(mode='test') val_dataset = MnistDataset(mode='test')
......
...@@ -85,6 +85,18 @@ def extract_args(func): ...@@ -85,6 +85,18 @@ def extract_args(func):
return inspect.getargspec(func)[0] return inspect.getargspec(func)[0]
def init_context(backend):
assert isinstance(backend, str) and backend.lower() in ['dynamic', 'static'], \
"Expected backend in ['dynamic', 'static'], but got {}".format(backend)
place = fluid.CUDAPlace(distributed.Env().dev_id) if \
distributed.Env().nranks > 1 else fluid.CUDAPlace(0)
distributed.prepare_distributed_context()
backend = backend.lower()
if backend == 'dynamic':
fluid.enable_dygraph(place)
class Input(fluid.dygraph.Layer): class Input(fluid.dygraph.Layer):
def __init__(self, shape=None, dtype=None, name=None): def __init__(self, shape=None, dtype=None, name=None):
super(Input, self).__init__() super(Input, self).__init__()
...@@ -357,7 +369,7 @@ class StaticGraphAdapter(object): ...@@ -357,7 +369,7 @@ class StaticGraphAdapter(object):
# TODO: fixme if have better way to get batch size # TODO: fixme if have better way to get batch size
samples = state[0].shape[0] samples = state[0].shape[0]
current_count = self._merge_count.get(self.mode + '_total', 0) current_count = self._merge_count.get(self.mode + '_total', 0)
if current_count + samples > total_size: if current_count + samples >= total_size:
state = [s[:total_size - current_count, ...] for s in state] state = [s[:total_size - current_count, ...] for s in state]
self._merge_count[self.mode + '_total'] = 0 self._merge_count[self.mode + '_total'] = 0
self._merge_count[self.mode + '_batch'] = total_size - current_count self._merge_count[self.mode + '_batch'] = total_size - current_count
...@@ -423,12 +435,11 @@ class StaticGraphAdapter(object): ...@@ -423,12 +435,11 @@ class StaticGraphAdapter(object):
strategy=dist_strategy) strategy=dist_strategy)
self.model._optimizer.minimize(self._loss_endpoint) self.model._optimizer.minimize(self._loss_endpoint)
if self._nranks > 1 and mode != 'train' and self.model._test_dataloader is not None \ if self._nranks > 1 and mode != 'train':
and isinstance(self.model._test_dataloader, DataLoader):
outputs = [distributed._all_gather(o, self._nranks) for o in outputs] outputs = [distributed._all_gather(o, self._nranks) for o in outputs]
if mode != 'test': if mode != 'test':
labels = [distributed._all_gather(l, self._nranks) for l in labels] labels = [distributed._all_gather(l, self._nranks) for l in labels]
if mode != 'test': if mode != 'test':
for metric in self.model._metrics: for metric in self.model._metrics:
metrics.append(to_list(metric.add_metric_op(outputs, labels))) metrics.append(to_list(metric.add_metric_op(outputs, labels)))
...@@ -566,11 +577,12 @@ class DynamicGraphAdapter(object): ...@@ -566,11 +577,12 @@ class DynamicGraphAdapter(object):
metrics = [] metrics = []
for metric in self.model._metrics: for metric in self.model._metrics:
# cut off padding value. # cut off padding value.
if self.model._test_dataloader is not None and self._nranks > 1: if self.model._test_dataloader is not None and self._nranks > 1 \
and isinstance(self.model._test_dataloader, DataLoader):
total_size = len(self.model._test_dataloader.dataset) total_size = len(self.model._test_dataloader.dataset)
samples = outputs[0].shape[0] samples = outputs[0].shape[0]
current_count = self._merge_count.get(self.mode + '_total', 0) current_count = self._merge_count.get(self.mode + '_total', 0)
if current_count + samples > total_size: if current_count + samples >= total_size:
outputs = [o[:total_size - metric.count[0]] for o in outputs] outputs = [o[:total_size - metric.count[0]] for o in outputs]
labels = [l[:total_size - metric.count[0]] for l in labels] labels = [l[:total_size - metric.count[0]] for l in labels]
self._merge_count[self.mode + '_total'] = 0 self._merge_count[self.mode + '_total'] = 0
...@@ -685,11 +697,9 @@ class Model(fluid.dygraph.Layer): ...@@ -685,11 +697,9 @@ class Model(fluid.dygraph.Layer):
# init multiple gpus context # init multiple gpus context
self._place = fluid.CUDAPlace(distributed.Env().dev_id) \ self._place = fluid.CUDAPlace(distributed.Env().dev_id) \
if distributed.Env().nranks > 1 else fluid.CUDAPlace(0) if distributed.Env().nranks > 1 else fluid.CUDAPlace(0)
if distributed.get_nranks() > 1:
distributed.prepare_distributed_context(self._place)
# init backend # init backend
if fluid.in_dygraph_mode(): if fluid.in_dygraph_mode():
self._adapter = DynamicGraphAdapter(self) self._adapter = DynamicGraphAdapter(self)
else: else:
self._adapter = StaticGraphAdapter(self) self._adapter = StaticGraphAdapter(self)
...@@ -974,7 +984,7 @@ class Model(fluid.dygraph.Layer): ...@@ -974,7 +984,7 @@ class Model(fluid.dygraph.Layer):
logs[k] = v logs[k] = v
logs['step'] = step logs['step'] = step
if mode == 'train' or self._adapter._merge_count[mode + '_batch'] <= 0: if mode == 'train' or self._adapter._merge_count.get(mode + '_batch', 0) <= 0:
logs['batch_size'] = batch_size * distributed.Env().nranks logs['batch_size'] = batch_size * distributed.Env().nranks
else: else:
logs['batch_size'] = self._adapter._merge_count[mode + '_batch'] logs['batch_size'] = self._adapter._merge_count[mode + '_batch']
......
...@@ -28,7 +28,7 @@ import contextlib ...@@ -28,7 +28,7 @@ import contextlib
import paddle import paddle
from paddle import fluid from paddle import fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from model import Model, CrossEntropy, Input, Loss from model import Model, CrossEntropy, Input, Loss, init_context
from metrics import Accuracy from metrics import Accuracy
from callbacks import ProgBarLogger from callbacks import ProgBarLogger
from paddle.fluid.io import BatchSampler, DataLoader from paddle.fluid.io import BatchSampler, DataLoader
...@@ -110,11 +110,6 @@ class MNIST(Model): ...@@ -110,11 +110,6 @@ class MNIST(Model):
return x return x
@contextlib.contextmanager
def null_guard():
yield
class MLP(Model): class MLP(Model):
def __init__(self): def __init__(self):
super(MLP, self).__init__() super(MLP, self).__init__()
...@@ -146,12 +141,10 @@ class MyCrossEntropy(Loss): ...@@ -146,12 +141,10 @@ class MyCrossEntropy(Loss):
class TestModel(unittest.TestCase): class TestModel(unittest.TestCase):
def fit(self, dynamic, is_mlp=False): def fit(self, dynamic, is_mlp=False):
init_context('dynamic' if FLAGS.dynamic else 'static')
im_shape = (-1, 784) im_shape = (-1, 784)
batch_size = 128 batch_size = 128
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if fluid.dygraph.parallel.Env().nranks > 1 else fluid.CUDAPlace(0)
fluid.enable_dygraph(place) if dynamic else None
inputs = [Input(im_shape, 'float32', name='image')] inputs = [Input(im_shape, 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')] labels = [Input([None, 1], 'int64', name='label')]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册