util.py 3.6 KB
Newer Older
C
chenxuyi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
C
chenxuyi 已提交
14
"""global utils"""
C
chenxuyi 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals

import os
import six
import re
import json
import argparse
import itertools
import logging
from functools import reduce

from propeller.types import RunConfig
from propeller.types import HParams

log = logging.getLogger(__name__)


def ArgumentParser(name):
C
chenxuyi 已提交
35
    """predefined argparser"""
C
chenxuyi 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
    parser = argparse.ArgumentParser('propeller model')
    parser.add_argument('--run_config', type=str, default='')
    parser.add_argument(
        '--hparam', type=str, nargs='*', action='append', default=[['']])
    return parser


def _get_dict_from_environ_or_json_or_file(args, env_name):
    if args == '':
        return None
    if args is None:
        s = os.environ.get(env_name)
    else:
        s = args
        if os.path.exists(s):
            s = open(s).read()
    if isinstance(s, six.string_types):
        try:
M
Meiyim 已提交
54 55 56 57 58 59 60
            r = json.loads(s)
        except ValueError:
            try:
                r = eval(s)
            except SyntaxError as e:
                raise ValueError('json parse error: %s \n>Got json: %s' %
                                 (repr(e), s))
C
chenxuyi 已提交
61 62 63 64 65 66
        return r
    else:
        return s  #None


def parse_file(filename):
C
chenxuyi 已提交
67
    """useless api"""
C
chenxuyi 已提交
68 69 70 71 72 73 74
    d = _get_dict_from_environ_or_json_or_file(filename, None)
    if d is None:
        raise ValueError('file(%s) not found' % filename)
    return d


def parse_runconfig(args=None):
C
chenxuyi 已提交
75
    """get run_config from env or file"""
C
chenxuyi 已提交
76 77 78 79 80 81 82 83
    d = _get_dict_from_environ_or_json_or_file(args.run_config,
                                               'PROPELLER_RUNCONFIG')
    if d is None:
        raise ValueError('run_config not found')
    return RunConfig(**d)


def parse_hparam(args=None):
C
chenxuyi 已提交
84
    """get hparam from env or file"""
C
chenxuyi 已提交
85 86 87 88 89 90 91 92 93 94
    if args is not None:
        hparam_strs = reduce(list.__add__, args.hparam)
    else:
        hparam_strs = [None]

    hparams = [
        _get_dict_from_environ_or_json_or_file(hp, 'PROPELLER_HPARAMS')
        for hp in hparam_strs
    ]
    hparams = [HParams(**h) for h in hparams if h is not None]
M
Meiyim 已提交
95 96 97 98 99
    if len(hparams) == 0:
        return HParams()
    else:
        hparam = reduce(lambda x, y: x.join(y), hparams)
        return hparam
C
chenxuyi 已提交
100 101 102


def flatten(s):
C
chenxuyi 已提交
103
    """doc"""
C
chenxuyi 已提交
104 105 106 107 108 109 110
    assert is_struture(s)
    schema = [len(ss) for ss in s]
    flt = list(itertools.chain(*s))
    return flt, schema


def unflatten(structure, schema):
C
chenxuyi 已提交
111
    """doc"""
C
chenxuyi 已提交
112 113 114 115 116 117 118 119 120
    start = 0
    res = []
    for _range in schema:
        res.append(structure[start:start + _range])
        start += _range
    return res


def is_struture(s):
C
chenxuyi 已提交
121
    """doc"""
C
chenxuyi 已提交
122 123 124 125
    return isinstance(s, list) or isinstance(s, tuple)


def map_structure(func, s):
C
chenxuyi 已提交
126
    """same sa tf.map_structure"""
C
chenxuyi 已提交
127 128 129 130 131 132
    if isinstance(s, list) or isinstance(s, tuple):
        return [map_structure(func, ss) for ss in s]
    elif isinstance(s, dict):
        return {k: map_structure(func, v) for k, v in six.iteritems(s)}
    else:
        return func(s)