module.py 25.1 KB
Newer Older
S
Steffy-zxf 已提交
1
#coding:utf-8
W
wuzewu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
W
wuzewu 已提交
19 20 21 22 23

import os
import time
import sys
import functools
W
wuzewu 已提交
24 25 26
import inspect
import importlib
import tarfile
W
wuzewu 已提交
27
import six
W
wuzewu 已提交
28
import shutil
W
wuzewu 已提交
29 30 31 32

import paddle
import paddle.fluid as fluid

W
wuzewu 已提交
33 34
from paddlehub.common import utils
from paddlehub.common import paddle_helper
W
wuzewu 已提交
35
from paddlehub.common.dir import CACHE_HOME
S
shenyuhan 已提交
36
from paddlehub.common.lock import lock
W
wuzewu 已提交
37 38
from paddlehub.common.logger import logger
from paddlehub.common.hub_server import CacheUpdater
W
wuzewu 已提交
39
from paddlehub.common import tmp_dir
W
wuzewu 已提交
40
from paddlehub.common.downloader import progress
W
wuzewu 已提交
41 42
from paddlehub.module import module_desc_pb2
from paddlehub.module.manager import default_module_manager
W
wuzewu 已提交
43 44
from paddlehub.module.checker import ModuleChecker
from paddlehub.module.signature import Signature, create_signature
W
wuzewu 已提交
45
from paddlehub.module.base_processor import BaseProcessor
W
wuzewu 已提交
46
from paddlehub.io.parser import yaml_parser
W
wuzewu 已提交
47
from paddlehub import version
W
wuzewu 已提交
48

Z
Zeyu Chen 已提交
49
# PaddleHub module dir name
W
wuzewu 已提交
50 51 52 53 54
ASSETS_DIRNAME = "assets"
MODEL_DIRNAME = "model"
MODULE_DESC_PBNAME = "module_desc.pb"
PYTHON_DIR = "python"
PROCESSOR_NAME = "processor"
Z
Zeyu Chen 已提交
55
# PaddleHub var prefix
W
wuzewu 已提交
56
HUB_VAR_PREFIX = "@HUB_%s@"
W
wuzewu 已提交
57 58 59 60 61 62
# PaddleHub Module package suffix
HUB_PACKAGE_SUFFIX = "phm"


def create_module(directory, name, author, email, module_type, summary,
                  version):
W
wuzewu 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
    save_file = "{}-{}.{}".format(name, version, HUB_PACKAGE_SUFFIX)

    with tmp_dir() as base_dir:
        # package the module
        with tarfile.open(save_file, "w:gz") as tar:
            module_dir = os.path.join(base_dir, name)
            shutil.copytree(directory, module_dir)

            # record module info and serialize
            desc = module_desc_pb2.ModuleDesc()
            attr = desc.attr
            attr.type = module_desc_pb2.MAP
            module_info = attr.map.data['module_info']
            module_info.type = module_desc_pb2.MAP
            utils.from_pyobj_to_module_attr(name, module_info.map.data['name'])
            utils.from_pyobj_to_module_attr(author,
                                            module_info.map.data['author'])
            utils.from_pyobj_to_module_attr(
                email, module_info.map.data['author_email'])
            utils.from_pyobj_to_module_attr(module_type,
                                            module_info.map.data['type'])
            utils.from_pyobj_to_module_attr(summary,
                                            module_info.map.data['summary'])
            utils.from_pyobj_to_module_attr(version,
                                            module_info.map.data['version'])
            module_desc_path = os.path.join(module_dir, "module_desc.pb")
            with open(module_desc_path, "wb") as f:
                f.write(desc.SerializeToString())

            # generate check info
            checker = ModuleChecker(module_dir)
            checker.generate_check_info()

            # add __init__
            module_init = os.path.join(module_dir, "__init__.py")
            with open(module_init, "a") as file:
                file.write("")

            _cwd = os.getcwd()
            os.chdir(base_dir)
W
wuzewu 已提交
103 104 105 106 107
            module_dir = module_dir.replace(base_dir, ".")
            tar.add(module_dir, recursive=False)
            files = []
            for dirname, _, subfiles in os.walk(module_dir):
                for file in subfiles:
W
wuzewu 已提交
108 109
                    #                     if file.startswith("."):
                    #                         continue
W
wuzewu 已提交
110 111 112 113 114 115 116 117 118 119 120
                    files.append(os.path.join(dirname, file))

            total_length = len(files)
            print("Create Module {}-{}".format(name, version))
            for index, file in enumerate(files):
                done = int(float(index) / total_length * 50)
                progress("[%-50s] %.2f%%" % ('=' * done,
                                             float(index / total_length * 100)))
                tar.add(file)
            progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
            print("Module package saved as {}".format(save_file))
W
wuzewu 已提交
121
            os.chdir(_cwd)
W
wuzewu 已提交
122 123


W
wuzewu 已提交
124
_module_runnable_func = {}
W
wuzewu 已提交
125 126


W
wuzewu 已提交
127
def runnable(func):
W
wuzewu 已提交
128
    mod = func.__module__ + "." + inspect.stack()[1][3]
W
wuzewu 已提交
129
    _module_runnable_func[mod] = func.__name__
W
wuzewu 已提交
130 131 132 133 134 135 136

    def _wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return _wrapper


W
wuzewu 已提交
137 138
class Module(object):

139 140 141
    _record = {}

    def __new__(cls, name=None, directory=None, module_dir=None, version=None):
W
wuzewu 已提交
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
        if cls.__name__ == "Module":
            if name:
                module = cls.init_with_name(name=name, version=version)
            elif directory:
                module = cls.init_with_directory(directory=directory)
            elif module_dir:
                logger.warning(
                    "Parameter module_dir is deprecated, please use directory to specify the path"
                )
                if isinstance(module_dir, list) or isinstance(
                        module_dir, tuple):
                    directory = module_dir[0]
                    version = module_dir[1]
                else:
                    directory = module_dir
                module = cls.init_with_directory(directory=directory)
走神的阿圆's avatar
走神的阿圆 已提交
158
            CacheUpdater("update_cache", module.name, module.version).start()
159 160 161
        else:
            module = object.__new__(cls)

W
wuzewu 已提交
162 163 164 165
        return module

    def __init__(self, name=None, directory=None, module_dir=None,
                 version=None):
166 167
        # Avoid module being initialized multiple times
        if not directory or id(self) in Module._record:
W
wuzewu 已提交
168
            return
169
        Module._record[id(self)] = True
W
wuzewu 已提交
170

W
wuzewu 已提交
171
        mod = self.__class__.__module__ + "." + self.__class__.__name__
W
wuzewu 已提交
172 173
        if mod in _module_runnable_func:
            _run_func_name = _module_runnable_func[mod]
W
wuzewu 已提交
174 175 176
            self._run_func = getattr(self, _run_func_name)
        else:
            self._run_func = None
W
wuzewu 已提交
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
        self._code_version = "v2"
        self._directory = directory
        self.module_desc_path = os.path.join(self.directory, MODULE_DESC_PBNAME)
        self._desc = module_desc_pb2.ModuleDesc()
        with open(self.module_desc_path, "rb") as file:
            self._desc.ParseFromString(file.read())

        module_info = self.desc.attr.map.data['module_info']
        self._name = utils.from_module_attr_to_pyobj(
            module_info.map.data['name'])
        self._author = utils.from_module_attr_to_pyobj(
            module_info.map.data['author'])
        self._author_email = utils.from_module_attr_to_pyobj(
            module_info.map.data['author_email'])
        self._version = utils.from_module_attr_to_pyobj(
            module_info.map.data['version'])
        self._type = utils.from_module_attr_to_pyobj(
            module_info.map.data['type'])
        self._summary = utils.from_module_attr_to_pyobj(
            module_info.map.data['summary'])

        self._initialize()

    @classmethod
    def init_with_name(cls, name, version=None):
        fp_lock = open(os.path.join(CACHE_HOME, name), "a")
        lock.flock(fp_lock, lock.LOCK_EX)
        log_msg = "Installing %s module" % name
        if version:
            log_msg += "-%s" % version
        logger.info(log_msg)
        extra = {"command": "install"}
        result, tips, module_dir = default_module_manager.install_module(
            module_name=name, module_version=version, extra=extra)
        if not result:
            logger.error(tips)
            raise RuntimeError(tips)
W
wuzewu 已提交
214

W
wuzewu 已提交
215 216 217
        logger.info(tips)
        lock.flock(fp_lock, lock.LOCK_UN)
        return cls.init_with_directory(directory=module_dir[0])
W
wuzewu 已提交
218

W
wuzewu 已提交
219 220 221 222 223 224 225 226 227 228 229
    @classmethod
    def init_with_directory(cls, directory):
        desc_file = os.path.join(directory, MODULE_DESC_PBNAME)
        checker = ModuleChecker(directory)
        checker.check()

        module_code_version = checker.module_code_version
        if module_code_version == "v2":
            basename = os.path.split(directory)[-1]
            dirname = os.path.join(*list(os.path.split(directory)[:-1]))
            sys.path.append(dirname)
W
wuzewu 已提交
230 231
            user_module = importlib.import_module("{}.module".format(basename))
            return user_module.HubModule(directory=directory)
W
wuzewu 已提交
232 233
        return ModuleV1(directory=directory)

W
wuzewu 已提交
234 235 236 237
    @property
    def run_func(self):
        return self._run_func

W
wuzewu 已提交
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
    @property
    def desc(self):
        return self._desc

    @property
    def directory(self):
        return self._directory

    @property
    def author(self):
        return self._author

    @property
    def author_email(self):
        return self._author_email

    @property
    def summary(self):
        return self._summary

    @property
    def type(self):
        return self._type

    @property
    def version(self):
        return self._version

    @property
    def name(self):
        return self._name

    @property
    def code_version(self):
        return self._code_version

    @property
W
wuzewu 已提交
275
    def is_runnable(self):
W
wuzewu 已提交
276
        return self._run_func != None
W
wuzewu 已提交
277 278 279

    def _initialize(self):
        pass
W
wuzewu 已提交
280 281


282
class ModuleHelper(object):
W
wuzewu 已提交
283 284
    def __init__(self, directory):
        self.directory = directory
W
wuzewu 已提交
285 286

    def module_desc_path(self):
W
wuzewu 已提交
287
        return os.path.join(self.directory, MODULE_DESC_PBNAME)
W
wuzewu 已提交
288 289

    def model_path(self):
W
wuzewu 已提交
290
        return os.path.join(self.directory, MODEL_DIRNAME)
W
wuzewu 已提交
291

W
wuzewu 已提交
292
    def processor_path(self):
W
wuzewu 已提交
293
        return os.path.join(self.directory, PYTHON_DIR)
W
wuzewu 已提交
294 295 296 297 298

    def processor_name(self):
        return PROCESSOR_NAME

    def assets_path(self):
W
wuzewu 已提交
299
        return os.path.join(self.directory, ASSETS_DIRNAME)
W
wuzewu 已提交
300

W
wuzewu 已提交
301

W
wuzewu 已提交
302 303
class ModuleV1(Module):
    def __init__(self, name=None, directory=None, module_dir=None,
B
BinLong 已提交
304
                 version=None):
W
wuzewu 已提交
305 306 307 308
        if not directory:
            return
        super(ModuleV1, self).__init__(name, directory, module_dir, version)
        self._code_version = "v1"
W
wuzewu 已提交
309 310 311 312
        self.program = None
        self.assets = []
        self.helper = None
        self.signatures = {}
W
wuzewu 已提交
313
        self.default_signature = None
W
wuzewu 已提交
314
        self.processor = None
W
wuzewu 已提交
315
        self.extra_info = {}
W
wuzewu 已提交
316 317 318 319 320 321 322

        # cache data
        self.last_call_name = None
        self.cache_feed_dict = None
        self.cache_fetch_dict = None
        self.cache_program = None

W
wuzewu 已提交
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
        self.helper = ModuleHelper(directory)
        exe = fluid.Executor(fluid.CPUPlace())
        self.program, _, _ = fluid.io.load_inference_model(
            self.helper.model_path(), executor=exe)
        for block in self.program.blocks:
            for op in block.ops:
                if "op_callstack" in op.all_attrs():
                    op._set_attr("op_callstack", [""])
        self._load_processor()
        self._load_assets()
        self._recover_from_desc()
        self._generate_sign_attr()
        self._generate_extra_info()
        self._restore_parameter(self.program)
        self._recover_variable_info(self.program)
W
wuzewu 已提交
338

W
wuzewu 已提交
339 340 341 342 343
    def _dump_processor(self):
        import inspect
        pymodule = inspect.getmodule(self.processor)
        pycode = inspect.getsource(pymodule)
        processor_path = self.helper.processor_path()
344 345 346
        processor_md5 = utils.md5(pycode)
        processor_md5 += str(time.time())
        processor_name = utils.md5(processor_md5)
W
wuzewu 已提交
347 348 349 350
        output_file = os.path.join(processor_path, processor_name + ".py")
        utils.mkdir(processor_path)
        with open(output_file, "w") as file:
            file.write(pycode)
W
wuzewu 已提交
351 352
        utils.from_pyobj_to_module_attr(
            processor_name, self.desc.attr.map.data['processor_info'])
W
wuzewu 已提交
353

W
wuzewu 已提交
354 355
    def _load_processor(self):
        processor_path = self.helper.processor_path()
W
wuzewu 已提交
356 357
        if os.path.exists(processor_path):
            sys.path.append(processor_path)
W
wuzewu 已提交
358 359
            processor_name = utils.from_module_attr_to_pyobj(
                self.desc.attr.map.data['processor_info'])
W
wuzewu 已提交
360 361 362
            self.processor = __import__(processor_name).Processor(module=self)
        else:
            self.processor = None
W
wuzewu 已提交
363

W
wuzewu 已提交
364 365 366 367 368
    def _dump_assets(self):
        utils.mkdir(self.helper.assets_path())
        for asset in self.assets:
            filename = os.path.basename(asset)
            newfile = os.path.join(self.helper.assets_path(), filename)
W
wuzewu 已提交
369
            shutil.copyfile(asset, newfile)
W
wuzewu 已提交
370 371 372 373 374 375 376 377

    def _load_assets(self):
        assets_path = self.helper.assets_path()
        self.assets = []
        for file in os.listdir(assets_path):
            filepath = os.path.join(self.helper.assets_path(), file)
            self.assets.append(filepath)

Z
Zeyu Chen 已提交
378
    def _restore_parameter(self, program):
W
wuzewu 已提交
379
        global_block = program.global_block()
W
wuzewu 已提交
380
        param_attrs = self.desc.attr.map.data['param_attrs']
W
wuzewu 已提交
381
        for key, param_attr in param_attrs.map.data.items():
W
wuzewu 已提交
382
            param = paddle_helper.from_module_attr_to_param(param_attr)
383
            param['name'] = self.get_var_name_with_prefix(key)
W
wuzewu 已提交
384 385 386 387 388 389 390 391 392 393
            if (param['name'] not in global_block.vars):
                continue
            var = global_block.var(param['name'])
            global_block.create_parameter(
                shape=var.shape,
                dtype=var.dtype,
                type=var.type,
                lod_level=var.lod_level,
                error_clip=var.error_clip,
                stop_gradient=var.stop_gradient,
W
wuzewu 已提交
394 395
                is_data=var.is_data,
                **param)
W
wuzewu 已提交
396 397

    def _recover_variable_info(self, program):
W
wuzewu 已提交
398
        var_infos = self.desc.attr.map.data['var_infos']
W
wuzewu 已提交
399
        for var_info in var_infos.map.data:
W
wuzewu 已提交
400
            idx = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
401
                var_infos.map.data[var_info].map.data['block_id'])
W
wuzewu 已提交
402
            stop_gradient = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
403 404
                var_infos.map.data[var_info].map.data['stop_gradient'])
            block = program.blocks[idx]
405
            var_name = self.get_var_name_with_prefix(var_info)
W
wuzewu 已提交
406 407 408 409
            if var_name in block.vars:
                var = block.vars[var_name]
                var.stop_gradient = stop_gradient

W
wuzewu 已提交
410 411 412 413 414 415 416 417
    def get_extra_info(self, key):
        return self.extra_info.get(key, None)

    def _generate_extra_info(self):
        for key in self.extra_info:
            self.__dict__["get_%s" % key] = functools.partial(
                self.get_extra_info, key=key)

W
wuzewu 已提交
418 419 420
    def _generate_sign_attr(self):
        self._check_signatures()
        for sign in self.signatures:
W
wuzewu 已提交
421 422
            self.__dict__[sign] = functools.partial(
                self.__call__, sign_name=sign)
W
wuzewu 已提交
423

424 425 426 427
    def get_vocab_path(self):
        for assets_file in self.assets:
            if "vocab.txt" in assets_file:
                return assets_file
K
kinghuin 已提交
428 429 430 431 432 433 434 435 436 437 438 439 440
        return None

    def get_word_dict_path(self):
        for assets_file in self.assets:
            if "dict.wordseg.pickle" in assets_file:
                return assets_file
        return None

    def get_spm_path(self):
        for assets_file in self.assets:
            if "spm_cased_simp_sampled.model" in assets_file:
                return assets_file
        return None
441

W
wuzewu 已提交
442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465
    def _recover_from_desc(self):
        # recover signature
        for sign, module_var in self.desc.sign2var.items():
            inputs = []
            outputs = []
            feed_names = []
            fetch_names = []
            for var in module_var.feed_desc:
                variable = self.program.global_block().vars[var.var_name]
                inputs.append(variable)
                feed_names.append(var.alias)

            for var in module_var.fetch_desc:
                variable = self.program.global_block().vars[var.var_name]
                outputs.append(variable)
                fetch_names.append(var.alias)

            self.signatures[sign] = create_signature(
                sign,
                inputs=inputs,
                outputs=outputs,
                feed_names=feed_names,
                fetch_names=fetch_names)

W
wuzewu 已提交
466
        # recover default signature
W
wuzewu 已提交
467 468
        default_signature_name = utils.from_module_attr_to_pyobj(
            self.desc.attr.map.data['default_signature'])
W
wuzewu 已提交
469
        self.default_signature = self.signatures[
W
wuzewu 已提交
470
            default_signature_name].name if default_signature_name else None
W
wuzewu 已提交
471

W
wuzewu 已提交
472
        # recover module info
W
wuzewu 已提交
473
        module_info = self.desc.attr.map.data['module_info']
W
wuzewu 已提交
474
        self._name = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
475
            module_info.map.data['name'])
W
wuzewu 已提交
476
        self._author = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
477
            module_info.map.data['author'])
W
wuzewu 已提交
478
        self._author_email = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
479
            module_info.map.data['author_email'])
W
wuzewu 已提交
480
        self._version = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
481
            module_info.map.data['version'])
W
wuzewu 已提交
482
        self._type = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
483
            module_info.map.data['type'])
W
wuzewu 已提交
484
        self._summary = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
485 486
            module_info.map.data['summary'])

W
wuzewu 已提交
487 488 489 490 491 492
        # recover extra info
        extra_info = self.desc.attr.map.data['extra_info']
        self.extra_info = {}
        for key, value in extra_info.map.data.items():
            self.extra_info[key] = utils.from_module_attr_to_pyobj(value)

W
wuzewu 已提交
493
        # recover name prefix
W
wuzewu 已提交
494
        self._name_prefix = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
495
            self.desc.attr.map.data["name_prefix"])
W
wuzewu 已提交
496

497
    def __call__(self, sign_name, data, use_gpu=False, batch_size=1, **kwargs):
W
wuzewu 已提交
498 499
        self.check_processor()

W
wuzewu 已提交
500
        def _get_reader_and_feeder(data_format, data, place):
W
wuzewu 已提交
501
            def _reader(process_data):
W
wuzewu 已提交
502 503 504 505 506 507 508 509 510
                for item in zip(*process_data):
                    yield item

            process_data = []
            feed_name_list = []
            for key in data_format:
                process_data.append([value['processed'] for value in data[key]])
                feed_name_list.append(data_format[key]['feed_key'])
            feeder = fluid.DataFeeder(feed_list=feed_name_list, place=place)
W
wuzewu 已提交
511
            return functools.partial(_reader, process_data=process_data), feeder
W
wuzewu 已提交
512

W
wuzewu 已提交
513 514 515 516 517 518 519 520
        if self.last_call_name != sign_name:
            self.last_call_name = sign_name
            self.cache_feed_dict, self.cache_fetch_dict, self.cache_program = self.context(
                sign_name, for_test=True)
        feed_dict = self.cache_feed_dict
        fetch_dict = self.cache_fetch_dict
        program = self.cache_program

W
wuzewu 已提交
521 522
        fetch_list = list(set([value for key, value in fetch_dict.items()]))
        with fluid.program_guard(program):
W
wuzewu 已提交
523 524
            result = []
            index = 0
525 526 527 528 529 530 531
            try:
                _places = os.environ["CUDA_VISIBLE_DEVICES"]
                int(_places[0])
            except:
                use_gpu = False

            place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
532

W
wuzewu 已提交
533
            exe = fluid.Executor(place=place)
W
wuzewu 已提交
534 535 536 537
            data = self.processor.preprocess(
                sign_name=sign_name, data_dict=data)
            data_format = self.processor.data_format(sign_name=sign_name)
            reader, feeder = _get_reader_and_feeder(data_format, data, place)
538
            reader = paddle.batch(reader, batch_size=batch_size)
W
wuzewu 已提交
539 540 541 542 543
            for batch in reader():
                data_out = exe.run(
                    feed=feeder.feed(batch),
                    fetch_list=fetch_list,
                    return_numpy=False)
W
wuzewu 已提交
544 545 546 547 548 549 550 551 552
                sub_data = {
                    key: value[index:index + len(batch)]
                    for key, value in data.items()
                }
                result += self.processor.postprocess(sign_name, data_out,
                                                     sub_data, **kwargs)
                index += len(batch)

        return result
W
wuzewu 已提交
553

W
wuzewu 已提交
554
    def check_processor(self):
W
wuzewu 已提交
555 556
        if not self.processor:
            raise ValueError("This Module is not callable!")
W
wuzewu 已提交
557

W
wuzewu 已提交
558
    @property
W
wuzewu 已提交
559
    def is_runnable(self):
W
wuzewu 已提交
560 561
        return self.default_signature != None

W
wuzewu 已提交
562
    def context(self,
563
                sign_name=None,
W
wuzewu 已提交
564
                for_test=False,
Z
Zeyu Chen 已提交
565
                trainable=True,
W
wuzewu 已提交
566
                regularizer=None,
567
                max_seq_len=128,
W
wuzewu 已提交
568
                learning_rate=1e-3):
569 570 571 572 573
        """
        Args:
            max_seq_len(int): maximum sequence length, this option is only
            available for BERT/ERNIE module
        """
W
wuzewu 已提交
574

575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603
        if sign_name:
            if sign_name not in self.signatures:
                raise KeyError(
                    "Module did not have a signature with name %s" % sign_name)
            signature = self.signatures[sign_name]
        else:
            inputs = [
                input for signature in self.signatures.values()
                for input in signature.inputs
            ]
            outputs = [
                output for signature in self.signatures.values()
                for output in signature.outputs
            ]
            feed_names = [
                feed_name for signature in self.signatures.values()
                for feed_name in signature.feed_names
            ]
            fetch_names = [
                fetch_name for signature in self.signatures.values()
                for fetch_name in signature.fetch_names
            ]
            signature = create_signature(
                name="hub_temp_signature",
                inputs=inputs,
                outputs=outputs,
                feed_names=feed_names,
                fetch_names=fetch_names,
                for_predict=False)
W
wuzewu 已提交
604

W
wuzewu 已提交
605
        program = self.program.clone(for_test=for_test)
W
wuzewu 已提交
606
        paddle_helper.remove_feed_fetch_op(program)
W
wuzewu 已提交
607 608

        if not for_test:
W
wuzewu 已提交
609
            paddle_helper.set_parameter_trainable(program, trainable)
W
wuzewu 已提交
610

W
wuzewu 已提交
611
            paddle_helper.set_parameter_learning_rate(program, learning_rate)
W
wuzewu 已提交
612

W
wuzewu 已提交
613
            paddle_helper.set_parameter_regularizer(program, regularizer)
W
wuzewu 已提交
614

Z
Zeyu Chen 已提交
615
            self._restore_parameter(program)
W
wuzewu 已提交
616

W
wuzewu 已提交
617 618
        self._recover_variable_info(program)

W
wuzewu 已提交
619
        paddle_helper.set_op_attr(program, is_test=for_test)
W
wuzewu 已提交
620 621 622 623 624 625 626 627 628 629 630 631 632 633
        feed_dict = {}
        fetch_dict = {}
        for index, var in enumerate(signature.inputs):
            feed_dict[index] = program.global_block().var(var.name)
            key = signature.feed_names[index]
            if key:
                feed_dict[key] = program.global_block().var(var.name)

        for index, var in enumerate(signature.outputs):
            fetch_dict[index] = program.global_block().var(var.name)
            key = signature.fetch_names[index]
            if key:
                fetch_dict[key] = program.global_block().var(var.name)

634
        # update BERT/ERNIE's input tensor's sequence length to max_seq_len
K
kinghuin 已提交
635
        if "bert" in self.name or self.name.startswith("ernie"):
636 637 638 639
            MAX_SEQ_LENGTH = 512
            if max_seq_len > MAX_SEQ_LENGTH or max_seq_len <= 0:
                raise ValueError(
                    "max_seq_len({}) should be in the range of [1, {}]".format(
K
kinghuin 已提交
640
                        max_seq_len, MAX_SEQ_LENGTH))
641
            logger.info(
642
                "Set maximum sequence length of input tensor to {}".format(
643
                    max_seq_len))
644 645
            if self.name.startswith("ernie_v2"):
                feed_list = [
646 647
                    "input_ids", "position_ids", "segment_ids", "input_mask",
                    "task_ids"
648 649 650 651 652 653
                ]
            else:
                feed_list = [
                    "input_ids", "position_ids", "segment_ids", "input_mask"
                ]
            for tensor_name in feed_list:
654 655 656 657 658 659 660
                seq_tensor_shape = [-1, max_seq_len, 1]
                logger.info("The shape of input tensor[{}] set to {}".format(
                    tensor_name, seq_tensor_shape))
                program.global_block().var(
                    feed_dict[tensor_name].name).desc.set_shape(
                        seq_tensor_shape)

661 662
        # record num parameters loaded by paddlehub
        num_param_loaded = 0
W
wuzewu 已提交
663
        for param in program.global_block().iter_parameters():
664 665 666
            num_param_loaded += 1
        logger.info(
            "%d pretrained paramaters loaded by PaddleHub" % num_param_loaded)
W
wuzewu 已提交
667

W
wuzewu 已提交
668 669
        return feed_dict, fetch_dict, program

670
    def get_name_prefix(self):
W
wuzewu 已提交
671
        return self._name_prefix
672 673 674 675

    def get_var_name_with_prefix(self, var_name):
        return self.get_name_prefix() + var_name

W
wuzewu 已提交
676
    def _check_signatures(self):
W
wuzewu 已提交
677 678
        if not self.signatures:
            raise ValueError("Signatures should not be None")
W
wuzewu 已提交
679 680

        for key, sign in self.signatures.items():
W
wuzewu 已提交
681 682 683 684
            if not isinstance(sign, Signature):
                raise TypeError(
                    "Item in Signatures shoule be an instance of paddlehub.Signature"
                )
W
wuzewu 已提交
685 686 687

            for input in sign.inputs:
                _tmp_program = input.block.program
W
wuzewu 已提交
688 689 690 691
                if not self.program == _tmp_program:
                    raise ValueError(
                        "All input and outputs variables in signature should come from the same Program"
                    )
W
wuzewu 已提交
692 693 694

            for output in sign.outputs:
                _tmp_program = output.block.program
W
wuzewu 已提交
695 696 697 698
                if not self.program == _tmp_program:
                    raise ValueError(
                        "All input and outputs variables in signature should come from the same Program"
                    )