module.py 25.0 KB
Newer Older
S
Steffy-zxf 已提交
1
#coding:utf-8
W
wuzewu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
W
wuzewu 已提交
19 20 21 22 23

import os
import time
import sys
import functools
W
wuzewu 已提交
24 25 26
import inspect
import importlib
import tarfile
W
wuzewu 已提交
27
import six
W
wuzewu 已提交
28
import shutil
W
wuzewu 已提交
29 30 31 32

import paddle
import paddle.fluid as fluid

W
wuzewu 已提交
33 34
from paddlehub.common import utils
from paddlehub.common import paddle_helper
W
wuzewu 已提交
35
from paddlehub.common.dir import CACHE_HOME
S
shenyuhan 已提交
36
from paddlehub.common.lock import lock
W
wuzewu 已提交
37 38
from paddlehub.common.logger import logger
from paddlehub.common.hub_server import CacheUpdater
W
wuzewu 已提交
39
from paddlehub.common import tmp_dir
W
wuzewu 已提交
40
from paddlehub.common.downloader import progress
W
wuzewu 已提交
41 42
from paddlehub.module import module_desc_pb2
from paddlehub.module.manager import default_module_manager
W
wuzewu 已提交
43 44
from paddlehub.module.checker import ModuleChecker
from paddlehub.module.signature import Signature, create_signature
W
wuzewu 已提交
45
from paddlehub.module.base_processor import BaseProcessor
W
wuzewu 已提交
46
from paddlehub.io.parser import yaml_parser
W
wuzewu 已提交
47
from paddlehub import version
W
wuzewu 已提交
48

Z
Zeyu Chen 已提交
49
# PaddleHub module dir name
W
wuzewu 已提交
50 51 52 53 54
ASSETS_DIRNAME = "assets"
MODEL_DIRNAME = "model"
MODULE_DESC_PBNAME = "module_desc.pb"
PYTHON_DIR = "python"
PROCESSOR_NAME = "processor"
Z
Zeyu Chen 已提交
55
# PaddleHub var prefix
W
wuzewu 已提交
56
HUB_VAR_PREFIX = "@HUB_%s@"
W
wuzewu 已提交
57 58 59 60 61 62
# PaddleHub Module package suffix
HUB_PACKAGE_SUFFIX = "phm"


def create_module(directory, name, author, email, module_type, summary,
                  version):
W
wuzewu 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
    save_file = "{}-{}.{}".format(name, version, HUB_PACKAGE_SUFFIX)

    with tmp_dir() as base_dir:
        # package the module
        with tarfile.open(save_file, "w:gz") as tar:
            module_dir = os.path.join(base_dir, name)
            shutil.copytree(directory, module_dir)

            # record module info and serialize
            desc = module_desc_pb2.ModuleDesc()
            attr = desc.attr
            attr.type = module_desc_pb2.MAP
            module_info = attr.map.data['module_info']
            module_info.type = module_desc_pb2.MAP
            utils.from_pyobj_to_module_attr(name, module_info.map.data['name'])
            utils.from_pyobj_to_module_attr(author,
                                            module_info.map.data['author'])
            utils.from_pyobj_to_module_attr(
                email, module_info.map.data['author_email'])
            utils.from_pyobj_to_module_attr(module_type,
                                            module_info.map.data['type'])
            utils.from_pyobj_to_module_attr(summary,
                                            module_info.map.data['summary'])
            utils.from_pyobj_to_module_attr(version,
                                            module_info.map.data['version'])
            module_desc_path = os.path.join(module_dir, "module_desc.pb")
            with open(module_desc_path, "wb") as f:
                f.write(desc.SerializeToString())

            # generate check info
            checker = ModuleChecker(module_dir)
            checker.generate_check_info()

            # add __init__
            module_init = os.path.join(module_dir, "__init__.py")
            with open(module_init, "a") as file:
                file.write("")

            _cwd = os.getcwd()
            os.chdir(base_dir)
W
wuzewu 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
            module_dir = module_dir.replace(base_dir, ".")
            tar.add(module_dir, recursive=False)
            files = []
            for dirname, _, subfiles in os.walk(module_dir):
                for file in subfiles:
                    files.append(os.path.join(dirname, file))

            total_length = len(files)
            print("Create Module {}-{}".format(name, version))
            for index, file in enumerate(files):
                done = int(float(index) / total_length * 50)
                progress("[%-50s] %.2f%%" % ('=' * done,
                                             float(index / total_length * 100)))
                tar.add(file)
            progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
            print("Module package saved as {}".format(save_file))
W
wuzewu 已提交
119
            os.chdir(_cwd)
W
wuzewu 已提交
120 121


W
wuzewu 已提交
122 123 124 125
_module_runable_func = {}


def runable(func):
W
wuzewu 已提交
126
    mod = func.__module__ + "." + inspect.stack()[1][3]
W
wuzewu 已提交
127 128 129 130 131 132 133 134
    _module_runable_func[mod] = func.__name__

    def _wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return _wrapper


W
wuzewu 已提交
135 136
class Module(object):

137 138 139
    _record = {}

    def __new__(cls, name=None, directory=None, module_dir=None, version=None):
W
wuzewu 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
        if cls.__name__ == "Module":
            if name:
                module = cls.init_with_name(name=name, version=version)
            elif directory:
                module = cls.init_with_directory(directory=directory)
            elif module_dir:
                logger.warning(
                    "Parameter module_dir is deprecated, please use directory to specify the path"
                )
                if isinstance(module_dir, list) or isinstance(
                        module_dir, tuple):
                    directory = module_dir[0]
                    version = module_dir[1]
                else:
                    directory = module_dir
                module = cls.init_with_directory(directory=directory)
走神的阿圆's avatar
走神的阿圆 已提交
156
            CacheUpdater("update_cache", module.name, module.version).start()
157 158 159
        else:
            module = object.__new__(cls)

W
wuzewu 已提交
160 161 162 163
        return module

    def __init__(self, name=None, directory=None, module_dir=None,
                 version=None):
164 165
        # Avoid module being initialized multiple times
        if not directory or id(self) in Module._record:
W
wuzewu 已提交
166
            return
167
        Module._record[id(self)] = True
W
wuzewu 已提交
168

W
wuzewu 已提交
169 170 171
        mod = self.__class__.__module__ + "." + self.__class__.__name__
        if mod in _module_runable_func:
            _run_func_name = _module_runable_func[mod]
W
wuzewu 已提交
172 173 174
            self._run_func = getattr(self, _run_func_name)
        else:
            self._run_func = None
W
wuzewu 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
        self._code_version = "v2"
        self._directory = directory
        self.module_desc_path = os.path.join(self.directory, MODULE_DESC_PBNAME)
        self._desc = module_desc_pb2.ModuleDesc()
        with open(self.module_desc_path, "rb") as file:
            self._desc.ParseFromString(file.read())

        module_info = self.desc.attr.map.data['module_info']
        self._name = utils.from_module_attr_to_pyobj(
            module_info.map.data['name'])
        self._author = utils.from_module_attr_to_pyobj(
            module_info.map.data['author'])
        self._author_email = utils.from_module_attr_to_pyobj(
            module_info.map.data['author_email'])
        self._version = utils.from_module_attr_to_pyobj(
            module_info.map.data['version'])
        self._type = utils.from_module_attr_to_pyobj(
            module_info.map.data['type'])
        self._summary = utils.from_module_attr_to_pyobj(
            module_info.map.data['summary'])

        self._initialize()

    @classmethod
    def init_with_name(cls, name, version=None):
        fp_lock = open(os.path.join(CACHE_HOME, name), "a")
        lock.flock(fp_lock, lock.LOCK_EX)
        log_msg = "Installing %s module" % name
        if version:
            log_msg += "-%s" % version
        logger.info(log_msg)
        extra = {"command": "install"}
        result, tips, module_dir = default_module_manager.install_module(
            module_name=name, module_version=version, extra=extra)
        if not result:
            logger.error(tips)
            raise RuntimeError(tips)
W
wuzewu 已提交
212

W
wuzewu 已提交
213 214 215
        logger.info(tips)
        lock.flock(fp_lock, lock.LOCK_UN)
        return cls.init_with_directory(directory=module_dir[0])
W
wuzewu 已提交
216

W
wuzewu 已提交
217 218 219 220 221 222 223 224 225 226 227
    @classmethod
    def init_with_directory(cls, directory):
        desc_file = os.path.join(directory, MODULE_DESC_PBNAME)
        checker = ModuleChecker(directory)
        checker.check()

        module_code_version = checker.module_code_version
        if module_code_version == "v2":
            basename = os.path.split(directory)[-1]
            dirname = os.path.join(*list(os.path.split(directory)[:-1]))
            sys.path.append(dirname)
W
wuzewu 已提交
228 229
            user_module = importlib.import_module("{}.module".format(basename))
            return user_module.HubModule(directory=directory)
W
wuzewu 已提交
230 231
        return ModuleV1(directory=directory)

W
wuzewu 已提交
232 233 234 235
    @property
    def run_func(self):
        return self._run_func

W
wuzewu 已提交
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
    @property
    def desc(self):
        return self._desc

    @property
    def directory(self):
        return self._directory

    @property
    def author(self):
        return self._author

    @property
    def author_email(self):
        return self._author_email

    @property
    def summary(self):
        return self._summary

    @property
    def type(self):
        return self._type

    @property
    def version(self):
        return self._version

    @property
    def name(self):
        return self._name

    @property
    def code_version(self):
        return self._code_version

    @property
    def is_runable(self):
W
wuzewu 已提交
274
        return self._run_func != None
W
wuzewu 已提交
275 276 277

    def _initialize(self):
        pass
W
wuzewu 已提交
278 279


280
class ModuleHelper(object):
W
wuzewu 已提交
281 282
    def __init__(self, directory):
        self.directory = directory
W
wuzewu 已提交
283 284

    def module_desc_path(self):
W
wuzewu 已提交
285
        return os.path.join(self.directory, MODULE_DESC_PBNAME)
W
wuzewu 已提交
286 287

    def model_path(self):
W
wuzewu 已提交
288
        return os.path.join(self.directory, MODEL_DIRNAME)
W
wuzewu 已提交
289

W
wuzewu 已提交
290
    def processor_path(self):
W
wuzewu 已提交
291
        return os.path.join(self.directory, PYTHON_DIR)
W
wuzewu 已提交
292 293 294 295 296

    def processor_name(self):
        return PROCESSOR_NAME

    def assets_path(self):
W
wuzewu 已提交
297
        return os.path.join(self.directory, ASSETS_DIRNAME)
W
wuzewu 已提交
298

W
wuzewu 已提交
299

W
wuzewu 已提交
300 301
class ModuleV1(Module):
    def __init__(self, name=None, directory=None, module_dir=None,
B
BinLong 已提交
302
                 version=None):
W
wuzewu 已提交
303 304 305 306
        if not directory:
            return
        super(ModuleV1, self).__init__(name, directory, module_dir, version)
        self._code_version = "v1"
W
wuzewu 已提交
307 308 309 310
        self.program = None
        self.assets = []
        self.helper = None
        self.signatures = {}
W
wuzewu 已提交
311
        self.default_signature = None
W
wuzewu 已提交
312
        self.processor = None
W
wuzewu 已提交
313
        self.extra_info = {}
W
wuzewu 已提交
314 315 316 317 318 319 320

        # cache data
        self.last_call_name = None
        self.cache_feed_dict = None
        self.cache_fetch_dict = None
        self.cache_program = None

W
wuzewu 已提交
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
        self.helper = ModuleHelper(directory)
        exe = fluid.Executor(fluid.CPUPlace())
        self.program, _, _ = fluid.io.load_inference_model(
            self.helper.model_path(), executor=exe)
        for block in self.program.blocks:
            for op in block.ops:
                if "op_callstack" in op.all_attrs():
                    op._set_attr("op_callstack", [""])
        self._load_processor()
        self._load_assets()
        self._recover_from_desc()
        self._generate_sign_attr()
        self._generate_extra_info()
        self._restore_parameter(self.program)
        self._recover_variable_info(self.program)
W
wuzewu 已提交
336

W
wuzewu 已提交
337 338 339 340 341
    def _dump_processor(self):
        import inspect
        pymodule = inspect.getmodule(self.processor)
        pycode = inspect.getsource(pymodule)
        processor_path = self.helper.processor_path()
342 343 344
        processor_md5 = utils.md5(pycode)
        processor_md5 += str(time.time())
        processor_name = utils.md5(processor_md5)
W
wuzewu 已提交
345 346 347 348
        output_file = os.path.join(processor_path, processor_name + ".py")
        utils.mkdir(processor_path)
        with open(output_file, "w") as file:
            file.write(pycode)
W
wuzewu 已提交
349 350
        utils.from_pyobj_to_module_attr(
            processor_name, self.desc.attr.map.data['processor_info'])
W
wuzewu 已提交
351

W
wuzewu 已提交
352 353
    def _load_processor(self):
        processor_path = self.helper.processor_path()
W
wuzewu 已提交
354 355
        if os.path.exists(processor_path):
            sys.path.append(processor_path)
W
wuzewu 已提交
356 357
            processor_name = utils.from_module_attr_to_pyobj(
                self.desc.attr.map.data['processor_info'])
W
wuzewu 已提交
358 359 360
            self.processor = __import__(processor_name).Processor(module=self)
        else:
            self.processor = None
W
wuzewu 已提交
361

W
wuzewu 已提交
362 363 364 365 366
    def _dump_assets(self):
        utils.mkdir(self.helper.assets_path())
        for asset in self.assets:
            filename = os.path.basename(asset)
            newfile = os.path.join(self.helper.assets_path(), filename)
W
wuzewu 已提交
367
            shutil.copyfile(asset, newfile)
W
wuzewu 已提交
368 369 370 371 372 373 374 375

    def _load_assets(self):
        assets_path = self.helper.assets_path()
        self.assets = []
        for file in os.listdir(assets_path):
            filepath = os.path.join(self.helper.assets_path(), file)
            self.assets.append(filepath)

Z
Zeyu Chen 已提交
376
    def _restore_parameter(self, program):
W
wuzewu 已提交
377
        global_block = program.global_block()
W
wuzewu 已提交
378
        param_attrs = self.desc.attr.map.data['param_attrs']
W
wuzewu 已提交
379
        for key, param_attr in param_attrs.map.data.items():
W
wuzewu 已提交
380
            param = paddle_helper.from_module_attr_to_param(param_attr)
381
            param['name'] = self.get_var_name_with_prefix(key)
W
wuzewu 已提交
382 383 384 385 386 387 388 389 390 391
            if (param['name'] not in global_block.vars):
                continue
            var = global_block.var(param['name'])
            global_block.create_parameter(
                shape=var.shape,
                dtype=var.dtype,
                type=var.type,
                lod_level=var.lod_level,
                error_clip=var.error_clip,
                stop_gradient=var.stop_gradient,
W
wuzewu 已提交
392 393
                is_data=var.is_data,
                **param)
W
wuzewu 已提交
394 395

    def _recover_variable_info(self, program):
W
wuzewu 已提交
396
        var_infos = self.desc.attr.map.data['var_infos']
W
wuzewu 已提交
397
        for var_info in var_infos.map.data:
W
wuzewu 已提交
398
            idx = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
399
                var_infos.map.data[var_info].map.data['block_id'])
W
wuzewu 已提交
400
            stop_gradient = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
401 402
                var_infos.map.data[var_info].map.data['stop_gradient'])
            block = program.blocks[idx]
403
            var_name = self.get_var_name_with_prefix(var_info)
W
wuzewu 已提交
404 405 406 407
            if var_name in block.vars:
                var = block.vars[var_name]
                var.stop_gradient = stop_gradient

W
wuzewu 已提交
408 409 410 411 412 413 414 415
    def get_extra_info(self, key):
        return self.extra_info.get(key, None)

    def _generate_extra_info(self):
        for key in self.extra_info:
            self.__dict__["get_%s" % key] = functools.partial(
                self.get_extra_info, key=key)

W
wuzewu 已提交
416 417 418
    def _generate_sign_attr(self):
        self._check_signatures()
        for sign in self.signatures:
W
wuzewu 已提交
419 420
            self.__dict__[sign] = functools.partial(
                self.__call__, sign_name=sign)
W
wuzewu 已提交
421

422 423 424 425
    def get_vocab_path(self):
        for assets_file in self.assets:
            if "vocab.txt" in assets_file:
                return assets_file
K
kinghuin 已提交
426 427 428 429 430 431 432 433 434 435 436 437 438
        return None

    def get_word_dict_path(self):
        for assets_file in self.assets:
            if "dict.wordseg.pickle" in assets_file:
                return assets_file
        return None

    def get_spm_path(self):
        for assets_file in self.assets:
            if "spm_cased_simp_sampled.model" in assets_file:
                return assets_file
        return None
439

W
wuzewu 已提交
440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463
    def _recover_from_desc(self):
        # recover signature
        for sign, module_var in self.desc.sign2var.items():
            inputs = []
            outputs = []
            feed_names = []
            fetch_names = []
            for var in module_var.feed_desc:
                variable = self.program.global_block().vars[var.var_name]
                inputs.append(variable)
                feed_names.append(var.alias)

            for var in module_var.fetch_desc:
                variable = self.program.global_block().vars[var.var_name]
                outputs.append(variable)
                fetch_names.append(var.alias)

            self.signatures[sign] = create_signature(
                sign,
                inputs=inputs,
                outputs=outputs,
                feed_names=feed_names,
                fetch_names=fetch_names)

W
wuzewu 已提交
464
        # recover default signature
W
wuzewu 已提交
465 466
        default_signature_name = utils.from_module_attr_to_pyobj(
            self.desc.attr.map.data['default_signature'])
W
wuzewu 已提交
467
        self.default_signature = self.signatures[
W
wuzewu 已提交
468
            default_signature_name].name if default_signature_name else None
W
wuzewu 已提交
469

W
wuzewu 已提交
470
        # recover module info
W
wuzewu 已提交
471
        module_info = self.desc.attr.map.data['module_info']
W
wuzewu 已提交
472
        self._name = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
473
            module_info.map.data['name'])
W
wuzewu 已提交
474
        self._author = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
475
            module_info.map.data['author'])
W
wuzewu 已提交
476
        self._author_email = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
477
            module_info.map.data['author_email'])
W
wuzewu 已提交
478
        self._version = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
479
            module_info.map.data['version'])
W
wuzewu 已提交
480
        self._type = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
481
            module_info.map.data['type'])
W
wuzewu 已提交
482
        self._summary = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
483 484
            module_info.map.data['summary'])

W
wuzewu 已提交
485 486 487 488 489 490
        # recover extra info
        extra_info = self.desc.attr.map.data['extra_info']
        self.extra_info = {}
        for key, value in extra_info.map.data.items():
            self.extra_info[key] = utils.from_module_attr_to_pyobj(value)

W
wuzewu 已提交
491
        # recover name prefix
W
wuzewu 已提交
492
        self._name_prefix = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
493
            self.desc.attr.map.data["name_prefix"])
W
wuzewu 已提交
494

495
    def __call__(self, sign_name, data, use_gpu=False, batch_size=1, **kwargs):
W
wuzewu 已提交
496 497
        self.check_processor()

W
wuzewu 已提交
498
        def _get_reader_and_feeder(data_format, data, place):
W
wuzewu 已提交
499
            def _reader(process_data):
W
wuzewu 已提交
500 501 502 503 504 505 506 507 508
                for item in zip(*process_data):
                    yield item

            process_data = []
            feed_name_list = []
            for key in data_format:
                process_data.append([value['processed'] for value in data[key]])
                feed_name_list.append(data_format[key]['feed_key'])
            feeder = fluid.DataFeeder(feed_list=feed_name_list, place=place)
W
wuzewu 已提交
509
            return functools.partial(_reader, process_data=process_data), feeder
W
wuzewu 已提交
510

W
wuzewu 已提交
511 512 513 514 515 516 517 518
        if self.last_call_name != sign_name:
            self.last_call_name = sign_name
            self.cache_feed_dict, self.cache_fetch_dict, self.cache_program = self.context(
                sign_name, for_test=True)
        feed_dict = self.cache_feed_dict
        fetch_dict = self.cache_fetch_dict
        program = self.cache_program

W
wuzewu 已提交
519 520
        fetch_list = list(set([value for key, value in fetch_dict.items()]))
        with fluid.program_guard(program):
W
wuzewu 已提交
521 522
            result = []
            index = 0
523 524 525 526 527 528 529
            try:
                _places = os.environ["CUDA_VISIBLE_DEVICES"]
                int(_places[0])
            except:
                use_gpu = False

            place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
530

W
wuzewu 已提交
531
            exe = fluid.Executor(place=place)
W
wuzewu 已提交
532 533 534 535
            data = self.processor.preprocess(
                sign_name=sign_name, data_dict=data)
            data_format = self.processor.data_format(sign_name=sign_name)
            reader, feeder = _get_reader_and_feeder(data_format, data, place)
536
            reader = paddle.batch(reader, batch_size=batch_size)
W
wuzewu 已提交
537 538 539 540 541
            for batch in reader():
                data_out = exe.run(
                    feed=feeder.feed(batch),
                    fetch_list=fetch_list,
                    return_numpy=False)
W
wuzewu 已提交
542 543 544 545 546 547 548 549 550
                sub_data = {
                    key: value[index:index + len(batch)]
                    for key, value in data.items()
                }
                result += self.processor.postprocess(sign_name, data_out,
                                                     sub_data, **kwargs)
                index += len(batch)

        return result
W
wuzewu 已提交
551

W
wuzewu 已提交
552
    def check_processor(self):
W
wuzewu 已提交
553 554
        if not self.processor:
            raise ValueError("This Module is not callable!")
W
wuzewu 已提交
555

W
wuzewu 已提交
556 557 558 559
    @property
    def is_runable(self):
        return self.default_signature != None

W
wuzewu 已提交
560
    def context(self,
561
                sign_name=None,
W
wuzewu 已提交
562
                for_test=False,
Z
Zeyu Chen 已提交
563
                trainable=True,
W
wuzewu 已提交
564
                regularizer=None,
565
                max_seq_len=128,
W
wuzewu 已提交
566
                learning_rate=1e-3):
567 568 569 570 571
        """
        Args:
            max_seq_len(int): maximum sequence length, this option is only
            available for BERT/ERNIE module
        """
W
wuzewu 已提交
572

573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601
        if sign_name:
            if sign_name not in self.signatures:
                raise KeyError(
                    "Module did not have a signature with name %s" % sign_name)
            signature = self.signatures[sign_name]
        else:
            inputs = [
                input for signature in self.signatures.values()
                for input in signature.inputs
            ]
            outputs = [
                output for signature in self.signatures.values()
                for output in signature.outputs
            ]
            feed_names = [
                feed_name for signature in self.signatures.values()
                for feed_name in signature.feed_names
            ]
            fetch_names = [
                fetch_name for signature in self.signatures.values()
                for fetch_name in signature.fetch_names
            ]
            signature = create_signature(
                name="hub_temp_signature",
                inputs=inputs,
                outputs=outputs,
                feed_names=feed_names,
                fetch_names=fetch_names,
                for_predict=False)
W
wuzewu 已提交
602

W
wuzewu 已提交
603
        program = self.program.clone(for_test=for_test)
W
wuzewu 已提交
604
        paddle_helper.remove_feed_fetch_op(program)
W
wuzewu 已提交
605 606

        if not for_test:
W
wuzewu 已提交
607
            paddle_helper.set_parameter_trainable(program, trainable)
W
wuzewu 已提交
608

W
wuzewu 已提交
609
            paddle_helper.set_parameter_learning_rate(program, learning_rate)
W
wuzewu 已提交
610

W
wuzewu 已提交
611
            paddle_helper.set_parameter_regularizer(program, regularizer)
W
wuzewu 已提交
612

Z
Zeyu Chen 已提交
613
            self._restore_parameter(program)
W
wuzewu 已提交
614

W
wuzewu 已提交
615 616
        self._recover_variable_info(program)

W
wuzewu 已提交
617
        paddle_helper.set_op_attr(program, is_test=for_test)
W
wuzewu 已提交
618 619 620 621 622 623 624 625 626 627 628 629 630 631
        feed_dict = {}
        fetch_dict = {}
        for index, var in enumerate(signature.inputs):
            feed_dict[index] = program.global_block().var(var.name)
            key = signature.feed_names[index]
            if key:
                feed_dict[key] = program.global_block().var(var.name)

        for index, var in enumerate(signature.outputs):
            fetch_dict[index] = program.global_block().var(var.name)
            key = signature.fetch_names[index]
            if key:
                fetch_dict[key] = program.global_block().var(var.name)

632
        # update BERT/ERNIE's input tensor's sequence length to max_seq_len
K
kinghuin 已提交
633
        if "bert" in self.name or self.name.startswith("ernie"):
634 635 636 637
            MAX_SEQ_LENGTH = 512
            if max_seq_len > MAX_SEQ_LENGTH or max_seq_len <= 0:
                raise ValueError(
                    "max_seq_len({}) should be in the range of [1, {}]".format(
K
kinghuin 已提交
638
                        max_seq_len, MAX_SEQ_LENGTH))
639
            logger.info(
640
                "Set maximum sequence length of input tensor to {}".format(
641
                    max_seq_len))
642 643
            if self.name.startswith("ernie_v2"):
                feed_list = [
644 645
                    "input_ids", "position_ids", "segment_ids", "input_mask",
                    "task_ids"
646 647 648 649 650 651
                ]
            else:
                feed_list = [
                    "input_ids", "position_ids", "segment_ids", "input_mask"
                ]
            for tensor_name in feed_list:
652 653 654 655 656 657 658
                seq_tensor_shape = [-1, max_seq_len, 1]
                logger.info("The shape of input tensor[{}] set to {}".format(
                    tensor_name, seq_tensor_shape))
                program.global_block().var(
                    feed_dict[tensor_name].name).desc.set_shape(
                        seq_tensor_shape)

659 660
        # record num parameters loaded by paddlehub
        num_param_loaded = 0
W
wuzewu 已提交
661
        for param in program.global_block().iter_parameters():
662 663 664
            num_param_loaded += 1
        logger.info(
            "%d pretrained paramaters loaded by PaddleHub" % num_param_loaded)
W
wuzewu 已提交
665

W
wuzewu 已提交
666 667
        return feed_dict, fetch_dict, program

668
    def get_name_prefix(self):
W
wuzewu 已提交
669
        return self._name_prefix
670 671 672 673

    def get_var_name_with_prefix(self, var_name):
        return self.get_name_prefix() + var_name

W
wuzewu 已提交
674
    def _check_signatures(self):
W
wuzewu 已提交
675 676
        if not self.signatures:
            raise ValueError("Signatures should not be None")
W
wuzewu 已提交
677 678

        for key, sign in self.signatures.items():
W
wuzewu 已提交
679 680 681 682
            if not isinstance(sign, Signature):
                raise TypeError(
                    "Item in Signatures shoule be an instance of paddlehub.Signature"
                )
W
wuzewu 已提交
683 684 685

            for input in sign.inputs:
                _tmp_program = input.block.program
W
wuzewu 已提交
686 687 688 689
                if not self.program == _tmp_program:
                    raise ValueError(
                        "All input and outputs variables in signature should come from the same Program"
                    )
W
wuzewu 已提交
690 691 692

            for output in sign.outputs:
                _tmp_program = output.block.program
W
wuzewu 已提交
693 694 695 696
                if not self.program == _tmp_program:
                    raise ValueError(
                        "All input and outputs variables in signature should come from the same Program"
                    )