module.py 24.6 KB
Newer Older
S
Steffy-zxf 已提交
1
#coding:utf-8
W
wuzewu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
W
wuzewu 已提交
19 20 21 22 23

import os
import time
import sys
import functools
W
wuzewu 已提交
24 25 26
import inspect
import importlib
import tarfile
W
wuzewu 已提交
27
import six
W
wuzewu 已提交
28
import shutil
W
wuzewu 已提交
29 30 31 32

import paddle
import paddle.fluid as fluid

W
wuzewu 已提交
33 34
from paddlehub.common import utils
from paddlehub.common import paddle_helper
W
wuzewu 已提交
35
from paddlehub.common.dir import CACHE_HOME
S
shenyuhan 已提交
36
from paddlehub.common.lock import lock
W
wuzewu 已提交
37 38
from paddlehub.common.logger import logger
from paddlehub.common.hub_server import CacheUpdater
W
wuzewu 已提交
39
from paddlehub.common import tmp_dir
W
wuzewu 已提交
40
from paddlehub.common.downloader import progress
W
wuzewu 已提交
41 42
from paddlehub.module import module_desc_pb2
from paddlehub.module.manager import default_module_manager
W
wuzewu 已提交
43 44
from paddlehub.module.checker import ModuleChecker
from paddlehub.module.signature import Signature, create_signature
W
wuzewu 已提交
45
from paddlehub.module.base_processor import BaseProcessor
W
wuzewu 已提交
46
from paddlehub.io.parser import yaml_parser
W
wuzewu 已提交
47
from paddlehub import version
W
wuzewu 已提交
48

Z
Zeyu Chen 已提交
49
# PaddleHub module dir name
W
wuzewu 已提交
50 51 52 53 54
ASSETS_DIRNAME = "assets"
MODEL_DIRNAME = "model"
MODULE_DESC_PBNAME = "module_desc.pb"
PYTHON_DIR = "python"
PROCESSOR_NAME = "processor"
Z
Zeyu Chen 已提交
55
# PaddleHub var prefix
W
wuzewu 已提交
56
HUB_VAR_PREFIX = "@HUB_%s@"
W
wuzewu 已提交
57

W
wuzewu 已提交
58
_module_runnable_func = {}
W
wuzewu 已提交
59 60


W
wuzewu 已提交
61
def runnable(func):
W
wuzewu 已提交
62
    mod = func.__module__ + "." + inspect.stack()[1][3]
W
wuzewu 已提交
63
    _module_runnable_func[mod] = func.__name__
W
wuzewu 已提交
64 65 66 67 68 69 70

    def _wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return _wrapper


走神的阿圆's avatar
走神的阿圆 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83
_module_serving_func = {}


def serving(func):
    mod = func.__module__ + "." + inspect.stack()[1][3]
    _module_serving_func[mod] = func.__name__

    def _wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return _wrapper


W
wuzewu 已提交
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
def moduleinfo(name, version, author, author_email, summary, type):
    def _wrapper(cls):
        if not issubclass(cls, Module):
            raise RuntimeError
        cls._name = name
        cls._version = version
        cls._author = author
        cls._author_email = author_email
        cls._summary = summary
        cls._type = type
        return cls

    return _wrapper


W
wuzewu 已提交
99
class Module(object):
W
wuzewu 已提交
100 101 102 103 104 105
    def __new__(cls,
                name=None,
                directory=None,
                module_dir=None,
                version=None,
                **kwargs):
W
wuzewu 已提交
106 107
        if cls.__name__ == "Module":
            if name:
W
wuzewu 已提交
108 109
                module = cls.init_with_name(
                    name=name, version=version, **kwargs)
W
wuzewu 已提交
110
            elif directory:
W
wuzewu 已提交
111
                module = cls.init_with_directory(directory=directory, **kwargs)
W
wuzewu 已提交
112 113 114 115 116 117 118 119 120 121
            elif module_dir:
                logger.warning(
                    "Parameter module_dir is deprecated, please use directory to specify the path"
                )
                if isinstance(module_dir, list) or isinstance(
                        module_dir, tuple):
                    directory = module_dir[0]
                    version = module_dir[1]
                else:
                    directory = module_dir
W
wuzewu 已提交
122
                module = cls.init_with_directory(directory=directory, **kwargs)
走神的阿圆's avatar
走神的阿圆 已提交
123
            CacheUpdater("update_cache", module.name, module.version).start()
124
        else:
125 126
            if not name and not directory:
                directory = os.path.dirname(
W
wuzewu 已提交
127
                    os.path.abspath(sys.modules[cls.__module__].__file__))
128 129 130 131
                module = Module.init_with_directory(
                    directory=directory, **kwargs)
            else:
                module = object.__new__(cls)
132

W
wuzewu 已提交
133 134
        return module

W
wuzewu 已提交
135 136 137 138 139 140
    def __init__(self,
                 name=None,
                 directory=None,
                 module_dir=None,
                 version=None,
                 **kwargs):
141
        # Avoid module being initialized multiple times
142
        if "_is_initialize" in self.__dict__ and self._is_initialize:
W
wuzewu 已提交
143
            return
W
wuzewu 已提交
144

W
wuzewu 已提交
145
        mod = self.__class__.__module__ + "." + self.__class__.__name__
W
wuzewu 已提交
146 147
        if mod in _module_runnable_func:
            _run_func_name = _module_runnable_func[mod]
W
wuzewu 已提交
148 149 150
            self._run_func = getattr(self, _run_func_name)
        else:
            self._run_func = None
走神的阿圆's avatar
走神的阿圆 已提交
151
        self._serving_func_name = _module_serving_func.get(mod, None)
W
wuzewu 已提交
152 153
        self._code_version = "v2"
        self._directory = directory
W
wuzewu 已提交
154
        self._initialize(**kwargs)
155
        self._is_initialize = True
W
wuzewu 已提交
156
        self._code_version = "v2"
W
wuzewu 已提交
157 158

    @classmethod
W
wuzewu 已提交
159
    def init_with_name(cls, name, version=None, **kwargs):
W
wuzewu 已提交
160 161 162 163 164 165 166 167 168 169 170 171
        fp_lock = open(os.path.join(CACHE_HOME, name), "a")
        lock.flock(fp_lock, lock.LOCK_EX)
        log_msg = "Installing %s module" % name
        if version:
            log_msg += "-%s" % version
        logger.info(log_msg)
        extra = {"command": "install"}
        result, tips, module_dir = default_module_manager.install_module(
            module_name=name, module_version=version, extra=extra)
        if not result:
            logger.error(tips)
            raise RuntimeError(tips)
W
wuzewu 已提交
172

W
wuzewu 已提交
173 174
        logger.info(tips)
        lock.flock(fp_lock, lock.LOCK_UN)
W
wuzewu 已提交
175
        return cls.init_with_directory(directory=module_dir[0], **kwargs)
W
wuzewu 已提交
176

W
wuzewu 已提交
177
    @classmethod
W
wuzewu 已提交
178
    def init_with_directory(cls, directory, **kwargs):
W
wuzewu 已提交
179
        desc_file = os.path.join(directory, MODULE_DESC_PBNAME)
W
wuzewu 已提交
180 181 182 183 184
        if os.path.exists(desc_file):
            checker = ModuleChecker(directory)
            checker.check()
            return ModuleV1(directory=directory, **kwargs)

W
wuzewu 已提交
185 186
        if directory.endswith(os.sep):
            directory = directory[:-1]
W
wuzewu 已提交
187 188 189 190 191 192 193 194 195 196 197
        basename = os.path.split(directory)[-1]
        dirname = os.path.join(*list(os.path.split(directory)[:-1]))
        sys.path.insert(0, dirname)
        _module = importlib.import_module("{}.module".format(basename))
        for _item, _cls in inspect.getmembers(_module, inspect.isclass):
            _item = _module.__dict__[_item]
            if issubclass(_item, Module):
                user_module = _item(directory=directory, **kwargs)
                break
        sys.path.pop(0)
        return user_module
W
wuzewu 已提交
198

W
wuzewu 已提交
199 200 201 202
    @property
    def run_func(self):
        return self._run_func

W
wuzewu 已提交
203 204 205 206 207 208
    @property
    def directory(self):
        return self._directory

    @property
    def author(self):
W
wuzewu 已提交
209
        return self.__class__._author
W
wuzewu 已提交
210 211 212

    @property
    def author_email(self):
W
wuzewu 已提交
213
        return self.__class__._author_email
W
wuzewu 已提交
214 215 216

    @property
    def summary(self):
W
wuzewu 已提交
217
        return self.__class__._summary
W
wuzewu 已提交
218 219 220

    @property
    def type(self):
W
wuzewu 已提交
221
        return self.__class__._type
W
wuzewu 已提交
222 223 224

    @property
    def version(self):
W
wuzewu 已提交
225
        return self.__class__._version
W
wuzewu 已提交
226 227 228

    @property
    def name(self):
W
wuzewu 已提交
229
        return self.__class__._name
W
wuzewu 已提交
230

W
wuzewu 已提交
231 232 233 234
    @property
    def code_version(self):
        return self._code_version

W
wuzewu 已提交
235
    @property
W
wuzewu 已提交
236
    def is_runnable(self):
W
wuzewu 已提交
237
        return self._run_func != None
W
wuzewu 已提交
238

走神的阿圆's avatar
走神的阿圆 已提交
239 240 241 242
    @property
    def serving_func_name(self):
        return self._serving_func_name

W
wuzewu 已提交
243 244
    def _initialize(self):
        pass
W
wuzewu 已提交
245 246


247
class ModuleHelper(object):
W
wuzewu 已提交
248 249
    def __init__(self, directory):
        self.directory = directory
W
wuzewu 已提交
250 251

    def module_desc_path(self):
W
wuzewu 已提交
252
        return os.path.join(self.directory, MODULE_DESC_PBNAME)
W
wuzewu 已提交
253 254

    def model_path(self):
W
wuzewu 已提交
255
        return os.path.join(self.directory, MODEL_DIRNAME)
W
wuzewu 已提交
256

W
wuzewu 已提交
257
    def processor_path(self):
W
wuzewu 已提交
258
        return os.path.join(self.directory, PYTHON_DIR)
W
wuzewu 已提交
259 260 261 262 263

    def processor_name(self):
        return PROCESSOR_NAME

    def assets_path(self):
W
wuzewu 已提交
264
        return os.path.join(self.directory, ASSETS_DIRNAME)
W
wuzewu 已提交
265

W
wuzewu 已提交
266

W
wuzewu 已提交
267 268
class ModuleV1(Module):
    def __init__(self, name=None, directory=None, module_dir=None,
B
BinLong 已提交
269
                 version=None):
W
wuzewu 已提交
270 271 272
        if not directory:
            return
        super(ModuleV1, self).__init__(name, directory, module_dir, version)
W
wuzewu 已提交
273 274 275 276
        self.program = None
        self.assets = []
        self.helper = None
        self.signatures = {}
W
wuzewu 已提交
277
        self.default_signature = None
W
wuzewu 已提交
278
        self.processor = None
W
wuzewu 已提交
279
        self.extra_info = {}
W
wuzewu 已提交
280
        self._code_version = "v1"
W
wuzewu 已提交
281

W
wuzewu 已提交
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
        # parse desc
        self.module_desc_path = os.path.join(self.directory, MODULE_DESC_PBNAME)
        self._desc = module_desc_pb2.ModuleDesc()
        with open(self.module_desc_path, "rb") as file:
            self._desc.ParseFromString(file.read())

        module_info = self.desc.attr.map.data['module_info']
        self._name = utils.from_module_attr_to_pyobj(
            module_info.map.data['name'])
        self._author = utils.from_module_attr_to_pyobj(
            module_info.map.data['author'])
        self._author_email = utils.from_module_attr_to_pyobj(
            module_info.map.data['author_email'])
        self._version = utils.from_module_attr_to_pyobj(
            module_info.map.data['version'])
        self._type = utils.from_module_attr_to_pyobj(
            module_info.map.data['type'])
        self._summary = utils.from_module_attr_to_pyobj(
            module_info.map.data['summary'])

W
wuzewu 已提交
302 303 304 305 306 307
        # cache data
        self.last_call_name = None
        self.cache_feed_dict = None
        self.cache_fetch_dict = None
        self.cache_program = None

W
wuzewu 已提交
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
        self.helper = ModuleHelper(directory)
        exe = fluid.Executor(fluid.CPUPlace())
        self.program, _, _ = fluid.io.load_inference_model(
            self.helper.model_path(), executor=exe)
        for block in self.program.blocks:
            for op in block.ops:
                if "op_callstack" in op.all_attrs():
                    op._set_attr("op_callstack", [""])
        self._load_processor()
        self._load_assets()
        self._recover_from_desc()
        self._generate_sign_attr()
        self._generate_extra_info()
        self._restore_parameter(self.program)
        self._recover_variable_info(self.program)
W
wuzewu 已提交
323

W
wuzewu 已提交
324
    @property
走神的阿圆's avatar
走神的阿圆 已提交
325 326 327 328
    def serving_func_name(self):
        serving_func_name = self.desc.attr.map.data['default_signature'].s
        return serving_func_name if serving_func_name != "" else None

W
wuzewu 已提交
329
    @property
W
wuzewu 已提交
330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
    def desc(self):
        return self._desc

    @property
    def author(self):
        return self._author

    @property
    def author_email(self):
        return self._author_email

    @property
    def summary(self):
        return self._summary

    @property
    def type(self):
        return self._type

    @property
    def version(self):
        return self._version

    @property
    def name(self):
        return self._name

W
wuzewu 已提交
357 358 359 360 361
    def _dump_processor(self):
        import inspect
        pymodule = inspect.getmodule(self.processor)
        pycode = inspect.getsource(pymodule)
        processor_path = self.helper.processor_path()
362 363 364
        processor_md5 = utils.md5(pycode)
        processor_md5 += str(time.time())
        processor_name = utils.md5(processor_md5)
W
wuzewu 已提交
365 366 367 368
        output_file = os.path.join(processor_path, processor_name + ".py")
        utils.mkdir(processor_path)
        with open(output_file, "w") as file:
            file.write(pycode)
W
wuzewu 已提交
369 370
        utils.from_pyobj_to_module_attr(
            processor_name, self.desc.attr.map.data['processor_info'])
W
wuzewu 已提交
371

W
wuzewu 已提交
372 373
    def _load_processor(self):
        processor_path = self.helper.processor_path()
W
wuzewu 已提交
374 375
        if os.path.exists(processor_path):
            sys.path.append(processor_path)
W
wuzewu 已提交
376 377
            processor_name = utils.from_module_attr_to_pyobj(
                self.desc.attr.map.data['processor_info'])
W
wuzewu 已提交
378 379 380
            self.processor = __import__(processor_name).Processor(module=self)
        else:
            self.processor = None
W
wuzewu 已提交
381

W
wuzewu 已提交
382 383 384 385 386
    def _dump_assets(self):
        utils.mkdir(self.helper.assets_path())
        for asset in self.assets:
            filename = os.path.basename(asset)
            newfile = os.path.join(self.helper.assets_path(), filename)
W
wuzewu 已提交
387
            shutil.copyfile(asset, newfile)
W
wuzewu 已提交
388 389 390 391 392 393 394 395

    def _load_assets(self):
        assets_path = self.helper.assets_path()
        self.assets = []
        for file in os.listdir(assets_path):
            filepath = os.path.join(self.helper.assets_path(), file)
            self.assets.append(filepath)

Z
Zeyu Chen 已提交
396
    def _restore_parameter(self, program):
W
wuzewu 已提交
397
        global_block = program.global_block()
W
wuzewu 已提交
398
        param_attrs = self.desc.attr.map.data['param_attrs']
W
wuzewu 已提交
399
        for key, param_attr in param_attrs.map.data.items():
W
wuzewu 已提交
400
            param = paddle_helper.from_module_attr_to_param(param_attr)
401
            param['name'] = self.get_var_name_with_prefix(key)
W
wuzewu 已提交
402 403 404 405 406 407 408 409 410 411
            if (param['name'] not in global_block.vars):
                continue
            var = global_block.var(param['name'])
            global_block.create_parameter(
                shape=var.shape,
                dtype=var.dtype,
                type=var.type,
                lod_level=var.lod_level,
                error_clip=var.error_clip,
                stop_gradient=var.stop_gradient,
W
wuzewu 已提交
412 413
                is_data=var.is_data,
                **param)
W
wuzewu 已提交
414 415

    def _recover_variable_info(self, program):
W
wuzewu 已提交
416
        var_infos = self.desc.attr.map.data['var_infos']
W
wuzewu 已提交
417
        for var_info in var_infos.map.data:
W
wuzewu 已提交
418
            idx = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
419
                var_infos.map.data[var_info].map.data['block_id'])
W
wuzewu 已提交
420
            stop_gradient = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
421 422
                var_infos.map.data[var_info].map.data['stop_gradient'])
            block = program.blocks[idx]
423
            var_name = self.get_var_name_with_prefix(var_info)
W
wuzewu 已提交
424 425 426 427
            if var_name in block.vars:
                var = block.vars[var_name]
                var.stop_gradient = stop_gradient

W
wuzewu 已提交
428 429 430 431 432 433 434 435
    def get_extra_info(self, key):
        return self.extra_info.get(key, None)

    def _generate_extra_info(self):
        for key in self.extra_info:
            self.__dict__["get_%s" % key] = functools.partial(
                self.get_extra_info, key=key)

W
wuzewu 已提交
436 437 438
    def _generate_sign_attr(self):
        self._check_signatures()
        for sign in self.signatures:
W
wuzewu 已提交
439 440
            self.__dict__[sign] = functools.partial(
                self.__call__, sign_name=sign)
W
wuzewu 已提交
441

442 443 444 445
    def get_vocab_path(self):
        for assets_file in self.assets:
            if "vocab.txt" in assets_file:
                return assets_file
K
kinghuin 已提交
446 447 448 449 450 451 452 453 454 455 456 457 458
        return None

    def get_word_dict_path(self):
        for assets_file in self.assets:
            if "dict.wordseg.pickle" in assets_file:
                return assets_file
        return None

    def get_spm_path(self):
        for assets_file in self.assets:
            if "spm_cased_simp_sampled.model" in assets_file:
                return assets_file
        return None
459

W
wuzewu 已提交
460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483
    def _recover_from_desc(self):
        # recover signature
        for sign, module_var in self.desc.sign2var.items():
            inputs = []
            outputs = []
            feed_names = []
            fetch_names = []
            for var in module_var.feed_desc:
                variable = self.program.global_block().vars[var.var_name]
                inputs.append(variable)
                feed_names.append(var.alias)

            for var in module_var.fetch_desc:
                variable = self.program.global_block().vars[var.var_name]
                outputs.append(variable)
                fetch_names.append(var.alias)

            self.signatures[sign] = create_signature(
                sign,
                inputs=inputs,
                outputs=outputs,
                feed_names=feed_names,
                fetch_names=fetch_names)

W
wuzewu 已提交
484
        # recover default signature
W
wuzewu 已提交
485 486
        default_signature_name = utils.from_module_attr_to_pyobj(
            self.desc.attr.map.data['default_signature'])
W
wuzewu 已提交
487
        self.default_signature = self.signatures[
W
wuzewu 已提交
488
            default_signature_name].name if default_signature_name else None
W
wuzewu 已提交
489

W
wuzewu 已提交
490
        # recover module info
W
wuzewu 已提交
491
        module_info = self.desc.attr.map.data['module_info']
W
wuzewu 已提交
492
        self._name = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
493
            module_info.map.data['name'])
W
wuzewu 已提交
494
        self._author = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
495
            module_info.map.data['author'])
W
wuzewu 已提交
496
        self._author_email = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
497
            module_info.map.data['author_email'])
W
wuzewu 已提交
498
        self._version = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
499
            module_info.map.data['version'])
W
wuzewu 已提交
500
        self._type = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
501
            module_info.map.data['type'])
W
wuzewu 已提交
502
        self._summary = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
503 504
            module_info.map.data['summary'])

W
wuzewu 已提交
505 506 507 508 509 510
        # recover extra info
        extra_info = self.desc.attr.map.data['extra_info']
        self.extra_info = {}
        for key, value in extra_info.map.data.items():
            self.extra_info[key] = utils.from_module_attr_to_pyobj(value)

W
wuzewu 已提交
511
        # recover name prefix
W
wuzewu 已提交
512
        self._name_prefix = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
513
            self.desc.attr.map.data["name_prefix"])
W
wuzewu 已提交
514

515
    def __call__(self, sign_name, data, use_gpu=False, batch_size=1, **kwargs):
W
wuzewu 已提交
516 517
        self.check_processor()

W
wuzewu 已提交
518
        def _get_reader_and_feeder(data_format, data, place):
W
wuzewu 已提交
519
            def _reader(process_data):
W
wuzewu 已提交
520 521 522 523 524 525 526 527 528
                for item in zip(*process_data):
                    yield item

            process_data = []
            feed_name_list = []
            for key in data_format:
                process_data.append([value['processed'] for value in data[key]])
                feed_name_list.append(data_format[key]['feed_key'])
            feeder = fluid.DataFeeder(feed_list=feed_name_list, place=place)
W
wuzewu 已提交
529
            return functools.partial(_reader, process_data=process_data), feeder
W
wuzewu 已提交
530

W
wuzewu 已提交
531 532 533 534 535 536 537 538
        if self.last_call_name != sign_name:
            self.last_call_name = sign_name
            self.cache_feed_dict, self.cache_fetch_dict, self.cache_program = self.context(
                sign_name, for_test=True)
        feed_dict = self.cache_feed_dict
        fetch_dict = self.cache_fetch_dict
        program = self.cache_program

W
wuzewu 已提交
539 540
        fetch_list = list(set([value for key, value in fetch_dict.items()]))
        with fluid.program_guard(program):
W
wuzewu 已提交
541 542
            result = []
            index = 0
543 544 545 546 547 548 549
            try:
                _places = os.environ["CUDA_VISIBLE_DEVICES"]
                int(_places[0])
            except:
                use_gpu = False

            place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
550

W
wuzewu 已提交
551
            exe = fluid.Executor(place=place)
W
wuzewu 已提交
552 553 554 555
            data = self.processor.preprocess(
                sign_name=sign_name, data_dict=data)
            data_format = self.processor.data_format(sign_name=sign_name)
            reader, feeder = _get_reader_and_feeder(data_format, data, place)
556
            reader = paddle.batch(reader, batch_size=batch_size)
W
wuzewu 已提交
557 558 559 560 561
            for batch in reader():
                data_out = exe.run(
                    feed=feeder.feed(batch),
                    fetch_list=fetch_list,
                    return_numpy=False)
W
wuzewu 已提交
562 563 564 565 566 567 568 569 570
                sub_data = {
                    key: value[index:index + len(batch)]
                    for key, value in data.items()
                }
                result += self.processor.postprocess(sign_name, data_out,
                                                     sub_data, **kwargs)
                index += len(batch)

        return result
W
wuzewu 已提交
571

W
wuzewu 已提交
572
    def check_processor(self):
W
wuzewu 已提交
573 574
        if not self.processor:
            raise ValueError("This Module is not callable!")
W
wuzewu 已提交
575

W
wuzewu 已提交
576
    @property
W
wuzewu 已提交
577
    def is_runnable(self):
W
wuzewu 已提交
578 579
        return self.default_signature != None

走神的阿圆's avatar
走神的阿圆 已提交
580 581 582 583
    @property
    def code_version(self):
        return self._code_version

W
wuzewu 已提交
584
    def context(self,
585
                sign_name=None,
W
wuzewu 已提交
586
                for_test=False,
Z
Zeyu Chen 已提交
587
                trainable=True,
W
wuzewu 已提交
588
                regularizer=None,
589
                max_seq_len=128,
W
wuzewu 已提交
590
                learning_rate=1e-3):
591 592 593 594 595
        """
        Args:
            max_seq_len(int): maximum sequence length, this option is only
            available for BERT/ERNIE module
        """
W
wuzewu 已提交
596

597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625
        if sign_name:
            if sign_name not in self.signatures:
                raise KeyError(
                    "Module did not have a signature with name %s" % sign_name)
            signature = self.signatures[sign_name]
        else:
            inputs = [
                input for signature in self.signatures.values()
                for input in signature.inputs
            ]
            outputs = [
                output for signature in self.signatures.values()
                for output in signature.outputs
            ]
            feed_names = [
                feed_name for signature in self.signatures.values()
                for feed_name in signature.feed_names
            ]
            fetch_names = [
                fetch_name for signature in self.signatures.values()
                for fetch_name in signature.fetch_names
            ]
            signature = create_signature(
                name="hub_temp_signature",
                inputs=inputs,
                outputs=outputs,
                feed_names=feed_names,
                fetch_names=fetch_names,
                for_predict=False)
W
wuzewu 已提交
626

W
wuzewu 已提交
627
        program = self.program.clone(for_test=for_test)
W
wuzewu 已提交
628
        paddle_helper.remove_feed_fetch_op(program)
W
wuzewu 已提交
629 630

        if not for_test:
W
wuzewu 已提交
631
            paddle_helper.set_parameter_trainable(program, trainable)
W
wuzewu 已提交
632

W
wuzewu 已提交
633
            paddle_helper.set_parameter_learning_rate(program, learning_rate)
W
wuzewu 已提交
634

W
wuzewu 已提交
635
            paddle_helper.set_parameter_regularizer(program, regularizer)
W
wuzewu 已提交
636

Z
Zeyu Chen 已提交
637
            self._restore_parameter(program)
W
wuzewu 已提交
638

W
wuzewu 已提交
639 640
        self._recover_variable_info(program)

W
wuzewu 已提交
641
        paddle_helper.set_op_attr(program, is_test=for_test)
W
wuzewu 已提交
642 643 644 645 646 647 648 649 650 651 652 653 654 655
        feed_dict = {}
        fetch_dict = {}
        for index, var in enumerate(signature.inputs):
            feed_dict[index] = program.global_block().var(var.name)
            key = signature.feed_names[index]
            if key:
                feed_dict[key] = program.global_block().var(var.name)

        for index, var in enumerate(signature.outputs):
            fetch_dict[index] = program.global_block().var(var.name)
            key = signature.fetch_names[index]
            if key:
                fetch_dict[key] = program.global_block().var(var.name)

656
        # update BERT/ERNIE's input tensor's sequence length to max_seq_len
K
kinghuin 已提交
657
        if "bert" in self.name or self.name.startswith("ernie"):
658 659 660 661
            MAX_SEQ_LENGTH = 512
            if max_seq_len > MAX_SEQ_LENGTH or max_seq_len <= 0:
                raise ValueError(
                    "max_seq_len({}) should be in the range of [1, {}]".format(
K
kinghuin 已提交
662
                        max_seq_len, MAX_SEQ_LENGTH))
663
            logger.info(
664
                "Set maximum sequence length of input tensor to {}".format(
665
                    max_seq_len))
666 667
            if self.name.startswith("ernie_v2"):
                feed_list = [
668 669
                    "input_ids", "position_ids", "segment_ids", "input_mask",
                    "task_ids"
670 671 672 673 674 675
                ]
            else:
                feed_list = [
                    "input_ids", "position_ids", "segment_ids", "input_mask"
                ]
            for tensor_name in feed_list:
676 677 678 679 680 681 682
                seq_tensor_shape = [-1, max_seq_len, 1]
                logger.info("The shape of input tensor[{}] set to {}".format(
                    tensor_name, seq_tensor_shape))
                program.global_block().var(
                    feed_dict[tensor_name].name).desc.set_shape(
                        seq_tensor_shape)

683 684
        # record num parameters loaded by paddlehub
        num_param_loaded = 0
W
wuzewu 已提交
685
        for param in program.global_block().iter_parameters():
686 687 688
            num_param_loaded += 1
        logger.info(
            "%d pretrained paramaters loaded by PaddleHub" % num_param_loaded)
W
wuzewu 已提交
689

W
wuzewu 已提交
690 691
        return feed_dict, fetch_dict, program

692
    def get_name_prefix(self):
W
wuzewu 已提交
693
        return self._name_prefix
694 695 696 697

    def get_var_name_with_prefix(self, var_name):
        return self.get_name_prefix() + var_name

W
wuzewu 已提交
698
    def _check_signatures(self):
W
wuzewu 已提交
699 700
        if not self.signatures:
            raise ValueError("Signatures should not be None")
W
wuzewu 已提交
701 702

        for key, sign in self.signatures.items():
W
wuzewu 已提交
703 704 705 706
            if not isinstance(sign, Signature):
                raise TypeError(
                    "Item in Signatures shoule be an instance of paddlehub.Signature"
                )
W
wuzewu 已提交
707 708 709

            for input in sign.inputs:
                _tmp_program = input.block.program
W
wuzewu 已提交
710 711 712 713
                if not self.program == _tmp_program:
                    raise ValueError(
                        "All input and outputs variables in signature should come from the same Program"
                    )
W
wuzewu 已提交
714 715 716

            for output in sign.outputs:
                _tmp_program = output.block.program
W
wuzewu 已提交
717 718 719 720
                if not self.program == _tmp_program:
                    raise ValueError(
                        "All input and outputs variables in signature should come from the same Program"
                    )