提交 a501b8c8 编写于 作者: L LielinJiang

refine code

...@@ -18,6 +18,7 @@ import copy ...@@ -18,6 +18,7 @@ import copy
from progressbar import ProgressBar from progressbar import ProgressBar
from paddle.fluid.dygraph.parallel import Env from paddle.fluid.dygraph.parallel import Env
def config_callbacks(callbacks=None, def config_callbacks(callbacks=None,
model=None, model=None,
batch_size=None, batch_size=None,
...@@ -26,6 +27,7 @@ def config_callbacks(callbacks=None, ...@@ -26,6 +27,7 @@ def config_callbacks(callbacks=None,
log_freq=2, log_freq=2,
verbose=2, verbose=2,
save_freq=1, save_freq=1,
save_dir=None,
metrics=None, metrics=None,
mode='train'): mode='train'):
cbks = callbacks or [] cbks = callbacks or []
...@@ -34,7 +36,7 @@ def config_callbacks(callbacks=None, ...@@ -34,7 +36,7 @@ def config_callbacks(callbacks=None,
cbks = cbks + [ProgBarLogger(log_freq, verbose=verbose)] cbks = cbks + [ProgBarLogger(log_freq, verbose=verbose)]
if not any(isinstance(k, ModelCheckpoint) for k in cbks): if not any(isinstance(k, ModelCheckpoint) for k in cbks):
cbks = cbks + [ModelCheckpoint(save_freq)] cbks = cbks + [ModelCheckpoint(save_freq, save_dir)]
cbk_list = CallbackList(cbks) cbk_list = CallbackList(cbks)
cbk_list.set_model(model) cbk_list.set_model(model)
...@@ -209,7 +211,7 @@ class ProgBarLogger(Callback): ...@@ -209,7 +211,7 @@ class ProgBarLogger(Callback):
def on_train_batch_end(self, step, logs=None): def on_train_batch_end(self, step, logs=None):
logs = logs or {} logs = logs or {}
self.train_step = step self.train_step += 1
if self.train_step % self.log_freq == 0 and self.verbose and Env().local_rank == 0: if self.train_step % self.log_freq == 0 and self.verbose and Env().local_rank == 0:
# if steps is not None, last step will update in on_epoch_end # if steps is not None, last step will update in on_epoch_end
...@@ -247,21 +249,24 @@ class ProgBarLogger(Callback): ...@@ -247,21 +249,24 @@ class ProgBarLogger(Callback):
class ModelCheckpoint(Callback): class ModelCheckpoint(Callback):
def __init__(self, save_freq=1, save_file='output'): def __init__(self, save_freq=1, save_dir=None):
self.save_freq = save_freq self.save_freq = save_freq
self.save_file = save_file self.save_dir = save_dir
def on_epoch_begin(self, epoch=None, logs=None): def on_epoch_begin(self, epoch=None, logs=None):
self.epoch = epoch self.epoch = epoch
def _is_save(self):
return self.model and self.save_dir and Env().local_rank == 0
def on_epoch_end(self, epoch, logs=None): def on_epoch_end(self, epoch, logs=None):
if self.model and self.epoch % self.save_freq == 0 and Env().local_rank == 0: if self._is_save() and self.epoch % self.save_freq == 0:
path = '{}/{}'.format(self.save_file, epoch) path = '{}/{}'.format(self.save_dir, epoch)
print('save checkpoint at {}'.format(path)) print('save checkpoint at {}'.format(path))
self.model.save(path) self.model.save(path)
def on_train_end(self, logs=None): def on_train_end(self, logs=None):
if self.model and Env().local_rank == 0: if self._is_save():
path = '{}/final'.format(self.save_file) path = '{}/final'.format(self.save_dir)
print('save checkpoint at {}'.format(path)) print('save checkpoint at {}'.format(path))
self.model.save(path) self.model.save(path)
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
import sys import sys
import six
import time import time
import math import math
import socket import socket
...@@ -21,10 +22,13 @@ import numpy as np ...@@ -21,10 +22,13 @@ import numpy as np
from paddle import fluid from paddle import fluid
from paddle.fluid.layers import collective from paddle.fluid.layers import collective
from paddle.fluid.dygraph.parallel import Env from paddle.fluid.dygraph.parallel import Env, ParallelStrategy
from paddle.fluid.io import BatchSampler from paddle.fluid.io import BatchSampler
_parallel_context_initialized = False
class DistributedBatchSampler(BatchSampler): class DistributedBatchSampler(BatchSampler):
"""Sampler that restricts data loading to a subset of the dataset. """Sampler that restricts data loading to a subset of the dataset.
...@@ -100,3 +104,97 @@ class DistributedBatchSampler(BatchSampler): ...@@ -100,3 +104,97 @@ class DistributedBatchSampler(BatchSampler):
def _all_gather(x, nranks, ring_id=0, use_calc_stream=True): def _all_gather(x, nranks, ring_id=0, use_calc_stream=True):
return collective._c_allgather(x, nranks, ring_id=ring_id, use_calc_stream=use_calc_stream) return collective._c_allgather(x, nranks, ring_id=ring_id, use_calc_stream=use_calc_stream)
def wait_server_ready(endpoints):
assert not isinstance(endpoints, six.string_types)
while True:
all_ok = True
not_ready_endpoints = []
for ep in endpoints:
ip_port = ep.split(":")
with contextlib.closing(
socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex((ip_port[0], int(ip_port[1])))
if result != 0:
all_ok = False
not_ready_endpoints.append(ep)
if not all_ok:
time.sleep(3)
else:
break
def init_communicator(program, rank, nranks, wait_port,
current_endpoint, endpoints):
if nranks < 2:
return
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
if rank == 0 and wait_port:
wait_server_ready(other_endpoints)
block = program.global_block()
nccl_id_var = block.create_var(
name=fluid.unique_name.generate('nccl_id'),
persistable=True,
type=fluid.core.VarDesc.VarType.RAW)
block.append_op(
type='c_gen_nccl_id',
inputs={},
outputs={'Out': nccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints
})
block.append_op(
type='c_comm_init',
inputs={'X': nccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': 0,
})
def prepare_distributed_context(place=None):
if place is None:
place = fluid.CUDAPlace(Env().dev_id) if Env().nranks > 1 \
else fluid.CUDAPlace(0)
strategy = ParallelStrategy()
strategy.nranks = Env().nranks
strategy.local_rank = Env().local_rank
strategy.trainer_endpoints = Env().trainer_endpoints
strategy.current_endpoint = Env().current_endpoint
if strategy.nranks < 2:
return
global _parallel_context_initialized
if not _parallel_context_initialized and isinstance(place, fluid.CUDAPlace):
def _init_context():
communicator_prog = fluid.Program()
init_communicator(communicator_prog, strategy.local_rank, strategy.nranks,
True, strategy.current_endpoint, strategy.trainer_endpoints)
exe = fluid.Executor(place)
exe.run(communicator_prog)
if fluid.in_dygraph_mode():
fluid.disable_dygraph()
_init_context()
fluid.enable_dygraph(place)
else:
_init_context()
else:
assert ("Only support CUDAPlace for now.")
_parallel_context_initialized = True
return strategy
\ No newline at end of file
...@@ -116,15 +116,17 @@ def main(): ...@@ -116,15 +116,17 @@ def main():
model = MNIST() model = MNIST()
optim = Momentum( optim = Momentum(
learning_rate=FLAGS.lr, learning_rate=FLAGS.lr, momentum=.9, parameter_list=model.parameters())
momentum=.9,
parameter_list=model.parameters())
model.prepare(optim, CrossEntropy(), Accuracy(topk=(1, 2)), inputs, labels) model.prepare(optim, CrossEntropy(), Accuracy(topk=(1, 2)), inputs, labels)
if FLAGS.resume is not None: if FLAGS.resume is not None:
model.load(FLAGS.resume) model.load(FLAGS.resume)
model.fit(train_dataset, val_dataset, epochs=FLAGS.epoch, batch_size=FLAGS.batch_size) model.fit(train_dataset,
val_dataset,
epochs=FLAGS.epoch,
batch_size=FLAGS.batch_size,
save_dir='mnist_checkpoint')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -21,7 +21,7 @@ import numpy as np ...@@ -21,7 +21,7 @@ import numpy as np
import six import six
import warnings import warnings
from collections import Iterable, OrderedDict from collections import Iterable
from paddle import fluid from paddle import fluid
from paddle.fluid.framework import in_dygraph_mode, Variable from paddle.fluid.framework import in_dygraph_mode, Variable
from paddle.fluid.executor import global_scope from paddle.fluid.executor import global_scope
...@@ -32,14 +32,12 @@ from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy ...@@ -32,14 +32,12 @@ from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
from paddle.fluid.incubate.fleet.base import role_maker from paddle.fluid.incubate.fleet.base import role_maker
from paddle.fluid.io import DataLoader from paddle.fluid.io import DataLoader
from distributed import DistributedBatchSampler, _all_gather from distributed import DistributedBatchSampler, _all_gather, prepare_distributed_context, _parallel_context_initialized
from metrics import Metric from metrics import Metric
from callbacks import config_callbacks from callbacks import config_callbacks
__all__ = ['Model', 'Loss', 'CrossEntropy', 'Input'] __all__ = ['Model', 'Loss', 'CrossEntropy', 'Input']
_parallel_context_inited = False
def to_list(value): def to_list(value):
if value is None: if value is None:
...@@ -142,8 +140,12 @@ class StaticGraphAdapter(object): ...@@ -142,8 +140,12 @@ class StaticGraphAdapter(object):
self._progs = {} self._progs = {}
self._compiled_progs = {} self._compiled_progs = {}
self._merge_count = {'eval_total': 0, 'test_total': 0, self._merge_count = {
'eval_batch': 0, 'test_batch': 0} 'eval_total': 0,
'test_total': 0,
'eval_batch': 0,
'test_batch': 0
}
self._nranks = Env().nranks self._nranks = Env().nranks
self._local_rank = Env().local_rank self._local_rank = Env().local_rank
...@@ -251,7 +253,8 @@ class StaticGraphAdapter(object): ...@@ -251,7 +253,8 @@ class StaticGraphAdapter(object):
# When using static learning rate, static-graph would make it # When using static learning rate, static-graph would make it
# a persistable var named 'unique_name.generate("learning_rate")', # a persistable var named 'unique_name.generate("learning_rate")',
# However, dygraph wouldn't save it. # However, dygraph wouldn't save it.
if var.name not in state: continue if var.name not in state:
continue
else: else:
# moment and other accumulators # moment and other accumulators
if var.name not in converted_state: if var.name not in converted_state:
...@@ -357,9 +360,12 @@ class StaticGraphAdapter(object): ...@@ -357,9 +360,12 @@ class StaticGraphAdapter(object):
samples = state[0].shape[0] samples = state[0].shape[0]
current_count = self._merge_count.get(self.mode + '_total', 0) current_count = self._merge_count.get(self.mode + '_total', 0)
if current_count + samples >= total_size: if current_count + samples >= total_size:
state = [s[:total_size - current_count, ...] for s in state] state = [
s[:total_size - current_count, ...] for s in state
]
self._merge_count[self.mode + '_total'] = 0 self._merge_count[self.mode + '_total'] = 0
self._merge_count[self.mode + '_batch'] = total_size - current_count self._merge_count[self.mode +
'_batch'] = total_size - current_count
else: else:
self._merge_count[self.mode + '_total'] += samples self._merge_count[self.mode + '_total'] += samples
self._merge_count[self.mode + '_batch'] = samples self._merge_count[self.mode + '_batch'] = samples
...@@ -397,7 +403,7 @@ class StaticGraphAdapter(object): ...@@ -397,7 +403,7 @@ class StaticGraphAdapter(object):
metrics = [] metrics = []
with fluid.program_guard(prog, self._startup_prog): with fluid.program_guard(prog, self._startup_prog):
if isinstance(self.model._inputs, dict): if isinstance(self.model._inputs, dict):
ins = [self.model._inputs[n] \ ins = [self.model._inputs[n]
for n in extract_args(self.model.forward) if n != 'self'] for n in extract_args(self.model.forward) if n != 'self']
else: else:
ins = self.model._inputs ins = self.model._inputs
...@@ -417,7 +423,8 @@ class StaticGraphAdapter(object): ...@@ -417,7 +423,8 @@ class StaticGraphAdapter(object):
if mode != 'test': if mode != 'test':
for metric in self.model._metrics: for metric in self.model._metrics:
metrics.append(to_list(metric.add_metric_op(outputs, labels))) metrics.append(
to_list(metric.add_metric_op(outputs, labels)))
if mode == 'train' and self.model._optimizer: if mode == 'train' and self.model._optimizer:
self._loss_endpoint = fluid.layers.sum(losses) self._loss_endpoint = fluid.layers.sum(losses)
...@@ -427,8 +434,8 @@ class StaticGraphAdapter(object): ...@@ -427,8 +434,8 @@ class StaticGraphAdapter(object):
dist_strategy = DistributedStrategy() dist_strategy = DistributedStrategy()
dist_strategy.mode = "collective" dist_strategy.mode = "collective"
dist_strategy.collective_mode = "grad_allreduce" dist_strategy.collective_mode = "grad_allreduce"
self.model._optimizer = fleet.distributed_optimizer(self.model._optimizer, self.model._optimizer = fleet.distributed_optimizer(
strategy=dist_strategy) self.model._optimizer, strategy=dist_strategy)
self.model._optimizer.minimize(self._loss_endpoint) self.model._optimizer.minimize(self._loss_endpoint)
...@@ -444,7 +451,6 @@ class StaticGraphAdapter(object): ...@@ -444,7 +451,6 @@ class StaticGraphAdapter(object):
"metric": metrics "metric": metrics
} }
def _compile_and_initialize(self, prog, mode): def _compile_and_initialize(self, prog, mode):
compiled_prog = self._compiled_progs.get(mode, None) compiled_prog = self._compiled_progs.get(mode, None)
if compiled_prog is not None: if compiled_prog is not None:
...@@ -464,7 +470,8 @@ class StaticGraphAdapter(object): ...@@ -464,7 +470,8 @@ class StaticGraphAdapter(object):
if self._executor is None: if self._executor is None:
if self._nranks > 1 and device.lower() == 'gpu': if self._nranks > 1 and device.lower() == 'gpu':
gpu_id = int(Env().dev_id) gpu_id = int(Env().dev_id)
place = fluid.CUDAPlace(gpu_id) if device.lower() == 'gpu' else fluid.CPUPlace() place = fluid.CUDAPlace(
gpu_id) if device.lower() == 'gpu' else fluid.CPUPlace()
else: else:
place = places[0] place = places[0]
self._executor = fluid.Executor(place) self._executor = fluid.Executor(place)
...@@ -484,7 +491,7 @@ class StaticGraphAdapter(object): ...@@ -484,7 +491,7 @@ class StaticGraphAdapter(object):
if self._nranks < 2: if self._nranks < 2:
compiled_prog = fluid.CompiledProgram(prog) compiled_prog = fluid.CompiledProgram(prog)
else: else:
compiled_prog = prog#fleet.main_program compiled_prog = prog
if len(places) > 1: if len(places) > 1:
loss_name = None loss_name = None
...@@ -501,8 +508,12 @@ class DynamicGraphAdapter(object): ...@@ -501,8 +508,12 @@ class DynamicGraphAdapter(object):
self.model = model self.model = model
self._nranks = Env().nranks self._nranks = Env().nranks
self._local_rank = Env().local_rank self._local_rank = Env().local_rank
self._merge_count = {'eval_total': 0, 'test_total': 0, self._merge_count = {
'eval_batch': 0, 'test_batch': 0} 'eval_total': 0,
'test_total': 0,
'eval_batch': 0,
'test_batch': 0
}
if self._nranks > 1: if self._nranks > 1:
stradegy = fluid.dygraph.parallel.ParallelStrategy() stradegy = fluid.dygraph.parallel.ParallelStrategy()
...@@ -510,7 +521,8 @@ class DynamicGraphAdapter(object): ...@@ -510,7 +521,8 @@ class DynamicGraphAdapter(object):
stradegy.local_rank = Env().local_rank stradegy.local_rank = Env().local_rank
stradegy.trainer_endpoints = Env().trainer_endpoints stradegy.trainer_endpoints = Env().trainer_endpoints
stradegy.current_endpoint = Env().current_endpoint stradegy.current_endpoint = Env().current_endpoint
self.ddp_model = fluid.dygraph.parallel.DataParallel(self.model, stradegy) self.ddp_model = fluid.dygraph.parallel.DataParallel(
self.model, stradegy)
@property @property
def mode(self): def mode(self):
...@@ -546,7 +558,8 @@ class DynamicGraphAdapter(object): ...@@ -546,7 +558,8 @@ class DynamicGraphAdapter(object):
self.model.clear_gradients() self.model.clear_gradients()
metrics = [] metrics = []
for metric in self.model._metrics: for metric in self.model._metrics:
metric_outs = metric.add_metric_op(to_list(outputs), to_list(labels)) metric_outs = metric.add_metric_op(
to_list(outputs), to_list(labels))
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m) metrics.append(m)
...@@ -576,15 +589,17 @@ class DynamicGraphAdapter(object): ...@@ -576,15 +589,17 @@ class DynamicGraphAdapter(object):
samples = outputs[0].shape[0] samples = outputs[0].shape[0]
current_count = self._merge_count.get(self.mode + '_total', 0) current_count = self._merge_count.get(self.mode + '_total', 0)
if current_count + samples >= total_size: if current_count + samples >= total_size:
outputs = [o[:total_size - metric.count[0]] for o in outputs] outputs = [
o[:total_size - metric.count[0]] for o in outputs
]
labels = [l[:total_size - metric.count[0]] for l in labels] labels = [l[:total_size - metric.count[0]] for l in labels]
self._merge_count[self.mode + '_total'] = 0 self._merge_count[self.mode + '_total'] = 0
self._merge_count[self.mode + '_batch'] = total_size - current_count self._merge_count[self.mode +
'_batch'] = total_size - current_count
else: else:
self._merge_count[self.mode + '_total'] += samples self._merge_count[self.mode + '_total'] += samples
self._merge_count[self.mode + '_batch'] = samples self._merge_count[self.mode + '_batch'] = samples
metric_outs = metric.add_metric_op(to_list(outputs), labels) metric_outs = metric.add_metric_op(to_list(outputs), labels)
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m) metrics.append(m)
...@@ -691,17 +706,16 @@ class Model(fluid.dygraph.Layer): ...@@ -691,17 +706,16 @@ class Model(fluid.dygraph.Layer):
self._place = fluid.CUDAPlace(Env().dev_id) \ self._place = fluid.CUDAPlace(Env().dev_id) \
if Env().nranks > 1 else fluid.CUDAPlace(0) if Env().nranks > 1 else fluid.CUDAPlace(0)
global _parallel_context_inited global _parallel_context_initialized
if Env().nranks > 1 and not _parallel_context_inited: if Env().nranks > 1 and not _parallel_context_initialized:
if fluid.in_dygraph_mode(): if fluid.in_dygraph_mode():
fluid.disable_dygraph() fluid.disable_dygraph()
fluid.enable_dygraph(self._place) fluid.enable_dygraph(self._place)
fluid.dygraph.parallel.prepare_context() fluid.dygraph.parallel.prepare_context()
else: else:
fluid.enable_dygraph(self._place) prepare_distributed_context(self._place)
fluid.dygraph.parallel.prepare_context()
fluid.disable_dygraph() _parallel_context_initialized = True
_parallel_context_inited = True
# init backend # init backend
if fluid.in_dygraph_mode(): if fluid.in_dygraph_mode():
...@@ -850,7 +864,8 @@ class Model(fluid.dygraph.Layer): ...@@ -850,7 +864,8 @@ class Model(fluid.dygraph.Layer):
metrics = metrics or [] metrics = metrics or []
for metric in to_list(metrics): for metric in to_list(metrics):
assert isinstance(metric, Metric), \ assert isinstance(metric, Metric), \
"{} is not sub class of Metric".format(metric.__class__.__name__) "{} is not sub class of Metric".format(
metric.__class__.__name__)
self._metrics = to_list(metrics) self._metrics = to_list(metrics)
self._inputs = inputs self._inputs = inputs
...@@ -873,6 +888,7 @@ class Model(fluid.dygraph.Layer): ...@@ -873,6 +888,7 @@ class Model(fluid.dygraph.Layer):
epochs=1, epochs=1,
eval_freq=1, eval_freq=1,
log_freq=10, log_freq=10,
save_dir=None,
save_freq=1, save_freq=1,
verbose=2, verbose=2,
drop_last=False, drop_last=False,
...@@ -882,17 +898,24 @@ class Model(fluid.dygraph.Layer): ...@@ -882,17 +898,24 @@ class Model(fluid.dygraph.Layer):
""" """
FIXME: add more comments and usage FIXME: add more comments and usage
Args: Args:
train_loader (DataLoader): an iterable data loader is used for train. train_dataset (Dataset): An instance of paddle.fluid.io.Dataset.
eval_loader (DataLoader): an iterable data loader is used for eval_dataset (Dataset): An instance of paddle.fluid.io.Dataset.
train_loader (DataLoader): An iterable data loader is used for train.
eval_loader (DataLoader): An iterable data loader is used for
evaluation at the end of epoch. If None, will not do evaluation. evaluation at the end of epoch. If None, will not do evaluation.
epochs (int): number of epochs to train the model. epochs (int): Integer number. The number of epochs to train the model.
eval_freq (int): evaluation frequency in epoch. eval_freq (int): The frequency, in number of epochs, an evalutation
log_freq (int): frequency to print log during training. is performed.
save_freq (int): frequency to save checkpoint during training. log_freq (int): The frequency, in number of steps, the training logs
verbose (int): verbosity mode, should be 0, 1, or 2. is printed.
save_dir(str|None): The directory to save checkpoint during training.
If None, will not save checkpoint.
save_freq (int): The frequency, in number of epochs, to save checkpoint.
verbose (int): The verbosity mode, should be 0, 1, or 2.
0 = silent, 1 = progress bar, 2 = one line per epoch. 0 = silent, 1 = progress bar, 2 = one line per epoch.
callbacks (Callback|None): list of `Callback` instances to apply callbacks (Callback|None): A list of `Callback` instances to apply
during training. during training. If None, `ProgBarLogger` and `ModelCheckpoint`
are automatically inserted.
""" """
assert train_dataset is not None or train_loader is not None, \ assert train_dataset is not None or train_loader is not None, \
...@@ -908,11 +931,13 @@ class Model(fluid.dygraph.Layer): ...@@ -908,11 +931,13 @@ class Model(fluid.dygraph.Layer):
feed_list = [x.forward() for x in self._inputs + self._labels] feed_list = [x.forward() for x in self._inputs + self._labels]
if train_loader is None: if train_loader is None:
train_sampler = DistributedBatchSampler(train_dataset, train_sampler = DistributedBatchSampler(
train_dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=shuffle, shuffle=shuffle,
drop_last=drop_last) drop_last=drop_last)
train_loader = DataLoader(train_dataset, train_loader = DataLoader(
train_dataset,
batch_sampler=train_sampler, batch_sampler=train_sampler,
places=self._place, places=self._place,
feed_list=feed_list, feed_list=feed_list,
...@@ -920,9 +945,10 @@ class Model(fluid.dygraph.Layer): ...@@ -920,9 +945,10 @@ class Model(fluid.dygraph.Layer):
return_list=True) return_list=True)
if eval_loader is None and eval_dataset is not None: if eval_loader is None and eval_dataset is not None:
eval_sampler = DistributedBatchSampler(eval_dataset, eval_sampler = DistributedBatchSampler(
batch_size=batch_size) eval_dataset, batch_size=batch_size)
eval_loader = DataLoader(eval_dataset, eval_loader = DataLoader(
eval_dataset,
batch_sampler=eval_sampler, batch_sampler=eval_sampler,
places=self._place, places=self._place,
feed_list=feed_list, feed_list=feed_list,
...@@ -932,18 +958,21 @@ class Model(fluid.dygraph.Layer): ...@@ -932,18 +958,21 @@ class Model(fluid.dygraph.Layer):
do_eval = eval_loader is not None do_eval = eval_loader is not None
self._test_dataloader = eval_loader self._test_dataloader = eval_loader
metrics_name = self._metrics_name() metrics_name = self._metrics_name()
steps = len(train_loader) if hasattr(train_loader, '__len__') else None
cbks = config_callbacks( cbks = config_callbacks(
callbacks, callbacks,
model=self, model=self,
epochs=epochs, epochs=epochs,
steps=None, steps=steps,
log_freq=log_freq, log_freq=log_freq,
save_freq=save_freq, save_freq=save_freq,
save_dir=save_dir,
verbose=verbose, verbose=verbose,
metrics=self._metrics_name(), ) metrics=self._metrics_name(), )
def _run_one_epoch(data_loader, callbacks, mode): def _run_one_epoch(data_loader, callbacks, mode):
size = data_loader.size if hasattr(data_loader, 'size') else None size = len(data_loader) if hasattr(data_loader,
'__len__') else None
logs = { logs = {
'steps': size, 'steps': size,
'metrics_name': metrics_name, 'metrics_name': metrics_name,
...@@ -978,7 +1007,8 @@ class Model(fluid.dygraph.Layer): ...@@ -978,7 +1007,8 @@ class Model(fluid.dygraph.Layer):
if mode == 'train' or self._adapter._merge_count.get(mode + '_batch', 0) <= 0: if mode == 'train' or self._adapter._merge_count.get(mode + '_batch', 0) <= 0:
logs['batch_size'] = batch_size * Env().nranks logs['batch_size'] = batch_size * Env().nranks
else: else:
logs['batch_size'] = self._adapter._merge_count[mode + '_batch'] logs['batch_size'] = self._adapter._merge_count[mode +
'_batch']
cbks.on_batch_end(mode, step, logs) cbks.on_batch_end(mode, step, logs)
self._reset_metrics() self._reset_metrics()
...@@ -1000,7 +1030,7 @@ class Model(fluid.dygraph.Layer): ...@@ -1000,7 +1030,7 @@ class Model(fluid.dygraph.Layer):
loader = eval_loader loader = eval_loader
if not isinstance(eval_loader, Iterable): if not isinstance(eval_loader, Iterable):
loader = eval_loader() loader = eval_loader()
logs = _run_one_epoch(eval_loader(), cbks, 'eval') logs = _run_one_epoch(eval_loader, cbks, 'eval')
cbks.on_end('eval', logs) cbks.on_end('eval', logs)
cbks.on_end('train', logs) cbks.on_end('train', logs)
......
...@@ -154,13 +154,15 @@ class TestModel(unittest.TestCase): ...@@ -154,13 +154,15 @@ class TestModel(unittest.TestCase):
model = MNIST() if not is_mlp else MLP() model = MNIST() if not is_mlp else MLP()
optim = fluid.optimizer.Momentum( optim = fluid.optimizer.Momentum(
learning_rate=0.01, learning_rate=0.01, momentum=.9, parameter_list=model.parameters())
momentum=.9,
parameter_list=model.parameters())
loss = CrossEntropy() if not is_mlp else MyCrossEntropy() loss = CrossEntropy() if not is_mlp else MyCrossEntropy()
model.prepare(optim, loss, Accuracy(), inputs, labels) model.prepare(optim, loss, Accuracy(), inputs, labels)
cbk = ProgBarLogger(50) cbk = ProgBarLogger(50)
model.fit(train_dataset, val_dataset, epochs=2, batch_size=batch_size, callbacks=cbk) model.fit(train_dataset,
val_dataset,
epochs=2,
batch_size=batch_size,
callbacks=cbk)
def test_fit_static(self): def test_fit_static(self):
self.fit(False) self.fit(False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册