未验证 提交 2e6dfa44 编写于 作者: L littletomatodonkey 提交者: GitHub

fix logger (#840)

* fix logger
* fix trainer for int64 on windows
上级 e4c4ec76
......@@ -54,13 +54,8 @@ def create_operators(params):
def build_dataloader(config, mode, device, seed=None):
assert mode in [
'Train',
'Eval',
'Test',
'Gallery',
'Query'
], "Mode should be Train, Eval, Test, Gallery, Query"
assert mode in ['Train', 'Eval', 'Test', 'Gallery', 'Query'
], "Mode should be Train, Eval, Test, Gallery, Query"
# build dataset
config_dataset = config[mode]['dataset']
config_dataset = copy.deepcopy(config_dataset)
......@@ -72,7 +67,7 @@ def build_dataloader(config, mode, device, seed=None):
dataset = eval(dataset_name)(**config_dataset)
logger.info("build dataset({}) success...".format(dataset))
logger.debug("build dataset({}) success...".format(dataset))
# build sampler
config_sampler = config[mode]['sampler']
......@@ -85,7 +80,7 @@ def build_dataloader(config, mode, device, seed=None):
sampler_name = config_sampler.pop("name")
batch_sampler = eval(sampler_name)(dataset, **config_sampler)
logger.info("build batch_sampler({}) success...".format(batch_sampler))
logger.debug("build batch_sampler({}) success...".format(batch_sampler))
# build batch operator
def mix_collate_fn(batch):
......@@ -132,5 +127,5 @@ def build_dataloader(config, mode, device, seed=None):
batch_sampler=batch_sampler,
collate_fn=batch_collate_fn)
logger.info("build data_loader({}) success...".format(data_loader))
logger.debug("build data_loader({}) success...".format(data_loader))
return data_loader
......@@ -30,6 +30,8 @@ import paddle.distributed as dist
from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config
from ppcls.data import build_dataloader
from ppcls.arch import build_model
from ppcls.loss import build_loss
......@@ -49,6 +51,11 @@ class Trainer(object):
self.mode = mode
self.config = config
self.output_dir = self.config['Global']['output_dir']
log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
f"{mode}.log")
init_logger(name='root', log_file=log_file)
print_config(config)
# set device
assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu"]
self.device = paddle.set_device(self.config["Global"]["device"])
......@@ -153,8 +160,8 @@ class Trainer(object):
time_info[key].reset()
time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0]
batch[1] = paddle.to_tensor(batch[1].numpy().astype("int64")
.reshape([-1, 1]))
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
global_step += 1
# image input
if not self.is_rec:
......@@ -206,8 +213,9 @@ class Trainer(object):
eta_msg = "eta: {:s}".format(
str(datetime.timedelta(seconds=int(eta_sec))))
logger.info(
"[Train][Epoch {}][Iter: {}/{}]{}, {}, {}, {}, {}".
format(epoch_id, iter_id,
"[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".
format(epoch_id, self.config["Global"][
"epochs"], iter_id,
len(self.train_dataloader), lr_msg, metric_msg,
time_msg, ips_msg, eta_msg))
tic = time.time()
......@@ -216,8 +224,8 @@ class Trainer(object):
"{}: {:.5f}".format(key, output_info[key].avg)
for key in output_info
])
logger.info("[Train][Epoch {}][Avg]{}".format(epoch_id,
metric_msg))
logger.info("[Train][Epoch {}/{}][Avg]{}".format(
epoch_id, self.config["Global"]["epochs"], metric_msg))
output_info.clear()
# eval model and save model if possible
......@@ -327,7 +335,7 @@ class Trainer(object):
time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0]
batch[0] = paddle.to_tensor(batch[0]).astype("float32")
batch[1] = paddle.to_tensor(batch[1]).reshape([-1, 1])
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
# image input
if self.is_rec:
out = self.model(batch[0], batch[1])
......@@ -438,9 +446,11 @@ class Trainer(object):
for key in metric_tmp:
if key not in metric_dict:
metric_dict[key] = metric_tmp[key] * block_fea.shape[0] / len(query_feas)
metric_dict[key] = metric_tmp[key] * block_fea.shape[
0] / len(query_feas)
else:
metric_dict[key] += metric_tmp[key] * block_fea.shape[0] / len(query_feas)
metric_dict[key] += metric_tmp[key] * block_fea.shape[
0] / len(query_feas)
metric_info_list = []
for key in metric_dict:
......@@ -467,10 +477,10 @@ class Trainer(object):
for idx, batch in enumerate(dataloader(
)): # load is very time-consuming
batch = [paddle.to_tensor(x) for x in batch]
batch[1] = batch[1].reshape([-1, 1])
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
if len(batch) == 3:
has_unique_id = True
batch[2] = batch[2].reshape([-1, 1])
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
out = self.model(batch[0], batch[1])
batch_feas = out["features"]
......
......@@ -52,5 +52,5 @@ class CombinedLoss(nn.Layer):
def build_loss(config):
module_class = CombinedLoss(copy.deepcopy(config))
logger.info("build loss {} success.".format(module_class))
logger.debug("build loss {} success.".format(module_class))
return module_class
......@@ -45,7 +45,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
config = copy.deepcopy(config)
# step1 build lr
lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch)
logger.info("build lr ({}) success..".format(lr))
logger.debug("build lr ({}) success..".format(lr))
# step2 build regularization
if 'regularizer' in config and config['regularizer'] is not None:
reg_config = config.pop('regularizer')
......@@ -53,7 +53,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
reg = getattr(paddle.regularizer, reg_name)(**reg_config)
else:
reg = None
logger.info("build regularizer ({}) success..".format(reg))
logger.debug("build regularizer ({}) success..".format(reg))
# step3 build optimizer
optim_name = config.pop('name')
if 'clip_norm' in config:
......@@ -65,5 +65,5 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
weight_decay=reg,
grad_clip=grad_clip,
**config)(parameters=parameters)
logger.info("build optimizer ({}) success..".format(optim))
logger.debug("build optimizer ({}) success..".format(optim))
return optim, lr
......@@ -67,18 +67,14 @@ def print_dict(d, delimiter=0):
placeholder = "-" * 60
for k, v in sorted(d.items()):
if isinstance(v, dict):
logger.info("{}{} : ".format(delimiter * " ",
logger.coloring(k, "HEADER")))
logger.info("{}{} : ".format(delimiter * " ", k))
print_dict(v, delimiter + 4)
elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
logger.info("{}{} : ".format(delimiter * " ",
logger.coloring(str(k), "HEADER")))
logger.info("{}{} : ".format(delimiter * " ", k))
for value in v:
print_dict(value, delimiter + 4)
else:
logger.info("{}{} : {}".format(delimiter * " ",
logger.coloring(k, "HEADER"),
logger.coloring(v, "OKGREEN")))
logger.info("{}{} : {}".format(delimiter * " ", k, v))
if k.isupper():
logger.info(placeholder)
......@@ -175,7 +171,7 @@ def override_config(config, options=None):
return config
def get_config(fname, overrides=None, show=True):
def get_config(fname, overrides=None, show=False):
"""
Read config from file
"""
......
......@@ -12,70 +12,86 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import datetime
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S")
def time_zone(sec, fmt):
real_time = datetime.datetime.now()
return real_time.timetuple()
logging.Formatter.converter = time_zone
_logger = logging.getLogger(__name__)
import sys
Color = {
'RED': '\033[31m',
'HEADER': '\033[35m', # deep purple
'PURPLE': '\033[95m', # purple
'OKBLUE': '\033[94m',
'OKGREEN': '\033[92m',
'WARNING': '\033[93m',
'FAIL': '\033[91m',
'ENDC': '\033[0m'
}
def coloring(message, color="OKGREEN"):
assert color in Color.keys()
if os.environ.get('PADDLECLAS_COLORING', False):
return Color[color] + str(message) + Color["ENDC"]
import logging
import datetime
import paddle.distributed as dist
_logger = None
def init_logger(name='root', log_file=None, log_level=logging.INFO):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
global _logger
assert _logger is None, "logger should not be initialized twice or more."
_logger = logging.getLogger(name)
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt="%Y/%m/%d %H:%M:%S")
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(formatter)
_logger.addHandler(stream_handler)
if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True)
file_handler = logging.FileHandler(log_file, 'a')
file_handler.setFormatter(formatter)
_logger.addHandler(file_handler)
if dist.get_rank() == 0:
_logger.setLevel(log_level)
else:
return message
_logger.setLevel(logging.ERROR)
def anti_fleet(log):
def log_at_trainer0(log):
"""
logs will print multi-times when calling Fleet API.
Only display single log and ignore the others.
"""
def wrapper(fmt, *args):
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
if dist.get_rank() == 0:
log(fmt, *args)
return wrapper
@anti_fleet
@log_at_trainer0
def info(fmt, *args):
_logger.info(fmt, *args)
@anti_fleet
@log_at_trainer0
def debug(fmt, *args):
_logger.debug(fmt, *args)
@log_at_trainer0
def warning(fmt, *args):
_logger.warning(coloring(fmt, "RED"), *args)
_logger.warning(fmt, *args)
@anti_fleet
@log_at_trainer0
def error(fmt, *args):
_logger.error(coloring(fmt, "FAIL"), *args)
_logger.error(fmt, *args)
def scaler(name, value, step, writer):
......@@ -108,13 +124,12 @@ def advertise():
website = "https://github.com/PaddlePaddle/PaddleClas"
AD_LEN = 6 + len(max([copyright, ad, website], key=len))
info(
coloring("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
"=" * (AD_LEN + 4),
"=={}==".format(copyright.center(AD_LEN)),
"=" * (AD_LEN + 4),
"=={}==".format(' ' * AD_LEN),
"=={}==".format(ad.center(AD_LEN)),
"=={}==".format(' ' * AD_LEN),
"=={}==".format(website.center(AD_LEN)),
"=" * (AD_LEN + 4), ), "RED"))
info("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
"=" * (AD_LEN + 4),
"=={}==".format(copyright.center(AD_LEN)),
"=" * (AD_LEN + 4),
"=={}==".format(' ' * AD_LEN),
"=={}==".format(ad.center(AD_LEN)),
"=={}==".format(' ' * AD_LEN),
"=={}==".format(website.center(AD_LEN)),
"=" * (AD_LEN + 4), ))
......@@ -115,19 +115,6 @@ def init_model(config, net, optimizer=None):
pretrained_model), "HEADER"))
def _save_student_model(net, model_prefix):
"""
save student model if the net is the network contains student
"""
student_model_prefix = model_prefix + "_student.pdparams"
if hasattr(net, "_layers"):
net = net._layers
if hasattr(net, "student"):
paddle.save(net.student.state_dict(), student_model_prefix)
logger.info("Already save student model in {}".format(
student_model_prefix))
def save_model(net,
optimizer,
metric_info,
......@@ -141,11 +128,9 @@ def save_model(net,
return
model_path = os.path.join(model_path, model_name)
_mkdir_if_not_exist(model_path)
model_prefix = os.path.join(model_path, prefix)
_save_student_model(net, model_prefix)
model_path = os.path.join(model_path, prefix)
paddle.save(net.state_dict(), model_prefix + ".pdparams")
paddle.save(optimizer.state_dict(), model_prefix + ".pdopt")
paddle.save(metric_info, model_prefix + ".pdstates")
paddle.save(net.state_dict(), model_path + ".pdparams")
paddle.save(optimizer.state_dict(), model_path + ".pdopt")
paddle.save(metric_info, model_path + ".pdstates")
logger.info("Already save model in {}".format(model_path))
......@@ -25,6 +25,7 @@ from ppcls.engine.trainer import Trainer
if __name__ == "__main__":
args = config.parse_args()
config = config.get_config(args.config, overrides=args.override, show=True)
config = config.get_config(
args.config, overrides=args.override, show=False)
trainer = Trainer(config, mode="eval")
trainer.eval()
......@@ -25,7 +25,8 @@ from ppcls.engine.trainer import Trainer
if __name__ == "__main__":
args = config.parse_args()
config = config.get_config(args.config, overrides=args.override, show=True)
config = config.get_config(
args.config, overrides=args.override, show=False)
trainer = Trainer(config, mode="infer")
trainer.infer()
......@@ -25,6 +25,7 @@ from ppcls.engine.trainer import Trainer
if __name__ == "__main__":
args = config.parse_args()
config = config.get_config(args.config, overrides=args.override, show=True)
config = config.get_config(
args.config, overrides=args.override, show=False)
trainer = Trainer(config, mode="train")
trainer.train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册