load_model.py 9.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import logging
import os
import shutil
import sys
G
Guanghua Yu 已提交
20
import pkg_resources as pkg
21 22 23 24 25
import paddle

from . import get_logger
_logger = get_logger(__name__, level=logging.INFO)

G
Guanghua Yu 已提交
26 27 28
__all__ = [
    'load_inference_model', 'get_model_dir', 'load_onnx_model', 'export_onnx'
]
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99


def load_inference_model(path_prefix,
                         executor,
                         model_filename=None,
                         params_filename=None):
    # Load onnx model to Inference model.
    if path_prefix.endswith('.onnx'):
        inference_program, feed_target_names, fetch_targets = load_onnx_model(
            path_prefix)
        return [inference_program, feed_target_names, fetch_targets]
    # Load Inference model.
    # TODO: clean code
    if model_filename is not None and model_filename.endswith('.pdmodel'):
        model_name = '.'.join(model_filename.split('.')[:-1])
        assert os.path.exists(
            os.path.join(path_prefix, model_name + '.pdmodel')
        ), 'Please check {}, or fix model_filename parameter.'.format(
            os.path.join(path_prefix, model_name + '.pdmodel'))
        assert os.path.exists(
            os.path.join(path_prefix, model_name + '.pdiparams')
        ), 'Please check {}, or fix params_filename parameter.'.format(
            os.path.join(path_prefix, model_name + '.pdiparams'))
        model_path_prefix = os.path.join(path_prefix, model_name)
        [inference_program, feed_target_names, fetch_targets] = (
            paddle.static.load_inference_model(
                path_prefix=model_path_prefix, executor=executor))
    elif model_filename is not None and params_filename is not None:
        [inference_program, feed_target_names, fetch_targets] = (
            paddle.static.load_inference_model(
                path_prefix=path_prefix,
                executor=executor,
                model_filename=model_filename,
                params_filename=params_filename))
    else:
        model_name = '.'.join(model_filename.split('.')
                              [:-1]) if model_filename is not None else 'model'
        if os.path.exists(os.path.join(path_prefix, model_name + '.pdmodel')):
            model_path_prefix = os.path.join(path_prefix, model_name)
            [inference_program, feed_target_names, fetch_targets] = (
                paddle.static.load_inference_model(
                    path_prefix=model_path_prefix, executor=executor))
        else:
            [inference_program, feed_target_names, fetch_targets] = (
                paddle.static.load_inference_model(
                    path_prefix=path_prefix, executor=executor))

    return [inference_program, feed_target_names, fetch_targets]


def get_model_dir(model_dir, model_filename, params_filename):
    if model_dir.endswith('.onnx'):
        updated_model_dir = model_dir.rstrip().rstrip('.onnx') + '_infer'
    else:
        updated_model_dir = model_dir.rstrip('/')

    if model_filename == None:
        updated_model_filename = 'model.pdmodel'
    else:
        updated_model_filename = model_filename

    if params_filename == None:
        updated_params_filename = 'model.pdiparams'
    else:
        updated_params_filename = params_filename

    if params_filename is None and model_filename is not None:
        raise NotImplementedError(
            "NOT SUPPORT parameters saved in separate files. Please convert it to single binary file first."
        )
    return updated_model_dir, updated_model_filename, updated_params_filename
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117


def load_onnx_model(model_path, disable_feedback=False):
    assert model_path.endswith(
        '.onnx'
    ), '{} does not end with .onnx suffix and cannot be loaded.'.format(
        model_path)
    inference_model_path = model_path.rstrip().rstrip('.onnx') + '_infer'
    exe = paddle.static.Executor(paddle.CPUPlace())
    if os.path.exists(os.path.join(
            inference_model_path, 'model.pdmodel')) and os.path.exists(
                os.path.join(inference_model_path, 'model.pdiparams')):
        val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
            os.path.join(inference_model_path, 'model'), exe)
        _logger.info('Loaded model from: {}'.format(inference_model_path))
        return val_program, feed_target_names, fetch_targets
    else:
        # onnx to paddle inference model.
G
Guanghua Yu 已提交
118 119 120
        assert os.path.exists(
            model_path), 'Not found `{}`, please check model path.'.format(
                model_path)
G
Guanghua Yu 已提交
121 122 123 124 125
        try:
            pkg.require('x2paddle')
        except:
            from pip._internal import main
            main(['install', 'x2paddle'])
126 127
        # check onnx installation and version
        try:
G
Guanghua Yu 已提交
128
            pkg.require('onnx')
129 130 131 132 133
            import onnx
            version = onnx.version.version
            v0, v1, v2 = version.split('.')
            version_sum = int(v0) * 100 + int(v1) * 10 + int(v2)
            if version_sum < 160:
G
Guanghua Yu 已提交
134 135
                _logger.error(
                    "onnx>=1.6.0 is required, please use \"pip install onnx\".")
136
        except:
G
Guanghua Yu 已提交
137 138
            from pip._internal import main
            main(['install', 'onnx==1.12.0'])
139

G
Guanghua Yu 已提交
140 141 142 143 144 145 146 147
        from x2paddle.decoder.onnx_decoder import ONNXDecoder
        from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper
        from x2paddle.optimizer.optimizer import GraphOptimizer
        from x2paddle.utils import ConverterCheck
        time_info = int(time.time())
        if not disable_feedback:
            ConverterCheck(
                task="ONNX", time_info=time_info, convert_state="Start").start()
148 149 150
        # support distributed convert model
        model_idx = paddle.distributed.get_rank(
        ) if paddle.distributed.get_world_size() > 1 else 0
151 152 153 154 155 156 157 158
        try:
            _logger.info("Now translating model from onnx to paddle.")
            model = ONNXDecoder(model_path)
            mapper = ONNXOpMapper(model)
            mapper.paddle_graph.build()
            graph_opt = GraphOptimizer(source_frame="onnx")
            graph_opt.optimize(mapper.paddle_graph)
            _logger.info("Model optimized.")
159 160
            onnx2paddle_out_dir = os.path.join(
                inference_model_path, 'onnx2paddle_{}'.format(model_idx))
161 162 163 164 165 166
            mapper.paddle_graph.gen_model(onnx2paddle_out_dir)
            _logger.info("Successfully exported Paddle static graph model!")
            if not disable_feedback:
                ConverterCheck(
                    task="ONNX", time_info=time_info,
                    convert_state="Success").start()
G
Guanghua Yu 已提交
167 168 169 170
        except Exception as e:
            _logger.warning(e)
            _logger.error(
                "x2paddle threw an exception, you can ask for help at: https://github.com/PaddlePaddle/X2Paddle/issues"
171 172 173 174
            )
            sys.exit(1)

        if paddle.distributed.get_rank() == 0:
175 176 177 178 179 180 181 182
            shutil.move(
                os.path.join(onnx2paddle_out_dir, 'inference_model',
                             'model.pdmodel'),
                os.path.join(inference_model_path, 'model.pdmodel'))
            shutil.move(
                os.path.join(onnx2paddle_out_dir, 'inference_model',
                             'model.pdiparams'),
                os.path.join(inference_model_path, 'model.pdiparams'))
183 184 185 186
            load_model_path = inference_model_path
        else:
            load_model_path = os.path.join(onnx2paddle_out_dir,
                                           'inference_model')
187 188 189

        paddle.enable_static()
        val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
190 191 192 193 194 195
            os.path.join(load_model_path, 'model'), exe)
        _logger.info('Loaded model from: {}'.format(load_model_path))
        # Clean up the file storage directory
        shutil.rmtree(
            os.path.join(inference_model_path, 'onnx2paddle_{}'.format(
                model_idx)))
G
Guanghua Yu 已提交
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
        return val_program, feed_target_names, fetch_targets


def export_onnx(model_dir,
                model_filename=None,
                params_filename=None,
                save_file_path='output.onnx',
                opset_version=13,
                deploy_backend='tensorrt'):
    if not model_filename:
        model_filename = 'model.pdmodel'
    if not params_filename:
        params_filename = 'model.pdiparams'
    try:
        pkg.require('paddle2onnx')
    except:
        from pip._internal import main
        main(['install', 'paddle2onnx==1.0.0rc3'])
    import paddle2onnx
    paddle2onnx.command.c_paddle_to_onnx(
        model_file=os.path.join(model_dir, model_filename),
        params_file=os.path.join(model_dir, params_filename),
        save_file=save_file_path,
        opset_version=opset_version,
        enable_onnx_checker=True,
        deploy_backend=deploy_backend)
    _logger.info('Convert model to ONNX: {}'.format(save_file_path))