提交 6e59472d 编写于 作者: Q qingqing01

Do not save checkpoint if not set save_dir

上级 839db7b1
...@@ -18,6 +18,7 @@ import copy ...@@ -18,6 +18,7 @@ import copy
from progressbar import ProgressBar from progressbar import ProgressBar
from distributed import get_local_rank from distributed import get_local_rank
def config_callbacks(callbacks=None, def config_callbacks(callbacks=None,
model=None, model=None,
batch_size=None, batch_size=None,
...@@ -26,6 +27,7 @@ def config_callbacks(callbacks=None, ...@@ -26,6 +27,7 @@ def config_callbacks(callbacks=None,
log_freq=2, log_freq=2,
verbose=2, verbose=2,
save_freq=1, save_freq=1,
save_dir=None,
metrics=None, metrics=None,
mode='train'): mode='train'):
cbks = callbacks or [] cbks = callbacks or []
...@@ -34,7 +36,7 @@ def config_callbacks(callbacks=None, ...@@ -34,7 +36,7 @@ def config_callbacks(callbacks=None,
cbks = cbks + [ProgBarLogger(log_freq, verbose=verbose)] cbks = cbks + [ProgBarLogger(log_freq, verbose=verbose)]
if not any(isinstance(k, ModelCheckpoint) for k in cbks): if not any(isinstance(k, ModelCheckpoint) for k in cbks):
cbks = cbks + [ModelCheckpoint(save_freq)] cbks = cbks + [ModelCheckpoint(save_freq, save_dir)]
cbk_list = CallbackList(cbks) cbk_list = CallbackList(cbks)
cbk_list.set_model(model) cbk_list.set_model(model)
...@@ -209,9 +211,10 @@ class ProgBarLogger(Callback): ...@@ -209,9 +211,10 @@ class ProgBarLogger(Callback):
def on_train_batch_end(self, step, logs=None): def on_train_batch_end(self, step, logs=None):
logs = logs or {} logs = logs or {}
self.train_step = step self.train_step += 1
if self.train_step % self.log_freq == 0 and self.verbose and get_local_rank() == 0: if self.train_step % self.log_freq == 0 and self.verbose and get_local_rank(
) == 0:
# if steps is not None, last step will update in on_epoch_end # if steps is not None, last step will update in on_epoch_end
if self.steps and self.train_step < self.steps: if self.steps and self.train_step < self.steps:
self._updates(logs, 'train') self._updates(logs, 'train')
...@@ -247,21 +250,24 @@ class ProgBarLogger(Callback): ...@@ -247,21 +250,24 @@ class ProgBarLogger(Callback):
class ModelCheckpoint(Callback): class ModelCheckpoint(Callback):
def __init__(self, save_freq=1, save_file='output'): def __init__(self, save_freq=1, save_dir=None):
self.save_freq = save_freq self.save_freq = save_freq
self.save_file = save_file self.save_dir = save_dir
def on_epoch_begin(self, epoch=None, logs=None): def on_epoch_begin(self, epoch=None, logs=None):
self.epoch = epoch self.epoch = epoch
def _is_save(self):
return self.model and self.save_dir and get_local_rank() == 0
def on_epoch_end(self, epoch, logs=None): def on_epoch_end(self, epoch, logs=None):
if self.model and self.epoch % self.save_freq == 0 and get_local_rank() == 0: if self._is_save() and self.epoch % self.save_freq == 0:
path = '{}/{}'.format(self.save_file, epoch) path = '{}/{}'.format(self.save_dir, epoch)
print('save checkpoint at {}'.format(path)) print('save checkpoint at {}'.format(path))
self.model.save(path) self.model.save(path)
def on_train_end(self, logs=None): def on_train_end(self, logs=None):
if self.model and get_local_rank() == 0: if self._is_save():
path = '{}/final'.format(self.save_file) path = '{}/final'.format(self.save_dir)
print('save checkpoint at {}'.format(path)) print('save checkpoint at {}'.format(path))
self.model.save(path) self.model.save(path)
...@@ -107,24 +107,26 @@ class MNIST(Model): ...@@ -107,24 +107,26 @@ class MNIST(Model):
def main(): def main():
init_context('dynamic' if FLAGS.dynamic else 'static') init_context('dynamic' if FLAGS.dynamic else 'static')
train_dataset = MnistDataset(mode='train') train_dataset = MnistDataset(mode='train')
val_dataset = MnistDataset(mode='test') val_dataset = MnistDataset(mode='test')
inputs = [Input([None, 784], 'float32', name='image')] inputs = [Input([None, 784], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')] labels = [Input([None, 1], 'int64', name='label')]
model = MNIST() model = MNIST()
optim = Momentum( optim = Momentum(
learning_rate=FLAGS.lr, learning_rate=FLAGS.lr, momentum=.9, parameter_list=model.parameters())
momentum=.9,
parameter_list=model.parameters())
model.prepare(optim, CrossEntropy(), Accuracy(topk=(1, 2)), inputs, labels) model.prepare(optim, CrossEntropy(), Accuracy(topk=(1, 2)), inputs, labels)
if FLAGS.resume is not None: if FLAGS.resume is not None:
model.load(FLAGS.resume) model.load(FLAGS.resume)
model.fit(train_dataset, val_dataset, epochs=FLAGS.epoch, batch_size=FLAGS.batch_size) model.fit(train_dataset,
val_dataset,
epochs=FLAGS.epoch,
batch_size=FLAGS.batch_size,
save_dir='mnist_checkpoint')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -38,7 +38,6 @@ from paddle.fluid.io import DataLoader ...@@ -38,7 +38,6 @@ from paddle.fluid.io import DataLoader
from metrics import Metric from metrics import Metric
from callbacks import config_callbacks from callbacks import config_callbacks
__all__ = ['Model', 'Loss', 'CrossEntropy', 'Input'] __all__ = ['Model', 'Loss', 'CrossEntropy', 'Input']
...@@ -87,7 +86,7 @@ def extract_args(func): ...@@ -87,7 +86,7 @@ def extract_args(func):
def init_context(backend): def init_context(backend):
assert isinstance(backend, str) and backend.lower() in ['dynamic', 'static'], \ assert isinstance(backend, str) and backend.lower() in ['dynamic', 'static'], \
"Expected backend in ['dynamic', 'static'], but got {}".format(backend) "Expected backend in ['dynamic', 'static'], but got {}".format(backend)
place = fluid.CUDAPlace(distributed.Env().dev_id) if \ place = fluid.CUDAPlace(distributed.Env().dev_id) if \
distributed.Env().nranks > 1 else fluid.CUDAPlace(0) distributed.Env().nranks > 1 else fluid.CUDAPlace(0)
...@@ -155,9 +154,13 @@ class StaticGraphAdapter(object): ...@@ -155,9 +154,13 @@ class StaticGraphAdapter(object):
self._progs = {} self._progs = {}
self._compiled_progs = {} self._compiled_progs = {}
self._merge_count = {'eval_total': 0, 'test_total': 0, self._merge_count = {
'eval_batch': 0, 'test_batch': 0} 'eval_total': 0,
'test_total': 0,
'eval_batch': 0,
'test_batch': 0
}
self._nranks = distributed.Env().nranks self._nranks = distributed.Env().nranks
self._local_rank = distributed.Env().local_rank self._local_rank = distributed.Env().local_rank
...@@ -370,9 +373,12 @@ class StaticGraphAdapter(object): ...@@ -370,9 +373,12 @@ class StaticGraphAdapter(object):
samples = state[0].shape[0] samples = state[0].shape[0]
current_count = self._merge_count.get(self.mode + '_total', 0) current_count = self._merge_count.get(self.mode + '_total', 0)
if current_count + samples >= total_size: if current_count + samples >= total_size:
state = [s[:total_size - current_count, ...] for s in state] state = [
s[:total_size - current_count, ...] for s in state
]
self._merge_count[self.mode + '_total'] = 0 self._merge_count[self.mode + '_total'] = 0
self._merge_count[self.mode + '_batch'] = total_size - current_count self._merge_count[self.mode +
'_batch'] = total_size - current_count
else: else:
self._merge_count[self.mode + '_total'] += samples self._merge_count[self.mode + '_total'] += samples
self._merge_count[self.mode + '_batch'] = samples self._merge_count[self.mode + '_batch'] = samples
...@@ -405,7 +411,7 @@ class StaticGraphAdapter(object): ...@@ -405,7 +411,7 @@ class StaticGraphAdapter(object):
# HACK workaround learning rate map issue # HACK workaround learning rate map issue
lr_var = self.model._optimizer._learning_rate_map[self._orig_prog] lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
self.model._optimizer._learning_rate_map[prog] = lr_var self.model._optimizer._learning_rate_map[prog] = lr_var
losses = [] losses = []
metrics = [] metrics = []
with fluid.program_guard(prog, self._startup_prog): with fluid.program_guard(prog, self._startup_prog):
...@@ -421,16 +427,22 @@ class StaticGraphAdapter(object): ...@@ -421,16 +427,22 @@ class StaticGraphAdapter(object):
outputs = to_list(self.model.forward(*inputs)) outputs = to_list(self.model.forward(*inputs))
if mode != 'test' and self.model._loss_function: if mode != 'test' and self.model._loss_function:
losses = self.model._loss_function(outputs, labels) losses = self.model._loss_function(outputs, labels)
if self._nranks > 1 and mode != 'train': if self._nranks > 1 and mode != 'train':
outputs = [distributed._all_gather(o, self._nranks) for o in outputs] outputs = [
distributed._all_gather(o, self._nranks) for o in outputs
]
if mode != 'test': if mode != 'test':
labels = [distributed._all_gather(l, self._nranks) for l in labels] labels = [
distributed._all_gather(l, self._nranks)
for l in labels
]
if mode != 'test': if mode != 'test':
for metric in self.model._metrics: for metric in self.model._metrics:
metrics.append(to_list(metric.add_metric_op(outputs, labels))) metrics.append(
to_list(metric.add_metric_op(outputs, labels)))
if mode == 'train' and self.model._optimizer: if mode == 'train' and self.model._optimizer:
self._loss_endpoint = fluid.layers.sum(losses) self._loss_endpoint = fluid.layers.sum(losses)
...@@ -440,16 +452,16 @@ class StaticGraphAdapter(object): ...@@ -440,16 +452,16 @@ class StaticGraphAdapter(object):
dist_strategy = DistributedStrategy() dist_strategy = DistributedStrategy()
dist_strategy.mode = "collective" dist_strategy.mode = "collective"
dist_strategy.collective_mode = "grad_allreduce" dist_strategy.collective_mode = "grad_allreduce"
self.model._optimizer = fleet.distributed_optimizer(self.model._optimizer, self.model._optimizer = fleet.distributed_optimizer(
strategy=dist_strategy) self.model._optimizer, strategy=dist_strategy)
self.model._optimizer.minimize(self._loss_endpoint) self.model._optimizer.minimize(self._loss_endpoint)
if mode != 'train': # clone again to put it in test mode if mode != 'train': # clone again to put it in test mode
prog = prog.clone(for_test=True) prog = prog.clone(for_test=True)
self._input_vars[mode] = inputs self._input_vars[mode] = inputs
self._progs[mode] = prog self._progs[mode] = prog
self._endpoints[mode] = { self._endpoints[mode] = {
"output": outputs, "output": outputs,
...@@ -457,7 +469,6 @@ class StaticGraphAdapter(object): ...@@ -457,7 +469,6 @@ class StaticGraphAdapter(object):
"metric": metrics "metric": metrics
} }
def _compile_and_initialize(self, prog, mode): def _compile_and_initialize(self, prog, mode):
compiled_prog = self._compiled_progs.get(mode, None) compiled_prog = self._compiled_progs.get(mode, None)
if compiled_prog is not None: if compiled_prog is not None:
...@@ -477,7 +488,8 @@ class StaticGraphAdapter(object): ...@@ -477,7 +488,8 @@ class StaticGraphAdapter(object):
if self._executor is None: if self._executor is None:
if self._nranks > 1 and device.lower() == 'gpu': if self._nranks > 1 and device.lower() == 'gpu':
gpu_id = int(distributed.Env().dev_id) gpu_id = int(distributed.Env().dev_id)
place = fluid.CUDAPlace(gpu_id) if device.lower() == 'gpu' else fluid.CPUPlace() place = fluid.CUDAPlace(gpu_id) if device.lower(
) == 'gpu' else fluid.CPUPlace()
else: else:
place = places[0] place = places[0]
self._executor = fluid.Executor(place) self._executor = fluid.Executor(place)
...@@ -497,7 +509,7 @@ class StaticGraphAdapter(object): ...@@ -497,7 +509,7 @@ class StaticGraphAdapter(object):
if self._nranks < 2: if self._nranks < 2:
compiled_prog = fluid.CompiledProgram(prog) compiled_prog = fluid.CompiledProgram(prog)
else: else:
compiled_prog = prog#fleet.main_program compiled_prog = prog #fleet.main_program
if len(places) > 1: if len(places) > 1:
loss_name = None loss_name = None
...@@ -514,8 +526,12 @@ class DynamicGraphAdapter(object): ...@@ -514,8 +526,12 @@ class DynamicGraphAdapter(object):
self.model = model self.model = model
self._nranks = distributed.Env().nranks self._nranks = distributed.Env().nranks
self._local_rank = distributed.Env().local_rank self._local_rank = distributed.Env().local_rank
self._merge_count = {'eval_total': 0, 'test_total': 0, self._merge_count = {
'eval_batch': 0, 'test_batch': 0} 'eval_total': 0,
'test_total': 0,
'eval_batch': 0,
'test_batch': 0
}
if self._nranks > 1: if self._nranks > 1:
self.ddp_model = distributed.DistributedDataParallel(self.model) self.ddp_model = distributed.DistributedDataParallel(self.model)
...@@ -554,7 +570,8 @@ class DynamicGraphAdapter(object): ...@@ -554,7 +570,8 @@ class DynamicGraphAdapter(object):
self.model.clear_gradients() self.model.clear_gradients()
metrics = [] metrics = []
for metric in self.model._metrics: for metric in self.model._metrics:
metric_outs = metric.add_metric_op(to_list(outputs), to_list(labels)) metric_outs = metric.add_metric_op(
to_list(outputs), to_list(labels))
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m) metrics.append(m)
...@@ -573,7 +590,10 @@ class DynamicGraphAdapter(object): ...@@ -573,7 +590,10 @@ class DynamicGraphAdapter(object):
else: else:
losses = [] losses = []
if self._nranks > 1: if self._nranks > 1:
outputs = [distributed._all_gather(o, self._nranks) for o in to_list(outputs)] outputs = [
distributed._all_gather(o, self._nranks)
for o in to_list(outputs)
]
labels = [distributed._all_gather(l, self._nranks) for l in labels] labels = [distributed._all_gather(l, self._nranks) for l in labels]
metrics = [] metrics = []
for metric in self.model._metrics: for metric in self.model._metrics:
...@@ -584,15 +604,17 @@ class DynamicGraphAdapter(object): ...@@ -584,15 +604,17 @@ class DynamicGraphAdapter(object):
samples = outputs[0].shape[0] samples = outputs[0].shape[0]
current_count = self._merge_count.get(self.mode + '_total', 0) current_count = self._merge_count.get(self.mode + '_total', 0)
if current_count + samples >= total_size: if current_count + samples >= total_size:
outputs = [o[:total_size - metric.count[0]] for o in outputs] outputs = [
o[:total_size - metric.count[0]] for o in outputs
]
labels = [l[:total_size - metric.count[0]] for l in labels] labels = [l[:total_size - metric.count[0]] for l in labels]
self._merge_count[self.mode + '_total'] = 0 self._merge_count[self.mode + '_total'] = 0
self._merge_count[self.mode + '_batch'] = total_size - current_count self._merge_count[self.mode +
'_batch'] = total_size - current_count
else: else:
self._merge_count[self.mode + '_total'] += samples self._merge_count[self.mode + '_total'] += samples
self._merge_count[self.mode + '_batch'] = samples self._merge_count[self.mode + '_batch'] = samples
metric_outs = metric.add_metric_op(to_list(outputs), labels) metric_outs = metric.add_metric_op(to_list(outputs), labels)
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m) metrics.append(m)
...@@ -608,7 +630,10 @@ class DynamicGraphAdapter(object): ...@@ -608,7 +630,10 @@ class DynamicGraphAdapter(object):
inputs = [to_variable(x) for x in to_list(inputs)] inputs = [to_variable(x) for x in to_list(inputs)]
outputs = self.model.forward(*inputs) outputs = self.model.forward(*inputs)
if self._nranks > 2: if self._nranks > 2:
outputs = [distributed._all_gather(o, self._nranks) for o in to_list(outputs)] outputs = [
distributed._all_gather(o, self._nranks)
for o in to_list(outputs)
]
return [to_numpy(o) for o in to_list(outputs)] return [to_numpy(o) for o in to_list(outputs)]
def parameters(self, *args, **kwargs): def parameters(self, *args, **kwargs):
...@@ -829,7 +854,7 @@ class Model(fluid.dygraph.Layer): ...@@ -829,7 +854,7 @@ class Model(fluid.dygraph.Layer):
the variable to the environment variable and set its value to 1. the variable to the environment variable and set its value to 1.
The default is None. The default is None.
""" """
self._optimizer = optimizer self._optimizer = optimizer
if loss_function: if loss_function:
if not isinstance(loss_function, Loss): if not isinstance(loss_function, Loss):
...@@ -852,7 +877,7 @@ class Model(fluid.dygraph.Layer): ...@@ -852,7 +877,7 @@ class Model(fluid.dygraph.Layer):
self._inputs = inputs self._inputs = inputs
self._labels = labels self._labels = labels
self._device = device self._device = device
if device is None: if device is None:
self._device = 'GPU' if fluid.is_compiled_with_cuda() else 'CPU' self._device = 'GPU' if fluid.is_compiled_with_cuda() else 'CPU'
self._device_ids = device_ids self._device_ids = device_ids
...@@ -869,6 +894,7 @@ class Model(fluid.dygraph.Layer): ...@@ -869,6 +894,7 @@ class Model(fluid.dygraph.Layer):
epochs=1, epochs=1,
eval_freq=1, eval_freq=1,
log_freq=10, log_freq=10,
save_dir=None,
save_freq=1, save_freq=1,
verbose=2, verbose=2,
drop_last=False, drop_last=False,
...@@ -878,17 +904,22 @@ class Model(fluid.dygraph.Layer): ...@@ -878,17 +904,22 @@ class Model(fluid.dygraph.Layer):
""" """
FIXME: add more comments and usage FIXME: add more comments and usage
Args: Args:
train_loader (DataLoader): an iterable data loader is used for train. train_loader (DataLoader): An iterable data loader is used for train.
eval_loader (DataLoader): an iterable data loader is used for eval_loader (DataLoader): An iterable data loader is used for
evaluation at the end of epoch. If None, will not do evaluation. evaluation at the end of epoch. If None, will not do evaluation.
epochs (int): number of epochs to train the model. epochs (int): Integer number. The number of epochs to train the model.
eval_freq (int): evaluation frequency in epoch. eval_freq (int): The frequency, in number of epochs, an evalutation
log_freq (int): frequency to print log during training. is performed.
save_freq (int): frequency to save checkpoint during training. log_freq (int): The frequency, in number of steps, the training logs
verbose (int): verbosity mode, should be 0, 1, or 2. is printed.
save_dir(str|None): The directory to save checkpoint during training.
If None, will not save checkpoint.
save_freq (int): The frequency, in number of epochs, to save checkpoint.
verbose (int): The verbosity mode, should be 0, 1, or 2.
0 = silent, 1 = progress bar, 2 = one line per epoch. 0 = silent, 1 = progress bar, 2 = one line per epoch.
callbacks (Callback|None): list of `Callback` instances to apply callbacks (Callback|None): A list of `Callback` instances to apply
during training. during training. If None, `ProgBarLogger` and `ModelCheckpoint`
are automatically inserted.
""" """
assert train_dataset is not None or train_loader is not None, \ assert train_dataset is not None or train_loader is not None, \
...@@ -904,37 +935,42 @@ class Model(fluid.dygraph.Layer): ...@@ -904,37 +935,42 @@ class Model(fluid.dygraph.Layer):
feed_list = [x.forward() for x in self._inputs + self._labels] feed_list = [x.forward() for x in self._inputs + self._labels]
if train_loader is None: if train_loader is None:
train_sampler = DistributedBatchSampler(train_dataset, train_sampler = DistributedBatchSampler(
batch_size=batch_size, train_dataset,
shuffle=shuffle, batch_size=batch_size,
drop_last=drop_last) shuffle=shuffle,
train_loader = DataLoader(train_dataset, drop_last=drop_last)
batch_sampler=train_sampler, train_loader = DataLoader(
places=self._place, train_dataset,
feed_list=feed_list, batch_sampler=train_sampler,
num_workers=num_workers, places=self._place,
return_list=True) feed_list=feed_list,
num_workers=num_workers,
return_list=True)
if eval_loader is None and eval_dataset is not None: if eval_loader is None and eval_dataset is not None:
eval_sampler = DistributedBatchSampler(eval_dataset, eval_sampler = DistributedBatchSampler(
batch_size=batch_size) eval_dataset, batch_size=batch_size)
eval_loader = DataLoader(eval_dataset, eval_loader = DataLoader(
batch_sampler=eval_sampler, eval_dataset,
places=self._place, batch_sampler=eval_sampler,
feed_list=feed_list, places=self._place,
num_workers=num_workers, feed_list=feed_list,
return_list=True) num_workers=num_workers,
return_list=True)
do_eval = eval_loader is not None do_eval = eval_loader is not None
self._test_dataloader = eval_loader self._test_dataloader = eval_loader
metrics_name = self._metrics_name() metrics_name = self._metrics_name()
steps = len(train_loader) if hasattr(train_loader, '__len__') else None
cbks = config_callbacks( cbks = config_callbacks(
callbacks, callbacks,
model=self, model=self,
epochs=epochs, epochs=epochs,
steps=None, steps=steps,
log_freq=log_freq, log_freq=log_freq,
save_freq=save_freq, save_freq=save_freq,
save_dir=save_dir,
verbose=verbose, verbose=verbose,
metrics=self._metrics_name(), ) metrics=self._metrics_name(), )
...@@ -965,16 +1001,18 @@ class Model(fluid.dygraph.Layer): ...@@ -965,16 +1001,18 @@ class Model(fluid.dygraph.Layer):
for metric in self._metrics: for metric in self._metrics:
res = metric.accumulate() res = metric.accumulate()
metrics.extend(to_list(res)) metrics.extend(to_list(res))
assert len(metrics_name) == len(metrics) assert len(metrics_name) == len(metrics)
for k, v in zip(metrics_name, metrics): for k, v in zip(metrics_name, metrics):
logs[k] = v logs[k] = v
logs['step'] = step logs['step'] = step
if mode == 'train' or self._adapter._merge_count.get(mode + '_batch', 0) <= 0: if mode == 'train' or self._adapter._merge_count.get(
mode + '_batch', 0) <= 0:
logs['batch_size'] = batch_size * distributed.Env().nranks logs['batch_size'] = batch_size * distributed.Env().nranks
else: else:
logs['batch_size'] = self._adapter._merge_count[mode + '_batch'] logs['batch_size'] = self._adapter._merge_count[mode +
'_batch']
cbks.on_batch_end(mode, step, logs) cbks.on_batch_end(mode, step, logs)
self._reset_metrics() self._reset_metrics()
......
...@@ -151,16 +151,18 @@ class TestModel(unittest.TestCase): ...@@ -151,16 +151,18 @@ class TestModel(unittest.TestCase):
train_dataset = MnistDataset(mode='train') train_dataset = MnistDataset(mode='train')
val_dataset = MnistDataset(mode='test') val_dataset = MnistDataset(mode='test')
model = MNIST() if not is_mlp else MLP() model = MNIST() if not is_mlp else MLP()
optim = fluid.optimizer.Momentum( optim = fluid.optimizer.Momentum(
learning_rate=0.01, learning_rate=0.01, momentum=.9, parameter_list=model.parameters())
momentum=.9,
parameter_list=model.parameters())
loss = CrossEntropy() if not is_mlp else MyCrossEntropy() loss = CrossEntropy() if not is_mlp else MyCrossEntropy()
model.prepare(optim, loss, Accuracy(), inputs, labels) model.prepare(optim, loss, Accuracy(), inputs, labels)
cbk = ProgBarLogger(50) cbk = ProgBarLogger(50)
model.fit(train_dataset, val_dataset, epochs=2, batch_size=batch_size, callbacks=cbk) model.fit(train_dataset,
val_dataset,
epochs=2,
batch_size=batch_size,
callbacks=cbk)
def test_fit_static(self): def test_fit_static(self):
self.fit(False) self.fit(False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册