module.py 10.3 KB
Newer Older
Z
Zeyu Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Z
Zeyu Chen 已提交
15 16
# coding=utf-8

Z
Zeyu Chen 已提交
17 18 19 20 21
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import paddle.fluid as fluid
Z
Zeyu Chen 已提交
22
import numpy as np
Z
Zeyu Chen 已提交
23 24
import tempfile
import os
Z
Zeyu Chen 已提交
25

Z
Zeyu Chen 已提交
26
from collections import defaultdict
Z
Zeyu Chen 已提交
27
from paddle_hub.downloader import download_and_uncompress
Z
Zeyu Chen 已提交
28
from paddle_hub import module_desc_pb2
Z
Zeyu Chen 已提交
29

Z
Zeyu Chen 已提交
30
__all__ = ["Module", "ModuleConfig", "ModuleUtils"]
Z
Zeyu Chen 已提交
31
DICT_NAME = "dict.txt"
Z
Zeyu Chen 已提交
32
ASSETS_NAME = "assets"
Z
Zeyu Chen 已提交
33 34 35 36 37 38 39


def mkdir(path):
    """ the same as the shell command mkdir -p "
    """
    if not os.path.exists(path):
        os.makedirs(path)
Z
Zeyu Chen 已提交
40

Z
Zeyu Chen 已提交
41

Z
Zeyu Chen 已提交
42
class Module(object):
Z
Zeyu Chen 已提交
43 44 45 46
    """
    A module represents a
    """

Z
Zeyu Chen 已提交
47 48 49
    def __init__(self, module_url=None, module_dir=None):
        if module_url == None and module_dir == None:
            raise Exception("Module:module_url and module_dir are None!")
Z
Zeyu Chen 已提交
50 51 52

        self.module_dir = ""
        self.module_name = ""
Z
Zeyu Chen 已提交
53
        # donwload module
Z
Zeyu Chen 已提交
54
        if module_url is not None and module_url.startswith("http"):
Z
Zeyu Chen 已提交
55
            # if it's remote url link, then download and uncompress it
Z
Zeyu Chen 已提交
56 57
            self.module_name, self.module_dir = download_and_uncompress(
                module_url)
Z
Zeyu Chen 已提交
58
            #TODO(ZeyuChen): check url link is valid url
Z
Zeyu Chen 已提交
59
        elif module_dir is not None:
Z
Zeyu Chen 已提交
60
            # otherwise it's local path, no need to deal with it
Z
Zeyu Chen 已提交
61
            self.module_dir = module_dir
Z
Zeyu Chen 已提交
62
            # use the path name as module name by default
Z
Zeyu Chen 已提交
63
            self.module_name = module_dir.split("/")[-1]
Z
Zeyu Chen 已提交
64
            #TODO(ZeyuChen) add more check about loading module from local path
Z
Zeyu Chen 已提交
65 66 67 68 69 70

        # load paddle inference model
        place = fluid.CPUPlace()
        self.exe = fluid.Executor(fluid.CPUPlace())
        [self.inference_program, self.feed_target_names,
         self.fetch_targets] = fluid.io.load_inference_model(
71
             dirname=self.module_dir, executor=self.exe)
Z
Zeyu Chen 已提交
72 73 74 75 76 77 78 79

        print("inference_program")
        print(self.inference_program)
        print("feed_target_names")
        print(self.feed_target_names)
        print("fetch_targets")
        print(self.fetch_targets)

Z
Zeyu Chen 已提交
80 81
        self.config = ModuleConfig(self.module_dir)
        self.config.load()
Z
Zeyu Chen 已提交
82

Z
Zeyu Chen 已提交
83 84
    def _construct_feed_dict(self, inputs):
        """ Construct feed dict according to user's inputs and module config.
Z
Zeyu Chen 已提交
85
        """
Z
Zeyu Chen 已提交
86 87 88 89
        feed_dict = {}
        for k in inputs:
            if k in self.feed_target_names:
                feed_dict[k] = inputs[k]
Z
Zeyu Chen 已提交
90

Z
Zeyu Chen 已提交
91 92 93 94 95 96 97 98
        return feed_dict

    def __call__(self, inputs=None, sign_name="default"):
        """ Call default signature and return results
        """
        # word_ids_lod_tensor = self._preprocess_input(inputs)
        feed_dict = self._construct_feed_dict(inputs)
        print("feed_dict", feed_dict)
Z
Zeyu Chen 已提交
99

Z
Zeyu Chen 已提交
100 101
        ret_numpy = self.config.return_numpy()
        print("ret_numpy", ret_numpy)
Z
Zeyu Chen 已提交
102 103
        results = self.exe.run(
            self.inference_program,
Z
Zeyu Chen 已提交
104 105
            #feed={self.feed_target_names[0]: word_ids_lod_tensor},
            feed=feed_dict,
Z
Zeyu Chen 已提交
106
            fetch_list=self.fetch_targets,
Z
Zeyu Chen 已提交
107
            return_numpy=ret_numpy)
Z
Zeyu Chen 已提交
108

Z
Zeyu Chen 已提交
109 110 111
        print("module fetch_target_names", self.feed_target_names)
        print("module fetch_targets", self.fetch_targets)
        np_result = np.array(results[0])
Z
Zeyu Chen 已提交
112 113 114 115

        return np_result

    def get_vars(self):
Z
Zeyu Chen 已提交
116 117 118
        """
        Return variable list of the module program
        """
Z
Zeyu Chen 已提交
119 120
        return self.inference_program.list_vars()

Z
Zeyu Chen 已提交
121
    def get_feed_var(self, key, signature="default"):
Z
Zeyu Chen 已提交
122 123 124
        """
        Get feed variable according to variable key and signature
        """
Z
Zeyu Chen 已提交
125
        for var in self.inference_program.list_vars():
Z
Zeyu Chen 已提交
126
            if var.name == self.config.feed_var_name(key, signature):
Z
Zeyu Chen 已提交
127 128
                return var

Z
Zeyu Chen 已提交
129 130 131
        raise Exception("Can't find input var {}".format(key))

    def get_fetch_var(self, key, signature="default"):
Z
Zeyu Chen 已提交
132 133 134
        """
        Get fetch variable according to variable key and signature
        """
Z
Zeyu Chen 已提交
135
        for var in self.inference_program.list_vars():
Z
Zeyu Chen 已提交
136
            if var.name == self.config.fetch_var_name(key, signature):
Z
Zeyu Chen 已提交
137 138
                return var

Z
Zeyu Chen 已提交
139 140
        raise Exception("Can't find output var {}".format(key))

Z
Zeyu Chen 已提交
141 142 143 144
    def get_inference_program(self):
        return self.inference_program

    # for text sequence input, transform to lod tensor as paddle graph's input
Z
Zeyu Chen 已提交
145
    def _preprocess_input(self, inputs):
Z
Zeyu Chen 已提交
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
        # words id mapping and dealing with oov
        # transform to lod tensor
        seq = []
        for s in inputs:
            seq.append(self._word_id_mapping(s))

        lod_tensor = self.seq2lod_tensor(seq)

        return lod_tensor

    def seq2lod_tensor(self, seq_inputs, place=fluid.CPUPlace()):
        """ sequence to lod tensor, need to determine which space"""
        lod = []
        lod.append([])
        for s in seq_inputs:
            # generate lod
            lod[0].append(len(s))

        # print("seq", seq_inputs)
        # print("lod", lod)

        lod_tensor = fluid.create_lod_tensor(seq_inputs, lod, place)

        return lod_tensor

    def _word_id_mapping(self, inputs):
Z
Zeyu Chen 已提交
172 173 174
        word_dict = self.config.get_dict()
        return list(map(lambda x: word_dict[x], inputs))

Z
Zeyu Chen 已提交
175

Z
Zeyu Chen 已提交
176
class ModuleConfig(object):
Z
Zeyu Chen 已提交
177
    def __init__(self, module_dir, module_name=None):
Z
Zeyu Chen 已提交
178 179
        # generate model desc protobuf
        self.module_dir = module_dir
Z
Zeyu Chen 已提交
180 181 182
        self.desc = module_desc_pb2.ModuleDesc()
        if module_name == None:
            module_name = module_dir.split("/")[-1]
Z
Zeyu Chen 已提交
183
        # initialize module config default value
Z
Zeyu Chen 已提交
184 185
        self.desc.name = module_name
        self.desc.contain_assets = True
Z
Zeyu Chen 已提交
186
        self.desc.return_numpy = False
Z
Zeyu Chen 已提交
187

Z
Zeyu Chen 已提交
188 189 190 191
        # init dict
        self.dict = defaultdict(int)
        self.dict.setdefault(0)

Z
Zeyu Chen 已提交
192 193 194 195
    def get_dict(self):
        """ Return dictionary in Module"""
        return self.dict

196
    def load(self):
Z
Zeyu Chen 已提交
197 198
        """
        Load module config from module directory.
Z
Zeyu Chen 已提交
199
        """
Z
Zeyu Chen 已提交
200
        #TODO(ZeyuChen): check module_desc.pb exsitance
201 202
        pb_path = os.path.join(self.module_dir, "module_desc.pb")
        with open(pb_path, "rb") as fi:
Z
Zeyu Chen 已提交
203 204 205 206
            self.desc.ParseFromString(fi.read())

        if self.desc.contain_assets:
            # load assets
Z
Zeyu Chen 已提交
207
            assets_dir = os.path.join(self.module_dir, ASSETS_NAME)
Z
Zeyu Chen 已提交
208 209 210 211 212 213 214 215 216 217
            dict_path = os.path.join(assets_dir, DICT_NAME)
            word_id = 0

            with open(dict_path) as fi:
                words = fi.readlines()
                #TODO(ZeyuChen) check whether word id is duplicated and valid
                for line in fi:
                    w, w_id = line.split()
                    self.dict[w] = int(w_id)

Z
Zeyu Chen 已提交
218
    def dump(self):
Z
Zeyu Chen 已提交
219
        """ Save Module configure file to disk.
220
        """
Z
Zeyu Chen 已提交
221
        pb_path = os.path.join(self.module_dir, "module_desc.pb")
Z
Zeyu Chen 已提交
222 223 224 225
        with open(pb_path, "wb") as fo:
            fo.write(self.desc.SerializeToString())

        # save assets/dictionary
Z
Zeyu Chen 已提交
226
        assets_dir = os.path.join(self.module_dir, ASSETS_NAME)
Z
Zeyu Chen 已提交
227 228
        mkdir(assets_dir)
        with open(os.path.join(assets_dir, DICT_NAME), "w") as fo:
Z
Zeyu Chen 已提交
229 230
            for w in self.dict:
                w_id = self.dict[w]
Z
Zeyu Chen 已提交
231
                fo.write("{}\t{}\n".format(w, w_id))
Z
Zeyu Chen 已提交
232

Z
Zeyu Chen 已提交
233 234 235 236 237
    def return_numpy(self):
        """Return numpy or not according to the proto config.
        """
        return self.desc.return_numpy

Z
Zeyu Chen 已提交
238
    def save_dict(self, word_dict, dict_name=DICT_NAME):
Z
Zeyu Chen 已提交
239 240
        """ Save dictionary for NLP module
        """
Z
Zeyu Chen 已提交
241 242
        for w in word_dict:
            self.dict[w] = word_dict[w]
Z
Zeyu Chen 已提交
243

Z
Zeyu Chen 已提交
244
    def register_feed_signature(self, feed_desc, sign_name="default"):
Z
Zeyu Chen 已提交
245 246 247 248 249 250
        """ Register feed signature to the Module

        Args:
            fetch_desc: a dictionary of signature to input variable
            sign_name: signature name, use "default" as default signature
        """
Z
Zeyu Chen 已提交
251 252 253 254 255 256 257
        #TODO(ZeyuChen) check fetch_desc key is valid and no duplicated
        for k in feed_desc:
            feed = self.desc.sign2var[sign_name].feed_desc.add()
            feed.key = k
            feed.var_name = feed_desc[k]

    def register_fetch_signature(self, fetch_desc, sign_name="default"):
Z
Zeyu Chen 已提交
258 259 260 261 262 263
        """ Register fetch signature to the Module

        Args:
            fetch_desc: a dictionary of signature to input variable
            sign_name: signature name, use "default" as default signature
        """
Z
Zeyu Chen 已提交
264 265 266 267 268 269 270
        #TODO(ZeyuChen) check fetch_desc key is valid and no duplicated
        for k in fetch_desc:
            fetch = self.desc.sign2var[sign_name].fetch_desc.add()
            fetch.key = k
            fetch.var_name = fetch_desc[k]

    def feed_var_name(self, key, sign_name="default"):
Z
Zeyu Chen 已提交
271 272
        """get module's feed/input variable name
        """
Z
Zeyu Chen 已提交
273 274 275 276 277 278
        for desc in self.desc.sign2var[sign_name].feed_desc:
            if desc.key == key:
                return desc.var_name
        raise Exception("feed variable {} not found".format(key))

    def fetch_var_name(self, key, sign_name="default"):
Z
Zeyu Chen 已提交
279 280
        """get module's fetch/output variable name
        """
Z
Zeyu Chen 已提交
281 282 283 284 285
        for desc in self.desc.sign2var[sign_name].fetch_desc:
            if desc.key == key:
                return desc.var_name
        raise Exception("fetch variable {} not found".format(key))

Z
Zeyu Chen 已提交
286 287 288

class ModuleUtils(object):
    def __init__(self):
Z
Zeyu Chen 已提交
289
        pass
Z
Zeyu Chen 已提交
290

Z
Zeyu Chen 已提交
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308
    @staticmethod
    def remove_feed_fetch_op(program):
        """ remove feed and fetch operator and variable for fine-tuning
        """
        print("remove feed fetch op")
        block = program.global_block()
        need_to_remove_op_index = []
        for i, op in enumerate(block.ops):
            if op.type == "feed" or op.type == "fetch":
                need_to_remove_op_index.append(i)

        for index in need_to_remove_op_index[::-1]:
            block._remove_op(index)

        block._remove_var("feed")
        block._remove_var("fetch")

        program.desc.flush()
Z
Zeyu Chen 已提交
309 310 311
        # print("********************************")
        # print(program)
        # print("********************************")