utils.py 6.4 KB
Newer Older
W
wuzewu 已提交
1
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
W
wuzewu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
W
wuzewu 已提交
18

W
wuzewu 已提交
19
import os
Z
Zeyu Chen 已提交
20 21
import time
import multiprocessing
22
import hashlib
W
wuzewu 已提交
23

W
wuzewu 已提交
24 25 26 27 28 29
import paddle
import paddle.fluid as fluid

from paddlehub.module import module_desc_pb2
from paddlehub.common.logger import logger

W
wuzewu 已提交
30 31 32 33 34 35 36

def to_list(input):
    if not isinstance(input, list):
        if not isinstance(input, tuple):
            input = [input]

    return input
W
wuzewu 已提交
37 38


W
wuzewu 已提交
39 40 41 42 43
def mkdir(path):
    """ the same as the shell command mkdir -p "
    """
    if not os.path.exists(path):
        os.makedirs(path)
44 45


46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
def md5_of_file(file):
    md5 = hashlib.md5()
    with open(file, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            md5.update(chunk)

    return md5.hexdigest()


def md5(text):
    if isinstance(text, str):
        text = text.encode("utf8")
    md5 = hashlib.md5()
    md5.update(text)
    return md5.hexdigest()


63 64 65 66 67 68 69 70 71 72 73 74
def get_keyed_type_of_pyobj(pyobj):
    if isinstance(pyobj, bool):
        return module_desc_pb2.BOOLEAN
    elif isinstance(pyobj, int):
        return module_desc_pb2.INT
    elif isinstance(pyobj, str):
        return module_desc_pb2.STRING
    elif isinstance(pyobj, float):
        return module_desc_pb2.FLOAT
    return module_desc_pb2.STRING


W
wuzewu 已提交
75 76
def get_pykey(key, keyed_type):
    if keyed_type == module_desc_pb2.BOOLEAN:
W
wuzewu 已提交
77
        return key == "True"
W
wuzewu 已提交
78 79 80 81 82 83 84 85 86 87
    elif keyed_type == module_desc_pb2.INT:
        return int(key)
    elif keyed_type == module_desc_pb2.STRING:
        return str(key)
    elif keyed_type == module_desc_pb2.FLOAT:
        return float(key)
    return str(key)


#TODO(wuzewu): solving the problem of circular references
W
wuzewu 已提交
88
def from_pyobj_to_module_attr(pyobj, module_attr, obj_filter=None):
W
wuzewu 已提交
89 90
    if obj_filter and obj_filter(pyobj):
        return
91
    if isinstance(pyobj, bool):
W
wuzewu 已提交
92 93
        module_attr.type = module_desc_pb2.BOOLEAN
        module_attr.b = pyobj
94
    elif isinstance(pyobj, int):
W
wuzewu 已提交
95 96
        module_attr.type = module_desc_pb2.INT
        module_attr.i = pyobj
97
    elif isinstance(pyobj, str):
W
wuzewu 已提交
98 99
        module_attr.type = module_desc_pb2.STRING
        module_attr.s = pyobj
100
    elif isinstance(pyobj, float):
W
wuzewu 已提交
101 102
        module_attr.type = module_desc_pb2.FLOAT
        module_attr.f = pyobj
103
    elif isinstance(pyobj, list) or isinstance(pyobj, tuple):
W
wuzewu 已提交
104
        module_attr.type = module_desc_pb2.LIST
105
        for index, obj in enumerate(pyobj):
W
wuzewu 已提交
106 107
            from_pyobj_to_module_attr(obj, module_attr.list.data[str(index)],
                                      obj_filter)
108
    elif isinstance(pyobj, set):
W
wuzewu 已提交
109
        module_attr.type = module_desc_pb2.SET
110
        for index, obj in enumerate(list(pyobj)):
W
wuzewu 已提交
111 112
            from_pyobj_to_module_attr(obj, module_attr.set.data[str(index)],
                                      obj_filter)
113
    elif isinstance(pyobj, dict):
W
wuzewu 已提交
114
        module_attr.type = module_desc_pb2.MAP
115
        for key, value in pyobj.items():
W
wuzewu 已提交
116 117 118
            from_pyobj_to_module_attr(value, module_attr.map.data[str(key)],
                                      obj_filter)
            module_attr.map.key_type[str(key)] = get_keyed_type_of_pyobj(key)
119
    elif isinstance(pyobj, type(None)):
W
wuzewu 已提交
120
        module_attr.type = module_desc_pb2.NONE
121
    else:
W
wuzewu 已提交
122 123
        module_attr.type = module_desc_pb2.OBJECT
        module_attr.name = str(pyobj.__class__.__name__)
W
wuzewu 已提交
124 125
        if not hasattr(pyobj, "__dict__"):
            logger.warning(
W
wuzewu 已提交
126
                "python obj %s has not __dict__ attr" % module_attr.name)
W
wuzewu 已提交
127
            return
128
        for key, value in pyobj.__dict__.items():
W
wuzewu 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
            from_pyobj_to_module_attr(value, module_attr.object.data[str(key)],
                                      obj_filter)
            module_attr.object.key_type[str(key)] = get_keyed_type_of_pyobj(key)


def from_module_attr_to_pyobj(module_attr):
    if module_attr.type == module_desc_pb2.BOOLEAN:
        result = module_attr.b
    elif module_attr.type == module_desc_pb2.INT:
        result = module_attr.i
    elif module_attr.type == module_desc_pb2.STRING:
        result = module_attr.s
    elif module_attr.type == module_desc_pb2.FLOAT:
        result = module_attr.f
    elif module_attr.type == module_desc_pb2.LIST:
144
        result = []
W
wuzewu 已提交
145
        for index in range(len(module_attr.list.data)):
146
            result.append(
W
wuzewu 已提交
147 148
                from_module_attr_to_pyobj(module_attr.list.data[str(index)]))
    elif module_attr.type == module_desc_pb2.SET:
149
        result = set()
W
wuzewu 已提交
150
        for index in range(len(module_attr.set.data)):
151
            result.add(
W
wuzewu 已提交
152 153
                from_module_attr_to_pyobj(module_attr.set.data[str(index)]))
    elif module_attr.type == module_desc_pb2.MAP:
154
        result = {}
W
wuzewu 已提交
155 156 157 158
        for key, value in module_attr.map.data.items():
            key = get_pykey(key, module_attr.map.key_type[key])
            result[key] = from_module_attr_to_pyobj(value)
    elif module_attr.type == module_desc_pb2.NONE:
159
        result = None
W
wuzewu 已提交
160
    elif module_attr.type == module_desc_pb2.OBJECT:
161
        result = None
W
wuzewu 已提交
162
        logger.warning("can't tran module attr to python object")
163 164
    else:
        result = None
W
wuzewu 已提交
165
        logger.warning("unknown type of module attr")
166 167

    return result
W
wuzewu 已提交
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189


def check_path(path):
    pass


def check_url(url):
    pass


def get_file_ext(file_path):
    return os.path.splitext(file_path)[-1]


def is_csv_file(file_path):
    return get_file_ext(file_path) == ".csv"


def is_yaml_file(file_path):
    return get_file_ext(file_path) == ".yml"


Z
Zeyu Chen 已提交
190 191 192 193 194 195 196 197 198 199 200
def get_running_device_info(config):
    if config.use_cuda:
        place = fluid.CUDAPlace(0)
        dev_count = fluid.core.get_cuda_device_count()
    else:
        place = fluid.CPUPlace()
        dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))

    return place, dev_count


W
wuzewu 已提交
201 202 203 204 205
if __name__ == "__main__":
    print(is_yaml_file("test.yml"))
    print(is_csv_file("test.yml"))
    print(is_yaml_file("test.csv"))
    print(is_csv_file("test.csv"))