run_trainer.py 4.8 KB
Newer Older
W
webyfdt 已提交
1 2 3 4
# -*- coding: utf-8 -*
"""import"""
import os
import sys
K
Kennycao123 已提交
5
sys.path.append("../../../")
K
Kennycao123 已提交
6 7 8
from erniekit.common.register import RegisterSet
from erniekit.common import register
from erniekit.data.data_set import DataSet
W
webyfdt 已提交
9
import logging
K
Kennycao123 已提交
10 11 12
from erniekit.utils import args
from erniekit.utils import params
from erniekit.utils import log
W
webyfdt 已提交
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
import paddle

logging.getLogger().setLevel(logging.INFO)


def dataset_reader_from_params(params_dict):
    """
    :param params_dict:
    :return:
    """
    dataset_reader = DataSet(params_dict)
    dataset_reader.build()

    return dataset_reader


def model_from_params(params_dict, dataset_reader):
    """
    :param params_dict:
    :param dataset_reader
    :return:
    """
    opt_params = params_dict.get("optimization", None)
    num_train_examples = dataset_reader.train_reader.dataset.get_num_examples()
    # 按配置计算warmup_steps
    if opt_params and opt_params.__contains__("warmup_steps"):
        trainers_num = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
        batch_size_train = dataset_reader.train_reader.dataset.config.batch_size
        epoch_train = dataset_reader.train_reader.dataset.config.epoch
        max_train_steps = epoch_train * num_train_examples // batch_size_train // trainers_num
        # 知识蒸馏TD2需要将TD1的max_train_step算进来
        task_distill_params = params_dict.get("task_distill_step2", None)
        if task_distill_params and task_distill_params.__contains__("td1_epoch"):
            # TD1训练的轮数,需要在TD2的配置文件里设置
            td1_epoch = task_distill_params["td1_epoch"]
            # 默认TD1和TD2的batch_size一致,训练样本数一致
            td1_batch_size = task_distill_params.get("td1_batch_size", batch_size_train)
            max_train_steps += td1_epoch * num_train_examples // td1_batch_size // trainers_num

        warmup_steps = opt_params.get("warmup_steps", 0)

        if warmup_steps == 0:
            warmup_proportion = opt_params.get("warmup_proportion", 0.1)
            warmup_steps = int(max_train_steps * warmup_proportion)

        logging.info("Device count: %d" % trainers_num)
        logging.info("Num train examples: %d" % num_train_examples)
        logging.info("Max train steps: %d" % max_train_steps)
        logging.info("Num warmup steps: %d" % warmup_steps)

        opt_params = {}
        opt_params["warmup_steps"] = warmup_steps
        opt_params["max_train_steps"] = max_train_steps
        opt_params["num_train_examples"] = num_train_examples

        # combine params dict
        params_dict["optimization"].update(opt_params)
    model_name = params_dict.get("type")
    model_class = RegisterSet.models.__getitem__(model_name)
    model = model_class(params_dict)
    return model, num_train_examples


def build_trainer(params_dict, dataset_reader, model, num_train_examples=0):
    """build trainer"""
    trainer_name = params_dict.get("type", "CustomTrainer")
    trainer_class = RegisterSet.trainer.__getitem__(trainer_name)
    params_dict["num_train_examples"] = num_train_examples
    trainer = trainer_class(params=params_dict, data_set_reader=dataset_reader, model=model)
    return trainer


def run_trainer(param_dict):
    """
    :param param_dict:
    :return:
    """
    logging.info("run trainer.... pid = " + str(os.getpid()))
    dataset_reader_params_dict = param_dict.get("dataset_reader")
    dataset_reader = dataset_reader_from_params(dataset_reader_params_dict)

    model_params_dict = param_dict.get("model")
    model, num_train_examples = model_from_params(model_params_dict, dataset_reader)
    model_params_dict["num_train_examples"] = num_train_examples

    trainer_params_dict = param_dict.get("trainer")
    trainer = build_trainer(trainer_params_dict, dataset_reader, model, num_train_examples)

    trainer.do_train()
    logging.info("end of run train and eval .....")


if __name__ == "__main__":
    args = args.build_common_arguments()
    log.init_log("./log/test", level=logging.DEBUG)
    param_dict = params.from_file(args.param_path)
    _params = params.replace_none(param_dict)
    
    # 记得import一下注册的模块
    register.import_modules()
    register.import_new_module("model", "bow_matching_pairwise")
    register.import_new_module("model", "ernie_matching_fc_pointwise")
    register.import_new_module("model", "ernie_matching_siamese_pairwise")
    register.import_new_module("model", "ernie_matching_siamese_pointwise")
    register.import_new_module("trainer", "custom_trainer")
    register.import_new_module("trainer", "custom_dynamic_trainer")
    register.import_new_module("data_set_reader", "ernie_classification_dataset_reader")

K
Kennycao123 已提交
121
    # erniekitDataLoader
W
webyfdt 已提交
122 123 124 125
    trainer_params = param_dict.get("trainer")
    paddle.set_device(trainer_params.get("PADDLE_PLACE_TYPE", "cpu"))
    run_trainer(_params)
    os._exit(0)