提交 94bb6cdc 编写于 作者: M michaelowenliu

using callbacks in core/train

上级 19872341
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
from paddle.io import DistributedBatchSampler
import paddle.nn.functional as F
import paddleseg.utils.logger as logger
from paddleseg.utils import load_pretrained_model
from paddleseg.utils import resume
from paddleseg.utils import Timer, calculate_eta
from paddleseg.core.val import evaluate
from paddleseg.cvlibs import callbacks
def check_logits_losses(logits, losses):
len_logits = len(logits)
len_losses = len(losses['types'])
if len_logits != len_losses:
raise RuntimeError(
'The length of logits should equal to the types of loss config: {} != {}.'
.format(len_logits, len_losses))
def loss_computation(logits, label, losses):
check_logits_losses(logits, losses)
loss = 0
for i in range(len(logits)):
logit = logits[i]
if logit.shape[-2:] != label.shape[-2:]:
logit = F.resize_bilinear(logit, label.shape[-2:])
loss_i = losses['types'][i](logit, label)
loss += losses['coef'][i] * loss_i
return loss
def seg_train(model,
train_dataset,
places=None,
val_dataset=None,
losses=None,
optimizer=None,
save_dir='output',
iters=10000,
batch_size=2,
resume_model=None,
save_interval_iters=1000,
log_iters=10,
num_workers=8):
nranks = ParallelEnv().nranks
start_iter = 0
if resume_model is not None:
start_iter = resume(model, optimizer, resume_model)
if nranks > 1:
strategy = fluid.dygraph.prepare_context()
ddp_model = fluid.dygraph.DataParallel(model, strategy)
batch_sampler = DistributedBatchSampler(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
places=places,
num_workers=num_workers,
return_list=True,
)
out_labels = ["loss", "reader_cost", "batch_cost"]
base_logger = callbacks.BaseLogger(period=log_iters)
train_logger = callbacks.TrainLogger(log_freq=log_iters)
model_ckpt = callbacks.ModelCheckpoint(save_dir, save_params_only=False, period=save_interval_iters)
vdl = callbacks.VisualDL(log_dir=os.path.join(save_dir, "log"))
cbks_list = [base_logger, train_logger, model_ckpt, vdl]
cbks = callbacks.CallbackList(cbks_list)
cbks.set_model(model)
cbks.set_optimizer(optimizer)
cbks.set_params({
"batch_size": batch_size,
"total_iters": iters,
"log_iters": log_iters,
"verbose": 1,
"do_validation": True,
"metrics": out_labels,
"iters_per_epoch": len(batch_sampler)
})
logs = {}
logs = {key: 0.0 for key in out_labels}
timer = Timer()
timer.start()
############## 1 ################
cbks.on_train_begin(logs)
#################################
iter = start_iter
while iter < iters:
for data in loader:
iter += 1
if iter > iters:
break
logs["reader_cost"] = timer.elapsed_time()
############## 2 ################
cbks.on_iter_begin(iter, logs)
#################################
images = data[0]
labels = data[1].astype('int64')
if nranks > 1:
logits = ddp_model(images)
loss = loss_computation(logits, labels, losses)
# apply_collective_grads sum grads over multiple gpus.
loss = ddp_model.scale_loss(loss)
loss.backward()
ddp_model.apply_collective_grads()
else:
logits = model(images)
loss = loss_computation(logits, labels, losses)
loss.backward()
optimizer.step()
optimizer._learning_rate.step()
model.clear_gradients()
logs['loss'] = loss.numpy()[0]
logs["batch_cost"] = timer.elapsed_time()
############## 3 ################
cbks.on_iter_end(iter, logs)
#################################
timer.restart()
############### 4 ###############
cbks.on_train_end(logs)
#################################
\ No newline at end of file
# -*- encoding: utf-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import numpy as np
import paddle
from paddle.distributed.parallel import ParallelEnv
from visualdl import LogWriter
from paddleseg.utils.progbar import Progbar
import paddleseg.utils.logger as logger
class CallbackList(object):
"""Container abstracting a list of callbacks.
# Arguments
callbacks: List of `Callback` instances.
"""
def __init__(self, callbacks=None):
callbacks = callbacks or []
self.callbacks = [c for c in callbacks]
def append(self, callback):
self.callbacks.append(callback)
def set_params(self, params):
for callback in self.callbacks:
callback.set_params(params)
def set_model(self, model):
for callback in self.callbacks:
callback.set_model(model)
def set_optimizer(self, optimizer):
for callback in self.callbacks:
callback.set_optimizer(optimizer)
def on_iter_begin(self, iter, logs=None):
"""Called right before processing a batch.
"""
logs = logs or {}
for callback in self.callbacks:
callback.on_iter_begin(iter, logs)
self._t_enter_iter = time.time()
def on_iter_end(self, iter, logs=None):
"""Called at the end of a batch.
"""
logs = logs or {}
for callback in self.callbacks:
callback.on_iter_end(iter, logs)
self._t_exit_iter = time.time()
def on_train_begin(self, logs=None):
"""Called at the beginning of training.
"""
logs = logs or {}
for callback in self.callbacks:
callback.on_train_begin(logs)
def on_train_end(self, logs=None):
"""Called at the end of training.
"""
logs = logs or {}
for callback in self.callbacks:
callback.on_train_end(logs)
def __iter__(self):
return iter(self.callbacks)
class Callback(object):
"""Abstract base class used to build new callbacks.
"""
def __init__(self):
self.validation_data = None
def set_params(self, params):
self.params = params
def set_model(self, model):
self.model = model
def set_optimizer(self, optimizer):
self.optimizer = optimizer
def on_iter_begin(self, iter, logs=None):
pass
def on_iter_end(self, iter, logs=None):
pass
def on_train_begin(self, logs=None):
pass
def on_train_end(self, logs=None):
pass
class BaseLogger(Callback):
def __init__(self, period=10):
super(BaseLogger, self).__init__()
self.period = period
def _reset(self):
self.totals = {}
def on_train_begin(self, logs=None):
self.totals = {}
def on_iter_end(self, iter, logs=None):
logs = logs or {}
#(iter - 1) // iters_per_epoch + 1
for k, v in logs.items():
if k in self.totals.keys():
self.totals[k] += v
else:
self.totals[k] = v
if iter % self.period == 0 and ParallelEnv().local_rank == 0:
for k in self.totals:
logs[k] = self.totals[k] / self.period
self._reset()
class TrainLogger(Callback):
def __init__(self, log_freq=10):
self.log_freq = log_freq
def _calculate_eta(self, remaining_iters, speed):
if remaining_iters < 0:
remaining_iters = 0
remaining_time = int(remaining_iters * speed)
result = "{:0>2}:{:0>2}:{:0>2}"
arr = []
for i in range(2, -1, -1):
arr.append(int(remaining_time / 60**i))
remaining_time %= 60**i
return result.format(*arr)
def on_iter_end(self, iter, logs=None):
if iter % self.log_freq == 0 and ParallelEnv().local_rank == 0:
total_iters = self.params["total_iters"]
iters_per_epoch = self.params["iters_per_epoch"]
remaining_iters = total_iters - iter
eta = self._calculate_eta(remaining_iters, logs["batch_cost"])
current_epoch = (iter - 1) // self.params["iters_per_epoch"] + 1
loss = logs["loss"]
lr = self.optimizer.get_lr()
batch_cost = logs["batch_cost"]
reader_cost = logs["reader_cost"]
logger.info(
"[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.4f} | ETA {}".
format(current_epoch, iter, total_iters,
loss, lr, batch_cost, reader_cost, eta))
class ProgbarLogger(Callback):
def __init__(self):
super(ProgbarLogger, self).__init__()
def on_train_begin(self, logs=None):
self.verbose = self.params["verbose"]
self.total_iters = self.params["total_iters"]
self.target = self.params["total_iters"]
self.progbar = Progbar(target=self.target, verbose=self.verbose)
self.seen = 0
self.log_values = []
def on_iter_begin(self, iter, logs=None):
#self.seen = 0
if self.seen < self.target:
self.log_values = []
def on_iter_end(self, iter, logs=None):
logs = logs or {}
self.seen += 1
for k in self.params['metrics']:
if k in logs:
self.log_values.append((k, logs[k]))
#if self.verbose and self.seen < self.target and ParallelEnv.local_rank == 0:
#print(self.log_values)
if self.seen < self.target:
self.progbar.update(self.seen, self.log_values)
class ModelCheckpoint(Callback):
def __init__(self, save_dir, monitor="miou",
save_best_only=False, save_params_only=True,
mode="max", period=1):
super(ModelCheckpoint, self).__init__()
self.monitor = monitor
self.save_dir = save_dir
self.save_best_only = save_best_only
self.save_params_only = save_params_only
self.period = period
self.iters_since_last_save = 0
if mode == "min":
self.monitor_op = np.less
self.best = np.Inf
elif mode == "max":
self.monitor_op = np.greater
self.best = -np.Inf
else:
raise RuntimeError("mode is not either \"min\" or \"max\"!")
def on_train_begin(self, logs=None):
self.verbose = self.params["verbose"]
save_dir = self.save_dir
if not os.path.isdir(save_dir):
if os.path.exists(save_dir):
os.remove(save_dir)
os.makedirs(save_dir)
def on_iter_end(self, iter, logs=None):
logs = logs or {}
self.iters_since_last_save += 1
current_save_dir = os.path.join(self.save_dir, "iter_{}".format(iter))
current_save_dir = os.path.abspath(current_save_dir)
#if self.iters_since_last_save % self.period and ParallelEnv().local_rank == 0:
#self.iters_since_last_save = 0
if iter % self.period == 0 and ParallelEnv().local_rank == 0:
if self.verbose > 0:
print("iter {iter_num}: saving model to {path}".format(
iter_num=iter, path=current_save_dir))
filepath = os.path.join(current_save_dir, 'model')
paddle.save(self.model.state_dict(), filepath)
if not self.save_params_only:
paddle.save(self.optimizer.state_dict(), filepath)
class VisualDL(Callback):
def __init__(self, log_dir="./log", freq=1):
super(VisualDL, self).__init__()
self.log_dir = log_dir
self.freq = freq
def on_train_begin(self, logs=None):
self.writer = LogWriter(self.log_dir)
def on_iter_end(self, iter, logs=None):
logs = logs or {}
if iter % self.freq == 0 and ParallelEnv().local_rank == 0:
for k, v in logs.items():
self.writer.add_scalar("Train/{}".format(k), v, iter)
self.writer.flush()
def on_train_end(self, logs=None):
self.writer.close()
if __name__ == "__main__":
paddle.disable_static()
x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
linear = paddle.nn.Linear(10, 10)
scheduler = paddle.optimizer.lr_scheduler.PolynomialLR(learning_rate=0.5, decay_steps=20)
sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
callbacks1 = ModelCheckpoint(save_dir="/mnt/liuyi22/PaddlePaddle/PaddleSeg/dygraph/1", verbose=1, period=10)
callback_list = CallbackList([callbacks1])
callback_list.set_model(linear)
for iter in range(100):
callback_list.on_iter_end(iter)
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import time
import numpy as np
class Progbar(object):
"""Displays a progress bar.
refers to https://github.com/keras-team/keras/blob/keras-2/keras/utils/generic_utils.py
Arguments:
target: Total number of steps expected, None if unknown.
width: Progress bar width on screen.
verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
stateful_metrics: Iterable of string names of metrics that should *not* be
averaged over time. Metrics in this list will be displayed as-is. All
others will be averaged by the progbar before display.
interval: Minimum visual progress update interval (in seconds).
unit_name: Display name for step counts (usually "step" or "sample").
"""
def __init__(self,
target,
width=30,
verbose=1,
interval=0.05,
stateful_metrics=None,
unit_name='step'):
self.target = target
self.width = width
self.verbose = verbose
self.interval = interval
self.unit_name = unit_name
if stateful_metrics:
self.stateful_metrics = set(stateful_metrics)
else:
self.stateful_metrics = set()
self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
sys.stdout.isatty()) or
'ipykernel' in sys.modules or
'posix' in sys.modules or
'PYCHARM_HOSTED' in os.environ)
self._total_width = 0
self._seen_so_far = 0
# We use a dict + list to avoid garbage collection
# issues found in OrderedDict
self._values = {}
self._values_order = []
self._start = time.time()
self._last_update = 0
def update(self, current, values=None, finalize=None):
"""Updates the progress bar.
Arguments:
current: Index of current step.
values: List of tuples: `(name, value_for_last_step)`. If `name` is in
`stateful_metrics`, `value_for_last_step` will be displayed as-is.
Else, an average of the metric over time will be displayed.
finalize: Whether this is the last update for the progress bar. If
`None`, defaults to `current >= self.target`.
"""
if finalize is None:
if self.target is None:
finalize = False
else:
finalize = current >= self.target
values = values or []
for k, v in values:
if k not in self._values_order:
self._values_order.append(k)
if k not in self.stateful_metrics:
# In the case that progress bar doesn't have a target value in the first
# epoch, both on_batch_end and on_epoch_end will be called, which will
# cause 'current' and 'self._seen_so_far' to have the same value. Force
# the minimal value to 1 here, otherwise stateful_metric will be 0s.
value_base = max(current - self._seen_so_far, 1)
if k not in self._values:
self._values[k] = [v * value_base, value_base]
else:
self._values[k][0] += v * value_base
self._values[k][1] += value_base
else:
# Stateful metrics output a numeric value. This representation
# means "take an average from a single value" but keeps the
# numeric formatting.
self._values[k] = [v, 1]
self._seen_so_far = current
now = time.time()
info = ' - %.0fs' % (now - self._start)
if self.verbose == 1:
if now - self._last_update < self.interval and not finalize:
return
prev_total_width = self._total_width
if self._dynamic_display:
sys.stdout.write('\b' * prev_total_width)
sys.stdout.write('\r')
else:
sys.stdout.write('\n')
if self.target is not None:
numdigits = int(np.log10(self.target)) + 1
bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target)
prog = float(current) / self.target
prog_width = int(self.width * prog)
if prog_width > 0:
bar += ('=' * (prog_width - 1))
if current < self.target:
bar += '>'
else:
bar += '='
bar += ('.' * (self.width - prog_width))
bar += ']'
else:
bar = '%7d/Unknown' % current
self._total_width = len(bar)
sys.stdout.write(bar)
if current:
time_per_unit = (now - self._start) / current
else:
time_per_unit = 0
if self.target is None or finalize:
if time_per_unit >= 1 or time_per_unit == 0:
info += ' %.0fs/%s' % (time_per_unit, self.unit_name)
elif time_per_unit >= 1e-3:
info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name)
else:
info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name)
else:
eta = time_per_unit * (self.target - current)
if eta > 3600:
eta_format = '%d:%02d:%02d' % (eta // 3600,
(eta % 3600) // 60, eta % 60)
elif eta > 60:
eta_format = '%d:%02d' % (eta // 60, eta % 60)
else:
eta_format = '%ds' % eta
info = ' - ETA: %s' % eta_format
for k in self._values_order:
info += ' - %s:' % k
if isinstance(self._values[k], list):
avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
if abs(avg) > 1e-3:
info += ' %.4f' % avg
else:
info += ' %.4e' % avg
else:
info += ' %s' % self._values[k]
self._total_width += len(info)
if prev_total_width > self._total_width:
info += (' ' * (prev_total_width - self._total_width))
if finalize:
info += '\n'
sys.stdout.write(info)
sys.stdout.flush()
elif self.verbose == 2:
if finalize:
numdigits = int(np.log10(self.target)) + 1
count = ('%' + str(numdigits) + 'd/%d') % (current, self.target)
info = count + info
for k in self._values_order:
info += ' - %s:' % k
avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
if avg > 1e-3:
info += ' %.4f' % avg
else:
info += ' %.4e' % avg
info += '\n'
sys.stdout.write(info)
sys.stdout.flush()
self._last_update = now
def add(self, n, values=None):
self.update(self._seen_so_far + n, values)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册