util.py 3.3 KB
Newer Older
C
chenxuyi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals

import os
import six
import re
import json
import argparse
import itertools
import logging
from functools import reduce

from propeller.types import RunConfig
from propeller.types import HParams

log = logging.getLogger(__name__)


def ArgumentParser(name):
    parser = argparse.ArgumentParser('propeller model')
    parser.add_argument('--run_config', type=str, default='')
    parser.add_argument(
        '--hparam', type=str, nargs='*', action='append', default=[['']])
    return parser


def _get_dict_from_environ_or_json_or_file(args, env_name):
    if args == '':
        return None
    if args is None:
        s = os.environ.get(env_name)
    else:
        s = args
        if os.path.exists(s):
            s = open(s).read()
    if isinstance(s, six.string_types):
        try:
            r = eval(s)
        except SyntaxError as e:
            raise ValueError('json parse error: %s \n>Got json: %s' %
                             (repr(e), s))
        return r
    else:
        return s  #None


def parse_file(filename):
    d = _get_dict_from_environ_or_json_or_file(filename, None)
    if d is None:
        raise ValueError('file(%s) not found' % filename)
    return d


def parse_runconfig(args=None):
    d = _get_dict_from_environ_or_json_or_file(args.run_config,
                                               'PROPELLER_RUNCONFIG')
    if d is None:
        raise ValueError('run_config not found')
    return RunConfig(**d)


def parse_hparam(args=None):
    if args is not None:
        hparam_strs = reduce(list.__add__, args.hparam)
    else:
        hparam_strs = [None]

    hparams = [
        _get_dict_from_environ_or_json_or_file(hp, 'PROPELLER_HPARAMS')
        for hp in hparam_strs
    ]
    hparams = [HParams(**h) for h in hparams if h is not None]
    if len(hparams) is None:
        raise ValueError('hparam not found')
    hparam = reduce(lambda x, y: x.join(y), hparams)
    return hparam


def flatten(s):
    assert is_struture(s)
    schema = [len(ss) for ss in s]
    flt = list(itertools.chain(*s))
    return flt, schema


def unflatten(structure, schema):
    start = 0
    res = []
    for _range in schema:
        res.append(structure[start:start + _range])
        start += _range
    return res


def is_struture(s):
    return isinstance(s, list) or isinstance(s, tuple)


def map_structure(func, s):
    if isinstance(s, list) or isinstance(s, tuple):
        return [map_structure(func, ss) for ss in s]
    elif isinstance(s, dict):
        return {k: map_structure(func, v) for k, v in six.iteritems(s)}
    else:
        return func(s)