basic_helper.py 3.5 KB
Newer Older
X
xixiaoyao 已提交
1 2 3 4 5
# coding=utf-8
import os
import json
import yaml
from config_helper import PDConfig
X
xixiaoyao 已提交
6
from paddle import fluid
X
xixiaoyao 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117

def get_basename(f):
    return os.path.splitext(f)[0]


def get_suffix(f):
    return os.path.splitext(f)[-1]


def parse_yaml(f, asdict=True, support_cmd_line=False):
    assert os.path.exists(f), "file {} not found.".format(f)
    if support_cmd_line:
        args = PDConfig(yaml_file=f, fuse_args=True)
        args.build()
        return args.asdict() if asdict else args
    else:
        if asdict:
            with open(f, "r") as fin: 
                yaml_config = yaml.load(fin, Loader=yaml.SafeLoader)
            return yaml_config
        else:
            raise NotImplementedError()


def parse_json(f, asdict=True, support_cmd_line=False):
    assert os.path.exists(f), "file {} not found.".format(f)
    if support_cmd_line:
        args = PDConfig(json_file=f, fuse_args=support_cmd_line)
        args.build()
        return args.asdict() if asdict else args
    else:
        if asdict:
            with open(f, "r") as fin: 
                config = json.load(fin)
            return config
        else:
            raise NotImplementedError()
            

def parse_list(string, astype=str):
    assert isinstance(string, str), "{} is not a string.".format(string)
    if ',' not in string:
        return [astype(string)]
    string = string.replace(',', ' ')
    return [astype(i) for i in string.split()]


def try_float(s):
    try:
        float(s)
        return(float(s))
    except:
        return s


# TODO: 增加None机制,允许hidden size、batch size和seqlen设置为None
def check_io(in_attr, out_attr, strict=False, in_name="left", out_name="right"):
    for name, attr in in_attr.items():
        assert name in out_attr, in_name+': '+name+' not found in '+out_name
        if attr != out_attr[name]:
            if strict:
                raise ValueError(name+': shape or dtype not consistent!')
            else:
                logging.warning('{}: shape or dtype not consistent!\n{}:\n{}\n{}:\n{}'.format(name, in_name, attr, out_name, out_attr[name]))


def encode_inputs(inputs, scope_name, sep='/', cand_set=None):
    outputs = {}
    for k, v in inputs.items():
        if cand_set is not None:
            if k in cand_set:
                outputs[k] = v
            if scope_name+sep+k in cand_set:
                outputs[scope_name+sep+k] = v
        else:
            outputs[scope_name+sep+k] = v
    return outputs


def decode_inputs(inputs, scope_name, sep='/', keep_unk_keys=True):
    outputs = {}
    for name, value in inputs.items():
        # var for backbone are also available to tasks
        if keep_unk_keys and sep not in name:
            outputs[name] = value
        # var for this inst
        if name.startswith(scope_name+'/'):
            outputs[name[len(scope_name+'/'):]] = value
    return outputs


def build_executor(on_gpu):
    if on_gpu:
        place = fluid.CUDAPlace(0)
        # dev_count = fluid.core.get_cuda_device_count()
    else:
        place = fluid.CPUPlace()
        # dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
    # return fluid.Executor(place), dev_count
    return fluid.Executor(place)


def fit_attr(conf, fit_attr, strict=False):
    for i, attr in fit_attr.items():
        if i not in conf:
            if strict:
                raise Exception('Argument {} is required to create a controller.'.format(i))
            else:
                continue
        conf[i] = attr(conf[i])
    return conf