module.py 21.2 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle_hub.tools import utils
from paddle_hub.tools.logger import logger
W
wuzewu 已提交
20
from paddle_hub.tools.downloader import default_downloader
W
wuzewu 已提交
21 22
from paddle_hub.tools import paddle_helper
from paddle_hub.module import module_desc_pb2
W
wuzewu 已提交
23
from paddle_hub.module import check_info_pb2
W
wuzewu 已提交
24
from paddle_hub.module.signature import Signature, create_signature
W
wuzewu 已提交
25
from paddle_hub.module.checker import ModuleChecker
Z
Zeyu Chen 已提交
26
from paddle_hub.io.reader import yaml_reader
W
wuzewu 已提交
27
from paddle_hub import version
W
wuzewu 已提交
28
from paddle_hub.module.base_processor import BaseProcessor
W
wuzewu 已提交
29
from shutil import copyfile
W
wuzewu 已提交
30
import os
31
import time
W
wuzewu 已提交
32
import sys
W
wuzewu 已提交
33
import functools
W
wuzewu 已提交
34 35 36 37 38 39
import paddle
import paddle.fluid as fluid

__all__ = ['Module', 'create_module']


W
wuzewu 已提交
40 41
def create_module(sign_arr,
                  module_dir,
W
wuzewu 已提交
42
                  processor=None,
W
wuzewu 已提交
43 44 45
                  assets=None,
                  module_info=None,
                  exe=None):
W
wuzewu 已提交
46
    sign_arr = utils.to_list(sign_arr)
W
wuzewu 已提交
47 48 49 50 51
    module = Module(
        signatures=sign_arr,
        processor=processor,
        assets=assets,
        module_info=module_info)
W
wuzewu 已提交
52 53 54 55 56 57 58
    module.serialize_to_path(path=module_dir, exe=exe)


# paddle hub module dir name
ASSETS_DIRNAME = "assets"
MODEL_DIRNAME = "model"
MODULE_DESC_PBNAME = "module_desc.pb"
W
wuzewu 已提交
59 60
PYTHON_DIR = "python"
PROCESSOR_NAME = "processor"
W
wuzewu 已提交
61
# paddle hub var prefix
62
HUB_VAR_PREFIX = "@HUB_%s@"
W
wuzewu 已提交
63 64 65 66 67 68 69 70 71 72 73 74


class ModuleHelper:
    def __init__(self, module_dir):
        self.module_dir = module_dir

    def module_desc_path(self):
        return os.path.join(self.module_dir, MODULE_DESC_PBNAME)

    def model_path(self):
        return os.path.join(self.module_dir, MODEL_DIRNAME)

W
wuzewu 已提交
75 76 77 78 79 80 81 82 83
    def processor_path(self):
        return os.path.join(self.module_dir, PYTHON_DIR)

    def processor_name(self):
        return PROCESSOR_NAME

    def assets_path(self):
        return os.path.join(self.module_dir, ASSETS_DIRNAME)

W
wuzewu 已提交
84 85

class Module:
W
wuzewu 已提交
86 87 88 89 90 91 92
    def __init__(self,
                 url=None,
                 module_dir=None,
                 signatures=None,
                 module_info=None,
                 assets=None,
                 processor=None):
W
wuzewu 已提交
93 94 95 96 97
        self.desc = module_desc_pb2.ModuleDesc()
        self.program = None
        self.assets = []
        self.helper = None
        self.signatures = {}
W
wuzewu 已提交
98
        self.default_signature = None
W
wuzewu 已提交
99 100
        self.module_info = None
        self.processor = None
W
wuzewu 已提交
101
        self.assets = []
102
        self.name = "temp"
W
wuzewu 已提交
103 104 105 106 107
        if url:
            self._init_with_url(url=url)
        elif module_dir:
            self._init_with_module_file(module_dir=module_dir)
        elif signatures:
W
wuzewu 已提交
108 109 110 111
            if processor:
                assert issubclass(
                    processor, BaseProcessor
                ), "processor should be sub class of hub.BaseProcessor"
W
wuzewu 已提交
112 113 114 115
            if assets:
                self.assets = utils.to_list(assets)
                for asset in assets:
                    utils.check_path(assets)
W
wuzewu 已提交
116 117
            self.processor = processor
            self._generate_module_info(module_info)
W
wuzewu 已提交
118 119 120 121 122
            self._init_with_signature(signatures=signatures)
        else:
            raise "Error! HubModule Can't init with nothing"

    def _init_with_url(self, url):
W
wuzewu 已提交
123 124 125
        utils.check_url(url)
        result, _, module_dir = default_downloader.download_file_and_uncompress(
            url, save_path=".")
W
wuzewu 已提交
126 127
        self._init_with_module_file(module_dir)

W
wuzewu 已提交
128 129 130 131 132
    def _dump_processor(self):
        import inspect
        pymodule = inspect.getmodule(self.processor)
        pycode = inspect.getsource(pymodule)
        processor_path = self.helper.processor_path()
133 134 135
        processor_md5 = utils.md5(pycode)
        processor_md5 += str(time.time())
        processor_name = utils.md5(processor_md5)
W
wuzewu 已提交
136 137 138 139
        output_file = os.path.join(processor_path, processor_name + ".py")
        utils.mkdir(processor_path)
        with open(output_file, "w") as file:
            file.write(pycode)
140 141
        utils.from_pyobj_to_flexible_data(
            processor_name, self.desc.extra_info.map.data['processor_info'])
W
wuzewu 已提交
142

W
wuzewu 已提交
143 144
    def _load_processor(self):
        processor_path = self.helper.processor_path()
W
wuzewu 已提交
145 146
        if os.path.exists(processor_path):
            sys.path.append(processor_path)
147 148
            processor_name = utils.from_flexible_data_to_pyobj(
                self.desc.extra_info.map.data['processor_info'])
W
wuzewu 已提交
149 150 151
            self.processor = __import__(processor_name).Processor(module=self)
        else:
            self.processor = None
W
wuzewu 已提交
152

W
wuzewu 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165 166
    def _dump_assets(self):
        utils.mkdir(self.helper.assets_path())
        for asset in self.assets:
            filename = os.path.basename(asset)
            newfile = os.path.join(self.helper.assets_path(), filename)
            copyfile(asset, newfile)

    def _load_assets(self):
        assets_path = self.helper.assets_path()
        self.assets = []
        for file in os.listdir(assets_path):
            filepath = os.path.join(self.helper.assets_path(), file)
            self.assets.append(filepath)

W
wuzewu 已提交
167
    def _init_with_module_file(self, module_dir):
W
wuzewu 已提交
168 169 170 171 172
        checker = ModuleChecker(module_dir)
        if not checker.check():
            logger.error("module check fail")
            exit(1)

W
wuzewu 已提交
173 174 175 176 177 178 179 180 181
        self.helper = ModuleHelper(module_dir)
        with open(self.helper.module_desc_path(), "rb") as fi:
            self.desc.ParseFromString(fi.read())

        exe = fluid.Executor(fluid.CPUPlace())
        self.program, _, _ = fluid.io.load_inference_model(
            self.helper.model_path(), executor=exe)
        self._recovery_parameter(self.program)
        self._recover_variable_info(self.program)
W
wuzewu 已提交
182
        self._load_processor()
W
wuzewu 已提交
183
        self._load_assets()
W
wuzewu 已提交
184
        self._recover_from_desc()
W
wuzewu 已提交
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
        self._generate_sign_attr()

    def _init_with_signature(self, signatures):
        self._process_signatures(signatures)
        self._check_signatures()
        self._generate_desc()
        self._generate_sign_attr()

    def _init_with_program(self, program):
        pass

    def _process_signatures(self, signatures):
        self.signatures = {}
        self.program = signatures[0].inputs[0].block.program
        for sign in signatures:
            if sign.name in self.signatures:
                raise "Error! signature array contains repeat signatrue %s" % sign
            self.signatures[sign.name] = sign

    def _recovery_parameter(self, program):
        global_block = self.program.global_block()
        param_attrs = self.desc.extra_info.map.data['param_attrs']
        for key, param_attr in param_attrs.map.data.items():
            param = paddle_helper.from_flexible_data_to_param(param_attr)
209
            param['name'] = self.get_var_name_with_prefix(key)
W
wuzewu 已提交
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
            if (param['name'] not in global_block.vars):
                continue
            var = global_block.var(param['name'])
            global_block.create_parameter(
                **param,
                shape=var.shape,
                dtype=var.dtype,
                type=var.type,
                lod_level=var.lod_level,
                error_clip=var.error_clip,
                stop_gradient=var.stop_gradient,
                is_data=var.is_data)

    def _recover_variable_info(self, program):
        var_infos = self.desc.extra_info.map.data['var_infos']
        for var_info in var_infos.map.data:
            idx = utils.from_flexible_data_to_pyobj(
                var_infos.map.data[var_info].map.data['block_id'])
            stop_gradient = utils.from_flexible_data_to_pyobj(
                var_infos.map.data[var_info].map.data['stop_gradient'])
            block = program.blocks[idx]
231
            var_name = self.get_var_name_with_prefix(var_info)
W
wuzewu 已提交
232 233 234 235
            if var_name in block.vars:
                var = block.vars[var_name]
                var.stop_gradient = stop_gradient

W
wuzewu 已提交
236 237 238 239 240 241 242
    def _generate_module_info(self, module_info=None):
        if not module_info:
            self.module_info = {}
        else:
            if not utils.is_yaml_file(module_info):
                logger.critical("module info file should in yaml format")
                exit(1)
243 244 245 246 247 248 249
            self.module_info = yaml_reader.read(module_info)
        self.author = self.module_info.get('author', 'UNKNOWN')
        self.author_email = self.module_info.get('author_email', 'UNKNOWN')
        self.summary = self.module_info.get('summary', 'UNKNOWN')
        self.type = self.module_info.get('type', 'UNKNOWN')
        self.version = self.module_info.get('version', 'UNKNOWN')
        self.name = self.module_info.get('name', 'UNKNOWN')
W
wuzewu 已提交
250

W
wuzewu 已提交
251 252 253
    def _generate_sign_attr(self):
        self._check_signatures()
        for sign in self.signatures:
W
wuzewu 已提交
254 255
            self.__dict__[sign] = functools.partial(
                self.__call__, sign_name=sign)
W
wuzewu 已提交
256

W
wuzewu 已提交
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
    def _recover_from_desc(self):
        # recover signature
        for sign, module_var in self.desc.sign2var.items():
            inputs = []
            outputs = []
            feed_names = []
            fetch_names = []
            for var in module_var.feed_desc:
                variable = self.program.global_block().vars[var.var_name]
                inputs.append(variable)
                feed_names.append(var.alias)

            for var in module_var.fetch_desc:
                variable = self.program.global_block().vars[var.var_name]
                outputs.append(variable)
                fetch_names.append(var.alias)

            self.signatures[sign] = create_signature(
                sign,
                inputs=inputs,
                outputs=outputs,
                feed_names=feed_names,
                fetch_names=fetch_names)

        # recover module info
        module_info = self.desc.extra_info.map.data['module_info']
        self.name = utils.from_flexible_data_to_pyobj(
            module_info.map.data['name'])
        self.author = utils.from_flexible_data_to_pyobj(
            module_info.map.data['author'])
        self.author_email = utils.from_flexible_data_to_pyobj(
            module_info.map.data['author_email'])
        self.version = utils.from_flexible_data_to_pyobj(
            module_info.map.data['version'])
        self.type = utils.from_flexible_data_to_pyobj(
            module_info.map.data['type'])
        self.summary = utils.from_flexible_data_to_pyobj(
            module_info.map.data['summary'])

W
wuzewu 已提交
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
    def _generate_desc(self):
        # save fluid Parameter
        extra_info = self.desc.extra_info
        extra_info.type = module_desc_pb2.MAP
        param_attrs = extra_info.map.data['param_attrs']
        param_attrs.type = module_desc_pb2.MAP
        for param in self.program.global_block().iter_parameters():
            param_attr = param_attrs.map.data[param.name]
            paddle_helper.from_param_to_flexible_data(param, param_attr)

        # save Variable Info
        var_infos = extra_info.map.data['var_infos']
        var_infos.type = module_desc_pb2.MAP
        for block in self.program.blocks:
            for var in block.vars.values():
                var_info = var_infos.map.data[var.name]
                var_info.type = module_desc_pb2.MAP
                utils.from_pyobj_to_flexible_data(
                    var.stop_gradient, var_info.map.data['stop_gradient'])
                utils.from_pyobj_to_flexible_data(block.idx,
                                                  var_info.map.data['block_id'])

        # save signarture info
        for key, sign in self.signatures.items():
            var = self.desc.sign2var[sign.name]
            feed_desc = var.feed_desc
            fetch_desc = var.fetch_desc
            feed_names = sign.feed_names
            fetch_names = sign.fetch_names
            for index, input in enumerate(sign.inputs):
                feed_var = feed_desc.add()
327
                feed_var.var_name = self.get_var_name_with_prefix(input.name)
W
wuzewu 已提交
328 329 330 331
                feed_var.alias = feed_names[index]

            for index, output in enumerate(sign.outputs):
                fetch_var = fetch_desc.add()
332
                fetch_var.var_name = self.get_var_name_with_prefix(output.name)
W
wuzewu 已提交
333 334
                fetch_var.alias = fetch_names[index]

W
wuzewu 已提交
335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350
        # save module info
        module_info = extra_info.map.data['module_info']
        module_info.type = module_desc_pb2.MAP
        utils.from_pyobj_to_flexible_data(self.name,
                                          module_info.map.data['name'])
        utils.from_pyobj_to_flexible_data(self.version,
                                          module_info.map.data['version'])
        utils.from_pyobj_to_flexible_data(self.author,
                                          module_info.map.data['author'])
        utils.from_pyobj_to_flexible_data(self.author_email,
                                          module_info.map.data['author_email'])
        utils.from_pyobj_to_flexible_data(self.type,
                                          module_info.map.data['type'])
        utils.from_pyobj_to_flexible_data(self.summary,
                                          module_info.map.data['summary'])

W
wuzewu 已提交
351
    def __call__(self, sign_name, data, **kwargs):
W
wuzewu 已提交
352 353
        self.check_processor()

W
wuzewu 已提交
354 355 356 357 358 359 360 361 362 363 364 365 366 367
        def _get_reader_and_feeder(data_format, data, place):
            def _reader():
                nonlocal process_data
                for item in zip(*process_data):
                    yield item

            process_data = []
            feed_name_list = []
            for key in data_format:
                process_data.append([value['processed'] for value in data[key]])
                feed_name_list.append(data_format[key]['feed_key'])
            feeder = fluid.DataFeeder(feed_list=feed_name_list, place=place)
            return _reader, feeder

W
wuzewu 已提交
368
        feed_dict, fetch_dict, program = self.context(sign_name, for_test=True)
W
wuzewu 已提交
369 370 371
        #TODO(wuzewu): more option
        fetch_list = list(set([value for key, value in fetch_dict.items()]))
        with fluid.program_guard(program):
W
wuzewu 已提交
372 373
            result = []
            index = 0
W
wuzewu 已提交
374 375
            place = fluid.CPUPlace()
            exe = fluid.Executor(place=place)
W
wuzewu 已提交
376 377 378 379 380
            data = self.processor.preprocess(
                sign_name=sign_name, data_dict=data)
            data_format = self.processor.data_format(sign_name=sign_name)
            reader, feeder = _get_reader_and_feeder(data_format, data, place)
            reader = paddle.batch(reader, batch_size=2)
W
wuzewu 已提交
381 382 383 384 385
            for batch in reader():
                data_out = exe.run(
                    feed=feeder.feed(batch),
                    fetch_list=fetch_list,
                    return_numpy=False)
W
wuzewu 已提交
386 387 388 389 390 391 392 393 394
                sub_data = {
                    key: value[index:index + len(batch)]
                    for key, value in data.items()
                }
                result += self.processor.postprocess(sign_name, data_out,
                                                     sub_data, **kwargs)
                index += len(batch)

        return result
W
wuzewu 已提交
395

W
wuzewu 已提交
396 397 398
    def check_processor(self):
        assert self.processor, "this module couldn't be call"

W
wuzewu 已提交
399 400 401 402 403 404
    def context(self,
                sign_name,
                for_test=False,
                trainable=False,
                regularizer=None,
                learning_rate=1e-3):
W
wuzewu 已提交
405 406 407 408

        assert sign_name in self.signatures, "module did not have a signature with name %s" % sign_name
        signature = self.signatures[sign_name]

W
wuzewu 已提交
409
        program = self.program.clone(for_test=for_test)
W
wuzewu 已提交
410
        paddle_helper.remove_feed_fetch_op(program)
W
wuzewu 已提交
411 412 413 414 415 416 417 418 419 420 421 422 423 424

        if not for_test:
            if trainable != "Default":
                paddle_helper.set_parameter_trainable(program, trainable)

            if learning_rate != "Default":
                paddle_helper.set_parameter_learning_rate(
                    program, learning_rate)

            if regularizer != "Default":
                paddle_helper.set_parameter_regularizer(program, regularizer)

            self._recovery_parameter(program)

W
wuzewu 已提交
425 426
        self._recover_variable_info(program)

W
wuzewu 已提交
427
        paddle_helper.set_op_attr(program, is_test=for_test)
W
wuzewu 已提交
428
        #TODO(wuzewu): return feed_list and fetch_list directly
W
wuzewu 已提交
429 430 431 432 433 434 435 436 437 438 439 440 441 442
        feed_dict = {}
        fetch_dict = {}
        for index, var in enumerate(signature.inputs):
            feed_dict[index] = program.global_block().var(var.name)
            key = signature.feed_names[index]
            if key:
                feed_dict[key] = program.global_block().var(var.name)

        for index, var in enumerate(signature.outputs):
            fetch_dict[index] = program.global_block().var(var.name)
            key = signature.fetch_names[index]
            if key:
                fetch_dict[key] = program.global_block().var(var.name)

W
wuzewu 已提交
443 444 445
        for param in self.program.global_block().iter_parameters():
            logger.debug("%s %s" % (param.name, param.optimize_attr))

W
wuzewu 已提交
446 447
        return feed_dict, fetch_dict, program

448 449 450 451 452 453
    def get_name_prefix(self):
        return HUB_VAR_PREFIX % self.name

    def get_var_name_with_prefix(self, var_name):
        return self.get_name_prefix() + var_name

W
wuzewu 已提交
454 455 456 457 458 459
    def parameters(self):
        pass

    def parameter_attrs(self):
        pass

W
wuzewu 已提交
460 461 462
    def default_signature(self):
        return self.default_signature

W
wuzewu 已提交
463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489
    def _check_signatures(self):
        assert self.signatures, "signature array should not be None"

        for key, sign in self.signatures.items():
            assert isinstance(sign,
                              Signature), "sign_arr should be list of Signature"

            for input in sign.inputs:
                _tmp_program = input.block.program
                assert self.program == _tmp_program, "all the variable should come from the same program"

            for output in sign.outputs:
                _tmp_program = output.block.program
                assert self.program == _tmp_program, "all the variable should come from the same program"

    def serialize_to_path(self, path=None, exe=None):
        self._check_signatures()
        self._generate_desc()
        # create module path for saving
        if path is None:
            path = os.path.join(".", self.name)
        self.helper = ModuleHelper(path)
        utils.mkdir(self.helper.module_dir)

        # create module pb
        module_desc = module_desc_pb2.ModuleDesc()
        logger.info("hub version is %s" % version.hub_version)
W
wuzewu 已提交
490
        logger.info("module proto version is %s" % version.module_proto_version)
W
wuzewu 已提交
491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524
        logger.info("paddle version is %s" % paddle.__version__)

        feeded_var_names = [
            input.name for key, sign in self.signatures.items()
            for input in sign.inputs
        ]
        target_vars = [
            output for key, sign in self.signatures.items()
            for output in sign.outputs
        ]
        feeded_var_names = list(set(feeded_var_names))
        target_vars = list(set(target_vars))

        # save inference program
        program = self.program.clone()
        if not exe:
            place = fluid.CPUPlace()
            exe = fluid.Executor(place=place)
        utils.mkdir(self.helper.model_path())
        fluid.io.save_inference_model(
            self.helper.model_path(),
            feeded_var_names=list(feeded_var_names),
            target_vars=list(target_vars),
            main_program=program,
            executor=exe)

        with open(os.path.join(self.helper.model_path(), "__model__"),
                  "rb") as file:
            program_desc_str = file.read()
            rename_program = fluid.framework.Program.parse_from_string(
                program_desc_str)
            varlist = {
                var: block
                for block in rename_program.blocks for var in block.vars
525
                if self.get_name_prefix() not in var
W
wuzewu 已提交
526 527 528
            }
            for var, block in varlist.items():
                old_name = var
529
                new_name = self.get_var_name_with_prefix(old_name)
W
wuzewu 已提交
530 531 532 533 534 535 536 537
                block._rename_var(old_name, new_name)
            utils.mkdir(self.helper.model_path())
            with open(
                    os.path.join(self.helper.model_path(), "__model__"),
                    "wb") as f:
                f.write(rename_program.desc.serialize_to_string())

            for file in os.listdir(self.helper.model_path()):
538
                if (file == "__model__" or self.get_name_prefix() in file):
W
wuzewu 已提交
539 540 541 542
                    continue
                os.rename(
                    os.path.join(self.helper.model_path(), file),
                    os.path.join(self.helper.model_path(),
543
                                 self.get_var_name_with_prefix(file)))
W
wuzewu 已提交
544 545

        # create processor file
W
wuzewu 已提交
546 547
        if self.processor:
            self._dump_processor()
W
wuzewu 已提交
548 549 550

        # create assets
        self._dump_assets()
W
wuzewu 已提交
551 552 553 554

        # create check info
        checker = ModuleChecker(self.helper.module_dir)
        checker.generate_check_info()
555 556 557 558 559

        # Serialize module_desc pb
        module_pb = self.desc.SerializeToString()
        with open(self.helper.module_desc_path(), "wb") as f:
            f.write(module_pb)