module.py 25.4 KB
Newer Older
S
Steffy-zxf 已提交
1
# coding:utf-8
W
wuzewu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
W
wuzewu 已提交
19 20 21 22 23

import os
import time
import sys
import functools
W
wuzewu 已提交
24 25
import inspect
import importlib
W
wuzewu 已提交
26
import shutil
W
wuzewu 已提交
27 28 29 30

import paddle
import paddle.fluid as fluid

W
wuzewu 已提交
31 32
from paddlehub.common import utils
from paddlehub.common import paddle_helper
W
wuzewu 已提交
33
from paddlehub.common.dir import CACHE_HOME
S
shenyuhan 已提交
34
from paddlehub.common.lock import lock
W
wuzewu 已提交
35 36
from paddlehub.common.logger import logger
from paddlehub.common.hub_server import CacheUpdater
W
wuzewu 已提交
37 38
from paddlehub.module import module_desc_pb2
from paddlehub.module.manager import default_module_manager
W
wuzewu 已提交
39 40
from paddlehub.module.checker import ModuleChecker
from paddlehub.module.signature import Signature, create_signature
W
wuzewu 已提交
41

Z
Zeyu Chen 已提交
42
# PaddleHub module dir name
W
wuzewu 已提交
43 44 45 46 47
ASSETS_DIRNAME = "assets"
MODEL_DIRNAME = "model"
MODULE_DESC_PBNAME = "module_desc.pb"
PYTHON_DIR = "python"
PROCESSOR_NAME = "processor"
Z
Zeyu Chen 已提交
48
# PaddleHub var prefix
W
wuzewu 已提交
49
HUB_VAR_PREFIX = "@HUB_%s@"
W
wuzewu 已提交
50

W
wuzewu 已提交
51
_module_runnable_func = {}
W
wuzewu 已提交
52 53


W
wuzewu 已提交
54
def runnable(func):
W
wuzewu 已提交
55
    mod = func.__module__ + "." + inspect.stack()[1][3]
W
wuzewu 已提交
56
    _module_runnable_func[mod] = func.__name__
W
wuzewu 已提交
57 58 59 60 61 62 63

    def _wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return _wrapper


走神的阿圆's avatar
走神的阿圆 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76
_module_serving_func = {}


def serving(func):
    mod = func.__module__ + "." + inspect.stack()[1][3]
    _module_serving_func[mod] = func.__name__

    def _wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    return _wrapper


W
wuzewu 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
def moduleinfo(name, version, author, author_email, summary, type):
    def _wrapper(cls):
        if not issubclass(cls, Module):
            raise RuntimeError
        cls._name = name
        cls._version = version
        cls._author = author
        cls._author_email = author_email
        cls._summary = summary
        cls._type = type
        return cls

    return _wrapper


W
wuzewu 已提交
92
class Module(fluid.dygraph.Layer):
W
wuzewu 已提交
93 94 95 96 97 98
    def __new__(cls,
                name=None,
                directory=None,
                module_dir=None,
                version=None,
                **kwargs):
W
wuzewu 已提交
99 100
        if cls.__name__ == "Module":
            if name:
W
wuzewu 已提交
101 102
                module = cls.init_with_name(
                    name=name, version=version, **kwargs)
W
wuzewu 已提交
103
            elif directory:
W
wuzewu 已提交
104
                module = cls.init_with_directory(directory=directory, **kwargs)
W
wuzewu 已提交
105 106 107 108 109 110 111 112 113 114
            elif module_dir:
                logger.warning(
                    "Parameter module_dir is deprecated, please use directory to specify the path"
                )
                if isinstance(module_dir, list) or isinstance(
                        module_dir, tuple):
                    directory = module_dir[0]
                    version = module_dir[1]
                else:
                    directory = module_dir
W
wuzewu 已提交
115
                module = cls.init_with_directory(directory=directory, **kwargs)
走神的阿圆's avatar
走神的阿圆 已提交
116
            CacheUpdater("update_cache", module.name, module.version).start()
117
        else:
118 119
            if not name and not directory:
                directory = os.path.dirname(
W
wuzewu 已提交
120
                    os.path.abspath(sys.modules[cls.__module__].__file__))
121 122 123
                module = Module.init_with_directory(
                    directory=directory, **kwargs)
            else:
W
wuzewu 已提交
124
                module = fluid.dygraph.Layer.__new__(cls)
125

W
wuzewu 已提交
126 127
        return module

W
wuzewu 已提交
128 129 130 131 132 133
    def __init__(self,
                 name=None,
                 directory=None,
                 module_dir=None,
                 version=None,
                 **kwargs):
134
        # Avoid module being initialized multiple times
135
        if "_is_initialize" in self.__dict__ and self._is_initialize:
W
wuzewu 已提交
136
            return
W
wuzewu 已提交
137

W
wuzewu 已提交
138
        super(Module, self).__init__()
S
Steffy-zxf 已提交
139 140
        _run_func_name = self._get_func_name(self.__class__,
                                             _module_runnable_func)
K
kinghuin 已提交
141 142
        self._run_func = getattr(self,
                                 _run_func_name) if _run_func_name else None
S
Steffy-zxf 已提交
143 144
        self._serving_func_name = self._get_func_name(self.__class__,
                                                      _module_serving_func)
W
wuzewu 已提交
145
        self._directory = directory
W
wuzewu 已提交
146
        self._initialize(**kwargs)
147
        self._is_initialize = True
W
wuzewu 已提交
148
        self._code_version = "v2"
W
wuzewu 已提交
149 150
        self.model_runner = fluid.dygraph.StaticModelRunner(
            self.pretrained_model_path)
W
wuzewu 已提交
151 152

    @property
W
wuzewu 已提交
153 154
    def pretrained_model_path(self):
        return self.default_pretrained_model_path
W
wuzewu 已提交
155

S
Steffy-zxf 已提交
156 157 158 159 160 161 162 163 164 165 166
    def _get_func_name(self, current_cls, module_func_dict):
        mod = current_cls.__module__ + "." + current_cls.__name__
        if mod in module_func_dict:
            _func_name = module_func_dict[mod]
            return _func_name
        elif current_cls.__bases__:
            for base_class in current_cls.__bases__:
                return self._get_func_name(base_class, module_func_dict)
        else:
            return None

W
wuzewu 已提交
167
    @classmethod
W
wuzewu 已提交
168
    def init_with_name(cls, name, version=None, **kwargs):
W
wuzewu 已提交
169 170 171 172 173 174 175 176 177 178 179 180
        fp_lock = open(os.path.join(CACHE_HOME, name), "a")
        lock.flock(fp_lock, lock.LOCK_EX)
        log_msg = "Installing %s module" % name
        if version:
            log_msg += "-%s" % version
        logger.info(log_msg)
        extra = {"command": "install"}
        result, tips, module_dir = default_module_manager.install_module(
            module_name=name, module_version=version, extra=extra)
        if not result:
            logger.error(tips)
            raise RuntimeError(tips)
W
wuzewu 已提交
181

W
wuzewu 已提交
182 183
        logger.info(tips)
        lock.flock(fp_lock, lock.LOCK_UN)
W
wuzewu 已提交
184
        return cls.init_with_directory(directory=module_dir[0], **kwargs)
W
wuzewu 已提交
185

W
wuzewu 已提交
186
    @classmethod
W
wuzewu 已提交
187
    def init_with_directory(cls, directory, **kwargs):
W
wuzewu 已提交
188
        desc_file = os.path.join(directory, MODULE_DESC_PBNAME)
W
wuzewu 已提交
189 190 191 192 193
        if os.path.exists(desc_file):
            checker = ModuleChecker(directory)
            checker.check()
            return ModuleV1(directory=directory, **kwargs)

W
wuzewu 已提交
194 195
        if directory.endswith(os.sep):
            directory = directory[:-1]
W
wuzewu 已提交
196 197 198 199 200 201
        basename = os.path.split(directory)[-1]
        dirname = os.path.join(*list(os.path.split(directory)[:-1]))
        sys.path.insert(0, dirname)
        _module = importlib.import_module("{}.module".format(basename))
        for _item, _cls in inspect.getmembers(_module, inspect.isclass):
            _item = _module.__dict__[_item]
W
wuzewu 已提交
202 203 204
            _file = os.path.realpath(sys.modules[_item.__module__].__file__)
            _module_path = os.path.realpath(
                os.path.join(directory, "module.py"))
S
Steffy-zxf 已提交
205
            if issubclass(_item, Module) and _file.startswith(_module_path):
W
wuzewu 已提交
206 207 208 209
                user_module = _item(directory=directory, **kwargs)
                break
        sys.path.pop(0)
        return user_module
W
wuzewu 已提交
210

W
wuzewu 已提交
211 212 213 214
    @property
    def run_func(self):
        return self._run_func

W
wuzewu 已提交
215 216 217 218 219 220
    @property
    def directory(self):
        return self._directory

    @property
    def author(self):
W
wuzewu 已提交
221
        return self.__class__._author
W
wuzewu 已提交
222 223 224

    @property
    def author_email(self):
W
wuzewu 已提交
225
        return self.__class__._author_email
W
wuzewu 已提交
226 227 228

    @property
    def summary(self):
W
wuzewu 已提交
229
        return self.__class__._summary
W
wuzewu 已提交
230 231 232

    @property
    def type(self):
W
wuzewu 已提交
233
        return self.__class__._type
W
wuzewu 已提交
234 235 236

    @property
    def version(self):
W
wuzewu 已提交
237
        return self.__class__._version
W
wuzewu 已提交
238 239 240

    @property
    def name(self):
W
wuzewu 已提交
241
        return self.__class__._name
W
wuzewu 已提交
242

W
wuzewu 已提交
243 244 245 246
    @property
    def code_version(self):
        return self._code_version

W
wuzewu 已提交
247
    @property
W
wuzewu 已提交
248
    def is_runnable(self):
W
wuzewu 已提交
249
        return self._run_func != None
W
wuzewu 已提交
250

走神的阿圆's avatar
走神的阿圆 已提交
251 252 253 254
    @property
    def serving_func_name(self):
        return self._serving_func_name

W
wuzewu 已提交
255 256
    def _initialize(self):
        pass
W
wuzewu 已提交
257

W
wuzewu 已提交
258 259
    def forward(self, *args, **kwargs):
        return self.model_runner(*args, **kwargs)
W
wuzewu 已提交
260

W
wuzewu 已提交
261

262
class ModuleHelper(object):
W
wuzewu 已提交
263 264
    def __init__(self, directory):
        self.directory = directory
W
wuzewu 已提交
265 266

    def module_desc_path(self):
W
wuzewu 已提交
267
        return os.path.join(self.directory, MODULE_DESC_PBNAME)
W
wuzewu 已提交
268 269

    def model_path(self):
W
wuzewu 已提交
270
        return os.path.join(self.directory, MODEL_DIRNAME)
W
wuzewu 已提交
271

W
wuzewu 已提交
272
    def processor_path(self):
W
wuzewu 已提交
273
        return os.path.join(self.directory, PYTHON_DIR)
W
wuzewu 已提交
274 275 276 277 278

    def processor_name(self):
        return PROCESSOR_NAME

    def assets_path(self):
W
wuzewu 已提交
279
        return os.path.join(self.directory, ASSETS_DIRNAME)
W
wuzewu 已提交
280

W
wuzewu 已提交
281

W
wuzewu 已提交
282 283
class ModuleV1(Module):
    def __init__(self, name=None, directory=None, module_dir=None,
B
BinLong 已提交
284
                 version=None):
W
wuzewu 已提交
285 286 287
        if not directory:
            return
        super(ModuleV1, self).__init__(name, directory, module_dir, version)
W
wuzewu 已提交
288 289 290 291
        self.program = None
        self.assets = []
        self.helper = None
        self.signatures = {}
W
wuzewu 已提交
292
        self.default_signature = None
W
wuzewu 已提交
293
        self.processor = None
W
wuzewu 已提交
294
        self.extra_info = {}
W
wuzewu 已提交
295
        self._code_version = "v1"
W
wuzewu 已提交
296

W
wuzewu 已提交
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
        # parse desc
        self.module_desc_path = os.path.join(self.directory, MODULE_DESC_PBNAME)
        self._desc = module_desc_pb2.ModuleDesc()
        with open(self.module_desc_path, "rb") as file:
            self._desc.ParseFromString(file.read())

        module_info = self.desc.attr.map.data['module_info']
        self._name = utils.from_module_attr_to_pyobj(
            module_info.map.data['name'])
        self._author = utils.from_module_attr_to_pyobj(
            module_info.map.data['author'])
        self._author_email = utils.from_module_attr_to_pyobj(
            module_info.map.data['author_email'])
        self._version = utils.from_module_attr_to_pyobj(
            module_info.map.data['version'])
        self._type = utils.from_module_attr_to_pyobj(
            module_info.map.data['type'])
        self._summary = utils.from_module_attr_to_pyobj(
            module_info.map.data['summary'])

W
wuzewu 已提交
317 318 319 320 321 322
        # cache data
        self.last_call_name = None
        self.cache_feed_dict = None
        self.cache_fetch_dict = None
        self.cache_program = None

W
wuzewu 已提交
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
        self.helper = ModuleHelper(directory)
        exe = fluid.Executor(fluid.CPUPlace())
        self.program, _, _ = fluid.io.load_inference_model(
            self.helper.model_path(), executor=exe)
        for block in self.program.blocks:
            for op in block.ops:
                if "op_callstack" in op.all_attrs():
                    op._set_attr("op_callstack", [""])
        self._load_processor()
        self._load_assets()
        self._recover_from_desc()
        self._generate_sign_attr()
        self._generate_extra_info()
        self._restore_parameter(self.program)
        self._recover_variable_info(self.program)
W
wuzewu 已提交
338

W
wuzewu 已提交
339
    @property
走神的阿圆's avatar
走神的阿圆 已提交
340 341 342 343
    def serving_func_name(self):
        serving_func_name = self.desc.attr.map.data['default_signature'].s
        return serving_func_name if serving_func_name != "" else None

W
wuzewu 已提交
344
    @property
W
wuzewu 已提交
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
    def desc(self):
        return self._desc

    @property
    def author(self):
        return self._author

    @property
    def author_email(self):
        return self._author_email

    @property
    def summary(self):
        return self._summary

    @property
    def type(self):
        return self._type

    @property
    def version(self):
        return self._version

    @property
    def name(self):
        return self._name

W
wuzewu 已提交
372 373 374 375 376
    def _dump_processor(self):
        import inspect
        pymodule = inspect.getmodule(self.processor)
        pycode = inspect.getsource(pymodule)
        processor_path = self.helper.processor_path()
377 378 379
        processor_md5 = utils.md5(pycode)
        processor_md5 += str(time.time())
        processor_name = utils.md5(processor_md5)
W
wuzewu 已提交
380 381 382 383
        output_file = os.path.join(processor_path, processor_name + ".py")
        utils.mkdir(processor_path)
        with open(output_file, "w") as file:
            file.write(pycode)
W
wuzewu 已提交
384 385
        utils.from_pyobj_to_module_attr(
            processor_name, self.desc.attr.map.data['processor_info'])
W
wuzewu 已提交
386

W
wuzewu 已提交
387 388
    def _load_processor(self):
        processor_path = self.helper.processor_path()
W
wuzewu 已提交
389 390
        if os.path.exists(processor_path):
            sys.path.append(processor_path)
W
wuzewu 已提交
391 392
            processor_name = utils.from_module_attr_to_pyobj(
                self.desc.attr.map.data['processor_info'])
W
wuzewu 已提交
393 394 395
            self.processor = __import__(processor_name).Processor(module=self)
        else:
            self.processor = None
W
wuzewu 已提交
396

W
wuzewu 已提交
397 398 399 400 401
    def _dump_assets(self):
        utils.mkdir(self.helper.assets_path())
        for asset in self.assets:
            filename = os.path.basename(asset)
            newfile = os.path.join(self.helper.assets_path(), filename)
W
wuzewu 已提交
402
            shutil.copyfile(asset, newfile)
W
wuzewu 已提交
403 404 405 406 407 408 409 410

    def _load_assets(self):
        assets_path = self.helper.assets_path()
        self.assets = []
        for file in os.listdir(assets_path):
            filepath = os.path.join(self.helper.assets_path(), file)
            self.assets.append(filepath)

Z
Zeyu Chen 已提交
411
    def _restore_parameter(self, program):
W
wuzewu 已提交
412
        global_block = program.global_block()
W
wuzewu 已提交
413
        param_attrs = self.desc.attr.map.data['param_attrs']
W
wuzewu 已提交
414
        for key, param_attr in param_attrs.map.data.items():
W
wuzewu 已提交
415
            param = paddle_helper.from_module_attr_to_param(param_attr)
416
            param['name'] = self.get_var_name_with_prefix(key)
W
wuzewu 已提交
417 418 419 420 421 422 423 424 425 426
            if (param['name'] not in global_block.vars):
                continue
            var = global_block.var(param['name'])
            global_block.create_parameter(
                shape=var.shape,
                dtype=var.dtype,
                type=var.type,
                lod_level=var.lod_level,
                error_clip=var.error_clip,
                stop_gradient=var.stop_gradient,
W
wuzewu 已提交
427 428
                is_data=var.is_data,
                **param)
W
wuzewu 已提交
429 430

    def _recover_variable_info(self, program):
W
wuzewu 已提交
431
        var_infos = self.desc.attr.map.data['var_infos']
W
wuzewu 已提交
432
        for var_info in var_infos.map.data:
W
wuzewu 已提交
433
            idx = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
434
                var_infos.map.data[var_info].map.data['block_id'])
W
wuzewu 已提交
435
            stop_gradient = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
436 437
                var_infos.map.data[var_info].map.data['stop_gradient'])
            block = program.blocks[idx]
438
            var_name = self.get_var_name_with_prefix(var_info)
W
wuzewu 已提交
439 440 441 442
            if var_name in block.vars:
                var = block.vars[var_name]
                var.stop_gradient = stop_gradient

W
wuzewu 已提交
443 444 445 446 447 448 449 450
    def get_extra_info(self, key):
        return self.extra_info.get(key, None)

    def _generate_extra_info(self):
        for key in self.extra_info:
            self.__dict__["get_%s" % key] = functools.partial(
                self.get_extra_info, key=key)

W
wuzewu 已提交
451 452 453
    def _generate_sign_attr(self):
        self._check_signatures()
        for sign in self.signatures:
W
wuzewu 已提交
454 455
            self.__dict__[sign] = functools.partial(
                self.__call__, sign_name=sign)
W
wuzewu 已提交
456

457 458 459 460
    def get_vocab_path(self):
        for assets_file in self.assets:
            if "vocab.txt" in assets_file:
                return assets_file
K
kinghuin 已提交
461 462 463 464 465 466 467 468 469 470 471 472 473
        return None

    def get_word_dict_path(self):
        for assets_file in self.assets:
            if "dict.wordseg.pickle" in assets_file:
                return assets_file
        return None

    def get_spm_path(self):
        for assets_file in self.assets:
            if "spm_cased_simp_sampled.model" in assets_file:
                return assets_file
        return None
474

W
wuzewu 已提交
475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498
    def _recover_from_desc(self):
        # recover signature
        for sign, module_var in self.desc.sign2var.items():
            inputs = []
            outputs = []
            feed_names = []
            fetch_names = []
            for var in module_var.feed_desc:
                variable = self.program.global_block().vars[var.var_name]
                inputs.append(variable)
                feed_names.append(var.alias)

            for var in module_var.fetch_desc:
                variable = self.program.global_block().vars[var.var_name]
                outputs.append(variable)
                fetch_names.append(var.alias)

            self.signatures[sign] = create_signature(
                sign,
                inputs=inputs,
                outputs=outputs,
                feed_names=feed_names,
                fetch_names=fetch_names)

W
wuzewu 已提交
499
        # recover default signature
W
wuzewu 已提交
500 501
        default_signature_name = utils.from_module_attr_to_pyobj(
            self.desc.attr.map.data['default_signature'])
W
wuzewu 已提交
502
        self.default_signature = self.signatures[
W
wuzewu 已提交
503
            default_signature_name].name if default_signature_name else None
W
wuzewu 已提交
504

W
wuzewu 已提交
505
        # recover module info
W
wuzewu 已提交
506
        module_info = self.desc.attr.map.data['module_info']
W
wuzewu 已提交
507
        self._name = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
508
            module_info.map.data['name'])
W
wuzewu 已提交
509
        self._author = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
510
            module_info.map.data['author'])
W
wuzewu 已提交
511
        self._author_email = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
512
            module_info.map.data['author_email'])
W
wuzewu 已提交
513
        self._version = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
514
            module_info.map.data['version'])
W
wuzewu 已提交
515
        self._type = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
516
            module_info.map.data['type'])
W
wuzewu 已提交
517
        self._summary = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
518 519
            module_info.map.data['summary'])

W
wuzewu 已提交
520 521 522 523 524 525
        # recover extra info
        extra_info = self.desc.attr.map.data['extra_info']
        self.extra_info = {}
        for key, value in extra_info.map.data.items():
            self.extra_info[key] = utils.from_module_attr_to_pyobj(value)

W
wuzewu 已提交
526
        # recover name prefix
W
wuzewu 已提交
527
        self._name_prefix = utils.from_module_attr_to_pyobj(
W
wuzewu 已提交
528
            self.desc.attr.map.data["name_prefix"])
W
wuzewu 已提交
529

530
    def __call__(self, sign_name, data, use_gpu=False, batch_size=1, **kwargs):
W
wuzewu 已提交
531 532
        self.check_processor()

W
wuzewu 已提交
533
        def _get_reader_and_feeder(data_format, data, place):
W
wuzewu 已提交
534
            def _reader(process_data):
W
wuzewu 已提交
535 536 537 538 539 540 541 542 543
                for item in zip(*process_data):
                    yield item

            process_data = []
            feed_name_list = []
            for key in data_format:
                process_data.append([value['processed'] for value in data[key]])
                feed_name_list.append(data_format[key]['feed_key'])
            feeder = fluid.DataFeeder(feed_list=feed_name_list, place=place)
W
wuzewu 已提交
544
            return functools.partial(_reader, process_data=process_data), feeder
W
wuzewu 已提交
545

W
wuzewu 已提交
546 547 548 549 550 551 552 553
        if self.last_call_name != sign_name:
            self.last_call_name = sign_name
            self.cache_feed_dict, self.cache_fetch_dict, self.cache_program = self.context(
                sign_name, for_test=True)
        feed_dict = self.cache_feed_dict
        fetch_dict = self.cache_fetch_dict
        program = self.cache_program

W
wuzewu 已提交
554 555
        fetch_list = list(set([value for key, value in fetch_dict.items()]))
        with fluid.program_guard(program):
W
wuzewu 已提交
556 557
            result = []
            index = 0
558 559 560 561 562 563 564
            try:
                _places = os.environ["CUDA_VISIBLE_DEVICES"]
                int(_places[0])
            except:
                use_gpu = False

            place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
565

W
wuzewu 已提交
566
            exe = fluid.Executor(place=place)
W
wuzewu 已提交
567 568 569 570
            data = self.processor.preprocess(
                sign_name=sign_name, data_dict=data)
            data_format = self.processor.data_format(sign_name=sign_name)
            reader, feeder = _get_reader_and_feeder(data_format, data, place)
571
            reader = paddle.batch(reader, batch_size=batch_size)
W
wuzewu 已提交
572 573 574 575 576
            for batch in reader():
                data_out = exe.run(
                    feed=feeder.feed(batch),
                    fetch_list=fetch_list,
                    return_numpy=False)
W
wuzewu 已提交
577 578 579 580 581 582 583 584 585
                sub_data = {
                    key: value[index:index + len(batch)]
                    for key, value in data.items()
                }
                result += self.processor.postprocess(sign_name, data_out,
                                                     sub_data, **kwargs)
                index += len(batch)

        return result
W
wuzewu 已提交
586

W
wuzewu 已提交
587
    def check_processor(self):
W
wuzewu 已提交
588 589
        if not self.processor:
            raise ValueError("This Module is not callable!")
W
wuzewu 已提交
590

W
wuzewu 已提交
591
    @property
W
wuzewu 已提交
592
    def is_runnable(self):
W
wuzewu 已提交
593 594
        return self.default_signature != None

走神的阿圆's avatar
走神的阿圆 已提交
595 596 597 598
    @property
    def code_version(self):
        return self._code_version

W
wuzewu 已提交
599
    def context(self,
600
                sign_name=None,
W
wuzewu 已提交
601
                for_test=False,
Z
Zeyu Chen 已提交
602
                trainable=True,
W
wuzewu 已提交
603
                regularizer=None,
604
                max_seq_len=128,
W
wuzewu 已提交
605
                learning_rate=1e-3):
606 607 608 609 610
        """
        Args:
            max_seq_len(int): maximum sequence length, this option is only
            available for BERT/ERNIE module
        """
W
wuzewu 已提交
611

612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640
        if sign_name:
            if sign_name not in self.signatures:
                raise KeyError(
                    "Module did not have a signature with name %s" % sign_name)
            signature = self.signatures[sign_name]
        else:
            inputs = [
                input for signature in self.signatures.values()
                for input in signature.inputs
            ]
            outputs = [
                output for signature in self.signatures.values()
                for output in signature.outputs
            ]
            feed_names = [
                feed_name for signature in self.signatures.values()
                for feed_name in signature.feed_names
            ]
            fetch_names = [
                fetch_name for signature in self.signatures.values()
                for fetch_name in signature.fetch_names
            ]
            signature = create_signature(
                name="hub_temp_signature",
                inputs=inputs,
                outputs=outputs,
                feed_names=feed_names,
                fetch_names=fetch_names,
                for_predict=False)
W
wuzewu 已提交
641

W
wuzewu 已提交
642
        program = self.program.clone(for_test=for_test)
W
wuzewu 已提交
643
        paddle_helper.remove_feed_fetch_op(program)
W
wuzewu 已提交
644 645

        if not for_test:
W
wuzewu 已提交
646
            paddle_helper.set_parameter_trainable(program, trainable)
W
wuzewu 已提交
647

W
wuzewu 已提交
648
            paddle_helper.set_parameter_learning_rate(program, learning_rate)
W
wuzewu 已提交
649

W
wuzewu 已提交
650
            paddle_helper.set_parameter_regularizer(program, regularizer)
W
wuzewu 已提交
651

Z
Zeyu Chen 已提交
652
            self._restore_parameter(program)
W
wuzewu 已提交
653

W
wuzewu 已提交
654 655
        self._recover_variable_info(program)

W
wuzewu 已提交
656
        paddle_helper.set_op_attr(program, is_test=for_test)
W
wuzewu 已提交
657 658 659 660 661 662 663 664 665 666 667 668 669 670
        feed_dict = {}
        fetch_dict = {}
        for index, var in enumerate(signature.inputs):
            feed_dict[index] = program.global_block().var(var.name)
            key = signature.feed_names[index]
            if key:
                feed_dict[key] = program.global_block().var(var.name)

        for index, var in enumerate(signature.outputs):
            fetch_dict[index] = program.global_block().var(var.name)
            key = signature.fetch_names[index]
            if key:
                fetch_dict[key] = program.global_block().var(var.name)

671
        # update BERT/ERNIE's input tensor's sequence length to max_seq_len
K
kinghuin 已提交
672
        if "bert" in self.name or self.name.startswith("ernie"):
673 674 675 676
            MAX_SEQ_LENGTH = 512
            if max_seq_len > MAX_SEQ_LENGTH or max_seq_len <= 0:
                raise ValueError(
                    "max_seq_len({}) should be in the range of [1, {}]".format(
K
kinghuin 已提交
677
                        max_seq_len, MAX_SEQ_LENGTH))
678
            logger.info(
679
                "Set maximum sequence length of input tensor to {}".format(
680
                    max_seq_len))
681 682
            if self.name.startswith("ernie_v2"):
                feed_list = [
683 684
                    "input_ids", "position_ids", "segment_ids", "input_mask",
                    "task_ids"
685 686 687 688 689 690
                ]
            else:
                feed_list = [
                    "input_ids", "position_ids", "segment_ids", "input_mask"
                ]
            for tensor_name in feed_list:
691 692 693 694 695 696 697
                seq_tensor_shape = [-1, max_seq_len, 1]
                logger.info("The shape of input tensor[{}] set to {}".format(
                    tensor_name, seq_tensor_shape))
                program.global_block().var(
                    feed_dict[tensor_name].name).desc.set_shape(
                        seq_tensor_shape)

698 699
        # record num parameters loaded by paddlehub
        num_param_loaded = 0
W
wuzewu 已提交
700
        for param in program.global_block().iter_parameters():
701 702 703
            num_param_loaded += 1
        logger.info(
            "%d pretrained paramaters loaded by PaddleHub" % num_param_loaded)
W
wuzewu 已提交
704

W
wuzewu 已提交
705 706
        return feed_dict, fetch_dict, program

707
    def get_name_prefix(self):
W
wuzewu 已提交
708
        return self._name_prefix
709 710 711 712

    def get_var_name_with_prefix(self, var_name):
        return self.get_name_prefix() + var_name

W
wuzewu 已提交
713
    def _check_signatures(self):
W
wuzewu 已提交
714 715
        if not self.signatures:
            raise ValueError("Signatures should not be None")
W
wuzewu 已提交
716 717

        for key, sign in self.signatures.items():
W
wuzewu 已提交
718 719 720 721
            if not isinstance(sign, Signature):
                raise TypeError(
                    "Item in Signatures shoule be an instance of paddlehub.Signature"
                )
W
wuzewu 已提交
722 723 724

            for input in sign.inputs:
                _tmp_program = input.block.program
W
wuzewu 已提交
725 726 727 728
                if not self.program == _tmp_program:
                    raise ValueError(
                        "All input and outputs variables in signature should come from the same Program"
                    )
W
wuzewu 已提交
729 730 731

            for output in sign.outputs:
                _tmp_program = output.block.program
W
wuzewu 已提交
732 733 734 735
                if not self.program == _tmp_program:
                    raise ValueError(
                        "All input and outputs variables in signature should come from the same Program"
                    )