module.py 25.2 KB
Newer Older
S
Steffy-zxf 已提交
1
#coding:utf-8
W
wuzewu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
W
wuzewu 已提交
19 20 21 22 23

import os
import time
import sys
import functools
W
wuzewu 已提交
24 25 26
import inspect
import importlib
import tarfile
W
wuzewu 已提交
27
import six
W
wuzewu 已提交
28
import shutil
W
wuzewu 已提交
29 30 31 32

import paddle
import paddle.fluid as fluid

W
wuzewu 已提交
33 34
from paddlehub.common import utils
from paddlehub.common import paddle_helper
W
wuzewu 已提交
35
from paddlehub.common.dir import CACHE_HOME
S
shenyuhan 已提交
36
from paddlehub.common.lock import lock
W
wuzewu 已提交
37 38
from paddlehub.common.logger import logger
from paddlehub.common.hub_server import CacheUpdater
W
wuzewu 已提交
39
from paddlehub.common import tmp_dir
W
wuzewu 已提交
40
from paddlehub.common.downloader import progress
W
wuzewu 已提交
41 42
from paddlehub.module import module_desc_pb2
from paddlehub.module.manager import default_module_manager
W
wuzewu 已提交
43 44
from paddlehub.module.checker import ModuleChecker
from paddlehub.module.signature import Signature, create_signature
W
wuzewu 已提交
45
from paddlehub.module.base_processor import BaseProcessor
W
wuzewu 已提交
46
from paddlehub.io.parser import yaml_parser
W
wuzewu 已提交
47
from paddlehub import version
W
wuzewu 已提交
48

Z
Zeyu Chen 已提交
49
# PaddleHub module dir name
W
wuzewu 已提交
50 51 52 53 54
ASSETS_DIRNAME = "assets"
MODEL_DIRNAME = "model"
MODULE_DESC_PBNAME = "module_desc.pb"
PYTHON_DIR = "python"
PROCESSOR_NAME = "processor"
Z
Zeyu Chen 已提交
55
# PaddleHub var prefix
W
wuzewu 已提交
56
HUB_VAR_PREFIX = "@HUB_%s@"
W
wuzewu 已提交
57 58 59 60 61 62
# PaddleHub Module package suffix
HUB_PACKAGE_SUFFIX = "phm"


def create_module(directory, name, author, email, module_type, summary,
                  version):
W
wuzewu 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
    save_file = "{}-{}.{}".format(name, version, HUB_PACKAGE_SUFFIX)

    with tmp_dir() as base_dir:
        # package the module
        with tarfile.open(save_file, "w:gz") as tar:
            module_dir = os.path.join(base_dir, name)
            shutil.copytree(directory, module_dir)

            # record module info and serialize
            desc = module_desc_pb2.ModuleDesc()
            attr = desc.attr
            attr.type = module_desc_pb2.MAP
            module_info = attr.map.data['module_info']
            module_info.type = module_desc_pb2.MAP
            utils.from_pyobj_to_module_attr(name, module_info.map.data['name'])
            utils.from_pyobj_to_module_attr(author,
                                            module_info.map.data['author'])
            utils.from_pyobj_to_module_attr(
                email, module_info.map.data['author_email'])
            utils.from_pyobj_to_module_attr(module_type,
                                            module_info.map.data['type'])
            utils.from_pyobj_to_module_attr(summary,
                                            module_info.map.data['summary'])
            utils.from_pyobj_to_module_attr(version,
                                            module_info.map.data['version'])
            module_desc_path = os.path.join(module_dir, "module_desc.pb")
            with open(module_desc_path, "wb") as f:
                f.write(desc.SerializeToString())

            # generate check info
            checker = ModuleChecker(module_dir)
            checker.generate_check_info()

            # add __init__
            module_init = os.path.join(module_dir, "__init__.py")
            with open(module_init, "a") as file:
                file.write("")

            _cwd = os.getcwd()
            os.chdir(base_dir)
W
wuzewu 已提交
103 104 105 106 107
            module_dir = module_dir.replace(base_dir, ".")
            tar.add(module_dir, recursive=False)
            files = []
            for dirname, _, subfiles in os.walk(module_dir):
                for file in subfiles:
W
wuzewu 已提交
108 109
                    #                     if file.startswith("."):
                    #                         continue
W
wuzewu 已提交
110 111 112 113 114 115 116 117 118 119 120
                    files.append(os.path.join(dirname, file))

            total_length = len(files)
            print("Create Module {}-{}".format(name, version))
            for index, file in enumerate(files):
                done = int(float(index) / total_length * 50)
                progress("[%-50s] %.2f%%" % ('=' * done,
                                             float(index / total_length * 100)))
                tar.add(file)
            progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
            print("Module package saved as {}".format(save_file))
W
wuzewu 已提交
121
            os.chdir(_cwd)
W
wuzewu 已提交
122 123


W
wuzewu 已提交
124
_module_runnable_func = {}
W
wuzewu 已提交
125 126


W
wuzewu 已提交
127
def runnable(func):
W
wuzewu 已提交
128
    mod = func.__module__ + "." + inspect.stack()[1][3]
W
wuzewu 已提交
129
    _module_runnable_func[mod] = func.__name__
W
wuzewu 已提交
130 131 132 133 134 135 136

    def _wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return _wrapper


W
wuzewu 已提交
137 138
class Module(object):

139 140 141
    _record = {}

    def __new__(cls, name=None, directory=None, module_dir=None, version=None):
W
wuzewu 已提交
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
        if cls.__name__ == "Module":
            if name:
                module = cls.init_with_name(name=name, version=version)
            elif directory:
                module = cls.init_with_directory(directory=directory)
            elif module_dir:
                logger.warning(
                    "Parameter module_dir is deprecated, please use directory to specify the path"
                )
                if isinstance(module_dir, list) or isinstance(
                        module_dir, tuple):
                    directory = module_dir[0]
                    version = module_dir[1]
                else:
                    directory = module_dir
                module = cls.init_with_directory(directory=directory)
走神的阿圆's avatar
走神的阿圆 已提交
158
            CacheUpdater("update_cache", module.name, module.version).start()
159 160 161
        else:
            module = object.__new__(cls)

W
wuzewu 已提交
162 163 164 165
        return module

    def __init__(self, name=None, directory=None, module_dir=None,
                 version=None):
166 167
        # Avoid module being initialized multiple times
        if not directory or id(self) in Module._record:
W
wuzewu 已提交
168
            return
169
        Module._record[id(self)] = True
W
wuzewu 已提交
170

W
wuzewu 已提交
171
        mod = self.__class__.__module__ + "." + self.__class__.__name__
W
wuzewu 已提交
172 173
        if mod in _module_runnable_func:
            _run_func_name = _module_runnable_func[mod]
W
wuzewu 已提交
174 175 176
            self._run_func = getattr(self, _run_func_name)
        else:
            self._run_func = None
W
wuzewu 已提交
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
        self._code_version = "v2"
        self._directory = directory
        self.module_desc_path = os.path.join(self.directory, MODULE_DESC_PBNAME)
        self._desc = module_desc_pb2.ModuleDesc()
        with open(self.module_desc_path, "rb") as file:
            self._desc.ParseFromString(file.read())

        module_info = self.desc.attr.map.data['module_info']
        self._name = utils.from_module_attr_to_pyobj(
            module_info.map.data['name'])
        self._author = utils.from_module_attr_to_pyobj(
            module_info.map.data['author'])
        self._author_email = utils.from_module_attr_to_pyobj(
            module_info.map.data['author_email'])
        self._version = utils.from_module_attr_to_pyobj(
            module_info.map.data['version'])
        self._type = utils.from_module_attr_to_pyobj(
            module_info.map.data['type'])
        self._summary = utils.from_module_attr_to_pyobj(
            module_info.map.data['summary'])

        self._initialize()

    @classmethod
    def init_with_name(cls, name, version=None):
        fp_lock = open(os.path.join(CACHE_HOME, name), "a")
        lock.flock(fp_lock, lock.LOCK_EX)
        log_msg = "Installing %s module" % name
        if version:
            log_msg += "-%s" % version
        logger.info(log_msg)
        extra = {"command": "install"}
        result, tips, module_dir = default_module_manager.install_module(
            module_name=name, module_version=version, extra=extra)
        if not result:
            logger.error(tips)
            raise RuntimeError(tips)
W
wuzewu 已提交
214

W
wuzewu 已提交
215 216 217
        logger.info(tips)
        lock.flock(fp_lock, lock.LOCK_UN)
        return cls.init_with_directory(directory=module_dir[0])
W
wuzewu 已提交
218

W
wuzewu 已提交
219 220 221 222 223 224 225 226
    @classmethod
    def init_with_directory(cls, directory):
        desc_file = os.path.join(directory, MODULE_DESC_PBNAME)
        checker = ModuleChecker(directory)
        checker.check()

        module_code_version = checker.module_code_version
        if module_code_version == "v2":
W
wuzewu 已提交
227 228 229 230 231 232 233 234
            sys.path.insert(0, directory)
            # clear module cache
            if 'module' in sys.modules:
                sys.modules.pop('module')
            _module = importlib.import_module("module")
            user_module = _module.HubModule(directory=directory)
            sys.path.pop(0)
            return user_module
W
wuzewu 已提交
235 236
        return ModuleV1(directory=directory)

W
wuzewu 已提交
237 238 239 240
    @property
    def run_func(self):
        return self._run_func

W
wuzewu 已提交
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
    @property
    def desc(self):
        return self._desc

    @property
    def directory(self):
        return self._directory

    @property
    def author(self):
        return self._author

    @property
    def author_email(self):
        return self._author_email

    @property
    def summary(self):
        return self._summary

    @property
    def type(self):
        return self._type

    @property
    def version(self):
        return self._version

    @property
    def name(self):
        return self._name

    @property
    def code_version(self):
        return self._code_version

    @property
W
wuzewu 已提交
278
    def is_runnable(self):
W
wuzewu 已提交
279
        return self._run_func != None
W
wuzewu 已提交
280 281 282

    def _initialize(self):
        pass
W
wuzewu 已提交
283 284


285
class ModuleHelper(object):
W
wuzewu 已提交
286 287
    def __init__(self, directory):
        self.directory = directory
W
wuzewu 已提交
288 289

    def module_desc_path(self):
W
wuzewu 已提交
290
        return os.path.join(self.directory, MODULE_DESC_PBNAME)
W
wuzewu 已提交
291 292

    def model_path(self):
W
wuzewu 已提交
293
        return os.path.join(self.directory, MODEL_DIRNAME)
W
wuzewu 已提交
294

W
wuzewu 已提交
295
    def processor_path(self):
W
wuzewu 已提交
296
        return os.path.join(self.directory, PYTHON_DIR)
W
wuzewu 已提交
297 298 299 300 301

    def processor_name(self):
        return PROCESSOR_NAME

    def assets_path(self):
W
wuzewu 已提交
302
        return os.path.join(self.directory, ASSETS_DIRNAME)
W
wuzewu 已提交
303

W
wuzewu 已提交
304

W
wuzewu 已提交
305 306
class ModuleV1(Module):
    def __init__(self, name=None, directory=None, module_dir=None,
B
BinLong 已提交
307
                 version=None):
W
wuzewu 已提交
308 309 310 311
        if not directory:
            return
        super(ModuleV1, self).__init__(name, directory, module_dir, version)
        self._code_version = "v1"
W
wuzewu 已提交
312 313 314 315
        self.program = None
        self.assets = []
        self.helper = None
        self.signatures = {}
W
wuzewu 已提交
316
        self.default_signature = None
W
wuzewu 已提交
317
        self.processor = None
W
wuzewu 已提交
318
        self.extra_info = {}
W
wuzewu 已提交
319 320 321 322 323 324 325

        # cache data
        self.last_call_name = None
        self.cache_feed_dict = None
        self.cache_fetch_dict = None
        self.cache_program = None

W
wuzewu 已提交
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
        self.helper = ModuleHelper(directory)
        exe = fluid.Executor(fluid.CPUPlace())
        self.program, _, _ = fluid.io.load_inference_model(
            self.helper.model_path(), executor=exe)
        for block in self.program.blocks:
            for op in block.ops:
                if "op_callstack" in op.all_attrs():
                    op._set_attr("op_callstack", [""])
        self._load_processor()
        self._load_assets()
        self._recover_from_desc()
        self._generate_sign_attr()
        self._generate_extra_info()
        self._restore_parameter(self.program)
        self._recover_variable_info(self.program)
W
wuzewu 已提交
341

W
wuzewu 已提交
342 343 344 345 346
    def _dump_processor(self):
        import inspect
        pymodule = inspect.getmodule(self.processor)
        pycode = inspect.getsource(pymodule)
        processor_path = self.helper.processor_path()
347 348 349
        processor_md5 = utils.md5(pycode)
        processor_md5 += str(time.time())
        processor_name = utils.md5(processor_md5)
W
wuzewu 已提交
350 351 352 353
        output_file = os.path.join(processor_path, processor_name + ".py")
        utils.mkdir(processor_path)
        with open(output_file, "w") as file:
            file.write(pycode)
W
wuzewu 已提交
354 355
        utils.from_pyobj_to_module_attr(
            processor_name, self.desc.attr.map.data['processor_info'])
W
wuzewu 已提交
356

W
wuzewu 已提交
357 358
    def _load_processor(self):
        processor_path = self.helper.processor_path()
W
wuzewu 已提交
359 360
        if os.path.exists(processor_path):
            sys.path.append(processor_path)
W
wuzewu 已提交
361 362
            processor_name = utils.from_module_attr_to_pyobj(
                self.desc.attr.map.data['processor_info'])
W
wuzewu 已提交
363 364 365
            self.processor = __import__(processor_name).Processor(module=self)
        else:
            self.processor = None
W
wuzewu 已提交
366

W
wuzewu 已提交
367 368 369 370 371
    def _dump_assets(self):
        utils.mkdir(self.helper.assets_path())
        for asset in self.assets:
            filename = os.path.basename(asset)
            newfile = os.path.join(self.helper.assets_path(), filename)
W
wuzewu 已提交
372
            shutil.copyfile(asset, newfile)
W
wuzewu 已提交
373 374 375 376 377 378 379 380

    def _load_assets(self):
        assets_path = self.helper.assets_path()
        self.assets = []
        for file in os.listdir(assets_path):
            filepath = os.path.join(self.helper.assets_path(), file)
            self.assets.append(filepath)

Z
Zeyu Chen 已提交
381
    def _restore_parameter(self, program):
W
wuzewu 已提交
382
        global_block = program.global_block()
W
wuzewu 已提交
383
        param_attrs = self.desc.attr.map.data['param_attrs']
W
wuzewu 已提交
384
        for key, param_attr in param_attrs.map.data.items():
W
wuzewu 已提交
385
            param = paddle_helper.from_module_attr_to_param(param_attr)
386
            param['name'] = self.get_var_name_with_prefix(key)
W
wuzewu 已提交
387 388 389 390 391 392 393 394 395 396
            if (param['name'] not in global_block.vars):
                continue
            var = global_block.var(param['name'])
            global_block.create_parameter(
                shape=var.shape,
                dtype=var.dtype,
                type=var.type,
                lod_level=var.lod_level,
                error_clip=var.error_clip,
                stop_gradient=var.stop_gradient,
W
wuzewu 已提交
397 398
                is_data=var.is_data,
                **param)
W
wuzewu 已提交
399 400

    def _recover_variable_info(self, program):
W
wuzewu 已提交
401
        var_infos = self.desc.attr.map.data['var_infos']
W
wuzewu 已提交
402
        for var_info in var_infos.map.data:
W
wuzewu 已提交
403
            idx = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
404
                var_infos.map.data[var_info].map.data['block_id'])
W
wuzewu 已提交
405
            stop_gradient = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
406 407
                var_infos.map.data[var_info].map.data['stop_gradient'])
            block = program.blocks[idx]
408
            var_name = self.get_var_name_with_prefix(var_info)
W
wuzewu 已提交
409 410 411 412
            if var_name in block.vars:
                var = block.vars[var_name]
                var.stop_gradient = stop_gradient

W
wuzewu 已提交
413 414 415 416 417 418 419 420
    def get_extra_info(self, key):
        return self.extra_info.get(key, None)

    def _generate_extra_info(self):
        for key in self.extra_info:
            self.__dict__["get_%s" % key] = functools.partial(
                self.get_extra_info, key=key)

W
wuzewu 已提交
421 422 423
    def _generate_sign_attr(self):
        self._check_signatures()
        for sign in self.signatures:
W
wuzewu 已提交
424 425
            self.__dict__[sign] = functools.partial(
                self.__call__, sign_name=sign)
W
wuzewu 已提交
426

427 428 429 430
    def get_vocab_path(self):
        for assets_file in self.assets:
            if "vocab.txt" in assets_file:
                return assets_file
K
kinghuin 已提交
431 432 433 434 435 436 437 438 439 440 441 442 443
        return None

    def get_word_dict_path(self):
        for assets_file in self.assets:
            if "dict.wordseg.pickle" in assets_file:
                return assets_file
        return None

    def get_spm_path(self):
        for assets_file in self.assets:
            if "spm_cased_simp_sampled.model" in assets_file:
                return assets_file
        return None
444

W
wuzewu 已提交
445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468
    def _recover_from_desc(self):
        # recover signature
        for sign, module_var in self.desc.sign2var.items():
            inputs = []
            outputs = []
            feed_names = []
            fetch_names = []
            for var in module_var.feed_desc:
                variable = self.program.global_block().vars[var.var_name]
                inputs.append(variable)
                feed_names.append(var.alias)

            for var in module_var.fetch_desc:
                variable = self.program.global_block().vars[var.var_name]
                outputs.append(variable)
                fetch_names.append(var.alias)

            self.signatures[sign] = create_signature(
                sign,
                inputs=inputs,
                outputs=outputs,
                feed_names=feed_names,
                fetch_names=fetch_names)

W
wuzewu 已提交
469
        # recover default signature
W
wuzewu 已提交
470 471
        default_signature_name = utils.from_module_attr_to_pyobj(
            self.desc.attr.map.data['default_signature'])
W
wuzewu 已提交
472
        self.default_signature = self.signatures[
W
wuzewu 已提交
473
            default_signature_name].name if default_signature_name else None
W
wuzewu 已提交
474

W
wuzewu 已提交
475
        # recover module info
W
wuzewu 已提交
476
        module_info = self.desc.attr.map.data['module_info']
W
wuzewu 已提交
477
        self._name = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
478
            module_info.map.data['name'])
W
wuzewu 已提交
479
        self._author = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
480
            module_info.map.data['author'])
W
wuzewu 已提交
481
        self._author_email = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
482
            module_info.map.data['author_email'])
W
wuzewu 已提交
483
        self._version = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
484
            module_info.map.data['version'])
W
wuzewu 已提交
485
        self._type = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
486
            module_info.map.data['type'])
W
wuzewu 已提交
487
        self._summary = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
488 489
            module_info.map.data['summary'])

W
wuzewu 已提交
490 491 492 493 494 495
        # recover extra info
        extra_info = self.desc.attr.map.data['extra_info']
        self.extra_info = {}
        for key, value in extra_info.map.data.items():
            self.extra_info[key] = utils.from_module_attr_to_pyobj(value)

W
wuzewu 已提交
496
        # recover name prefix
W
wuzewu 已提交
497
        self._name_prefix = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
498
            self.desc.attr.map.data["name_prefix"])
W
wuzewu 已提交
499

500
    def __call__(self, sign_name, data, use_gpu=False, batch_size=1, **kwargs):
W
wuzewu 已提交
501 502
        self.check_processor()

W
wuzewu 已提交
503
        def _get_reader_and_feeder(data_format, data, place):
W
wuzewu 已提交
504
            def _reader(process_data):
W
wuzewu 已提交
505 506 507 508 509 510 511 512 513
                for item in zip(*process_data):
                    yield item

            process_data = []
            feed_name_list = []
            for key in data_format:
                process_data.append([value['processed'] for value in data[key]])
                feed_name_list.append(data_format[key]['feed_key'])
            feeder = fluid.DataFeeder(feed_list=feed_name_list, place=place)
W
wuzewu 已提交
514
            return functools.partial(_reader, process_data=process_data), feeder
W
wuzewu 已提交
515

W
wuzewu 已提交
516 517 518 519 520 521 522 523
        if self.last_call_name != sign_name:
            self.last_call_name = sign_name
            self.cache_feed_dict, self.cache_fetch_dict, self.cache_program = self.context(
                sign_name, for_test=True)
        feed_dict = self.cache_feed_dict
        fetch_dict = self.cache_fetch_dict
        program = self.cache_program

W
wuzewu 已提交
524 525
        fetch_list = list(set([value for key, value in fetch_dict.items()]))
        with fluid.program_guard(program):
W
wuzewu 已提交
526 527
            result = []
            index = 0
528 529 530 531 532 533 534
            try:
                _places = os.environ["CUDA_VISIBLE_DEVICES"]
                int(_places[0])
            except:
                use_gpu = False

            place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
535

W
wuzewu 已提交
536
            exe = fluid.Executor(place=place)
W
wuzewu 已提交
537 538 539 540
            data = self.processor.preprocess(
                sign_name=sign_name, data_dict=data)
            data_format = self.processor.data_format(sign_name=sign_name)
            reader, feeder = _get_reader_and_feeder(data_format, data, place)
541
            reader = paddle.batch(reader, batch_size=batch_size)
W
wuzewu 已提交
542 543 544 545 546
            for batch in reader():
                data_out = exe.run(
                    feed=feeder.feed(batch),
                    fetch_list=fetch_list,
                    return_numpy=False)
W
wuzewu 已提交
547 548 549 550 551 552 553 554 555
                sub_data = {
                    key: value[index:index + len(batch)]
                    for key, value in data.items()
                }
                result += self.processor.postprocess(sign_name, data_out,
                                                     sub_data, **kwargs)
                index += len(batch)

        return result
W
wuzewu 已提交
556

W
wuzewu 已提交
557
    def check_processor(self):
W
wuzewu 已提交
558 559
        if not self.processor:
            raise ValueError("This Module is not callable!")
W
wuzewu 已提交
560

W
wuzewu 已提交
561
    @property
W
wuzewu 已提交
562
    def is_runnable(self):
W
wuzewu 已提交
563 564
        return self.default_signature != None

W
wuzewu 已提交
565
    def context(self,
566
                sign_name=None,
W
wuzewu 已提交
567
                for_test=False,
Z
Zeyu Chen 已提交
568
                trainable=True,
W
wuzewu 已提交
569
                regularizer=None,
570
                max_seq_len=128,
W
wuzewu 已提交
571
                learning_rate=1e-3):
572 573 574 575 576
        """
        Args:
            max_seq_len(int): maximum sequence length, this option is only
            available for BERT/ERNIE module
        """
W
wuzewu 已提交
577

578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606
        if sign_name:
            if sign_name not in self.signatures:
                raise KeyError(
                    "Module did not have a signature with name %s" % sign_name)
            signature = self.signatures[sign_name]
        else:
            inputs = [
                input for signature in self.signatures.values()
                for input in signature.inputs
            ]
            outputs = [
                output for signature in self.signatures.values()
                for output in signature.outputs
            ]
            feed_names = [
                feed_name for signature in self.signatures.values()
                for feed_name in signature.feed_names
            ]
            fetch_names = [
                fetch_name for signature in self.signatures.values()
                for fetch_name in signature.fetch_names
            ]
            signature = create_signature(
                name="hub_temp_signature",
                inputs=inputs,
                outputs=outputs,
                feed_names=feed_names,
                fetch_names=fetch_names,
                for_predict=False)
W
wuzewu 已提交
607

W
wuzewu 已提交
608
        program = self.program.clone(for_test=for_test)
W
wuzewu 已提交
609
        paddle_helper.remove_feed_fetch_op(program)
W
wuzewu 已提交
610 611

        if not for_test:
W
wuzewu 已提交
612
            paddle_helper.set_parameter_trainable(program, trainable)
W
wuzewu 已提交
613

W
wuzewu 已提交
614
            paddle_helper.set_parameter_learning_rate(program, learning_rate)
W
wuzewu 已提交
615

W
wuzewu 已提交
616
            paddle_helper.set_parameter_regularizer(program, regularizer)
W
wuzewu 已提交
617

Z
Zeyu Chen 已提交
618
            self._restore_parameter(program)
W
wuzewu 已提交
619

W
wuzewu 已提交
620 621
        self._recover_variable_info(program)

W
wuzewu 已提交
622
        paddle_helper.set_op_attr(program, is_test=for_test)
W
wuzewu 已提交
623 624 625 626 627 628 629 630 631 632 633 634 635 636
        feed_dict = {}
        fetch_dict = {}
        for index, var in enumerate(signature.inputs):
            feed_dict[index] = program.global_block().var(var.name)
            key = signature.feed_names[index]
            if key:
                feed_dict[key] = program.global_block().var(var.name)

        for index, var in enumerate(signature.outputs):
            fetch_dict[index] = program.global_block().var(var.name)
            key = signature.fetch_names[index]
            if key:
                fetch_dict[key] = program.global_block().var(var.name)

637
        # update BERT/ERNIE's input tensor's sequence length to max_seq_len
K
kinghuin 已提交
638
        if "bert" in self.name or self.name.startswith("ernie"):
639 640 641 642
            MAX_SEQ_LENGTH = 512
            if max_seq_len > MAX_SEQ_LENGTH or max_seq_len <= 0:
                raise ValueError(
                    "max_seq_len({}) should be in the range of [1, {}]".format(
K
kinghuin 已提交
643
                        max_seq_len, MAX_SEQ_LENGTH))
644
            logger.info(
645
                "Set maximum sequence length of input tensor to {}".format(
646
                    max_seq_len))
647 648
            if self.name.startswith("ernie_v2"):
                feed_list = [
649 650
                    "input_ids", "position_ids", "segment_ids", "input_mask",
                    "task_ids"
651 652 653 654 655 656
                ]
            else:
                feed_list = [
                    "input_ids", "position_ids", "segment_ids", "input_mask"
                ]
            for tensor_name in feed_list:
657 658 659 660 661 662 663
                seq_tensor_shape = [-1, max_seq_len, 1]
                logger.info("The shape of input tensor[{}] set to {}".format(
                    tensor_name, seq_tensor_shape))
                program.global_block().var(
                    feed_dict[tensor_name].name).desc.set_shape(
                        seq_tensor_shape)

664 665
        # record num parameters loaded by paddlehub
        num_param_loaded = 0
W
wuzewu 已提交
666
        for param in program.global_block().iter_parameters():
667 668 669
            num_param_loaded += 1
        logger.info(
            "%d pretrained paramaters loaded by PaddleHub" % num_param_loaded)
W
wuzewu 已提交
670

W
wuzewu 已提交
671 672
        return feed_dict, fetch_dict, program

673
    def get_name_prefix(self):
W
wuzewu 已提交
674
        return self._name_prefix
675 676 677 678

    def get_var_name_with_prefix(self, var_name):
        return self.get_name_prefix() + var_name

W
wuzewu 已提交
679
    def _check_signatures(self):
W
wuzewu 已提交
680 681
        if not self.signatures:
            raise ValueError("Signatures should not be None")
W
wuzewu 已提交
682 683

        for key, sign in self.signatures.items():
W
wuzewu 已提交
684 685 686 687
            if not isinstance(sign, Signature):
                raise TypeError(
                    "Item in Signatures shoule be an instance of paddlehub.Signature"
                )
W
wuzewu 已提交
688 689 690

            for input in sign.inputs:
                _tmp_program = input.block.program
W
wuzewu 已提交
691 692 693 694
                if not self.program == _tmp_program:
                    raise ValueError(
                        "All input and outputs variables in signature should come from the same Program"
                    )
W
wuzewu 已提交
695 696 697

            for output in sign.outputs:
                _tmp_program = output.block.program
W
wuzewu 已提交
698 699 700 701
                if not self.program == _tmp_program:
                    raise ValueError(
                        "All input and outputs variables in signature should come from the same Program"
                    )