提交 c3ba953b 编写于 作者: L LielinJiang

merge load

...@@ -211,7 +211,7 @@ class ProgBarLogger(Callback): ...@@ -211,7 +211,7 @@ class ProgBarLogger(Callback):
logs = logs or {} logs = logs or {}
self.train_step = step self.train_step = step
if self.train_step % self.log_freq == 0 and self.verbose: if self.train_step % self.log_freq == 0 and self.verbose and get_local_rank() == 0:
# if steps is not None, last step will update in on_epoch_end # if steps is not None, last step will update in on_epoch_end
if self.steps and self.train_step < self.steps: if self.steps and self.train_step < self.steps:
self._updates(logs, 'train') self._updates(logs, 'train')
......
...@@ -32,7 +32,8 @@ from paddle.fluid.framework import Variable ...@@ -32,7 +32,8 @@ from paddle.fluid.framework import Variable
from paddle.fluid.executor import global_scope from paddle.fluid.executor import global_scope
from paddle.fluid.dygraph.parallel import Env, DataParallel, ParallelStrategy from paddle.fluid.dygraph.parallel import Env, DataParallel, ParallelStrategy
from paddle.fluid.layers.collective import _c_allreduce, _c_allgather, _c_broadcast, _c_sync_comm_stream, _c_sync_calc_stream from paddle.fluid.layers.collective import _c_allreduce, _c_allgather, _c_broadcast, \
_c_sync_comm_stream, _c_sync_calc_stream
from paddle.fluid.io import BatchSampler, DataLoader from paddle.fluid.io import BatchSampler, DataLoader
...@@ -52,7 +53,7 @@ class DistributedBatchSampler(BatchSampler): ...@@ -52,7 +53,7 @@ class DistributedBatchSampler(BatchSampler):
`__len__` for BatchSampler to get sample `__len__` for BatchSampler to get sample
number of data source. number of data source.
batch_size(int): sample indice number in a mini-batch indices. batch_size(int): sample indice number in a mini-batch indices.
shuffle(bool): whther to shuffle indices order before genrate shuffle(bool): whther to shuffle indices order before genrating
batch indices. Default False. batch indices. Default False.
drop_last(bool): whether drop the last incomplete batch dataset size drop_last(bool): whether drop the last incomplete batch dataset size
is not divisible by the batch size. Default False is not divisible by the batch size. Default False
...@@ -88,7 +89,8 @@ class DistributedBatchSampler(BatchSampler): ...@@ -88,7 +89,8 @@ class DistributedBatchSampler(BatchSampler):
np.random.RandomState(self.epoch).shuffle(indices) np.random.RandomState(self.epoch).shuffle(indices)
self.epoch += 1 self.epoch += 1
# subsample # subsample
indices = indices[self.local_rank * self.num_samples: (self.local_rank + 1) * self.num_samples] indices = indices[self.local_rank * self.num_samples:
(self.local_rank + 1) * self.num_samples]
assert len(indices) == self.num_samples assert len(indices) == self.num_samples
_sample_iter = iter(indices) _sample_iter = iter(indices)
...@@ -187,7 +189,7 @@ def wait_server_ready(endpoints): ...@@ -187,7 +189,7 @@ def wait_server_ready(endpoints):
break break
def initCommunicator(program, rank, nranks, wait_port, def init_communicator(program, rank, nranks, wait_port,
current_endpoint, endpoints): current_endpoint, endpoints):
if nranks < 2: if nranks < 2:
return return
...@@ -234,12 +236,11 @@ def prepare_context(place): ...@@ -234,12 +236,11 @@ def prepare_context(place):
if isinstance(place, core.CUDAPlace): if isinstance(place, core.CUDAPlace):
communicator_prog = framework.Program() communicator_prog = framework.Program()
initCommunicator(communicator_prog, strategy.local_rank, strategy.nranks, True, init_communicator(communicator_prog, strategy.local_rank, strategy.nranks, True,
strategy.current_endpoint, strategy.trainer_endpoints) strategy.current_endpoint, strategy.trainer_endpoints)
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(communicator_prog) exe.run(communicator_prog)
else: else:
# TODO(Yancey1989): add Gloo Parallel Context to support CPU parallel computation
assert ("Only support CUDAPlace for now.") assert ("Only support CUDAPlace for now.")
return strategy return strategy
...@@ -273,10 +274,6 @@ class DistributedDataParallel(DataParallel): ...@@ -273,10 +274,6 @@ class DistributedDataParallel(DataParallel):
assert g_var not in grad_var_set assert g_var not in grad_var_set
grad_var_set.add(g_var) grad_var_set.add(g_var)
# FIXME(zcd): the type of the var should be LoDTensor, i.e
# the gradients should be dense, otherwise, the following
# logic should be updated.
# 128 MB as a group
mega_bytes = 128 * 1024 * 1024 mega_bytes = 128 * 1024 * 1024
group_idx = 0 group_idx = 0
memory_counter = 0 memory_counter = 0
......
...@@ -134,17 +134,6 @@ def main(): ...@@ -134,17 +134,6 @@ def main():
if not os.path.exists('mnist_checkpoints'): if not os.path.exists('mnist_checkpoints'):
os.mkdir('mnist_checkpoints') os.mkdir('mnist_checkpoints')
# train_loader = fluid.io.xmap_readers(
# lambda b: [np.array([x[0] for x in b]).reshape(-1, 1, 28, 28),
# np.array([x[1] for x in b]).reshape(-1, 1)],
# paddle.batch(fluid.io.shuffle(paddle.dataset.mnist.train(), 6e4),
# batch_size=FLAGS.batch_size, drop_last=True), 1, 1)
# val_loader = fluid.io.xmap_readers(
# lambda b: [np.array([x[0] for x in b]).reshape(-1, 1, 28, 28),
# np.array([x[1] for x in b]).reshape(-1, 1)],
# paddle.batch(paddle.dataset.mnist.test(),
# batch_size=FLAGS.batch_size, drop_last=True), 1, 1)
with guard: with guard:
train_dataset = CustromMnistDataset(mode='train') train_dataset = CustromMnistDataset(mode='train')
......
...@@ -18,6 +18,8 @@ import inspect ...@@ -18,6 +18,8 @@ import inspect
import os import os
import pickle import pickle
import numpy as np import numpy as np
import six
import warnings
from collections import Iterable from collections import Iterable
from collections import OrderedDict from collections import OrderedDict
...@@ -167,7 +169,7 @@ class StaticGraphAdapter(object): ...@@ -167,7 +169,7 @@ class StaticGraphAdapter(object):
return self._run(inputs, None) return self._run(inputs, None)
def parameters(self, *args, **kwargs): def parameters(self, *args, **kwargs):
return None return super(Model, self.model).parameters(*args, **kwargs)
def save(self, path): def save(self, path):
def _save(state, path): def _save(state, path):
...@@ -201,39 +203,23 @@ class StaticGraphAdapter(object): ...@@ -201,39 +203,23 @@ class StaticGraphAdapter(object):
_save(optim, optim_path) _save(optim, optim_path)
def load(self, path): def load(self, param_state_pairs, optim_state):
def _load(path):
if not os.path.exists(path):
return
with open(path, 'rb') as f:
return pickle.load(f)
param_path = path + ".pdparams"
param_state = _load(param_path)
assert param_state, "failed to load parameters, please check path"
if self._executor is None: if self._executor is None:
executor = fluid.Executor(fluid.CPUPlace())._default_executor executor = fluid.Executor(fluid.CPUPlace())._default_executor
else: else:
executor = self._executor._default_executor executor = self._executor._default_executor
# restore parameter states
fluid.core._create_loaded_parameter( fluid.core._create_loaded_parameter(
list(self.model.state_dict().values()), global_scope(), executor) [param for param, state in param_state_pairs],
global_scope(), executor)
for key, var in self.model.state_dict().items(): for param, state in param_state_pairs:
assert key in param_state, \ self._set_var(param, state)
"parameter [{}] is not found in model file [{}]".format(
key, param_path)
self._set_var(var, param_state[key])
# restore optimizer states
# FIXME what if a different optimizer is used? # FIXME what if a different optimizer is used?
if not self.model._optimizer: if not self.model._optimizer or not optim_state:
return
optim_path = path + ".pdopt"
optim_state = _load(optim_path)
if optim_state is None:
return return
self._load_optimizer(optim_state, executor) self._load_optimizer(optim_state, executor)
def _load_optimizer(self, state, executor): def _load_optimizer(self, state, executor):
...@@ -361,7 +347,8 @@ class StaticGraphAdapter(object): ...@@ -361,7 +347,8 @@ class StaticGraphAdapter(object):
metrics = [] metrics = []
for metric, state in zip(self.model._metrics, metric_states): for metric, state in zip(self.model._metrics, metric_states):
# cut off padding size # cut off padding size
if self.mode != 'train' and self.model._test_dataloader is not None and self._nranks > 1: if self.mode != 'train' and self.model._test_dataloader is not None \
and self._nranks > 1:
total_size = len(self.model._test_dataloader.dataset) total_size = len(self.model._test_dataloader.dataset)
samples = state[0].shape[0] samples = state[0].shape[0]
current_count = self._merge_count.get(self.mode, 0) current_count = self._merge_count.get(self.mode, 0)
...@@ -425,7 +412,8 @@ class StaticGraphAdapter(object): ...@@ -425,7 +412,8 @@ class StaticGraphAdapter(object):
dist_strategy = DistributedStrategy() dist_strategy = DistributedStrategy()
dist_strategy.mode = "collective" dist_strategy.mode = "collective"
dist_strategy.collective_mode = "grad_allreduce" dist_strategy.collective_mode = "grad_allreduce"
self.model._optimizer = fleet.distributed_optimizer(self.model._optimizer, strategy=dist_strategy) self.model._optimizer = fleet.distributed_optimizer(self.model._optimizer,
strategy=dist_strategy)
self.model._optimizer.minimize(self._loss_endpoint) self.model._optimizer.minimize(self._loss_endpoint)
if self._nranks > 1 and mode != 'train' and self.model._test_dataloader is not None: if self._nranks > 1 and mode != 'train' and self.model._test_dataloader is not None:
...@@ -477,7 +465,8 @@ class StaticGraphAdapter(object): ...@@ -477,7 +465,8 @@ class StaticGraphAdapter(object):
uninitialized = [] uninitialized = []
for var_py in self._startup_prog.list_vars(): for var_py in self._startup_prog.list_vars():
var = fluid.global_scope().find_var(var_py.name) var = fluid.global_scope().find_var(var_py.name)
if not var_py.name.startswith('nccl_id') and var and var.get_tensor()._is_initialized(): if not var_py.name.startswith('nccl_id') and var and \
var.get_tensor()._is_initialized():
continue continue
uninitialized.append(var_py) uninitialized.append(var_py)
...@@ -549,9 +538,7 @@ class DynamicGraphAdapter(object): ...@@ -549,9 +538,7 @@ class DynamicGraphAdapter(object):
return ([to_numpy(l) for l in losses], metrics) \ return ([to_numpy(l) for l in losses], metrics) \
if len(metrics) > 0 else [to_numpy(l) for l in losses] if len(metrics) > 0 else [to_numpy(l) for l in losses]
def eval(self, inputs, labels, device='CPU', device_ids=None): def eval(self, inputs, labels=None):
assert self.model._loss_function, \
"model not ready, please call `model.prepare()` first"
super(Model, self.model).eval() super(Model, self.model).eval()
self.mode = 'eval' self.mode = 'eval'
inputs = to_list(inputs) inputs = to_list(inputs)
...@@ -609,10 +596,13 @@ class DynamicGraphAdapter(object): ...@@ -609,10 +596,13 @@ class DynamicGraphAdapter(object):
optim = self.model._optimizer.state_dict() optim = self.model._optimizer.state_dict()
fluid.save_dygraph(optim, path) fluid.save_dygraph(optim, path)
def load(self, path): def load(self, param_state_pairs, optim_state):
params, optim = fluid.load_dygraph(path) # restore parameter states
self.model.set_dict(params) for param, state in param_state_pairs:
if self.model._optimizer is None or optim is None: param.set_value(state)
# resotre optimizer states
if not self.model._optimizer or not optim_state:
return return
# If optimizer performs set_dict when state vars haven't been created, # If optimizer performs set_dict when state vars haven't been created,
...@@ -621,13 +611,13 @@ class DynamicGraphAdapter(object): ...@@ -621,13 +611,13 @@ class DynamicGraphAdapter(object):
# To contrive this when loading from static-graph saved states, extend # To contrive this when loading from static-graph saved states, extend
# state dict to include keys named accoring to dygraph naming rules. # state dict to include keys named accoring to dygraph naming rules.
# TODO: if len(self.model._optimizer._accumulators) > 0 # TODO: if len(self.model._optimizer._accumulators) > 0
converted_state = dict(optim) converted_state = dict(optim_state)
opt_unq_name = self.model._optimizer._name opt_unq_name = self.model._optimizer._name
opt_cls_name = self.model._optimizer.__class__.__name__ opt_cls_name = self.model._optimizer.__class__.__name__
opt_name = opt_unq_name[:opt_unq_name.rfind("_")] # remove suffix idx opt_name = opt_unq_name[:opt_unq_name.rfind("_")] # remove suffix idx
param_names = [param.name for param in self.model.parameters()] param_names = [param.name for param in self.model.parameters()]
for var_name, state_var in sorted( for var_name, state_var in sorted(
optim.items(), key=lambda x: len(x[0]), reverse=True): optim_state.items(), key=lambda x: len(x[0]), reverse=True):
if var_name in ["@LR_DECAY_COUNTER@", "global_step"]: if var_name in ["@LR_DECAY_COUNTER@", "global_step"]:
# NOTE: dygraph saved global_step is 1 larger than that in # NOTE: dygraph saved global_step is 1 larger than that in
# static-graph, since the time of global_step to increase is # static-graph, since the time of global_step to increase is
...@@ -697,8 +687,71 @@ class Model(fluid.dygraph.Layer): ...@@ -697,8 +687,71 @@ class Model(fluid.dygraph.Layer):
if distributed.get_local_rank() == 0: if distributed.get_local_rank() == 0:
return self._adapter.save(*args, **kwargs) return self._adapter.save(*args, **kwargs)
def load(self, *args, **kwargs): def load(self, path, skip_mismatch=False, reset_optimizer=False):
return self._adapter.load(*args, **kwargs) """
Load from files storing the model states and optimizer states. The file
for optimizer states is not necessary if no need to restore the optimizer.
NOTE: parameters are retrieved out from the file storing model states
accoring to their structured names.
For fine-tuning or transfer-learning models where some of the layers have
changed, keep parameters needed to restore have same structured names in
the pre-trained model and fine-tuning model.
Args:
path (str): The prefix of files storing the model states and
optimizer states. The files would be `path.pdparams` and
`path.pdopt` separately, and the latter is not necessary
when no need to restore.
skip_mismatch (bool): Whether to skip the loading of mismatch
parameter or raise an error when mismatch happens (not found
the parameter in file storing model states of or receives a
mismatch shape).
reset_optimizer (bool): If True, ignore the providing file storing
optimizer states and initialize optimizer states from scratch.
Otherwise, restore optimizer states from `path.pdopt` if
a optimizer has been set to the model. Default False.
"""
def _load_state_from_path(path):
if not os.path.exists(path):
return
with open(path, 'rb') as f:
return pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
def _check_match(key, param):
state = param_state.get(key, None)
if state is None:
raise ValueError(
"{} is not found in the providing file.".format(key))
if list(state.shape) != list(param.shape):
raise ValueError(
"{} receives a shape {}, but the expected shape is {}.".
format(key, list(state.shape), list(param.shape)))
return param, state
param_state = _load_state_from_path(path + ".pdparams")
assert param_state, "Failed to load parameters, please check path."
matched_param_state = []
for key, param in self.state_dict().items():
try:
match_res = _check_match(key, param)
except ValueError as err:
if skip_mismatch:
warnings.warn(
("Skip loading for {}. ".format(key) + err.message))
# reset optimizer when mismatch happens
reset_optimizer = True
else:
raise err
matched_param_state.append(match_res)
optim_state = None if reset_optimizer else _load_state_from_path(
path + ".pdopt")
return self._adapter.load(matched_param_state, optim_state)
def parameters(self, *args, **kwargs): def parameters(self, *args, **kwargs):
return self._adapter.parameters(*args, **kwargs) return self._adapter.parameters(*args, **kwargs)
......
...@@ -2,7 +2,6 @@ import sys ...@@ -2,7 +2,6 @@ import sys
import time import time
import numpy as np import numpy as np
from distributed import get_local_rank
class ProgressBar(object): class ProgressBar(object):
"""progress bar """ """progress bar """
...@@ -60,106 +59,105 @@ class ProgressBar(object): ...@@ -60,106 +59,105 @@ class ProgressBar(object):
else: else:
fps = ' - %.0fus/%s' % (time_per_unit * 1e6, 'step') fps = ' - %.0fus/%s' % (time_per_unit * 1e6, 'step')
if get_local_rank() == 0: info = ''
info = '' if self._verbose == 1:
if self._verbose == 1: prev_total_width = self._total_width
prev_total_width = self._total_width
if self._dynamic_display: if self._dynamic_display:
sys.stdout.write('\b' * prev_total_width) sys.stdout.write('\b' * prev_total_width)
sys.stdout.write('\r') sys.stdout.write('\r')
else: else:
sys.stdout.write('\n') sys.stdout.write('\n')
if self._num is not None: if self._num is not None:
numdigits = int(np.log10(self._num)) + 1 numdigits = int(np.log10(self._num)) + 1
bar_chars = ('step %' + str(numdigits) + 'd/%d [') % ( bar_chars = ('step %' + str(numdigits) + 'd/%d [') % (
current_num, self._num) current_num, self._num)
prog = float(current_num) / self._num prog = float(current_num) / self._num
prog_width = int(self._width * prog) prog_width = int(self._width * prog)
if prog_width > 0: if prog_width > 0:
bar_chars += ('=' * (prog_width - 1)) bar_chars += ('=' * (prog_width - 1))
if current_num < self._num: if current_num < self._num:
bar_chars += '>' bar_chars += '>'
else: else:
bar_chars += '=' bar_chars += '='
bar_chars += ('.' * (self._width - prog_width)) bar_chars += ('.' * (self._width - prog_width))
bar_chars += ']' bar_chars += ']'
else: else:
bar_chars = 'step %3d' % current_num bar_chars = 'step %3d' % current_num
self._total_width = len(bar_chars) self._total_width = len(bar_chars)
sys.stdout.write(bar_chars) sys.stdout.write(bar_chars)
for k, val in values: for k, val in values:
info += ' - %s:' % k info += ' - %s:' % k
val = val if isinstance(val, list) else [val] val = val if isinstance(val, list) else [val]
for i, v in enumerate(val): for i, v in enumerate(val):
if isinstance(v, (float, np.float32, np.float64)): if isinstance(v, (float, np.float32, np.float64)):
if abs(v) > 1e-3: if abs(v) > 1e-3:
info += ' %.4f' % v info += ' %.4f' % v
else:
info += ' %.4e' % v
else: else:
info += ' %s' % v info += ' %.4e' % v
if self._num is not None and current_num < self._num:
eta = time_per_unit * (self._num - current_num)
if eta > 3600:
eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) //
60, eta % 60)
elif eta > 60:
eta_format = '%d:%02d' % (eta // 60, eta % 60)
else: else:
eta_format = '%ds' % eta info += ' %s' % v
info += ' - ETA: %s' % eta_format if self._num is not None and current_num < self._num:
eta = time_per_unit * (self._num - current_num)
info += fps if eta > 3600:
self._total_width += len(info) eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) //
if prev_total_width > self._total_width: 60, eta % 60)
info += (' ' * (prev_total_width - self._total_width)) elif eta > 60:
eta_format = '%d:%02d' % (eta // 60, eta % 60)
# newline for another epoch
if self._num is not None and current_num >= self._num:
info += '\n'
if self._num is None:
info += '\n'
sys.stdout.write(info)
sys.stdout.flush()
self._last_update = now
elif self._verbose == 2:
if self._num:
numdigits = int(np.log10(self._num)) + 1
count = ('step %' + str(numdigits) + 'd/%d') % (current_num,
self._num)
else: else:
count = 'step %3d' % current_num eta_format = '%ds' % eta
info = count + info
info += ' - ETA: %s' % eta_format
for k, val in values:
info += ' - %s:' % k info += fps
val = val if isinstance(val, list) else [val] self._total_width += len(info)
for v in val: if prev_total_width > self._total_width:
if isinstance(v, (float, np.float32, np.float64)): info += (' ' * (prev_total_width - self._total_width))
if abs(v) > 1e-3:
info += ' %.4f' % v
else:
info += ' %.4e' % v
elif isinstance(v, np.ndarray) and \
isinstance(v.size, 1) and \
isinstance(v.dtype, (np.float32, np.float64)):
if abs(v[0]) > 1e-3:
info += ' %.4f' % v[0]
else:
info += ' %.4e' % v[0]
else:
info += ' %s' % v
info += fps # newline for another epoch
if self._num is not None and current_num >= self._num:
info += '\n' info += '\n'
sys.stdout.write(info) if self._num is None:
sys.stdout.flush() info += '\n'
sys.stdout.write(info)
sys.stdout.flush()
self._last_update = now
elif self._verbose == 2:
if self._num:
numdigits = int(np.log10(self._num)) + 1
count = ('step %' + str(numdigits) + 'd/%d') % (current_num,
self._num)
else:
count = 'step %3d' % current_num
info = count + info
for k, val in values:
info += ' - %s:' % k
val = val if isinstance(val, list) else [val]
for v in val:
if isinstance(v, (float, np.float32, np.float64)):
if abs(v) > 1e-3:
info += ' %.4f' % v
else:
info += ' %.4e' % v
elif isinstance(v, np.ndarray) and \
isinstance(v.size, 1) and \
isinstance(v.dtype, (np.float32, np.float64)):
if abs(v[0]) > 1e-3:
info += ' %.4f' % v[0]
else:
info += ' %.4e' % v[0]
else:
info += ' %s' % v
info += fps
info += '\n'
sys.stdout.write(info)
sys.stdout.flush()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册