# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Trainer class. """ import json import logging import os import sys import time import numpy as np import paddle import paddle.fluid as fluid import paddle.fluid.dygraph as dygraph from tqdm import tqdm from plato.args import str2bool from plato.data.data_loader import DataLoader from plato.metrics.metrics_tracker import MetricsTracker from plato.metrics.metrics import bleu from plato.metrics.metrics import distinct import plato.modules.parallel as parallel def get_logger(log_path, name="default"): logger = logging.getLogger(name) logger.propagate = False logger.setLevel(logging.DEBUG) formatter = logging.Formatter("%(message)s") sh = logging.StreamHandler(sys.stdout) sh.setFormatter(formatter) logger.addHandler(sh) fh = logging.FileHandler(log_path, mode="w") fh.setFormatter(formatter) logger.addHandler(fh) return logger def evaluate_generation_result(results): tgt = [result["tgt"].split(" ") for result in results] pred = [result["preds"][np.argmax(result["scores"])] if isinstance(result["preds"], list) else result["preds"] for result in results] pred = [p.split(" ") for p in pred] metrics = {} metrics_tracker = MetricsTracker() bleu1, bleu2 = bleu(pred, tgt) metrics.update({"bleu_1": bleu1, "bleu_2": bleu2}) intra_dist1, intra_dist2, inter_dist1, inter_dist2 = distinct(pred) metrics.update({"intra_dist_1": intra_dist1, "intra_dist_2": intra_dist2, "inter_dist_1": inter_dist1, "inter_dist_2": inter_dist2}) avg_len = sum(map(len, pred)) / len(pred) metrics.update({"len": avg_len}) metrics_tracker.update(metrics, num_samples=1) return metrics_tracker def save(model, model_path): if isinstance(model, parallel.DataParallel): model = model._layers if hasattr(fluid, "save_dygraph"): # >= 1.6.0 compatible fluid.save_dygraph(model.state_dict(), model_path) fluid.save_dygraph(model.optimizer.state_dict(), model_path) else: dygraph.save_persistables(model.state_dict(), model_path, optimizers=model.optimizer) return class Trainer(object): @classmethod def add_cmdline_argument(cls, parser): """ Add the cmdline arguments of trainer. """ group = parser.add_argument_group("Trainer") group.add_argument("--use_data_distributed", type=str2bool, default=False, help="Whether to use data distributed for parallel training.") group.add_argument("--valid_metric_name", type=str, default="-loss", help="The validation metric determining which checkpoint is the best.") group.add_argument("--num_epochs", type=int, default=10, help="Total number of training epochs to perform.") group.add_argument("--save_dir", type=str, required=True, help="The output directory where the model will be saved.") group.add_argument("--batch_size", type=int, default=8, help="Total batch size for training/evaluation/inference.") group.add_argument("--log_steps", type=int, default=100, help="The number of training steps to output current metrics " "on past training dataset.") group.add_argument("--valid_steps", type=int, default=2000, help="The number of training steps to perform a evaluation " "on validation datasets.") group.add_argument("--save_checkpoint", type=str2bool, default=True, help="Whether to save one checkpoints for each training epoch.") group.add_argument("--save_summary", type=str2bool, default=False, help="Whether to save metrics summary for visualDL module.") DataLoader.add_cmdline_argument(group) return group def __init__(self, model, to_tensor, hparams, logger=None): # Use data distributed if hparams.use_data_distributed: strategy = parallel.prepare_context() if strategy is not None: parallel_model = parallel.DataParallel(model, strategy) model.before_backward_fn = parallel_model.scale_loss model.after_backward_fn = parallel_model.apply_collective_grads model = parallel_model self.model = model self.to_tensor = to_tensor self.is_decreased_valid_metric = hparams.valid_metric_name[0] == "-" self.valid_metric_name = hparams.valid_metric_name[1:] self.num_epochs = hparams.num_epochs self.save_dir = hparams.save_dir self.log_steps = hparams.log_steps self.valid_steps = hparams.valid_steps self.save_checkpoint = hparams.save_checkpoint self.save_summary = hparams.save_summary if not os.path.exists(self.save_dir): os.makedirs(self.save_dir) self.logger = logger or get_logger(os.path.join(self.save_dir, "trainer.log"), "trainer") if self.save_summary: from visualdl import LogWriter self.summary_logger = LogWriter(os.path.join(self.save_dir, "summary"), sync_cycle=10000) self.train_summary = {} self.valid_summary = {} self.batch_metrics_tracker = MetricsTracker() self.token_metrics_tracker = MetricsTracker() self.best_valid_metric = float("inf" if self.is_decreased_valid_metric else "-inf") self.epoch = 0 self.batch_num = 0 def train_epoch(self, train_iter, valid_iter, infer_iter=None, infer_parse_dict=None): """ Train an epoch. @param train_iter @type : DataLoader @param valid_iter @type : DataLoader @param infer_iter @type : DataLoader @param infer_parse_dict @type : dict of function """ self.epoch += 1 num_batches = len(train_iter) self.batch_metrics_tracker.clear() self.token_metrics_tracker.clear() times = [] for batch_id, (batch, batch_size) in enumerate(train_iter, 1): batch = type(batch)(map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items())) batch["epoch"] = self.epoch batch["num_steps"] = self.batch_num # Do a training iteration start_time = time.time() metrics = self.model(batch, is_training=True) token_num = metrics.pop("token_num", None) elapsed = time.time() - start_time times.append(elapsed) batch_metrics = {k: v for k, v in metrics.items() if "token" not in k} token_metrics = {k: v for k, v in metrics.items() if "token" in k} self.batch_metrics_tracker.update(batch_metrics, batch_size) self.token_metrics_tracker.update(token_metrics, token_num) self.batch_num += 1 if self.log_steps and batch_id % self.log_steps == 0: batch_metrics_message = self.batch_metrics_tracker.value() token_metrics_message = self.token_metrics_tracker.value() message_prefix = f"[Train][{self.epoch}][{batch_id}/{num_batches}]" avg_time = f"AVG_Time-{sum(times[-self.log_steps:]) / self.log_steps:.3f}" message = " ".join([message_prefix, batch_metrics_message, token_metrics_message, avg_time]) self.logger.info(message) if self.save_summary: with self.summary_logger.mode("train"): for k, v in self.batch_metrics_tracker.items(): if k not in self.train_summary: self.train_summary[k] = self.summary_logger.scalar(k) scalar = self.train_summary[k] scalar.add_record(self.batch_num, v) for k, v in self.token_metrics_tracker.items(): if k not in self.train_summary: self.train_summary[k] = self.summary_logger.scalar(k) scalar = self.train_summary[k] scalar.add_record(self.batch_num, v) if self.valid_steps and valid_iter is not None and \ batch_id % self.valid_steps == 0: self.evaluate(valid_iter) if valid_iter is not None: self.evaluate(valid_iter) if infer_iter is not None and infer_parse_dict is not None: self.infer(infer_iter, infer_parse_dict) return def infer(self, data_iter, parse_dict, num_batches=None): """ Inference interface. @param : data_iter @type : DataLoader @param : parse_dict @type : dict of function @param : num_batches : the number of batch to infer @type : int/None """ self.logger.info("Generation starts ...") infer_save_file = os.path.join(self.save_dir, f"infer_{self.epoch}.result.json") # Inference infer_results = [] batch_cnt = 0 begin_time = time.time() for batch, batch_size in tqdm(data_iter, total=num_batches): batch = type(batch)(map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items())) result = self.model.infer(inputs=batch) batch_result = {} def to_list(batch): """ Parse list. """ return batch.tolist() # parse for k in result: if k in parse_dict: parse_fn = parse_dict[k] else: parse_fn = to_list if result[k] is not None: batch_result[k] = parse_fn(result[k]) for vs in zip(*batch_result.values()): infer_result = {} for k, v in zip(batch_result.keys(), vs): infer_result[k] = v infer_results.append(infer_result) batch_cnt += 1 if batch_cnt == num_batches: break self.logger.info(f"Saved inference results to {infer_save_file}") with open(infer_save_file, "w") as fp: json.dump(infer_results, fp, indent=2) infer_metrics_tracker = evaluate_generation_result(infer_results) metrics_message = infer_metrics_tracker.summary() message_prefix = f"[Infer][{self.epoch}]" time_cost = f"TIME-{time.time() - begin_time:.3f}" message = " ".join([message_prefix, metrics_message, time_cost]) self.logger.info(message) return def evaluate(self, data_iter, need_save=True): """ Evaluation interface @param : data_iter @type : DataLoader @param : need_save @type : bool """ if isinstance(self.model, parallel.DataParallel): need_save = need_save and parallel.Env().local_rank == 0 # Evaluation begin_time = time.time() batch_metrics_tracker = MetricsTracker() token_metrics_tracker = MetricsTracker() for batch, batch_size in data_iter: batch = type(batch)(map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items())) metrics = self.model(batch, is_training=False) token_num = int(metrics.pop("token_num")) batch_metrics = {k: v for k, v in metrics.items() if "token" not in k} token_metrics = {k: v for k, v in metrics.items() if "token" in k} batch_metrics_tracker.update(batch_metrics, batch_size) token_metrics_tracker.update(token_metrics, token_num) batch_metrics_message = batch_metrics_tracker.summary() token_metrics_message = token_metrics_tracker.summary() message_prefix = f"[Valid][{self.epoch}]" time_cost = f"TIME-{time.time() - begin_time:.3f}" message = " ".join([message_prefix, batch_metrics_message, token_metrics_message, time_cost]) self.logger.info(message) if need_save: # Check valid metric cur_valid_metric = batch_metrics_tracker.get(self.valid_metric_name) if self.is_decreased_valid_metric: is_best = cur_valid_metric < self.best_valid_metric else: is_best = cur_valid_metric > self.best_valid_metric if is_best: # Save current best model self.best_valid_metric = cur_valid_metric best_model_path = os.path.join(self.save_dir, "best.model") save(self.model, best_model_path) self.logger.info( f"Saved best model to '{best_model_path}' with new best valid metric " f"{self.valid_metric_name.upper()}-{self.best_valid_metric:.3f}") # Save checkpoint if self.save_checkpoint: model_file = os.path.join(self.save_dir, f"epoch_{self.epoch}.model") save(self.model, model_file) if self.save_summary: with self.summary_logger.mode("valid"): for k, v in self.batch_metrics_tracker.items(): if k not in self.valid_summary: self.valid_summary[k] = self.summary_logger.scalar(k) scalar = self.valid_summary[k] scalar.add_record(self.batch_num, v) for k, v in self.token_metrics_tracker.items(): if k not in self.valid_summary: self.valid_summary[k] = self.summary_logger.scalar(k) scalar = self.valid_summary[k] scalar.add_record(self.batch_num, v) return