base_cls.py 3.5 KB
Newer Older
W
webyfdt 已提交
1 2 3 4 5 6 7 8 9
# -*- coding: utf-8 -*
"""
分类任务的网络基类
"""
import collections
import logging
import numpy as np
import paddle

K
Kennycao123 已提交
10 11 12 13
from erniekit.common.register import RegisterSet
from erniekit.common.rule import InstanceName
from erniekit.metrics import metrics
from erniekit.model.model import BaseModel
W
webyfdt 已提交
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105


@RegisterSet.models.register
class BaseClassification(BaseModel):
    """BaseClassification
    """
    def __init__(self, model_params):
        """
        """
        BaseModel.__init__(self, model_params)

    def structure(self):
        """网络结构组织
        :return:
        """
        raise NotImplementedError

    def forward(self, fields_dict, phase):
        """ 前向计算
        :param fields_dict:
        :param phase:
        :return:
        """
        raise NotImplementedError

    def set_optimizer(self):
        """优化器设置
        :return: optimizer
        """
        opt_param = self.model_params.get('optimization', None)
        if opt_param:
            self.lr = opt_param.get('learning_rate', 2e-5)
        else:
            self.lr = 2e-5
        self.optimizer = paddle.optimizer.Adam(learning_rate=self.lr, parameters=self.parameters())
        return self.optimizer

    def get_metrics(self, forward_return_dict, meta_info, phase):
        """
        :param forward_return_dict: 前向计算得出的结果
        :param meta_info: 常用的meta信息,如step, used_time, gpu_id等
        :param phase: 当前调用的阶段,包含训练和评估
        :return:
        """
        predictions = forward_return_dict[InstanceName.PREDICT_RESULT]
        label = forward_return_dict[InstanceName.LABEL]
        # paddle_acc = forward_return_dict["acc"]
        if self.is_dygraph:
            if isinstance(predictions, list):
                predictions = [item.numpy() for item in predictions]
            else:
                predictions = predictions.numpy()

            if isinstance(label, list):
                label = [item.numpy() for item in label]
            else:
                label = label.numpy()

        metrics_acc = metrics.Acc()
        acc = metrics_acc.eval([predictions, label])
        metrics_pres = metrics.Precision()
        precision = metrics_pres.eval([predictions, label])

        if phase == InstanceName.TRAINING:
            step = meta_info[InstanceName.STEP]
            time_cost = meta_info[InstanceName.TIME_COST]
            loss = forward_return_dict[InstanceName.LOSS]
            if isinstance(loss, paddle.Tensor):
                loss_np = loss.numpy()
                mean_loss = np.mean(loss_np)
            else:
                mean_loss = np.mean(loss)

            logging.info("phase = {0} loss = {1} acc = {2} precision = {3} step = {4} time_cost = {5}".format(
                phase, mean_loss, acc, precision, step, round(time_cost, 4)))
        if phase == InstanceName.EVALUATE or phase == InstanceName.TEST:
            time_cost = meta_info[InstanceName.TIME_COST]
            step = meta_info[InstanceName.STEP]
            logging.info("phase = {0} acc = {1} precision = {2} time_cost = {3} step = {4}".format(
                phase, acc, precision, round(time_cost, 4), step))

        metrics_return_dict = collections.OrderedDict()
        metrics_return_dict["acc"] = acc
        metrics_return_dict["precision"] = precision
        return metrics_return_dict

    def fields_process(self, fields_dict, phase):
        """
        对fields_dict中序列化好的id按需做二次处理。
        :return: 处理好的fields
        """
        pass