From b33e4d14f8b210727623dda53cbcdf836f7a0700 Mon Sep 17 00:00:00 2001 From: wuzewu Date: Mon, 14 Sep 2020 15:20:56 +0800 Subject: [PATCH] Fix ModuleV1 infer issue --- paddlehub/__init__.py | 1 + paddlehub/compat/module/module_v1.py | 101 +++++++++++++++------ paddlehub/compat/module/module_v1_utils.py | 8 +- paddlehub/compat/type.py | 23 +++++ 4 files changed, 102 insertions(+), 31 deletions(-) create mode 100644 paddlehub/compat/type.py diff --git a/paddlehub/__init__.py b/paddlehub/__init__.py index 43037cd9..ec027d27 100644 --- a/paddlehub/__init__.py +++ b/paddlehub/__init__.py @@ -18,3 +18,4 @@ __version__ = '2.0.0a0' from paddlehub.module import Module from paddlehub.compat.module.processor import BaseProcessor +from paddlehub.compat.type import DataType diff --git a/paddlehub/compat/module/module_v1.py b/paddlehub/compat/module/module_v1.py index 9868ccfc..d3f40ac6 100644 --- a/paddlehub/compat/module/module_v1.py +++ b/paddlehub/compat/module/module_v1.py @@ -18,6 +18,7 @@ import os from typing import Tuple, List import paddle +from easydict import EasyDict from paddlehub.compat import paddle_utils from paddlehub.compat.module import module_v1_utils @@ -27,57 +28,70 @@ from paddlehub.utils import utils, log class ModuleV1(object): ''' ''' + def __init__(self, name: str = None, directory: str = None, version: str = None): if not directory: return - self.directory = directory desc_file = os.path.join(directory, 'module_desc.pb') self.desc = module_v1_utils.convert_module_desc(desc_file) + self.helper = self + self.signatures = self.desc.signatures + + self.directory = directory self._load_model() self._load_parameters() self._load_processor() self._load_assets() self._load_extra_info() - self._load_signatures() + self._generate_func() def _load_processor(self): python_path = os.path.join(self.directory, 'python') processor_name = self.desc.processor_info self.processor = utils.load_py_module(python_path, processor_name) + self.processor = self.processor.Processor(module=self) def _load_assets(self): - assets_path = os.path.join(self.directory, 'assets') self.assets = [] - for file in os.listdir(assets_path): - filepath = os.path.join(assets_path, file) + for file in os.listdir(self.assets_path()): + filepath = os.path.join(self.assets_path(), file) self.assets.append(filepath) def _load_parameters(self): global_block = self.program.global_block() + + # record num parameters loaded by PaddleHub + num_param_loaded = 0 + for param, attrs in self.desc.param_attrs.items(): name = self.desc.name_prefix + param if not name in global_block.vars: continue + num_param_loaded += 1 var = global_block.vars[name] - global_block.create_parameter(name=name, - shape=var.shape, - dtype=var.dtype, - type=var.type, - lod_level=var.lod_level, - error_clip=var.error_clip, - stop_gradient=var.stop_gradient, - is_data=var.is_data, - **attrs) + + global_block.create_parameter( + name=name, + shape=var.shape, + dtype=var.dtype, + type=var.type, + lod_level=var.lod_level, + error_clip=var.error_clip, + stop_gradient=var.stop_gradient, + is_data=var.is_data, + **attrs) + + log.logger.info('{} pretrained paramaters loaded by PaddleHub'.format(num_param_loaded)) def _load_extra_info(self): for key, value in self.desc.extra_info.items(): self.__dict__['get_{}'.format(key)] = value - def _load_signatures(self): + def _generate_func(self): for signature in self.desc.signatures: - self.__dict__[signature] = functools.partial(self.__call__, signature=signature) + self.__dict__[signature] = functools.partial(self.__call__, sign_name=signature) def _load_model(self): model_path = os.path.join(self.directory, 'model') @@ -91,7 +105,8 @@ class ModuleV1(object): continue op._set_attr('op_callstack', ['']) - def context(self, for_test: bool = False, trainable: bool = True) -> Tuple[dict, dict, paddle.static.Program]: + def context(self, signature: str = None, for_test: bool = False, + trainable: bool = True) -> Tuple[dict, dict, paddle.static.Program]: ''' ''' program = self.program.clone(for_test=for_test) @@ -100,29 +115,58 @@ class ModuleV1(object): # generate feed vars and fetch vars from signatures feed_dict = {} fetch_dict = {} - for info in self.desc.signatures.values(): - for feed_var in info.feed_vars: + varinfos = [self.desc.signatures[signature]] if signature else self.desc.signatures.values() + + for info in varinfos: + for feed_var in info.inputs: paddle_var = program.global_block().vars[feed_var.name] feed_dict[feed_var.alias] = paddle_var - for fetch_var in info.fetch_vars: + for fetch_var in info.outputs: paddle_var = program.global_block().vars[fetch_var.name] fetch_dict[fetch_var.alias] = paddle_var - # record num parameters loaded by PaddleHub - num_param_loaded = 0 for param in program.all_parameters(): - num_param_loaded += 1 param.trainable = trainable - log.logger.info('{} pretrained paramaters loaded by PaddleHub'.format(num_param_loaded)) - return feed_dict, fetch_dict, program - def __call__(self, signature, data, use_gpu: bool = False, batch_size: int = 1, **kwargs): + def __call__(self, sign_name: str, data: dict, use_gpu: bool = False, batch_size: int = 1, **kwargs): ''' ''' - ... + + def _get_reader_and_feeder(data_format, data, place): + def _reader(process_data): + for item in zip(*process_data): + yield item + + process_data = [] + feed_name_list = [] + for key in data_format: + process_data.append([value['processed'] for value in data[key]]) + feed_name_list.append(data_format[key]['feed_key']) + feeder = paddle.fluid.DataFeeder(feed_list=feed_name_list, place=place) + return functools.partial(_reader, process_data=process_data), feeder + + _, fetch_dict, program = self.context(signature=sign_name, for_test=True) + fetch_list = list([value for key, value in fetch_dict.items()]) + with paddle.static.program_guard(program): + result = [] + index = 0 + place = paddle.CUDAPlace(0) if use_gpu else paddle.CPUPlace() + + exe = paddle.static.Executor(place=place) + data = self.processor.preprocess(sign_name=sign_name, data_dict=data) + data_format = self.processor.data_format(sign_name=sign_name) + reader, feeder = _get_reader_and_feeder(data_format, data, place) + reader = paddle.batch(reader, batch_size=batch_size) + for batch in reader(): + data_out = exe.run(feed=feeder.feed(batch), fetch_list=fetch_list, return_numpy=False) + sub_data = {key: value[index:index + len(batch)] for key, value in data.items()} + result += self.processor.postprocess(sign_name, data_out, sub_data, **kwargs) + index += len(batch) + + return result @classmethod def get_py_requirements(cls) -> List[str]: @@ -139,3 +183,6 @@ class ModuleV1(object): cls.name = desc.module_info.name cls.version = utils.Version(desc.module_info.version) return cls + + def assets_path(self): + return os.path.join(self.directory, 'assets') diff --git a/paddlehub/compat/module/module_v1_utils.py b/paddlehub/compat/module/module_v1_utils.py index 8b09aa71..76f7f646 100644 --- a/paddlehub/compat/module/module_v1_utils.py +++ b/paddlehub/compat/module/module_v1_utils.py @@ -33,12 +33,12 @@ def convert_signatures(signmaps): for sign, var in signmaps.items(): _dict[sign] = EasyDict() for fetch_var in var.fetch_desc: - _dict[sign].fetch_vars = list() - _dict[sign].fetch_vars.append(EasyDict(name=fetch_var.var_name, alias=fetch_var.alias)) + _dict[sign].outputs = list() + _dict[sign].outputs.append(EasyDict(name=fetch_var.var_name, alias=fetch_var.alias)) for feed_var in var.feed_desc: - _dict[sign].feed_vars = list() - _dict[sign].feed_vars.append(EasyDict(name=feed_var.var_name, alias=feed_var.alias)) + _dict[sign].inputs = list() + _dict[sign].inputs.append(EasyDict(name=feed_var.var_name, alias=feed_var.alias)) return _dict diff --git a/paddlehub/compat/type.py b/paddlehub/compat/type.py new file mode 100644 index 00000000..3b723824 --- /dev/null +++ b/paddlehub/compat/type.py @@ -0,0 +1,23 @@ +#coding:utf-8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class DataType(object): + IMAGE = 0 + TEXT = 1 + AUDIO = 2 + VIDEO = 3 + INT = 4 + FLOAT = 5 -- GitLab