utils.py 1.4 KB
Newer Older
S
Superjom 已提交
1 2 3
import logging

logging.basicConfig()
S
fix PR  
Superjom 已提交
4
logger = logging.getLogger("paddle")
S
Superjom 已提交
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
logger.setLevel(logging.INFO)


class TaskMode:
    TRAIN_MODE = 0
    TEST_MODE = 1
    INFER_MODE = 2

    def __init__(self, mode):
        self.mode = mode

    def is_train(self):
        return self.mode == self.TRAIN_MODE

    def is_test(self):
        return self.mode == self.TEST_MODE

    def is_infer(self):
        return self.mode == self.INFER_MODE

    @staticmethod
    def create_train():
        return TaskMode(TaskMode.TRAIN_MODE)

    @staticmethod
    def create_test():
        return TaskMode(TaskMode.TEST_MODE)

    @staticmethod
    def create_infer():
        return TaskMode(TaskMode.INFER_MODE)


class ModelType:
    CLASSIFICATION = 0
    REGRESSION = 1

    def __init__(self, mode):
        self.mode = mode

    def is_classification(self):
        return self.mode == self.CLASSIFICATION

    def is_regression(self):
        return self.mode == self.REGRESSION

    @staticmethod
    def create_classification():
        return ModelType(ModelType.CLASSIFICATION)

    @staticmethod
    def create_regression():
        return ModelType(ModelType.REGRESSION)


def load_dnn_input_record(sent):
    return map(int, sent.split())


def load_lr_input_record(sent):
    res = []
    for _ in [x.split(':') for x in sent.split()]:
67 68 69
        res.append((
            int(_[0]),
            float(_[1]), ))
S
Superjom 已提交
70
    return res