utils.py 5.4 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# coding=utf-8

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
20 21
from paddle_hub import module_desc_pb2
from paddle_hub.logger import logger
W
wuzewu 已提交
22 23
import paddle
import paddle.fluid as fluid
W
wuzewu 已提交
24
import os
W
wuzewu 已提交
25 26 27 28 29 30 31 32


def to_list(input):
    if not isinstance(input, list):
        if not isinstance(input, tuple):
            input = [input]

    return input
W
wuzewu 已提交
33 34


W
wuzewu 已提交
35 36 37 38 39
def mkdir(path):
    """ the same as the shell command mkdir -p "
    """
    if not os.path.exists(path):
        os.makedirs(path)
40 41 42 43 44 45 46 47 48 49 50 51 52 53


def get_keyed_type_of_pyobj(pyobj):
    if isinstance(pyobj, bool):
        return module_desc_pb2.BOOLEAN
    elif isinstance(pyobj, int):
        return module_desc_pb2.INT
    elif isinstance(pyobj, str):
        return module_desc_pb2.STRING
    elif isinstance(pyobj, float):
        return module_desc_pb2.FLOAT
    return module_desc_pb2.STRING


W
wuzewu 已提交
54 55
def get_pykey(key, keyed_type):
    if keyed_type == module_desc_pb2.BOOLEAN:
W
wuzewu 已提交
56
        return key == "True"
W
wuzewu 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70
    elif keyed_type == module_desc_pb2.INT:
        return int(key)
    elif keyed_type == module_desc_pb2.STRING:
        return str(key)
    elif keyed_type == module_desc_pb2.FLOAT:
        return float(key)
    return str(key)


#TODO(wuzewu): solving the problem of circular references
def from_pyobj_to_flexible_data(pyobj, flexible_data, obj_filter=None):
    if obj_filter and obj_filter(pyobj):
        logger.info("filter python object")
        return
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
    if isinstance(pyobj, bool):
        flexible_data.type = module_desc_pb2.BOOLEAN
        flexible_data.b = pyobj
    elif isinstance(pyobj, int):
        flexible_data.type = module_desc_pb2.INT
        flexible_data.i = pyobj
    elif isinstance(pyobj, str):
        flexible_data.type = module_desc_pb2.STRING
        flexible_data.s = pyobj
    elif isinstance(pyobj, float):
        flexible_data.type = module_desc_pb2.FLOAT
        flexible_data.f = pyobj
    elif isinstance(pyobj, list) or isinstance(pyobj, tuple):
        flexible_data.type = module_desc_pb2.LIST
        for index, obj in enumerate(pyobj):
W
wuzewu 已提交
86 87
            from_pyobj_to_flexible_data(
                obj, flexible_data.list.data[str(index)], obj_filter)
88 89 90
    elif isinstance(pyobj, set):
        flexible_data.type = module_desc_pb2.SET
        for index, obj in enumerate(list(pyobj)):
W
wuzewu 已提交
91 92
            from_pyobj_to_flexible_data(obj, flexible_data.set.data[str(index)],
                                        obj_filter)
93 94 95
    elif isinstance(pyobj, dict):
        flexible_data.type = module_desc_pb2.MAP
        for key, value in pyobj.items():
W
wuzewu 已提交
96 97
            from_pyobj_to_flexible_data(value, flexible_data.map.data[str(key)],
                                        obj_filter)
98 99 100 101 102 103
            flexible_data.map.keyType[str(key)] = get_keyed_type_of_pyobj(key)
    elif isinstance(pyobj, type(None)):
        flexible_data.type = module_desc_pb2.NONE
    else:
        flexible_data.type = module_desc_pb2.OBJECT
        flexible_data.name = str(pyobj.__class__.__name__)
W
wuzewu 已提交
104 105 106 107
        if not hasattr(pyobj, "__dict__"):
            logger.warning(
                "python obj %s has not __dict__ attr" % flexible_data.name)
            return
108
        for key, value in pyobj.__dict__.items():
W
wuzewu 已提交
109 110
            from_pyobj_to_flexible_data(
                value, flexible_data.object.data[str(key)], obj_filter)
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
            flexible_data.object.keyType[str(key)] = get_keyed_type_of_pyobj(
                key)


def from_flexible_data_to_pyobj(flexible_data):
    if flexible_data.type == module_desc_pb2.BOOLEAN:
        result = flexible_data.b
    elif flexible_data.type == module_desc_pb2.INT:
        result = flexible_data.i
    elif flexible_data.type == module_desc_pb2.STRING:
        result = flexible_data.s
    elif flexible_data.type == module_desc_pb2.FLOAT:
        result = flexible_data.f
    elif flexible_data.type == module_desc_pb2.LIST:
        result = []
        for index in range(len(flexible_data.list.data)):
            result.append(
W
wuzewu 已提交
128 129
                from_flexible_data_to_pyobj(
                    flexible_data.list.data[str(index)]))
130 131 132 133
    elif flexible_data.type == module_desc_pb2.SET:
        result = set()
        for index in range(len(flexible_data.set.data)):
            result.add(
W
wuzewu 已提交
134
                from_flexible_data_to_pyobj(flexible_data.set.data[str(index)]))
135 136 137
    elif flexible_data.type == module_desc_pb2.MAP:
        result = {}
        for key, value in flexible_data.map.data.items():
W
wuzewu 已提交
138
            key = get_pykey(key, flexible_data.map.keyType[key])
139 140 141 142 143 144 145 146 147 148 149
            result[key] = from_flexible_data_to_pyobj(value)
    elif flexible_data.type == module_desc_pb2.NONE:
        result = None
    elif flexible_data.type == module_desc_pb2.OBJECT:
        result = None
        logger.warning("can't tran flexible_data to python object")
    else:
        result = None
        logger.warning("unknown type of flexible_data")

    return result