load_model.py 9.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import logging
import os
import shutil
import sys
G
Guanghua Yu 已提交
20
import pkg_resources as pkg
21 22 23 24 25
import paddle

from . import get_logger
_logger = get_logger(__name__, level=logging.INFO)

G
Guanghua Yu 已提交
26 27 28
__all__ = [
    'load_inference_model', 'get_model_dir', 'load_onnx_model', 'export_onnx'
]
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99


def load_inference_model(path_prefix,
                         executor,
                         model_filename=None,
                         params_filename=None):
    # Load onnx model to Inference model.
    if path_prefix.endswith('.onnx'):
        inference_program, feed_target_names, fetch_targets = load_onnx_model(
            path_prefix)
        return [inference_program, feed_target_names, fetch_targets]
    # Load Inference model.
    # TODO: clean code
    if model_filename is not None and model_filename.endswith('.pdmodel'):
        model_name = '.'.join(model_filename.split('.')[:-1])
        assert os.path.exists(
            os.path.join(path_prefix, model_name + '.pdmodel')
        ), 'Please check {}, or fix model_filename parameter.'.format(
            os.path.join(path_prefix, model_name + '.pdmodel'))
        assert os.path.exists(
            os.path.join(path_prefix, model_name + '.pdiparams')
        ), 'Please check {}, or fix params_filename parameter.'.format(
            os.path.join(path_prefix, model_name + '.pdiparams'))
        model_path_prefix = os.path.join(path_prefix, model_name)
        [inference_program, feed_target_names, fetch_targets] = (
            paddle.static.load_inference_model(
                path_prefix=model_path_prefix, executor=executor))
    elif model_filename is not None and params_filename is not None:
        [inference_program, feed_target_names, fetch_targets] = (
            paddle.static.load_inference_model(
                path_prefix=path_prefix,
                executor=executor,
                model_filename=model_filename,
                params_filename=params_filename))
    else:
        model_name = '.'.join(model_filename.split('.')
                              [:-1]) if model_filename is not None else 'model'
        if os.path.exists(os.path.join(path_prefix, model_name + '.pdmodel')):
            model_path_prefix = os.path.join(path_prefix, model_name)
            [inference_program, feed_target_names, fetch_targets] = (
                paddle.static.load_inference_model(
                    path_prefix=model_path_prefix, executor=executor))
        else:
            [inference_program, feed_target_names, fetch_targets] = (
                paddle.static.load_inference_model(
                    path_prefix=path_prefix, executor=executor))

    return [inference_program, feed_target_names, fetch_targets]


def get_model_dir(model_dir, model_filename, params_filename):
    if model_dir.endswith('.onnx'):
        updated_model_dir = model_dir.rstrip().rstrip('.onnx') + '_infer'
    else:
        updated_model_dir = model_dir.rstrip('/')

    if model_filename == None:
        updated_model_filename = 'model.pdmodel'
    else:
        updated_model_filename = model_filename

    if params_filename == None:
        updated_params_filename = 'model.pdiparams'
    else:
        updated_params_filename = params_filename

    if params_filename is None and model_filename is not None:
        raise NotImplementedError(
            "NOT SUPPORT parameters saved in separate files. Please convert it to single binary file first."
        )
    return updated_model_dir, updated_model_filename, updated_params_filename
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117


def load_onnx_model(model_path, disable_feedback=False):
    assert model_path.endswith(
        '.onnx'
    ), '{} does not end with .onnx suffix and cannot be loaded.'.format(
        model_path)
    inference_model_path = model_path.rstrip().rstrip('.onnx') + '_infer'
    exe = paddle.static.Executor(paddle.CPUPlace())
    if os.path.exists(os.path.join(
            inference_model_path, 'model.pdmodel')) and os.path.exists(
                os.path.join(inference_model_path, 'model.pdiparams')):
        val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
            os.path.join(inference_model_path, 'model'), exe)
        _logger.info('Loaded model from: {}'.format(inference_model_path))
        return val_program, feed_target_names, fetch_targets
    else:
        # onnx to paddle inference model.
G
Guanghua Yu 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131
        try:
            pkg.require('x2paddle')
        except:
            from pip._internal import main
            main(['install', 'x2paddle'])
        try:
            from x2paddle.decoder.onnx_decoder import ONNXDecoder
            from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper
            from x2paddle.optimizer.optimizer import GraphOptimizer
            from x2paddle.utils import ConverterCheck
        except:
            _logger.error(
                "x2paddle is not installed, please use \"pip install x2paddle\"."
            )
132 133 134 135 136 137
        time_info = int(time.time())
        if not disable_feedback:
            ConverterCheck(
                task="ONNX", time_info=time_info, convert_state="Start").start()
        # check onnx installation and version
        try:
G
Guanghua Yu 已提交
138
            pkg.require('onnx')
139 140 141 142 143
            import onnx
            version = onnx.version.version
            v0, v1, v2 = version.split('.')
            version_sum = int(v0) * 100 + int(v1) * 10 + int(v2)
            if version_sum < 160:
G
Guanghua Yu 已提交
144 145
                _logger.error(
                    "onnx>=1.6.0 is required, please use \"pip install onnx\".")
146
        except:
G
Guanghua Yu 已提交
147 148
            from pip._internal import main
            main(['install', 'onnx==1.12.0'])
149

150 151 152
        # support distributed convert model
        model_idx = paddle.distributed.get_rank(
        ) if paddle.distributed.get_world_size() > 1 else 0
153 154 155 156 157 158 159 160
        try:
            _logger.info("Now translating model from onnx to paddle.")
            model = ONNXDecoder(model_path)
            mapper = ONNXOpMapper(model)
            mapper.paddle_graph.build()
            graph_opt = GraphOptimizer(source_frame="onnx")
            graph_opt.optimize(mapper.paddle_graph)
            _logger.info("Model optimized.")
161 162
            onnx2paddle_out_dir = os.path.join(
                inference_model_path, 'onnx2paddle_{}'.format(model_idx))
163 164 165 166 167 168
            mapper.paddle_graph.gen_model(onnx2paddle_out_dir)
            _logger.info("Successfully exported Paddle static graph model!")
            if not disable_feedback:
                ConverterCheck(
                    task="ONNX", time_info=time_info,
                    convert_state="Success").start()
169 170 171 172 173 174 175
        except:
            _logger.info(
                "[ERROR] x2paddle threw an exception, you can ask for help at: https://github.com/PaddlePaddle/X2Paddle/issues"
            )
            sys.exit(1)

        if paddle.distributed.get_rank() == 0:
176 177 178 179 180 181 182 183
            shutil.move(
                os.path.join(onnx2paddle_out_dir, 'inference_model',
                             'model.pdmodel'),
                os.path.join(inference_model_path, 'model.pdmodel'))
            shutil.move(
                os.path.join(onnx2paddle_out_dir, 'inference_model',
                             'model.pdiparams'),
                os.path.join(inference_model_path, 'model.pdiparams'))
184 185 186 187
            load_model_path = inference_model_path
        else:
            load_model_path = os.path.join(onnx2paddle_out_dir,
                                           'inference_model')
188 189 190

        paddle.enable_static()
        val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
191 192 193 194 195 196
            os.path.join(load_model_path, 'model'), exe)
        _logger.info('Loaded model from: {}'.format(load_model_path))
        # Clean up the file storage directory
        shutil.rmtree(
            os.path.join(inference_model_path, 'onnx2paddle_{}'.format(
                model_idx)))
G
Guanghua Yu 已提交
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
        return val_program, feed_target_names, fetch_targets


def export_onnx(model_dir,
                model_filename=None,
                params_filename=None,
                save_file_path='output.onnx',
                opset_version=13,
                deploy_backend='tensorrt'):
    if not model_filename:
        model_filename = 'model.pdmodel'
    if not params_filename:
        params_filename = 'model.pdiparams'
    try:
        pkg.require('paddle2onnx')
    except:
        from pip._internal import main
        main(['install', 'paddle2onnx==1.0.0rc3'])
    import paddle2onnx
    paddle2onnx.command.c_paddle_to_onnx(
        model_file=os.path.join(model_dir, model_filename),
        params_file=os.path.join(model_dir, params_filename),
        save_file=save_file_path,
        opset_version=opset_version,
        enable_onnx_checker=True,
        deploy_backend=deploy_backend)
    _logger.info('Convert model to ONNX: {}'.format(save_file_path))