# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tqdm import json import numpy as np import os from datetime import datetime import logging from collections import defaultdict import paddle.fluid as F from pgl.utils.logger import log from pgl.utils.log_writer import LogWriter def multi_device(reader, dev_count): if dev_count == 1: for batch in reader: yield batch else: batches = [] for batch in reader: batches.append(batch) if len(batches) == dev_count: yield batches batches = [] def evaluate(exe, loader, prog, model, evaluator): total_labels = [] for i in range(len(loader.dataset)): g, l = loader.dataset[i] total_labels.append(l) total_labels = np.vstack(total_labels) pred_output = [] for feed_dict in loader: ret = exe.run(prog, feed=feed_dict, fetch_list=model.pred) pred_output.append(ret[0]) pred_output = np.vstack(pred_output) result = evaluator.eval({"y_true": total_labels, "y_pred": pred_output}) return result def _create_if_not_exist(path): basedir = os.path.dirname(path) if not os.path.exists(basedir): os.makedirs(basedir) def train_and_evaluate(exe, train_exe, valid_exe, train_ds, valid_ds, test_ds, train_prog, valid_prog, args, model, evaluator, dev_count=1): global_step = 0 timestamp = datetime.now().strftime("%Hh%Mm%Ss") log_path = os.path.join(args.log_dir, "log_%s" % timestamp) _create_if_not_exist(log_path) writer = LogWriter(log_path) best_valid_score = 0.0 for e in range(args.epoch): for feed_dict in multi_device(train_ds, dev_count): if dev_count > 1: ret = train_exe.run(feed=feed_dict, fetch_list=model.metrics.vars) ret = [[np.mean(v)] for v in ret] else: ret = train_exe.run(train_prog, feed=feed_dict, fetch_list=model.metrics.vars) ret = model.metrics.parse(ret) if global_step % args.train_log_step == 0: writer.add_scalar( "batch_loss", ret['loss'], global_step) log.info("epoch: %d | step: %d | loss: %.4f " % (e, global_step, ret['loss'])) global_step += 1 if global_step % args.eval_step == 0: valid_ret = evaluate(exe, valid_ds, valid_prog, model, evaluator) message = "valid: " for key, value in valid_ret.items(): message += "%s %.4f | " % (key, value) writer.add_scalar( "eval_%s" % key, value, global_step) log.info(message) # testing test_ret = evaluate(exe, test_ds, valid_prog, model, evaluator) message = "test: " for key, value in test_ret.items(): message += "%s %.4f | " % (key, value) writer.add_scalar( "test_%s" % key, value, global_step) log.info(message) # evaluate after one epoch valid_ret = evaluate(exe, valid_ds, valid_prog, model, evaluator) message = "epoch %s valid: " % e for key, value in valid_ret.items(): message += "%s %.4f | " % (key, value) writer.add_scalar("eval_%s" % key, value, global_step) log.info(message) # testing test_ret = evaluate(exe, test_ds, valid_prog, model, evaluator) message = "epoch %s test: " % e for key, value in test_ret.items(): message += "%s %.4f | " % (key, value) writer.add_scalar("test_%s" % key, value, global_step) log.info(message) message = "epoch %s best %s result | " % (e, args.eval_metrics) if valid_ret[args.eval_metrics] > best_valid_score: best_valid_score = valid_ret[args.eval_metrics] best_test_score = test_ret[args.eval_metrics] message += "valid %.4f | test %.4f" % (best_valid_score, best_test_score) log.info(message) # if global_step % args.save_step == 0: # F.io.save_persistables(exe, os.path.join(args.save_dir, "%s" % global_step), train_prog) writer.close()