diff --git a/callbacks.py b/callbacks.py index 2e4d37d34e61f2f10a856bbead243f069eb6ed92..a953fcfea94c1de5d7905c401f8ef3da87da84df 100644 --- a/callbacks.py +++ b/callbacks.py @@ -211,7 +211,7 @@ class ProgBarLogger(Callback): logs = logs or {} self.train_step = step - if self.train_step % self.log_freq == 0 and self.verbose: + if self.train_step % self.log_freq == 0 and self.verbose and get_local_rank() == 0: # if steps is not None, last step will update in on_epoch_end if self.steps and self.train_step < self.steps: self._updates(logs, 'train') diff --git a/distributed.py b/distributed.py index c9b60e3d0cb3080f4fe691b6d88cba124134e8e7..e8a424584e4ce45d5dbe448dcb2f03eb2eaaa74e 100644 --- a/distributed.py +++ b/distributed.py @@ -32,7 +32,8 @@ from paddle.fluid.framework import Variable from paddle.fluid.executor import global_scope from paddle.fluid.dygraph.parallel import Env, DataParallel, ParallelStrategy -from paddle.fluid.layers.collective import _c_allreduce, _c_allgather, _c_broadcast, _c_sync_comm_stream, _c_sync_calc_stream +from paddle.fluid.layers.collective import _c_allreduce, _c_allgather, _c_broadcast, \ + _c_sync_comm_stream, _c_sync_calc_stream from paddle.fluid.io import BatchSampler, DataLoader @@ -52,7 +53,7 @@ class DistributedBatchSampler(BatchSampler): `__len__` for BatchSampler to get sample number of data source. batch_size(int): sample indice number in a mini-batch indices. - shuffle(bool): whther to shuffle indices order before genrate + shuffle(bool): whther to shuffle indices order before genrating batch indices. Default False. drop_last(bool): whether drop the last incomplete batch dataset size is not divisible by the batch size. Default False @@ -88,7 +89,8 @@ class DistributedBatchSampler(BatchSampler): np.random.RandomState(self.epoch).shuffle(indices) self.epoch += 1 # subsample - indices = indices[self.local_rank * self.num_samples: (self.local_rank + 1) * self.num_samples] + indices = indices[self.local_rank * self.num_samples: + (self.local_rank + 1) * self.num_samples] assert len(indices) == self.num_samples _sample_iter = iter(indices) @@ -187,7 +189,7 @@ def wait_server_ready(endpoints): break -def initCommunicator(program, rank, nranks, wait_port, +def init_communicator(program, rank, nranks, wait_port, current_endpoint, endpoints): if nranks < 2: return @@ -234,12 +236,11 @@ def prepare_context(place): if isinstance(place, core.CUDAPlace): communicator_prog = framework.Program() - initCommunicator(communicator_prog, strategy.local_rank, strategy.nranks, True, + init_communicator(communicator_prog, strategy.local_rank, strategy.nranks, True, strategy.current_endpoint, strategy.trainer_endpoints) exe = fluid.Executor(place) exe.run(communicator_prog) else: - # TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation assert ("Only support CUDAPlace for now.") return strategy @@ -273,10 +274,6 @@ class DistributedDataParallel(DataParallel): assert g_var not in grad_var_set grad_var_set.add(g_var) - # FIXME(zcd): the type of the var should be LoDTensor, i.e - # the gradients should be dense, otherwise, the following - # logic should be updated. - # 128 MB as a group mega_bytes = 128 * 1024 * 1024 group_idx = 0 memory_counter = 0 diff --git a/mnist.py b/mnist.py index 00d12990fd7cd636cd4a2183a3e7ece54a641aff..764502bb71b46b341c06e393ed57e3a2b17a9cc5 100644 --- a/mnist.py +++ b/mnist.py @@ -134,17 +134,6 @@ def main(): if not os.path.exists('mnist_checkpoints'): os.mkdir('mnist_checkpoints') - # train_loader = fluid.io.xmap_readers( - # lambda b: [np.array([x[0] for x in b]).reshape(-1, 1, 28, 28), - # np.array([x[1] for x in b]).reshape(-1, 1)], - # paddle.batch(fluid.io.shuffle(paddle.dataset.mnist.train(), 6e4), - # batch_size=FLAGS.batch_size, drop_last=True), 1, 1) - # val_loader = fluid.io.xmap_readers( - # lambda b: [np.array([x[0] for x in b]).reshape(-1, 1, 28, 28), - # np.array([x[1] for x in b]).reshape(-1, 1)], - # paddle.batch(paddle.dataset.mnist.test(), - # batch_size=FLAGS.batch_size, drop_last=True), 1, 1) - with guard: train_dataset = CustromMnistDataset(mode='train') diff --git a/model.py b/model.py index f3ef27e4fbd0022ae8c6800a51ed0d2f298c21e0..5460241c0624025ae2c88aaa60b31323530b39c9 100644 --- a/model.py +++ b/model.py @@ -18,6 +18,8 @@ import inspect import os import pickle import numpy as np +import six +import warnings from collections import Iterable from collections import OrderedDict @@ -167,7 +169,7 @@ class StaticGraphAdapter(object): return self._run(inputs, None) def parameters(self, *args, **kwargs): - return None + return super(Model, self.model).parameters(*args, **kwargs) def save(self, path): def _save(state, path): @@ -201,39 +203,23 @@ class StaticGraphAdapter(object): _save(optim, optim_path) - def load(self, path): - def _load(path): - if not os.path.exists(path): - return - with open(path, 'rb') as f: - return pickle.load(f) - - param_path = path + ".pdparams" - param_state = _load(param_path) - assert param_state, "failed to load parameters, please check path" - + def load(self, param_state_pairs, optim_state): if self._executor is None: executor = fluid.Executor(fluid.CPUPlace())._default_executor else: executor = self._executor._default_executor + # restore parameter states fluid.core._create_loaded_parameter( - list(self.model.state_dict().values()), global_scope(), executor) - - for key, var in self.model.state_dict().items(): - assert key in param_state, \ - "parameter [{}] is not found in model file [{}]".format( - key, param_path) - self._set_var(var, param_state[key]) + [param for param, state in param_state_pairs], + global_scope(), executor) + for param, state in param_state_pairs: + self._set_var(param, state) + # restore optimizer states # FIXME what if a different optimizer is used? - if not self.model._optimizer: - return - optim_path = path + ".pdopt" - optim_state = _load(optim_path) - if optim_state is None: + if not self.model._optimizer or not optim_state: return - self._load_optimizer(optim_state, executor) def _load_optimizer(self, state, executor): @@ -361,7 +347,8 @@ class StaticGraphAdapter(object): metrics = [] for metric, state in zip(self.model._metrics, metric_states): # cut off padding size - if self.mode != 'train' and self.model._test_dataloader is not None and self._nranks > 1: + if self.mode != 'train' and self.model._test_dataloader is not None \ + and self._nranks > 1: total_size = len(self.model._test_dataloader.dataset) samples = state[0].shape[0] current_count = self._merge_count.get(self.mode, 0) @@ -425,7 +412,8 @@ class StaticGraphAdapter(object): dist_strategy = DistributedStrategy() dist_strategy.mode = "collective" dist_strategy.collective_mode = "grad_allreduce" - self.model._optimizer = fleet.distributed_optimizer(self.model._optimizer, strategy=dist_strategy) + self.model._optimizer = fleet.distributed_optimizer(self.model._optimizer, + strategy=dist_strategy) self.model._optimizer.minimize(self._loss_endpoint) if self._nranks > 1 and mode != 'train' and self.model._test_dataloader is not None: @@ -477,7 +465,8 @@ class StaticGraphAdapter(object): uninitialized = [] for var_py in self._startup_prog.list_vars(): var = fluid.global_scope().find_var(var_py.name) - if not var_py.name.startswith('nccl_id') and var and var.get_tensor()._is_initialized(): + if not var_py.name.startswith('nccl_id') and var and \ + var.get_tensor()._is_initialized(): continue uninitialized.append(var_py) @@ -549,9 +538,7 @@ class DynamicGraphAdapter(object): return ([to_numpy(l) for l in losses], metrics) \ if len(metrics) > 0 else [to_numpy(l) for l in losses] - def eval(self, inputs, labels, device='CPU', device_ids=None): - assert self.model._loss_function, \ - "model not ready, please call `model.prepare()` first" + def eval(self, inputs, labels=None): super(Model, self.model).eval() self.mode = 'eval' inputs = to_list(inputs) @@ -609,10 +596,13 @@ class DynamicGraphAdapter(object): optim = self.model._optimizer.state_dict() fluid.save_dygraph(optim, path) - def load(self, path): - params, optim = fluid.load_dygraph(path) - self.model.set_dict(params) - if self.model._optimizer is None or optim is None: + def load(self, param_state_pairs, optim_state): + # restore parameter states + for param, state in param_state_pairs: + param.set_value(state) + + # resotre optimizer states + if not self.model._optimizer or not optim_state: return # If optimizer performs set_dict when state vars haven't been created, @@ -621,13 +611,13 @@ class DynamicGraphAdapter(object): # To contrive this when loading from static-graph saved states, extend # state dict to include keys named accoring to dygraph naming rules. # TODO: if len(self.model._optimizer._accumulators) > 0 - converted_state = dict(optim) + converted_state = dict(optim_state) opt_unq_name = self.model._optimizer._name opt_cls_name = self.model._optimizer.__class__.__name__ opt_name = opt_unq_name[:opt_unq_name.rfind("_")] # remove suffix idx param_names = [param.name for param in self.model.parameters()] for var_name, state_var in sorted( - optim.items(), key=lambda x: len(x[0]), reverse=True): + optim_state.items(), key=lambda x: len(x[0]), reverse=True): if var_name in ["@LR_DECAY_COUNTER@", "global_step"]: # NOTE: dygraph saved global_step is 1 larger than that in # static-graph, since the time of global_step to increase is @@ -697,8 +687,71 @@ class Model(fluid.dygraph.Layer): if distributed.get_local_rank() == 0: return self._adapter.save(*args, **kwargs) - def load(self, *args, **kwargs): - return self._adapter.load(*args, **kwargs) + def load(self, path, skip_mismatch=False, reset_optimizer=False): + """ + Load from files storing the model states and optimizer states. The file + for optimizer states is not necessary if no need to restore the optimizer. + + NOTE: parameters are retrieved out from the file storing model states + accoring to their structured names. + + For fine-tuning or transfer-learning models where some of the layers have + changed, keep parameters needed to restore have same structured names in + the pre-trained model and fine-tuning model. + + Args: + path (str): The prefix of files storing the model states and + optimizer states. The files would be `path.pdparams` and + `path.pdopt` separately, and the latter is not necessary + when no need to restore. + skip_mismatch (bool): Whether to skip the loading of mismatch + parameter or raise an error when mismatch happens (not found + the parameter in file storing model states of or receives a + mismatch shape). + reset_optimizer (bool): If True, ignore the providing file storing + optimizer states and initialize optimizer states from scratch. + Otherwise, restore optimizer states from `path.pdopt` if + a optimizer has been set to the model. Default False. + """ + + def _load_state_from_path(path): + if not os.path.exists(path): + return + with open(path, 'rb') as f: + return pickle.load(f) if six.PY2 else pickle.load( + f, encoding='latin1') + + def _check_match(key, param): + state = param_state.get(key, None) + if state is None: + raise ValueError( + "{} is not found in the providing file.".format(key)) + if list(state.shape) != list(param.shape): + raise ValueError( + "{} receives a shape {}, but the expected shape is {}.". + format(key, list(state.shape), list(param.shape))) + return param, state + + param_state = _load_state_from_path(path + ".pdparams") + assert param_state, "Failed to load parameters, please check path." + + matched_param_state = [] + for key, param in self.state_dict().items(): + try: + match_res = _check_match(key, param) + except ValueError as err: + if skip_mismatch: + warnings.warn( + ("Skip loading for {}. ".format(key) + err.message)) + # reset optimizer when mismatch happens + reset_optimizer = True + else: + raise err + matched_param_state.append(match_res) + + optim_state = None if reset_optimizer else _load_state_from_path( + path + ".pdopt") + return self._adapter.load(matched_param_state, optim_state) def parameters(self, *args, **kwargs): return self._adapter.parameters(*args, **kwargs) diff --git a/progressbar.py b/progressbar.py index 1f07424df3242ab9e44841ffb9f962aa817ba18b..1aa301229f7da61a4c01083626327db5dc32586c 100644 --- a/progressbar.py +++ b/progressbar.py @@ -2,7 +2,6 @@ import sys import time import numpy as np -from distributed import get_local_rank class ProgressBar(object): """progress bar """ @@ -60,106 +59,105 @@ class ProgressBar(object): else: fps = ' - %.0fus/%s' % (time_per_unit * 1e6, 'step') - if get_local_rank() == 0: - info = '' - if self._verbose == 1: - prev_total_width = self._total_width + info = '' + if self._verbose == 1: + prev_total_width = self._total_width - if self._dynamic_display: - sys.stdout.write('\b' * prev_total_width) - sys.stdout.write('\r') - else: - sys.stdout.write('\n') + if self._dynamic_display: + sys.stdout.write('\b' * prev_total_width) + sys.stdout.write('\r') + else: + sys.stdout.write('\n') - if self._num is not None: - numdigits = int(np.log10(self._num)) + 1 + if self._num is not None: + numdigits = int(np.log10(self._num)) + 1 - bar_chars = ('step %' + str(numdigits) + 'd/%d [') % ( - current_num, self._num) - prog = float(current_num) / self._num - prog_width = int(self._width * prog) + bar_chars = ('step %' + str(numdigits) + 'd/%d [') % ( + current_num, self._num) + prog = float(current_num) / self._num + prog_width = int(self._width * prog) - if prog_width > 0: - bar_chars += ('=' * (prog_width - 1)) - if current_num < self._num: - bar_chars += '>' - else: - bar_chars += '=' - bar_chars += ('.' * (self._width - prog_width)) - bar_chars += ']' - else: - bar_chars = 'step %3d' % current_num - - self._total_width = len(bar_chars) - sys.stdout.write(bar_chars) - - for k, val in values: - info += ' - %s:' % k - val = val if isinstance(val, list) else [val] - for i, v in enumerate(val): - if isinstance(v, (float, np.float32, np.float64)): - if abs(v) > 1e-3: - info += ' %.4f' % v - else: - info += ' %.4e' % v + if prog_width > 0: + bar_chars += ('=' * (prog_width - 1)) + if current_num < self._num: + bar_chars += '>' + else: + bar_chars += '=' + bar_chars += ('.' * (self._width - prog_width)) + bar_chars += ']' + else: + bar_chars = 'step %3d' % current_num + + self._total_width = len(bar_chars) + sys.stdout.write(bar_chars) + + for k, val in values: + info += ' - %s:' % k + val = val if isinstance(val, list) else [val] + for i, v in enumerate(val): + if isinstance(v, (float, np.float32, np.float64)): + if abs(v) > 1e-3: + info += ' %.4f' % v else: - info += ' %s' % v - - if self._num is not None and current_num < self._num: - eta = time_per_unit * (self._num - current_num) - if eta > 3600: - eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) // - 60, eta % 60) - elif eta > 60: - eta_format = '%d:%02d' % (eta // 60, eta % 60) + info += ' %.4e' % v else: - eta_format = '%ds' % eta - - info += ' - ETA: %s' % eta_format - - info += fps - self._total_width += len(info) - if prev_total_width > self._total_width: - info += (' ' * (prev_total_width - self._total_width)) - - # newline for another epoch - if self._num is not None and current_num >= self._num: - info += '\n' - if self._num is None: - info += '\n' - - sys.stdout.write(info) - sys.stdout.flush() - self._last_update = now - elif self._verbose == 2: - if self._num: - numdigits = int(np.log10(self._num)) + 1 - count = ('step %' + str(numdigits) + 'd/%d') % (current_num, - self._num) + info += ' %s' % v + + if self._num is not None and current_num < self._num: + eta = time_per_unit * (self._num - current_num) + if eta > 3600: + eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) // + 60, eta % 60) + elif eta > 60: + eta_format = '%d:%02d' % (eta // 60, eta % 60) else: - count = 'step %3d' % current_num - info = count + info - - for k, val in values: - info += ' - %s:' % k - val = val if isinstance(val, list) else [val] - for v in val: - if isinstance(v, (float, np.float32, np.float64)): - if abs(v) > 1e-3: - info += ' %.4f' % v - else: - info += ' %.4e' % v - elif isinstance(v, np.ndarray) and \ - isinstance(v.size, 1) and \ - isinstance(v.dtype, (np.float32, np.float64)): - if abs(v[0]) > 1e-3: - info += ' %.4f' % v[0] - else: - info += ' %.4e' % v[0] - else: - info += ' %s' % v + eta_format = '%ds' % eta + + info += ' - ETA: %s' % eta_format + + info += fps + self._total_width += len(info) + if prev_total_width > self._total_width: + info += (' ' * (prev_total_width - self._total_width)) - info += fps + # newline for another epoch + if self._num is not None and current_num >= self._num: info += '\n' - sys.stdout.write(info) - sys.stdout.flush() + if self._num is None: + info += '\n' + + sys.stdout.write(info) + sys.stdout.flush() + self._last_update = now + elif self._verbose == 2: + if self._num: + numdigits = int(np.log10(self._num)) + 1 + count = ('step %' + str(numdigits) + 'd/%d') % (current_num, + self._num) + else: + count = 'step %3d' % current_num + info = count + info + + for k, val in values: + info += ' - %s:' % k + val = val if isinstance(val, list) else [val] + for v in val: + if isinstance(v, (float, np.float32, np.float64)): + if abs(v) > 1e-3: + info += ' %.4f' % v + else: + info += ' %.4e' % v + elif isinstance(v, np.ndarray) and \ + isinstance(v.size, 1) and \ + isinstance(v.dtype, (np.float32, np.float64)): + if abs(v[0]) > 1e-3: + info += ' %.4f' % v[0] + else: + info += ' %.4e' % v[0] + else: + info += ' %s' % v + + info += fps + info += '\n' + sys.stdout.write(info) + sys.stdout.flush()