module.py 13.3 KB
Newer Older
Z
Zeyu Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Z
Zeyu Chen 已提交
15 16
# coding=utf-8

Z
Zeyu Chen 已提交
17 18 19 20 21
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import paddle.fluid as fluid
Z
Zeyu Chen 已提交
22
import numpy as np
Z
Zeyu Chen 已提交
23 24
import tempfile
import os
W
wuzewu 已提交
25
import pickle
Z
Zeyu Chen 已提交
26

Z
Zeyu Chen 已提交
27
from collections import defaultdict
Z
Zeyu Chen 已提交
28
from paddle_hub.downloader import download_and_uncompress
Z
Zeyu Chen 已提交
29
from paddle_hub import module_desc_pb2
Z
Zeyu Chen 已提交
30

Z
Zeyu Chen 已提交
31
__all__ = ["Module", "ModuleConfig", "ModuleUtils"]
Z
Zeyu Chen 已提交
32
DICT_NAME = "dict.txt"
Z
Zeyu Chen 已提交
33
ASSETS_NAME = "assets"
Z
Zeyu Chen 已提交
34 35 36 37 38 39 40


def mkdir(path):
    """ the same as the shell command mkdir -p "
    """
    if not os.path.exists(path):
        os.makedirs(path)
Z
Zeyu Chen 已提交
41

Z
Zeyu Chen 已提交
42

Z
Zeyu Chen 已提交
43
class Module(object):
Z
Zeyu Chen 已提交
44 45 46 47
    """
    A module represents a
    """

Z
Zeyu Chen 已提交
48 49 50
    def __init__(self, module_url=None, module_dir=None):
        if module_url == None and module_dir == None:
            raise Exception("Module:module_url and module_dir are None!")
Z
Zeyu Chen 已提交
51 52 53

        self.module_dir = ""
        self.module_name = ""
Z
Zeyu Chen 已提交
54
        # donwload module
Z
Zeyu Chen 已提交
55
        if module_url is not None and module_url.startswith("http"):
Z
Zeyu Chen 已提交
56
            # if it's remote url link, then download and uncompress it
Z
Zeyu Chen 已提交
57 58
            self.module_name, self.module_dir = download_and_uncompress(
                module_url)
Z
Zeyu Chen 已提交
59
            #TODO(ZeyuChen): check url link is valid url
Z
Zeyu Chen 已提交
60
        elif module_dir is not None:
Z
Zeyu Chen 已提交
61
            # otherwise it's local path, no need to deal with it
Z
Zeyu Chen 已提交
62
            self.module_dir = module_dir
Z
Zeyu Chen 已提交
63
            # use the path name as module name by default
Z
Zeyu Chen 已提交
64
            self.module_name = module_dir.split("/")[-1]
Z
Zeyu Chen 已提交
65
            #TODO(ZeyuChen) add more check about loading module from local path
Z
Zeyu Chen 已提交
66 67 68

        # load paddle inference model
        place = fluid.CPUPlace()
Z
Zeyu Chen 已提交
69 70
        model_dir = os.path.join(self.module_dir, "model")
        print("model_dir", model_dir)
Z
Zeyu Chen 已提交
71 72 73
        self.exe = fluid.Executor(fluid.CPUPlace())
        [self.inference_program, self.feed_target_names,
         self.fetch_targets] = fluid.io.load_inference_model(
Z
Zeyu Chen 已提交
74
             dirname=model_dir, executor=self.exe)
Z
Zeyu Chen 已提交
75

76 77 78
        # remove feed fetch operator and variable
        ModuleUtils.remove_feed_fetch_op(self.inference_program)

Z
Zeyu Chen 已提交
79 80 81 82 83 84 85
        print("inference_program")
        print(self.inference_program)
        print("feed_target_names")
        print(self.feed_target_names)
        print("fetch_targets")
        print(self.fetch_targets)

Z
Zeyu Chen 已提交
86 87
        self.config = ModuleConfig(self.module_dir)
        self.config.load()
W
wuzewu 已提交
88 89 90 91
        self._process_parameter()

    def _process_parameter(self):
        global_block = self.inference_program.global_block()
W
wuzewu 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
        filepath = os.path.join(self.module_dir, "param.pkl")
        with open(filepath, "rb") as file:
            param_arr = pickle.load(file)
        for param in param_arr:
            if (param['name'] not in global_block.vars):
                continue
            var = global_block.var(param['name'])
            global_block.create_parameter(
                **param,
                shape=var.shape,
                dtype=var.dtype,
                type=var.type,
                lod_level=var.lod_level,
                error_clip=var.error_clip,
                stop_gradient=var.stop_gradient,
                is_data=var.is_data)
Z
Zeyu Chen 已提交
108

Z
Zeyu Chen 已提交
109 110
    def _construct_feed_dict(self, inputs):
        """ Construct feed dict according to user's inputs and module config.
Z
Zeyu Chen 已提交
111
        """
Z
Zeyu Chen 已提交
112 113 114 115
        feed_dict = {}
        for k in inputs:
            if k in self.feed_target_names:
                feed_dict[k] = inputs[k]
Z
Zeyu Chen 已提交
116

Z
Zeyu Chen 已提交
117 118 119 120 121 122 123 124
        return feed_dict

    def __call__(self, inputs=None, sign_name="default"):
        """ Call default signature and return results
        """
        # word_ids_lod_tensor = self._preprocess_input(inputs)
        feed_dict = self._construct_feed_dict(inputs)
        print("feed_dict", feed_dict)
Z
Zeyu Chen 已提交
125

Z
Zeyu Chen 已提交
126 127
        ret_numpy = self.config.return_numpy()
        print("ret_numpy", ret_numpy)
Z
Zeyu Chen 已提交
128 129
        results = self.exe.run(
            self.inference_program,
Z
Zeyu Chen 已提交
130 131
            #feed={self.feed_target_names[0]: word_ids_lod_tensor},
            feed=feed_dict,
Z
Zeyu Chen 已提交
132
            fetch_list=self.fetch_targets,
Z
Zeyu Chen 已提交
133
            return_numpy=ret_numpy)
Z
Zeyu Chen 已提交
134

Z
Zeyu Chen 已提交
135 136 137
        print("module fetch_target_names", self.feed_target_names)
        print("module fetch_targets", self.fetch_targets)
        np_result = np.array(results[0])
Z
Zeyu Chen 已提交
138 139 140 141

        return np_result

    def get_vars(self):
Z
Zeyu Chen 已提交
142 143 144
        """
        Return variable list of the module program
        """
Z
Zeyu Chen 已提交
145 146
        return self.inference_program.list_vars()

Z
Zeyu Chen 已提交
147
    def get_feed_var(self, key, signature="default"):
Z
Zeyu Chen 已提交
148 149 150
        """
        Get feed variable according to variable key and signature
        """
Z
Zeyu Chen 已提交
151
        for var in self.inference_program.list_vars():
Z
Zeyu Chen 已提交
152
            if var.name == self.config.feed_var_name(key, signature):
Z
Zeyu Chen 已提交
153 154
                return var

Z
Zeyu Chen 已提交
155 156
        raise Exception("Can't find input var {}".format(key))

Z
Zeyu Chen 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
    def get_feed_var_by_index(self, index, signature="default"):
        feed_vars = self.get_feed_vars(signature)
        assert index < len(
            feed_vars), "index out of range index {}, len {}".format(
                index, len(feed_vars))
        return feed_vars[index]

    def get_fetch_var_by_index(self, index, signature="default"):
        fetch_vars = self.get_fetch_vars(signature)
        assert index < len(
            fetch_vars), "index out of range index {}, len {}".format(
                index, len(fetch_vars))
        return fetch_vars[index]

    def get_feed_vars(self, signature="default"):
        """
        Get feed variable according to variable key and signature
        """
        feed_vars = []
        for feed_var in self.config.feed_var_names(signature):
            find_var = False
            for var in self.inference_program.list_vars():
                if var.name == feed_var.var_name:
                    feed_vars.append(var)
                    find_var = True
            if not find_var:
                raise Exception("Can't find feed var {}".format(feed_var_name))

        return feed_vars

    def get_fetch_vars(self, signature="default"):
        """
        Get feed variable according to variable key and signature
        """
        fetch_vars = []
        #TODO(ZeyuChen): use brute force to find variables, simple and easy to
        #understand
        for fetch_var in self.config.fetch_var_names(signature):
            find_var = False
            for var in self.inference_program.list_vars():
                if var.name == fetch_var.var_name:
                    fetch_vars.append(var)
                    find_var = True
            if not find_var:
                raise Exception("Can't find feed var {}".format(fetch_var_name))

        return fetch_vars

Z
Zeyu Chen 已提交
205
    def get_fetch_var(self, key, signature="default"):
Z
Zeyu Chen 已提交
206 207 208
        """
        Get fetch variable according to variable key and signature
        """
Z
Zeyu Chen 已提交
209
        for var in self.inference_program.list_vars():
Z
Zeyu Chen 已提交
210
            if var.name == self.config.fetch_var_name(key, signature):
Z
Zeyu Chen 已提交
211 212 213 214 215 216
                return var

    def get_inference_program(self):
        return self.inference_program

    # for text sequence input, transform to lod tensor as paddle graph's input
Z
Zeyu Chen 已提交
217
    def _preprocess_input(self, inputs):
Z
Zeyu Chen 已提交
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
        # words id mapping and dealing with oov
        # transform to lod tensor
        seq = []
        for s in inputs:
            seq.append(self._word_id_mapping(s))

        lod_tensor = self.seq2lod_tensor(seq)

        return lod_tensor

    def seq2lod_tensor(self, seq_inputs, place=fluid.CPUPlace()):
        """ sequence to lod tensor, need to determine which space"""
        lod = []
        lod.append([])
        for s in seq_inputs:
            # generate lod
            lod[0].append(len(s))

        # print("seq", seq_inputs)
        # print("lod", lod)

        lod_tensor = fluid.create_lod_tensor(seq_inputs, lod, place)

        return lod_tensor

    def _word_id_mapping(self, inputs):
Z
Zeyu Chen 已提交
244 245 246
        word_dict = self.config.get_dict()
        return list(map(lambda x: word_dict[x], inputs))

Z
Zeyu Chen 已提交
247

Z
Zeyu Chen 已提交
248
class ModuleConfig(object):
Z
Zeyu Chen 已提交
249
    def __init__(self, module_dir, module_name=None):
Z
Zeyu Chen 已提交
250 251
        # generate model desc protobuf
        self.module_dir = module_dir
Z
Zeyu Chen 已提交
252 253 254
        self.desc = module_desc_pb2.ModuleDesc()
        if module_name == None:
            module_name = module_dir.split("/")[-1]
Z
Zeyu Chen 已提交
255
        # initialize module config default value
Z
Zeyu Chen 已提交
256 257
        self.desc.name = module_name
        self.desc.contain_assets = True
Z
Zeyu Chen 已提交
258
        self.desc.return_numpy = False
Z
Zeyu Chen 已提交
259

Z
Zeyu Chen 已提交
260 261 262 263
        # init dict
        self.dict = defaultdict(int)
        self.dict.setdefault(0)

Z
Zeyu Chen 已提交
264 265 266 267
    def get_dict(self):
        """ Return dictionary in Module"""
        return self.dict

268
    def load(self):
Z
Zeyu Chen 已提交
269 270
        """
        Load module config from module directory.
Z
Zeyu Chen 已提交
271
        """
Z
Zeyu Chen 已提交
272
        #TODO(ZeyuChen): check module_desc.pb exsitance
273 274
        pb_path = os.path.join(self.module_dir, "module_desc.pb")
        with open(pb_path, "rb") as fi:
Z
Zeyu Chen 已提交
275 276
            self.desc.ParseFromString(fi.read())

W
wuzewu 已提交
277 278 279

#         print("self.desc.sign2var",
#               self.desc.sign2var["default"].feed_desc[0].var_name)
Z
Zeyu Chen 已提交
280

Z
Zeyu Chen 已提交
281 282
        if self.desc.contain_assets:
            # load assets
Z
Zeyu Chen 已提交
283
            assets_dir = os.path.join(self.module_dir, ASSETS_NAME)
Z
Zeyu Chen 已提交
284 285 286 287 288 289 290 291 292 293
            dict_path = os.path.join(assets_dir, DICT_NAME)
            word_id = 0

            with open(dict_path) as fi:
                words = fi.readlines()
                #TODO(ZeyuChen) check whether word id is duplicated and valid
                for line in fi:
                    w, w_id = line.split()
                    self.dict[w] = int(w_id)

Z
Zeyu Chen 已提交
294
    def dump(self):
Z
Zeyu Chen 已提交
295
        """ Save Module configure file to disk.
296
        """
Z
Zeyu Chen 已提交
297
        pb_path = os.path.join(self.module_dir, "module_desc.pb")
Z
Zeyu Chen 已提交
298 299 300 301
        with open(pb_path, "wb") as fo:
            fo.write(self.desc.SerializeToString())

        # save assets/dictionary
Z
Zeyu Chen 已提交
302
        assets_dir = os.path.join(self.module_dir, ASSETS_NAME)
Z
Zeyu Chen 已提交
303 304
        mkdir(assets_dir)
        with open(os.path.join(assets_dir, DICT_NAME), "w") as fo:
Z
Zeyu Chen 已提交
305 306
            for w in self.dict:
                w_id = self.dict[w]
Z
Zeyu Chen 已提交
307
                fo.write("{}\t{}\n".format(w, w_id))
Z
Zeyu Chen 已提交
308

Z
Zeyu Chen 已提交
309 310 311 312 313
    def return_numpy(self):
        """Return numpy or not according to the proto config.
        """
        return self.desc.return_numpy

Z
Zeyu Chen 已提交
314
    def save_dict(self, word_dict, dict_name=DICT_NAME):
Z
Zeyu Chen 已提交
315 316
        """ Save dictionary for NLP module
        """
Z
Zeyu Chen 已提交
317 318
        for w in word_dict:
            self.dict[w] = word_dict[w]
Z
Zeyu Chen 已提交
319

Z
Zeyu Chen 已提交
320
    def register_feed_signature(self, feed_desc, sign_name="default"):
Z
Zeyu Chen 已提交
321 322 323 324 325 326
        """ Register feed signature to the Module

        Args:
            fetch_desc: a dictionary of signature to input variable
            sign_name: signature name, use "default" as default signature
        """
Z
Zeyu Chen 已提交
327 328 329 330 331 332 333
        #TODO(ZeyuChen) check fetch_desc key is valid and no duplicated
        for k in feed_desc:
            feed = self.desc.sign2var[sign_name].feed_desc.add()
            feed.key = k
            feed.var_name = feed_desc[k]

    def register_fetch_signature(self, fetch_desc, sign_name="default"):
Z
Zeyu Chen 已提交
334 335 336 337 338 339
        """ Register fetch signature to the Module

        Args:
            fetch_desc: a dictionary of signature to input variable
            sign_name: signature name, use "default" as default signature
        """
Z
Zeyu Chen 已提交
340 341 342 343 344 345
        #TODO(ZeyuChen) check fetch_desc key is valid and no duplicated
        for k in fetch_desc:
            fetch = self.desc.sign2var[sign_name].fetch_desc.add()
            fetch.key = k
            fetch.var_name = fetch_desc[k]

Z
Zeyu Chen 已提交
346 347 348 349 350 351
    def feed_var_names(self, sign_name="default"):
        return self.desc.sign2var[sign_name].feed_desc

    def fetch_var_names(self, sign_name="default"):
        return self.desc.sign2var[sign_name].fetch_desc

Z
Zeyu Chen 已提交
352
    def feed_var_name(self, key, sign_name="default"):
Z
Zeyu Chen 已提交
353 354
        """get module's feed/input variable name
        """
Z
Zeyu Chen 已提交
355 356 357 358 359 360
        for desc in self.desc.sign2var[sign_name].feed_desc:
            if desc.key == key:
                return desc.var_name
        raise Exception("feed variable {} not found".format(key))

    def fetch_var_name(self, key, sign_name="default"):
Z
Zeyu Chen 已提交
361 362
        """get module's fetch/output variable name
        """
Z
Zeyu Chen 已提交
363 364 365 366 367
        for desc in self.desc.sign2var[sign_name].fetch_desc:
            if desc.key == key:
                return desc.var_name
        raise Exception("fetch variable {} not found".format(key))

Z
Zeyu Chen 已提交
368 369 370

class ModuleUtils(object):
    def __init__(self):
Z
Zeyu Chen 已提交
371
        pass
Z
Zeyu Chen 已提交
372

Z
Zeyu Chen 已提交
373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
    @staticmethod
    def remove_feed_fetch_op(program):
        """ remove feed and fetch operator and variable for fine-tuning
        """
        print("remove feed fetch op")
        block = program.global_block()
        need_to_remove_op_index = []
        for i, op in enumerate(block.ops):
            if op.type == "feed" or op.type == "fetch":
                need_to_remove_op_index.append(i)

        for index in need_to_remove_op_index[::-1]:
            block._remove_op(index)

        block._remove_var("feed")
        block._remove_var("fetch")

        program.desc.flush()
Z
Zeyu Chen 已提交
391 392 393
        # print("********************************")
        # print(program)
        # print("********************************")