module.py 25.2 KB
Newer Older
S
Steffy-zxf 已提交
1
# coding:utf-8
W
wuzewu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
W
wuzewu 已提交
19 20 21 22 23

import os
import time
import sys
import functools
W
wuzewu 已提交
24 25
import inspect
import importlib
W
wuzewu 已提交
26
import shutil
W
wuzewu 已提交
27 28 29 30

import paddle
import paddle.fluid as fluid

W
wuzewu 已提交
31 32
from paddlehub.common import utils
from paddlehub.common import paddle_helper
W
wuzewu 已提交
33
from paddlehub.common.dir import CACHE_HOME
S
shenyuhan 已提交
34
from paddlehub.common.lock import lock
W
wuzewu 已提交
35 36
from paddlehub.common.logger import logger
from paddlehub.common.hub_server import CacheUpdater
W
wuzewu 已提交
37 38
from paddlehub.module import module_desc_pb2
from paddlehub.module.manager import default_module_manager
W
wuzewu 已提交
39 40
from paddlehub.module.checker import ModuleChecker
from paddlehub.module.signature import Signature, create_signature
W
wuzewu 已提交
41

Z
Zeyu Chen 已提交
42
# PaddleHub module dir name
W
wuzewu 已提交
43 44 45 46 47
ASSETS_DIRNAME = "assets"
MODEL_DIRNAME = "model"
MODULE_DESC_PBNAME = "module_desc.pb"
PYTHON_DIR = "python"
PROCESSOR_NAME = "processor"
Z
Zeyu Chen 已提交
48
# PaddleHub var prefix
W
wuzewu 已提交
49
HUB_VAR_PREFIX = "@HUB_%s@"
W
wuzewu 已提交
50

W
wuzewu 已提交
51
_module_runnable_func = {}
W
wuzewu 已提交
52 53


W
wuzewu 已提交
54
def runnable(func):
W
wuzewu 已提交
55
    mod = func.__module__ + "." + inspect.stack()[1][3]
W
wuzewu 已提交
56
    _module_runnable_func[mod] = func.__name__
W
wuzewu 已提交
57 58 59 60 61 62 63

    def _wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return _wrapper


走神的阿圆's avatar
走神的阿圆 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76
_module_serving_func = {}


def serving(func):
    mod = func.__module__ + "." + inspect.stack()[1][3]
    _module_serving_func[mod] = func.__name__

    def _wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return _wrapper


W
wuzewu 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
def moduleinfo(name, version, author, author_email, summary, type):
    def _wrapper(cls):
        if not issubclass(cls, Module):
            raise RuntimeError
        cls._name = name
        cls._version = version
        cls._author = author
        cls._author_email = author_email
        cls._summary = summary
        cls._type = type
        return cls

    return _wrapper


W
wuzewu 已提交
92
class Module(fluid.dygraph.Layer):
W
wuzewu 已提交
93 94 95 96 97 98
    def __new__(cls,
                name=None,
                directory=None,
                module_dir=None,
                version=None,
                **kwargs):
W
wuzewu 已提交
99 100
        if cls.__name__ == "Module":
            if name:
W
wuzewu 已提交
101 102
                module = cls.init_with_name(
                    name=name, version=version, **kwargs)
W
wuzewu 已提交
103
            elif directory:
W
wuzewu 已提交
104
                module = cls.init_with_directory(directory=directory, **kwargs)
W
wuzewu 已提交
105 106 107 108 109 110 111 112 113 114
            elif module_dir:
                logger.warning(
                    "Parameter module_dir is deprecated, please use directory to specify the path"
                )
                if isinstance(module_dir, list) or isinstance(
                        module_dir, tuple):
                    directory = module_dir[0]
                    version = module_dir[1]
                else:
                    directory = module_dir
W
wuzewu 已提交
115
                module = cls.init_with_directory(directory=directory, **kwargs)
走神的阿圆's avatar
走神的阿圆 已提交
116
            CacheUpdater("update_cache", module.name, module.version).start()
117
        else:
118 119
            if not name and not directory:
                directory = os.path.dirname(
W
wuzewu 已提交
120
                    os.path.abspath(sys.modules[cls.__module__].__file__))
121 122 123
                module = Module.init_with_directory(
                    directory=directory, **kwargs)
            else:
W
wuzewu 已提交
124
                module = fluid.dygraph.Layer.__new__(cls)
125

W
wuzewu 已提交
126 127
        return module

W
wuzewu 已提交
128 129 130 131 132 133
    def __init__(self,
                 name=None,
                 directory=None,
                 module_dir=None,
                 version=None,
                 **kwargs):
134
        # Avoid module being initialized multiple times
135
        if "_is_initialize" in self.__dict__ and self._is_initialize:
W
wuzewu 已提交
136
            return
W
wuzewu 已提交
137

W
wuzewu 已提交
138
        super(Module, self).__init__()
S
Steffy-zxf 已提交
139 140
        _run_func_name = self._get_func_name(self.__class__,
                                             _module_runnable_func)
K
kinghuin 已提交
141 142
        self._run_func = getattr(self,
                                 _run_func_name) if _run_func_name else None
S
Steffy-zxf 已提交
143 144
        self._serving_func_name = self._get_func_name(self.__class__,
                                                      _module_serving_func)
W
wuzewu 已提交
145
        self._directory = directory
W
wuzewu 已提交
146
        self._initialize(**kwargs)
147
        self._is_initialize = True
W
wuzewu 已提交
148
        self._code_version = "v2"
W
wuzewu 已提交
149

S
Steffy-zxf 已提交
150 151 152 153 154 155 156 157 158 159 160
    def _get_func_name(self, current_cls, module_func_dict):
        mod = current_cls.__module__ + "." + current_cls.__name__
        if mod in module_func_dict:
            _func_name = module_func_dict[mod]
            return _func_name
        elif current_cls.__bases__:
            for base_class in current_cls.__bases__:
                return self._get_func_name(base_class, module_func_dict)
        else:
            return None

W
wuzewu 已提交
161
    @classmethod
W
wuzewu 已提交
162
    def init_with_name(cls, name, version=None, **kwargs):
W
wuzewu 已提交
163 164 165 166 167 168 169 170 171 172 173 174
        fp_lock = open(os.path.join(CACHE_HOME, name), "a")
        lock.flock(fp_lock, lock.LOCK_EX)
        log_msg = "Installing %s module" % name
        if version:
            log_msg += "-%s" % version
        logger.info(log_msg)
        extra = {"command": "install"}
        result, tips, module_dir = default_module_manager.install_module(
            module_name=name, module_version=version, extra=extra)
        if not result:
            logger.error(tips)
            raise RuntimeError(tips)
W
wuzewu 已提交
175

W
wuzewu 已提交
176 177
        logger.info(tips)
        lock.flock(fp_lock, lock.LOCK_UN)
W
wuzewu 已提交
178
        return cls.init_with_directory(directory=module_dir[0], **kwargs)
W
wuzewu 已提交
179

W
wuzewu 已提交
180
    @classmethod
W
wuzewu 已提交
181
    def init_with_directory(cls, directory, **kwargs):
W
wuzewu 已提交
182
        desc_file = os.path.join(directory, MODULE_DESC_PBNAME)
W
wuzewu 已提交
183 184 185 186 187
        if os.path.exists(desc_file):
            checker = ModuleChecker(directory)
            checker.check()
            return ModuleV1(directory=directory, **kwargs)

W
wuzewu 已提交
188 189
        if directory.endswith(os.sep):
            directory = directory[:-1]
W
wuzewu 已提交
190 191 192 193 194 195
        basename = os.path.split(directory)[-1]
        dirname = os.path.join(*list(os.path.split(directory)[:-1]))
        sys.path.insert(0, dirname)
        _module = importlib.import_module("{}.module".format(basename))
        for _item, _cls in inspect.getmembers(_module, inspect.isclass):
            _item = _module.__dict__[_item]
W
wuzewu 已提交
196 197 198
            _file = os.path.realpath(sys.modules[_item.__module__].__file__)
            _module_path = os.path.realpath(
                os.path.join(directory, "module.py"))
S
Steffy-zxf 已提交
199
            if issubclass(_item, Module) and _file.startswith(_module_path):
W
wuzewu 已提交
200 201 202 203
                user_module = _item(directory=directory, **kwargs)
                break
        sys.path.pop(0)
        return user_module
W
wuzewu 已提交
204

W
wuzewu 已提交
205 206 207 208
    @property
    def run_func(self):
        return self._run_func

W
wuzewu 已提交
209 210 211 212 213 214
    @property
    def directory(self):
        return self._directory

    @property
    def author(self):
W
wuzewu 已提交
215
        return self.__class__._author
W
wuzewu 已提交
216 217 218

    @property
    def author_email(self):
W
wuzewu 已提交
219
        return self.__class__._author_email
W
wuzewu 已提交
220 221 222

    @property
    def summary(self):
W
wuzewu 已提交
223
        return self.__class__._summary
W
wuzewu 已提交
224 225 226

    @property
    def type(self):
W
wuzewu 已提交
227
        return self.__class__._type
W
wuzewu 已提交
228 229 230

    @property
    def version(self):
W
wuzewu 已提交
231
        return self.__class__._version
W
wuzewu 已提交
232 233 234

    @property
    def name(self):
W
wuzewu 已提交
235
        return self.__class__._name
W
wuzewu 已提交
236

W
wuzewu 已提交
237 238 239 240
    @property
    def code_version(self):
        return self._code_version

W
wuzewu 已提交
241
    @property
W
wuzewu 已提交
242
    def is_runnable(self):
W
wuzewu 已提交
243
        return self._run_func != None
W
wuzewu 已提交
244

走神的阿圆's avatar
走神的阿圆 已提交
245 246 247 248
    @property
    def serving_func_name(self):
        return self._serving_func_name

W
wuzewu 已提交
249 250
    def _initialize(self):
        pass
W
wuzewu 已提交
251

W
wuzewu 已提交
252
    def forward(self, *args, **kwargs):
W
wuzewu 已提交
253 254
        raise RuntimeError('{} does not support dynamic graph mode yet.'.format(
            self.name))
W
wuzewu 已提交
255

W
wuzewu 已提交
256

257
class ModuleHelper(object):
W
wuzewu 已提交
258 259
    def __init__(self, directory):
        self.directory = directory
W
wuzewu 已提交
260 261

    def module_desc_path(self):
W
wuzewu 已提交
262
        return os.path.join(self.directory, MODULE_DESC_PBNAME)
W
wuzewu 已提交
263 264

    def model_path(self):
W
wuzewu 已提交
265
        return os.path.join(self.directory, MODEL_DIRNAME)
W
wuzewu 已提交
266

W
wuzewu 已提交
267
    def processor_path(self):
W
wuzewu 已提交
268
        return os.path.join(self.directory, PYTHON_DIR)
W
wuzewu 已提交
269 270 271 272 273

    def processor_name(self):
        return PROCESSOR_NAME

    def assets_path(self):
W
wuzewu 已提交
274
        return os.path.join(self.directory, ASSETS_DIRNAME)
W
wuzewu 已提交
275

W
wuzewu 已提交
276

W
wuzewu 已提交
277 278
class ModuleV1(Module):
    def __init__(self, name=None, directory=None, module_dir=None,
B
BinLong 已提交
279
                 version=None):
W
wuzewu 已提交
280 281 282
        if not directory:
            return
        super(ModuleV1, self).__init__(name, directory, module_dir, version)
W
wuzewu 已提交
283 284 285 286
        self.program = None
        self.assets = []
        self.helper = None
        self.signatures = {}
W
wuzewu 已提交
287
        self.default_signature = None
W
wuzewu 已提交
288
        self.processor = None
W
wuzewu 已提交
289
        self.extra_info = {}
W
wuzewu 已提交
290
        self._code_version = "v1"
W
wuzewu 已提交
291

W
wuzewu 已提交
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
        # parse desc
        self.module_desc_path = os.path.join(self.directory, MODULE_DESC_PBNAME)
        self._desc = module_desc_pb2.ModuleDesc()
        with open(self.module_desc_path, "rb") as file:
            self._desc.ParseFromString(file.read())

        module_info = self.desc.attr.map.data['module_info']
        self._name = utils.from_module_attr_to_pyobj(
            module_info.map.data['name'])
        self._author = utils.from_module_attr_to_pyobj(
            module_info.map.data['author'])
        self._author_email = utils.from_module_attr_to_pyobj(
            module_info.map.data['author_email'])
        self._version = utils.from_module_attr_to_pyobj(
            module_info.map.data['version'])
        self._type = utils.from_module_attr_to_pyobj(
            module_info.map.data['type'])
        self._summary = utils.from_module_attr_to_pyobj(
            module_info.map.data['summary'])

W
wuzewu 已提交
312 313 314 315 316 317
        # cache data
        self.last_call_name = None
        self.cache_feed_dict = None
        self.cache_fetch_dict = None
        self.cache_program = None

W
wuzewu 已提交
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
        self.helper = ModuleHelper(directory)
        exe = fluid.Executor(fluid.CPUPlace())
        self.program, _, _ = fluid.io.load_inference_model(
            self.helper.model_path(), executor=exe)
        for block in self.program.blocks:
            for op in block.ops:
                if "op_callstack" in op.all_attrs():
                    op._set_attr("op_callstack", [""])
        self._load_processor()
        self._load_assets()
        self._recover_from_desc()
        self._generate_sign_attr()
        self._generate_extra_info()
        self._restore_parameter(self.program)
        self._recover_variable_info(self.program)
W
wuzewu 已提交
333

W
wuzewu 已提交
334
    @property
走神的阿圆's avatar
走神的阿圆 已提交
335 336 337 338
    def serving_func_name(self):
        serving_func_name = self.desc.attr.map.data['default_signature'].s
        return serving_func_name if serving_func_name != "" else None

W
wuzewu 已提交
339
    @property
W
wuzewu 已提交
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
    def desc(self):
        return self._desc

    @property
    def author(self):
        return self._author

    @property
    def author_email(self):
        return self._author_email

    @property
    def summary(self):
        return self._summary

    @property
    def type(self):
        return self._type

    @property
    def version(self):
        return self._version

    @property
    def name(self):
        return self._name

W
wuzewu 已提交
367 368 369 370 371
    def _dump_processor(self):
        import inspect
        pymodule = inspect.getmodule(self.processor)
        pycode = inspect.getsource(pymodule)
        processor_path = self.helper.processor_path()
372 373 374
        processor_md5 = utils.md5(pycode)
        processor_md5 += str(time.time())
        processor_name = utils.md5(processor_md5)
W
wuzewu 已提交
375 376 377 378
        output_file = os.path.join(processor_path, processor_name + ".py")
        utils.mkdir(processor_path)
        with open(output_file, "w") as file:
            file.write(pycode)
W
wuzewu 已提交
379 380
        utils.from_pyobj_to_module_attr(
            processor_name, self.desc.attr.map.data['processor_info'])
W
wuzewu 已提交
381

W
wuzewu 已提交
382 383
    def _load_processor(self):
        processor_path = self.helper.processor_path()
W
wuzewu 已提交
384 385
        if os.path.exists(processor_path):
            sys.path.append(processor_path)
W
wuzewu 已提交
386 387
            processor_name = utils.from_module_attr_to_pyobj(
                self.desc.attr.map.data['processor_info'])
W
wuzewu 已提交
388 389 390
            self.processor = __import__(processor_name).Processor(module=self)
        else:
            self.processor = None
W
wuzewu 已提交
391

W
wuzewu 已提交
392 393 394 395 396
    def _dump_assets(self):
        utils.mkdir(self.helper.assets_path())
        for asset in self.assets:
            filename = os.path.basename(asset)
            newfile = os.path.join(self.helper.assets_path(), filename)
W
wuzewu 已提交
397
            shutil.copyfile(asset, newfile)
W
wuzewu 已提交
398 399 400 401 402 403 404 405

    def _load_assets(self):
        assets_path = self.helper.assets_path()
        self.assets = []
        for file in os.listdir(assets_path):
            filepath = os.path.join(self.helper.assets_path(), file)
            self.assets.append(filepath)

Z
Zeyu Chen 已提交
406
    def _restore_parameter(self, program):
W
wuzewu 已提交
407
        global_block = program.global_block()
W
wuzewu 已提交
408
        param_attrs = self.desc.attr.map.data['param_attrs']
W
wuzewu 已提交
409
        for key, param_attr in param_attrs.map.data.items():
W
wuzewu 已提交
410
            param = paddle_helper.from_module_attr_to_param(param_attr)
411
            param['name'] = self.get_var_name_with_prefix(key)
W
wuzewu 已提交
412 413 414 415 416 417 418 419 420 421
            if (param['name'] not in global_block.vars):
                continue
            var = global_block.var(param['name'])
            global_block.create_parameter(
                shape=var.shape,
                dtype=var.dtype,
                type=var.type,
                lod_level=var.lod_level,
                error_clip=var.error_clip,
                stop_gradient=var.stop_gradient,
W
wuzewu 已提交
422 423
                is_data=var.is_data,
                **param)
W
wuzewu 已提交
424 425

    def _recover_variable_info(self, program):
W
wuzewu 已提交
426
        var_infos = self.desc.attr.map.data['var_infos']
W
wuzewu 已提交
427
        for var_info in var_infos.map.data:
W
wuzewu 已提交
428
            idx = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
429
                var_infos.map.data[var_info].map.data['block_id'])
W
wuzewu 已提交
430
            stop_gradient = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
431 432
                var_infos.map.data[var_info].map.data['stop_gradient'])
            block = program.blocks[idx]
433
            var_name = self.get_var_name_with_prefix(var_info)
W
wuzewu 已提交
434 435 436 437
            if var_name in block.vars:
                var = block.vars[var_name]
                var.stop_gradient = stop_gradient

W
wuzewu 已提交
438 439 440 441 442 443 444 445
    def get_extra_info(self, key):
        return self.extra_info.get(key, None)

    def _generate_extra_info(self):
        for key in self.extra_info:
            self.__dict__["get_%s" % key] = functools.partial(
                self.get_extra_info, key=key)

W
wuzewu 已提交
446 447 448
    def _generate_sign_attr(self):
        self._check_signatures()
        for sign in self.signatures:
W
wuzewu 已提交
449 450
            self.__dict__[sign] = functools.partial(
                self.__call__, sign_name=sign)
W
wuzewu 已提交
451

452 453 454 455
    def get_vocab_path(self):
        for assets_file in self.assets:
            if "vocab.txt" in assets_file:
                return assets_file
K
kinghuin 已提交
456 457 458 459 460 461 462 463 464 465 466 467 468
        return None

    def get_word_dict_path(self):
        for assets_file in self.assets:
            if "dict.wordseg.pickle" in assets_file:
                return assets_file
        return None

    def get_spm_path(self):
        for assets_file in self.assets:
            if "spm_cased_simp_sampled.model" in assets_file:
                return assets_file
        return None
469

W
wuzewu 已提交
470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493
    def _recover_from_desc(self):
        # recover signature
        for sign, module_var in self.desc.sign2var.items():
            inputs = []
            outputs = []
            feed_names = []
            fetch_names = []
            for var in module_var.feed_desc:
                variable = self.program.global_block().vars[var.var_name]
                inputs.append(variable)
                feed_names.append(var.alias)

            for var in module_var.fetch_desc:
                variable = self.program.global_block().vars[var.var_name]
                outputs.append(variable)
                fetch_names.append(var.alias)

            self.signatures[sign] = create_signature(
                sign,
                inputs=inputs,
                outputs=outputs,
                feed_names=feed_names,
                fetch_names=fetch_names)

W
wuzewu 已提交
494
        # recover default signature
W
wuzewu 已提交
495 496
        default_signature_name = utils.from_module_attr_to_pyobj(
            self.desc.attr.map.data['default_signature'])
W
wuzewu 已提交
497
        self.default_signature = self.signatures[
W
wuzewu 已提交
498
            default_signature_name].name if default_signature_name else None
W
wuzewu 已提交
499

W
wuzewu 已提交
500
        # recover module info
W
wuzewu 已提交
501
        module_info = self.desc.attr.map.data['module_info']
W
wuzewu 已提交
502
        self._name = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
503
            module_info.map.data['name'])
W
wuzewu 已提交
504
        self._author = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
505
            module_info.map.data['author'])
W
wuzewu 已提交
506
        self._author_email = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
507
            module_info.map.data['author_email'])
W
wuzewu 已提交
508
        self._version = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
509
            module_info.map.data['version'])
W
wuzewu 已提交
510
        self._type = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
511
            module_info.map.data['type'])
W
wuzewu 已提交
512
        self._summary = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
513 514
            module_info.map.data['summary'])

W
wuzewu 已提交
515 516 517 518 519 520
        # recover extra info
        extra_info = self.desc.attr.map.data['extra_info']
        self.extra_info = {}
        for key, value in extra_info.map.data.items():
            self.extra_info[key] = utils.from_module_attr_to_pyobj(value)

W
wuzewu 已提交
521
        # recover name prefix
W
wuzewu 已提交
522
        self._name_prefix = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
523
            self.desc.attr.map.data["name_prefix"])
W
wuzewu 已提交
524

525
    def __call__(self, sign_name, data, use_gpu=False, batch_size=1, **kwargs):
W
wuzewu 已提交
526 527
        self.check_processor()

W
wuzewu 已提交
528
        def _get_reader_and_feeder(data_format, data, place):
W
wuzewu 已提交
529
            def _reader(process_data):
W
wuzewu 已提交
530 531 532 533 534 535 536 537 538
                for item in zip(*process_data):
                    yield item

            process_data = []
            feed_name_list = []
            for key in data_format:
                process_data.append([value['processed'] for value in data[key]])
                feed_name_list.append(data_format[key]['feed_key'])
            feeder = fluid.DataFeeder(feed_list=feed_name_list, place=place)
W
wuzewu 已提交
539
            return functools.partial(_reader, process_data=process_data), feeder
W
wuzewu 已提交
540

W
wuzewu 已提交
541 542 543 544 545 546 547 548
        if self.last_call_name != sign_name:
            self.last_call_name = sign_name
            self.cache_feed_dict, self.cache_fetch_dict, self.cache_program = self.context(
                sign_name, for_test=True)
        feed_dict = self.cache_feed_dict
        fetch_dict = self.cache_fetch_dict
        program = self.cache_program

W
wuzewu 已提交
549 550
        fetch_list = list(set([value for key, value in fetch_dict.items()]))
        with fluid.program_guard(program):
W
wuzewu 已提交
551 552
            result = []
            index = 0
553 554 555 556 557 558 559
            try:
                _places = os.environ["CUDA_VISIBLE_DEVICES"]
                int(_places[0])
            except:
                use_gpu = False

            place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
560

W
wuzewu 已提交
561
            exe = fluid.Executor(place=place)
W
wuzewu 已提交
562 563 564 565
            data = self.processor.preprocess(
                sign_name=sign_name, data_dict=data)
            data_format = self.processor.data_format(sign_name=sign_name)
            reader, feeder = _get_reader_and_feeder(data_format, data, place)
566
            reader = paddle.batch(reader, batch_size=batch_size)
W
wuzewu 已提交
567 568 569 570 571
            for batch in reader():
                data_out = exe.run(
                    feed=feeder.feed(batch),
                    fetch_list=fetch_list,
                    return_numpy=False)
W
wuzewu 已提交
572 573 574 575 576 577 578 579 580
                sub_data = {
                    key: value[index:index + len(batch)]
                    for key, value in data.items()
                }
                result += self.processor.postprocess(sign_name, data_out,
                                                     sub_data, **kwargs)
                index += len(batch)

        return result
W
wuzewu 已提交
581

W
wuzewu 已提交
582
    def check_processor(self):
W
wuzewu 已提交
583 584
        if not self.processor:
            raise ValueError("This Module is not callable!")
W
wuzewu 已提交
585

W
wuzewu 已提交
586
    @property
W
wuzewu 已提交
587
    def is_runnable(self):
W
wuzewu 已提交
588 589
        return self.default_signature != None

走神的阿圆's avatar
走神的阿圆 已提交
590 591 592 593
    @property
    def code_version(self):
        return self._code_version

W
wuzewu 已提交
594
    def context(self,
595
                sign_name=None,
W
wuzewu 已提交
596
                for_test=False,
Z
Zeyu Chen 已提交
597
                trainable=True,
W
wuzewu 已提交
598
                regularizer=None,
599
                max_seq_len=128,
W
wuzewu 已提交
600
                learning_rate=1e-3):
601 602 603 604 605
        """
        Args:
            max_seq_len(int): maximum sequence length, this option is only
            available for BERT/ERNIE module
        """
W
wuzewu 已提交
606

607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635
        if sign_name:
            if sign_name not in self.signatures:
                raise KeyError(
                    "Module did not have a signature with name %s" % sign_name)
            signature = self.signatures[sign_name]
        else:
            inputs = [
                input for signature in self.signatures.values()
                for input in signature.inputs
            ]
            outputs = [
                output for signature in self.signatures.values()
                for output in signature.outputs
            ]
            feed_names = [
                feed_name for signature in self.signatures.values()
                for feed_name in signature.feed_names
            ]
            fetch_names = [
                fetch_name for signature in self.signatures.values()
                for fetch_name in signature.fetch_names
            ]
            signature = create_signature(
                name="hub_temp_signature",
                inputs=inputs,
                outputs=outputs,
                feed_names=feed_names,
                fetch_names=fetch_names,
                for_predict=False)
W
wuzewu 已提交
636

W
wuzewu 已提交
637
        program = self.program.clone(for_test=for_test)
W
wuzewu 已提交
638
        paddle_helper.remove_feed_fetch_op(program)
W
wuzewu 已提交
639 640

        if not for_test:
W
wuzewu 已提交
641
            paddle_helper.set_parameter_trainable(program, trainable)
W
wuzewu 已提交
642

W
wuzewu 已提交
643
            paddle_helper.set_parameter_learning_rate(program, learning_rate)
W
wuzewu 已提交
644

W
wuzewu 已提交
645
            paddle_helper.set_parameter_regularizer(program, regularizer)
W
wuzewu 已提交
646

Z
Zeyu Chen 已提交
647
            self._restore_parameter(program)
W
wuzewu 已提交
648

W
wuzewu 已提交
649 650
        self._recover_variable_info(program)

W
wuzewu 已提交
651
        paddle_helper.set_op_attr(program, is_test=for_test)
W
wuzewu 已提交
652 653 654 655 656 657 658 659 660 661 662 663 664 665
        feed_dict = {}
        fetch_dict = {}
        for index, var in enumerate(signature.inputs):
            feed_dict[index] = program.global_block().var(var.name)
            key = signature.feed_names[index]
            if key:
                feed_dict[key] = program.global_block().var(var.name)

        for index, var in enumerate(signature.outputs):
            fetch_dict[index] = program.global_block().var(var.name)
            key = signature.fetch_names[index]
            if key:
                fetch_dict[key] = program.global_block().var(var.name)

666
        # update BERT/ERNIE's input tensor's sequence length to max_seq_len
K
kinghuin 已提交
667
        if "bert" in self.name or self.name.startswith("ernie"):
668 669 670 671
            MAX_SEQ_LENGTH = 512
            if max_seq_len > MAX_SEQ_LENGTH or max_seq_len <= 0:
                raise ValueError(
                    "max_seq_len({}) should be in the range of [1, {}]".format(
K
kinghuin 已提交
672
                        max_seq_len, MAX_SEQ_LENGTH))
673
            logger.info(
674
                "Set maximum sequence length of input tensor to {}".format(
675
                    max_seq_len))
676 677
            if self.name.startswith("ernie_v2"):
                feed_list = [
678 679
                    "input_ids", "position_ids", "segment_ids", "input_mask",
                    "task_ids"
680 681 682 683 684 685
                ]
            else:
                feed_list = [
                    "input_ids", "position_ids", "segment_ids", "input_mask"
                ]
            for tensor_name in feed_list:
686 687 688 689 690 691 692
                seq_tensor_shape = [-1, max_seq_len, 1]
                logger.info("The shape of input tensor[{}] set to {}".format(
                    tensor_name, seq_tensor_shape))
                program.global_block().var(
                    feed_dict[tensor_name].name).desc.set_shape(
                        seq_tensor_shape)

693 694
        # record num parameters loaded by paddlehub
        num_param_loaded = 0
W
wuzewu 已提交
695
        for param in program.global_block().iter_parameters():
696 697 698
            num_param_loaded += 1
        logger.info(
            "%d pretrained paramaters loaded by PaddleHub" % num_param_loaded)
W
wuzewu 已提交
699

W
wuzewu 已提交
700 701
        return feed_dict, fetch_dict, program

702
    def get_name_prefix(self):
W
wuzewu 已提交
703
        return self._name_prefix
704 705 706 707

    def get_var_name_with_prefix(self, var_name):
        return self.get_name_prefix() + var_name

W
wuzewu 已提交
708
    def _check_signatures(self):
W
wuzewu 已提交
709 710
        if not self.signatures:
            raise ValueError("Signatures should not be None")
W
wuzewu 已提交
711 712

        for key, sign in self.signatures.items():
W
wuzewu 已提交
713 714 715 716
            if not isinstance(sign, Signature):
                raise TypeError(
                    "Item in Signatures shoule be an instance of paddlehub.Signature"
                )
W
wuzewu 已提交
717 718 719

            for input in sign.inputs:
                _tmp_program = input.block.program
W
wuzewu 已提交
720 721 722 723
                if not self.program == _tmp_program:
                    raise ValueError(
                        "All input and outputs variables in signature should come from the same Program"
                    )
W
wuzewu 已提交
724 725 726

            for output in sign.outputs:
                _tmp_program = output.block.program
W
wuzewu 已提交
727 728 729 730
                if not self.program == _tmp_program:
                    raise ValueError(
                        "All input and outputs variables in signature should come from the same Program"
                    )