提交 ba723731 编写于 作者: L LielinJiang

add init context

上级 47cd178a
......@@ -26,7 +26,7 @@ from paddle.fluid.optimizer import Momentum
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.io import MNIST as MnistDataset
from model import Model, CrossEntropy, Input
from model import Model, CrossEntropy, Input, init_context
from metrics import Accuracy
......@@ -106,10 +106,8 @@ class MNIST(Model):
def main():
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if fluid.dygraph.parallel.Env().nranks > 1 else fluid.CUDAPlace(0)
fluid.enable_dygraph(place) if FLAGS.dynamic else None
init_context('dynamic' if FLAGS.dynamic else 'static')
train_dataset = MnistDataset(mode='train')
val_dataset = MnistDataset(mode='test')
......
......@@ -85,6 +85,18 @@ def extract_args(func):
return inspect.getargspec(func)[0]
def init_context(backend):
assert isinstance(backend, str) and backend.lower() in ['dynamic', 'static'], \
"Expected backend in ['dynamic', 'static'], but got {}".format(backend)
place = fluid.CUDAPlace(distributed.Env().dev_id) if \
distributed.Env().nranks > 1 else fluid.CUDAPlace(0)
distributed.prepare_distributed_context()
backend = backend.lower()
if backend == 'dynamic':
fluid.enable_dygraph(place)
class Input(fluid.dygraph.Layer):
def __init__(self, shape=None, dtype=None, name=None):
super(Input, self).__init__()
......@@ -357,7 +369,7 @@ class StaticGraphAdapter(object):
# TODO: fixme if have better way to get batch size
samples = state[0].shape[0]
current_count = self._merge_count.get(self.mode + '_total', 0)
if current_count + samples > total_size:
if current_count + samples >= total_size:
state = [s[:total_size - current_count, ...] for s in state]
self._merge_count[self.mode + '_total'] = 0
self._merge_count[self.mode + '_batch'] = total_size - current_count
......@@ -423,12 +435,11 @@ class StaticGraphAdapter(object):
strategy=dist_strategy)
self.model._optimizer.minimize(self._loss_endpoint)
if self._nranks > 1 and mode != 'train' and self.model._test_dataloader is not None \
and isinstance(self.model._test_dataloader, DataLoader):
if self._nranks > 1 and mode != 'train':
outputs = [distributed._all_gather(o, self._nranks) for o in outputs]
if mode != 'test':
labels = [distributed._all_gather(l, self._nranks) for l in labels]
if mode != 'test':
for metric in self.model._metrics:
metrics.append(to_list(metric.add_metric_op(outputs, labels)))
......@@ -566,11 +577,12 @@ class DynamicGraphAdapter(object):
metrics = []
for metric in self.model._metrics:
# cut off padding value.
if self.model._test_dataloader is not None and self._nranks > 1:
if self.model._test_dataloader is not None and self._nranks > 1 \
and isinstance(self.model._test_dataloader, DataLoader):
total_size = len(self.model._test_dataloader.dataset)
samples = outputs[0].shape[0]
current_count = self._merge_count.get(self.mode + '_total', 0)
if current_count + samples > total_size:
if current_count + samples >= total_size:
outputs = [o[:total_size - metric.count[0]] for o in outputs]
labels = [l[:total_size - metric.count[0]] for l in labels]
self._merge_count[self.mode + '_total'] = 0
......@@ -685,11 +697,9 @@ class Model(fluid.dygraph.Layer):
# init multiple gpus context
self._place = fluid.CUDAPlace(distributed.Env().dev_id) \
if distributed.Env().nranks > 1 else fluid.CUDAPlace(0)
if distributed.get_nranks() > 1:
distributed.prepare_distributed_context(self._place)
# init backend
if fluid.in_dygraph_mode():
if fluid.in_dygraph_mode():
self._adapter = DynamicGraphAdapter(self)
else:
self._adapter = StaticGraphAdapter(self)
......@@ -974,7 +984,7 @@ class Model(fluid.dygraph.Layer):
logs[k] = v
logs['step'] = step
if mode == 'train' or self._adapter._merge_count[mode + '_batch'] <= 0:
if mode == 'train' or self._adapter._merge_count.get(mode + '_batch', 0) <= 0:
logs['batch_size'] = batch_size * distributed.Env().nranks
else:
logs['batch_size'] = self._adapter._merge_count[mode + '_batch']
......
......@@ -28,7 +28,7 @@ import contextlib
import paddle
from paddle import fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from model import Model, CrossEntropy, Input, Loss
from model import Model, CrossEntropy, Input, Loss, init_context
from metrics import Accuracy
from callbacks import ProgBarLogger
from paddle.fluid.io import BatchSampler, DataLoader
......@@ -110,11 +110,6 @@ class MNIST(Model):
return x
@contextlib.contextmanager
def null_guard():
yield
class MLP(Model):
def __init__(self):
super(MLP, self).__init__()
......@@ -146,12 +141,10 @@ class MyCrossEntropy(Loss):
class TestModel(unittest.TestCase):
def fit(self, dynamic, is_mlp=False):
init_context('dynamic' if FLAGS.dynamic else 'static')
im_shape = (-1, 784)
batch_size = 128
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if fluid.dygraph.parallel.Env().nranks > 1 else fluid.CUDAPlace(0)
fluid.enable_dygraph(place) if dynamic else None
inputs = [Input(im_shape, 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册