types.py 3.2 KB
Newer Older
C
chenxuyi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
C
chenxuyi 已提交
14 15
"""Basic types"""

C
chenxuyi 已提交
16 17 18 19 20 21 22 23 24 25
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals

import os
import json
from collections import namedtuple


class RunMode(object):
C
chenxuyi 已提交
26
    """model_fn will be called in 3 modes"""
C
chenxuyi 已提交
27 28 29 30 31 32
    TRAIN = 1
    PREDICT = 2
    EVAL = 3


class HParams(object):
C
chenxuyi 已提交
33 34
    """Hyper paramerter"""

C
chenxuyi 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            self.__dict__[k] = v

    def __contains__(self, key):
        return key in self.__dict__

    def __getitem__(self, key):
        if key not in self.__dict__:
            raise ValueError('key(%s) not in HParams.' % key)
        return self.__dict__[key]

    def __repr__(self):
        return repr(self.to_dict())

    def __setitem__(self, key, val):
        self.__dict__[key] = val

C
chenxuyi 已提交
53 54 55
    @classmethod
    def from_json(cls, json_str):
        """doc"""
C
chenxuyi 已提交
56 57 58 59 60 61
        d = json.loads(json_str)
        if type(d) != dict:
            raise ValueError('json object must be dict.')
        return HParams.from_dict(d)

    def get(self, key, default=None):
C
chenxuyi 已提交
62
        """doc"""
C
chenxuyi 已提交
63 64
        return self.__dict__.get(key, default)

C
chenxuyi 已提交
65 66 67
    @classmethod
    def from_dict(cls, d):
        """doc"""
C
chenxuyi 已提交
68 69 70 71 72 73
        if type(d) != dict:
            raise ValueError('input must be dict.')
        hp = HParams(**d)
        return hp

    def to_json(self):
C
chenxuyi 已提交
74
        """doc"""
C
chenxuyi 已提交
75 76 77
        return json.dumps(self.__dict__)

    def to_dict(self):
C
chenxuyi 已提交
78
        """doc"""
C
chenxuyi 已提交
79 80 81
        return self.__dict__

    def join(self, other):
C
chenxuyi 已提交
82
        """doc"""
C
chenxuyi 已提交
83
        if not isinstance(other, HParams):
M
Meiyim 已提交
84 85
            raise ValueError('input must be HParams instance. got %s' %
                             type(other))
C
chenxuyi 已提交
86 87 88 89 90 91 92
        self.__dict__.update(**other.__dict__)
        return self


SummaryRecord = namedtuple('SummaryRecord', ['scalar', 'histogram'])

WarmStartSetting = namedtuple('WarmStartSetting', ['predicate_fn', 'from_dir'])
M
Meiyim 已提交
93
TextoneWarmStartSetting = namedtuple('TextoneWarmStartSetting', ['from_dir'])
C
chenxuyi 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110

RunConfig = namedtuple('RunConfig', [
    'model_dir', 'run_steps', 'max_steps', 'save_steps', 'eval_steps',
    'eval_max_steps', 'skip_steps', 'log_steps', 'max_ckpt', 'shit'
])
RunConfig.__new__.__defaults__ = (None, ) * len(RunConfig._fields)

ProgramPair = namedtuple('ProgramPair', ['train_program', 'startup_program'])

InferenceSpec = namedtuple('InferenceSpec', ['inputs', 'outputs'])

ModelSpec = namedtuple('ModelSpec', [
    'loss',
    'predictions',
    'metrics',
    'mode',
    'inference_spec',
C
chenxuyi 已提交
111 112
    'train_hooks',
    'eval_hooks',
C
chenxuyi 已提交
113 114 115 116 117
])
ModelSpec.__new__.__defaults__ = (None, ) * len(ModelSpec._fields)


class StopException(Exception):
C
chenxuyi 已提交
118
    """doc"""
C
chenxuyi 已提交
119
    pass