load_model.py 10.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import logging
import os
import shutil
import sys
G
Guanghua Yu 已提交
20
import pkg_resources as pkg
21 22 23 24 25
import paddle

from . import get_logger
_logger = get_logger(__name__, level=logging.INFO)

G
Guanghua Yu 已提交
26 27 28
__all__ = [
    'load_inference_model', 'get_model_dir', 'load_onnx_model', 'export_onnx'
]
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52


def load_inference_model(path_prefix,
                         executor,
                         model_filename=None,
                         params_filename=None):
    # Load onnx model to Inference model.
    if path_prefix.endswith('.onnx'):
        inference_program, feed_target_names, fetch_targets = load_onnx_model(
            path_prefix)
        return [inference_program, feed_target_names, fetch_targets]
    # Load Inference model.
    # TODO: clean code
    if model_filename is not None and model_filename.endswith('.pdmodel'):
        model_name = '.'.join(model_filename.split('.')[:-1])
        assert os.path.exists(
            os.path.join(path_prefix, model_name + '.pdmodel')
        ), 'Please check {}, or fix model_filename parameter.'.format(
            os.path.join(path_prefix, model_name + '.pdmodel'))
        assert os.path.exists(
            os.path.join(path_prefix, model_name + '.pdiparams')
        ), 'Please check {}, or fix params_filename parameter.'.format(
            os.path.join(path_prefix, model_name + '.pdiparams'))
        model_path_prefix = os.path.join(path_prefix, model_name)
53 54 55
        [inference_program, feed_target_names,
         fetch_targets] = (paddle.static.load_inference_model(
             path_prefix=model_path_prefix, executor=executor))
56
    elif model_filename is not None and params_filename is not None:
57 58 59 60 61 62
        [inference_program, feed_target_names,
         fetch_targets] = (paddle.static.load_inference_model(
             path_prefix=path_prefix,
             executor=executor,
             model_filename=model_filename,
             params_filename=params_filename))
63 64 65 66 67
    else:
        model_name = '.'.join(model_filename.split('.')
                              [:-1]) if model_filename is not None else 'model'
        if os.path.exists(os.path.join(path_prefix, model_name + '.pdmodel')):
            model_path_prefix = os.path.join(path_prefix, model_name)
68 69 70
            [inference_program, feed_target_names,
             fetch_targets] = (paddle.static.load_inference_model(
                 path_prefix=model_path_prefix, executor=executor))
71
        else:
72 73 74
            [inference_program, feed_target_names,
             fetch_targets] = (paddle.static.load_inference_model(
                 path_prefix=path_prefix, executor=executor))
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99

    return [inference_program, feed_target_names, fetch_targets]


def get_model_dir(model_dir, model_filename, params_filename):
    if model_dir.endswith('.onnx'):
        updated_model_dir = model_dir.rstrip().rstrip('.onnx') + '_infer'
    else:
        updated_model_dir = model_dir.rstrip('/')

    if model_filename == None:
        updated_model_filename = 'model.pdmodel'
    else:
        updated_model_filename = model_filename

    if params_filename == None:
        updated_params_filename = 'model.pdiparams'
    else:
        updated_params_filename = params_filename

    if params_filename is None and model_filename is not None:
        raise NotImplementedError(
            "NOT SUPPORT parameters saved in separate files. Please convert it to single binary file first."
        )
    return updated_model_dir, updated_model_filename, updated_params_filename
100 101


G
Guanghua Yu 已提交
102 103 104
def load_onnx_model(model_path,
                    disable_feedback=False,
                    enable_onnx_checker=True):
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
    assert model_path.endswith(
        '.onnx'
    ), '{} does not end with .onnx suffix and cannot be loaded.'.format(
        model_path)
    inference_model_path = model_path.rstrip().rstrip('.onnx') + '_infer'
    exe = paddle.static.Executor(paddle.CPUPlace())
    if os.path.exists(os.path.join(
            inference_model_path, 'model.pdmodel')) and os.path.exists(
                os.path.join(inference_model_path, 'model.pdiparams')):
        val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
            os.path.join(inference_model_path, 'model'), exe)
        _logger.info('Loaded model from: {}'.format(inference_model_path))
        return val_program, feed_target_names, fetch_targets
    else:
        # onnx to paddle inference model.
G
Guanghua Yu 已提交
120 121 122
        assert os.path.exists(
            model_path), 'Not found `{}`, please check model path.'.format(
                model_path)
G
Guanghua Yu 已提交
123
        try:
124 125 126 127
            import x2paddle
            version = x2paddle.__version__
            v0, v1, v2 = version.split('.')
            version_sum = int(v0) * 100 + int(v1) * 10 + int(v2)
128
            if version_sum != 139:
G
Guanghua Yu 已提交
129
                _logger.warning(
130
                    "x2paddle==1.3.9 is required, please use \"pip install x2paddle==1.3.9\"."
131
                )
132
                os.system('python -m pip install -U x2paddle==1.3.9')
G
Guanghua Yu 已提交
133
        except:
134
            os.system('python -m pip install -U x2paddle==1.3.9')
135 136
        # check onnx installation and version
        try:
G
Guanghua Yu 已提交
137
            pkg.require('onnx')
138 139 140 141 142
            import onnx
            version = onnx.version.version
            v0, v1, v2 = version.split('.')
            version_sum = int(v0) * 100 + int(v1) * 10 + int(v2)
            if version_sum < 160:
G
Guanghua Yu 已提交
143 144
                _logger.error(
                    "onnx>=1.6.0 is required, please use \"pip install onnx\".")
145
        except:
G
Guanghua Yu 已提交
146 147
            from pip._internal import main
            main(['install', 'onnx==1.12.0'])
148

G
Guanghua Yu 已提交
149 150 151 152 153 154 155
        from x2paddle.decoder.onnx_decoder import ONNXDecoder
        from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper
        from x2paddle.optimizer.optimizer import GraphOptimizer
        from x2paddle.utils import ConverterCheck
        time_info = int(time.time())
        if not disable_feedback:
            ConverterCheck(
156 157
                task="ONNX", time_info=time_info,
                convert_state="Start").start()
158 159 160
        # support distributed convert model
        model_idx = paddle.distributed.get_rank(
        ) if paddle.distributed.get_world_size() > 1 else 0
161 162
        try:
            _logger.info("Now translating model from onnx to paddle.")
G
Guanghua Yu 已提交
163
            model = ONNXDecoder(model_path, enable_onnx_checker)
164 165 166 167 168
            mapper = ONNXOpMapper(model)
            mapper.paddle_graph.build()
            graph_opt = GraphOptimizer(source_frame="onnx")
            graph_opt.optimize(mapper.paddle_graph)
            _logger.info("Model optimized.")
169 170
            onnx2paddle_out_dir = os.path.join(
                inference_model_path, 'onnx2paddle_{}'.format(model_idx))
171 172 173 174 175 176
            mapper.paddle_graph.gen_model(onnx2paddle_out_dir)
            _logger.info("Successfully exported Paddle static graph model!")
            if not disable_feedback:
                ConverterCheck(
                    task="ONNX", time_info=time_info,
                    convert_state="Success").start()
G
Guanghua Yu 已提交
177 178 179 180
        except Exception as e:
            _logger.warning(e)
            _logger.error(
                "x2paddle threw an exception, you can ask for help at: https://github.com/PaddlePaddle/X2Paddle/issues"
181 182 183 184
            )
            sys.exit(1)

        if paddle.distributed.get_rank() == 0:
185 186 187 188 189 190 191 192
            shutil.move(
                os.path.join(onnx2paddle_out_dir, 'inference_model',
                             'model.pdmodel'),
                os.path.join(inference_model_path, 'model.pdmodel'))
            shutil.move(
                os.path.join(onnx2paddle_out_dir, 'inference_model',
                             'model.pdiparams'),
                os.path.join(inference_model_path, 'model.pdiparams'))
193 194 195 196
            load_model_path = inference_model_path
        else:
            load_model_path = os.path.join(onnx2paddle_out_dir,
                                           'inference_model')
197 198 199

        paddle.enable_static()
        val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
200 201 202 203 204 205
            os.path.join(load_model_path, 'model'), exe)
        _logger.info('Loaded model from: {}'.format(load_model_path))
        # Clean up the file storage directory
        shutil.rmtree(
            os.path.join(inference_model_path, 'onnx2paddle_{}'.format(
                model_idx)))
G
Guanghua Yu 已提交
206 207 208 209 210 211 212 213 214 215 216 217 218 219
        return val_program, feed_target_names, fetch_targets


def export_onnx(model_dir,
                model_filename=None,
                params_filename=None,
                save_file_path='output.onnx',
                opset_version=13,
                deploy_backend='tensorrt'):
    if not model_filename:
        model_filename = 'model.pdmodel'
    if not params_filename:
        params_filename = 'model.pdiparams'
    try:
C
ceci3 已提交
220 221
        import paddle2onnx
        version = paddle2onnx.__version__
G
Guanghua Yu 已提交
222 223
        if version < '1.0.1':
            os.system('python -m pip install -U paddle2onnx==1.0.3')
G
Guanghua Yu 已提交
224 225
    except:
        from pip._internal import main
G
Guanghua Yu 已提交
226
        main(['install', 'paddle2onnx==1.0.3'])
G
Guanghua Yu 已提交
227 228 229 230 231 232 233
    import paddle2onnx
    paddle2onnx.command.c_paddle_to_onnx(
        model_file=os.path.join(model_dir, model_filename),
        params_file=os.path.join(model_dir, params_filename),
        save_file=save_file_path,
        opset_version=opset_version,
        enable_onnx_checker=True,
G
Guanghua Yu 已提交
234 235 236 237
        deploy_backend=deploy_backend,
        calibration_file=os.path.join(
            save_file_path.rstrip(os.path.split(save_file_path)[-1]),
            'calibration.cache'))
G
Guanghua Yu 已提交
238
    _logger.info('Convert model to ONNX: {}'.format(save_file_path))