提交 1b798365 编写于 作者: L LielinJiang

refine fit

上级 59be4ec2
...@@ -27,10 +27,10 @@ from paddle.fluid.framework import in_dygraph_mode, Variable ...@@ -27,10 +27,10 @@ from paddle.fluid.framework import in_dygraph_mode, Variable
from paddle.fluid.executor import global_scope from paddle.fluid.executor import global_scope
from paddle.fluid.io import is_belong_to_optimizer from paddle.fluid.io import is_belong_to_optimizer
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.parallel import Env from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
from paddle.fluid.incubate.fleet.base import role_maker from paddle.fluid.incubate.fleet.base import role_maker
from paddle.fluid.io import DataLoader from paddle.fluid.io import DataLoader, Dataset
from distributed import DistributedBatchSampler, _all_gather, prepare_distributed_context, _parallel_context_initialized from distributed import DistributedBatchSampler, _all_gather, prepare_distributed_context, _parallel_context_initialized
from metrics import Metric from metrics import Metric
...@@ -147,8 +147,8 @@ class StaticGraphAdapter(object): ...@@ -147,8 +147,8 @@ class StaticGraphAdapter(object):
'test_batch': 0 'test_batch': 0
} }
self._nranks = Env().nranks self._nranks = ParallelEnv().nranks
self._local_rank = Env().local_rank self._local_rank = ParallelEnv().local_rank
@property @property
def mode(self): def mode(self):
...@@ -469,7 +469,7 @@ class StaticGraphAdapter(object): ...@@ -469,7 +469,7 @@ class StaticGraphAdapter(object):
# therefore startup program only needs to run once # therefore startup program only needs to run once
if self._executor is None: if self._executor is None:
if self._nranks > 1 and device.lower() == 'gpu': if self._nranks > 1 and device.lower() == 'gpu':
gpu_id = int(Env().dev_id) gpu_id = int(ParallelEnv().dev_id)
place = fluid.CUDAPlace( place = fluid.CUDAPlace(
gpu_id) if device.lower() == 'gpu' else fluid.CPUPlace() gpu_id) if device.lower() == 'gpu' else fluid.CPUPlace()
else: else:
...@@ -506,8 +506,8 @@ class DynamicGraphAdapter(object): ...@@ -506,8 +506,8 @@ class DynamicGraphAdapter(object):
def __init__(self, model): def __init__(self, model):
super(DynamicGraphAdapter, self).__init__() super(DynamicGraphAdapter, self).__init__()
self.model = model self.model = model
self._nranks = Env().nranks self._nranks = ParallelEnv().nranks
self._local_rank = Env().local_rank self._local_rank = ParallelEnv().local_rank
self._merge_count = { self._merge_count = {
'eval_total': 0, 'eval_total': 0,
'test_total': 0, 'test_total': 0,
...@@ -517,10 +517,10 @@ class DynamicGraphAdapter(object): ...@@ -517,10 +517,10 @@ class DynamicGraphAdapter(object):
if self._nranks > 1: if self._nranks > 1:
stradegy = fluid.dygraph.parallel.ParallelStrategy() stradegy = fluid.dygraph.parallel.ParallelStrategy()
stradegy.nranks = Env().nranks stradegy.nranks = ParallelEnv().nranks
stradegy.local_rank = Env().local_rank stradegy.local_rank = ParallelEnv().local_rank
stradegy.trainer_endpoints = Env().trainer_endpoints stradegy.trainer_endpoints = ParallelEnv().trainer_endpoints
stradegy.current_endpoint = Env().current_endpoint stradegy.current_endpoint = ParallelEnv().current_endpoint
self.ddp_model = fluid.dygraph.parallel.DataParallel( self.ddp_model = fluid.dygraph.parallel.DataParallel(
self.model, stradegy) self.model, stradegy)
...@@ -703,11 +703,11 @@ class Model(fluid.dygraph.Layer): ...@@ -703,11 +703,11 @@ class Model(fluid.dygraph.Layer):
self._test_dataloader = None self._test_dataloader = None
# init multiple gpus context # init multiple gpus context
self._place = fluid.CUDAPlace(Env().dev_id) \ self._place = fluid.CUDAPlace(ParallelEnv().dev_id) \
if Env().nranks > 1 else fluid.CUDAPlace(0) if ParallelEnv().nranks > 1 else fluid.CUDAPlace(0)
global _parallel_context_initialized global _parallel_context_initialized
if Env().nranks > 1 and not _parallel_context_initialized: if ParallelEnv().nranks > 1 and not _parallel_context_initialized:
if fluid.in_dygraph_mode(): if fluid.in_dygraph_mode():
fluid.disable_dygraph() fluid.disable_dygraph()
fluid.enable_dygraph(self._place) fluid.enable_dygraph(self._place)
...@@ -733,7 +733,7 @@ class Model(fluid.dygraph.Layer): ...@@ -733,7 +733,7 @@ class Model(fluid.dygraph.Layer):
return self._adapter.test(*args, **kwargs) return self._adapter.test(*args, **kwargs)
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
if Env().local_rank == 0: if ParallelEnv().local_rank == 0:
return self._adapter.save(*args, **kwargs) return self._adapter.save(*args, **kwargs)
def load(self, path, skip_mismatch=False, reset_optimizer=False): def load(self, path, skip_mismatch=False, reset_optimizer=False):
...@@ -880,10 +880,8 @@ class Model(fluid.dygraph.Layer): ...@@ -880,10 +880,8 @@ class Model(fluid.dygraph.Layer):
def fit( def fit(
self, self,
train_dataset=None, train_data=None,
eval_dataset=None, eval_data=None,
train_loader=None,
eval_loader=None,
batch_size=1, batch_size=1,
epochs=1, epochs=1,
eval_freq=1, eval_freq=1,
...@@ -898,11 +896,16 @@ class Model(fluid.dygraph.Layer): ...@@ -898,11 +896,16 @@ class Model(fluid.dygraph.Layer):
""" """
FIXME: add more comments and usage FIXME: add more comments and usage
Args: Args:
train_dataset (Dataset): An instance of paddle.fluid.io.Dataset. train_data (Dataset|DataLoader): An iterable data loader is used for
eval_dataset (Dataset): An instance of paddle.fluid.io.Dataset. train. An instance of paddle.fluid.io.Dataset or
train_loader (DataLoader): An iterable data loader is used for train. paddle.fluid.io.Dataloader is recomended.
eval_loader (DataLoader): An iterable data loader is used for eval_data (Dataset|DataLoader): An iterable data loader is used for
evaluation at the end of epoch. If None, will not do evaluation. evaluation at the end of epoch. If None, will not do evaluation.
An instance of paddle.fluid.io.Dataset or paddle.fluid.io.Dataloader
is recomended.
batch_size (int): Integer number. The batch size of train_data and eval_data.
When train_data and eval_data are both the instance of Dataloader, this
parameter will be ignored.
epochs (int): Integer number. The number of epochs to train the model. epochs (int): Integer number. The number of epochs to train the model.
eval_freq (int): The frequency, in number of epochs, an evalutation eval_freq (int): The frequency, in number of epochs, an evalutation
is performed. is performed.
...@@ -913,47 +916,57 @@ class Model(fluid.dygraph.Layer): ...@@ -913,47 +916,57 @@ class Model(fluid.dygraph.Layer):
save_freq (int): The frequency, in number of epochs, to save checkpoint. save_freq (int): The frequency, in number of epochs, to save checkpoint.
verbose (int): The verbosity mode, should be 0, 1, or 2. verbose (int): The verbosity mode, should be 0, 1, or 2.
0 = silent, 1 = progress bar, 2 = one line per epoch. 0 = silent, 1 = progress bar, 2 = one line per epoch.
drop_last (bool): whether drop the last incomplete batch of train_data
when dataset size is not divisible by the batch size. When train_data
is an instance of Dataloader, this parameter will be ignored.
shuffle (bool): whther to shuffle train_data. When train_data is an instance
of Dataloader, this parameter will be ignored.
num_workers (int): the number of subprocess to load data, 0 for no subprocess
used and loading data in main process. When train_data and eval_data are
both the instance of Dataloader, this parameter will be ignored.
callbacks (Callback|None): A list of `Callback` instances to apply callbacks (Callback|None): A list of `Callback` instances to apply
during training. If None, `ProgBarLogger` and `ModelCheckpoint` during training. If None, `ProgBarLogger` and `ModelCheckpoint`
are automatically inserted. are automatically inserted.
""" """
assert train_dataset is not None or train_loader is not None, \ assert train_data is not None, \
"train_dataset or train_loader must be given" "train_data must be given!"
assert (train_loader is not None and train_dataset is None) or \
(train_loader is None and train_dataset is not None), \
"train_dataset should not be set when train_loader is given"
if fluid.in_dygraph_mode(): if fluid.in_dygraph_mode():
feed_list = None feed_list = None
else: else:
feed_list = [x.forward() for x in self._inputs + self._labels] feed_list = [x.forward() for x in self._inputs + self._labels]
if train_loader is None: if isinstance(train_data, Dataset):
train_sampler = DistributedBatchSampler( train_sampler = DistributedBatchSampler(
train_dataset, train_data,
batch_size=batch_size, batch_size=batch_size,
shuffle=shuffle, shuffle=shuffle,
drop_last=drop_last) drop_last=drop_last)
train_loader = DataLoader( train_loader = DataLoader(
train_dataset, train_data,
batch_sampler=train_sampler, batch_sampler=train_sampler,
places=self._place, places=self._place,
feed_list=feed_list, feed_list=feed_list,
num_workers=num_workers, num_workers=num_workers,
return_list=True) return_list=True)
else:
train_loader = train_data
if eval_loader is None and eval_dataset is not None: if eval_data is not None and isinstance(eval_data, Dataset):
eval_sampler = DistributedBatchSampler( eval_sampler = DistributedBatchSampler(
eval_dataset, batch_size=batch_size) eval_data, batch_size=batch_size)
eval_loader = DataLoader( eval_loader = DataLoader(
eval_dataset, eval_data,
batch_sampler=eval_sampler, batch_sampler=eval_sampler,
places=self._place, places=self._place,
feed_list=feed_list, feed_list=feed_list,
num_workers=num_workers, num_workers=num_workers,
return_list=True) return_list=True)
elif eval_data is not None:
eval_loader = eval_data
else:
eval_loader = None
do_eval = eval_loader is not None do_eval = eval_loader is not None
self._test_dataloader = eval_loader self._test_dataloader = eval_loader
...@@ -1005,7 +1018,7 @@ class Model(fluid.dygraph.Layer): ...@@ -1005,7 +1018,7 @@ class Model(fluid.dygraph.Layer):
logs['step'] = step logs['step'] = step
if mode == 'train' or self._adapter._merge_count.get(mode + '_batch', 0) <= 0: if mode == 'train' or self._adapter._merge_count.get(mode + '_batch', 0) <= 0:
logs['batch_size'] = batch_size * Env().nranks logs['batch_size'] = batch_size * ParallelEnv().nranks
else: else:
logs['batch_size'] = self._adapter._merge_count[mode + logs['batch_size'] = self._adapter._merge_count[mode +
'_batch'] '_batch']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册