convert.py 18.8 KB
Newer Older
S
SunAhong1993 已提交
1
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
J
jiangjiajun 已提交
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
S
SunAhong1993 已提交
14

15
from six import text_type as _text_type
S
SunAhong1993 已提交
16
from x2paddle import program
W
wjj19950828 已提交
17
from x2paddle.utils import ConverterCheck
18
import argparse
J
jiangjiajun 已提交
19
import sys
W
WJJ1995 已提交
20
import logging
W
wjj19950828 已提交
21
import time
22

J
jiangjiajun 已提交
23

24 25
def arg_parser():
    parser = argparse.ArgumentParser()
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
    parser.add_argument(
        "--model",
        "-m",
        type=_text_type,
        default=None,
        help="define model file path for tensorflow or onnx")
    parser.add_argument(
        "--prototxt",
        "-p",
        type=_text_type,
        default=None,
        help="prototxt file of caffe model")
    parser.add_argument(
        "--weight",
        "-w",
        type=_text_type,
        default=None,
        help="weight file of caffe model")
    parser.add_argument(
        "--save_dir",
        "-s",
        type=_text_type,
        default=None,
        help="path to save translated model")
J
upgrade  
jiangjiajun 已提交
50 51 52
    parser.add_argument(
        "--framework",
        "-f",
53 54 55 56
        type=_text_type,
        default=None,
        help="define which deeplearning framework(tensorflow/caffe/onnx/paddle2onnx)"
    )
S
SunAhong1993 已提交
57 58 59 60 61
    parser.add_argument(
        "--caffe_proto",
        "-c",
        type=_text_type,
        default=None,
J
upgrade  
jiangjiajun 已提交
62 63
        help="optional: the .py file compiled by caffe proto file of caffe model"
    )
64 65 66 67 68 69 70 71 72 73 74 75
    parser.add_argument(
        "--version",
        "-v",
        action="store_true",
        default=False,
        help="get version of x2paddle")
    parser.add_argument(
        "--define_input_shape",
        "-d",
        action="store_true",
        default=False,
        help="define input shape for tf model")
76 77 78 79 80 81 82
    parser.add_argument(
       "--input_shape_dict",
       "-isd",
       type=_text_type,
       default=None,
       help="define input shapes, e.g --input_shape_dict=\"{'image':[1, 3, 608, 608]}\" or" \
       "--input_shape_dict=\"{'image':[1, 3, 608, 608], 'im_shape': [1, 2], 'scale_factor': [1, 2]}\"")
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
    parser.add_argument(
        "--convert_torch_project",
        "-tp",
        action='store_true',
        help="Convert the PyTorch Project.")
    parser.add_argument(
        "--project_dir",
        "-pd",
        type=_text_type,
        default=None,
        help="define project folder path for pytorch")
    parser.add_argument(
        "--pretrain_model",
        "-pm",
        type=_text_type,
        default=None,
        help="pretrain model file of pytorch model")
100 101 102
    parser.add_argument(
        "--enable_code_optim",
        "-co",
103
        default=False,
104
        help="Turn on code optimization")
W
wjj19950828 已提交
105 106 107 108 109
    parser.add_argument(
        "--enable_onnx_checker",
        "-oc",
        default=True,
        help="Turn on onnx model checker")
W
wjj19950828 已提交
110 111 112 113
    parser.add_argument(
        "--disable_feedback",
        "-df",
        default=False,
W
wjj19950828 已提交
114
        help="Tune off feedback of model conversion.")
W
WJJ1995 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127 128
    parser.add_argument(
        "--to_lite", "-tl", default=False, help="convert to Paddle-Lite format")
    parser.add_argument(
        "--lite_valid_places",
        "-vp",
        type=_text_type,
        default="arm",
        help="Specify the executable backend of the model")
    parser.add_argument(
        "--lite_model_type",
        "-mt",
        type=_text_type,
        default="naive_buffer",
        help="The type of lite model")
129

130
    return parser
J
jiangjiajun 已提交
131

C
Channingss 已提交
132

W
WJJ1995 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
def convert2lite(save_dir,
                 lite_valid_places="arm",
                 lite_model_type="naive_buffer"):
    """Convert to Paddle-Lite format."""

    from paddlelite.lite import Opt
    opt = Opt()
    opt.set_model_dir(save_dir + "/inference_model")
    opt.set_valid_places(lite_valid_places)
    opt.set_model_type(lite_model_type)
    opt.set_optimize_out(save_dir + "/opt")
    opt.run()


def tf2paddle(model_path,
              save_dir,
              define_input_shape=False,
              convert_to_lite=False,
              lite_valid_places="arm",
W
wjj19950828 已提交
152 153
              lite_model_type="naive_buffer",
              disable_feedback=False):
W
wjj19950828 已提交
154 155
    # for convert_id
    time_info = int(time.time())
W
wjj19950828 已提交
156
    if not disable_feedback:
W
wjj19950828 已提交
157 158 159
        ConverterCheck(
            task="TensorFlow", time_info=time_info,
            convert_state="Start").start()
J
jiangjiajun 已提交
160 161
    # check tensorflow installation and version
    try:
162 163
        import os
        os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3'
J
jiangjiajun 已提交
164 165 166
        import tensorflow as tf
        version = tf.__version__
        if version >= '2.0.0' or version < '1.0.0':
W
WJJ1995 已提交
167
            logging.info(
W
WJJ1995 已提交
168
                "[ERROR] 1.0.0<=TensorFlow<2.0.0 is required, and v1.14.0 is recommended"
J
jiangjiajun 已提交
169 170 171
            )
            return
    except:
W
WJJ1995 已提交
172
        logging.info(
W
WJJ1995 已提交
173
            "[ERROR] TensorFlow is not installed, use \"pip install TensorFlow\"."
J
jiangjiajun@baidu.com 已提交
174
        )
J
jiangjiajun 已提交
175
        return
176

J
jiangjiajun 已提交
177
    from x2paddle.decoder.tf_decoder import TFDecoder
S
SunAhong1993 已提交
178
    from x2paddle.op_mapper.tf2paddle.tf_op_mapper import TFOpMapper
179

W
WJJ1995 已提交
180
    logging.info("Now translating model from TensorFlow to Paddle.")
181
    model = TFDecoder(model_path, define_input_shape=define_input_shape)
S
SunAhong1993 已提交
182
    mapper = TFOpMapper(model)
S
SunAhong1993 已提交
183
    mapper.paddle_graph.build()
W
WJJ1995 已提交
184
    logging.info("Model optimizing ...")
S
SunAhong1993 已提交
185 186 187
    from x2paddle.optimizer.optimizer import GraphOptimizer
    graph_opt = GraphOptimizer(source_frame="tf")
    graph_opt.optimize(mapper.paddle_graph)
W
WJJ1995 已提交
188
    logging.info("Model optimized!")
S
SunAhong1993 已提交
189
    mapper.paddle_graph.gen_model(save_dir)
W
WJJ1995 已提交
190
    logging.info("Successfully exported Paddle static graph model!")
W
wjj19950828 已提交
191
    if not disable_feedback:
W
wjj19950828 已提交
192 193 194
        ConverterCheck(
            task="TensorFlow", time_info=time_info,
            convert_state="Success").start()
W
WJJ1995 已提交
195
    if convert_to_lite:
W
WJJ1995 已提交
196
        logging.info("Now translating model from Paddle to Paddle Lite ...")
W
wjj19950828 已提交
197
        if not disable_feedback:
W
wjj19950828 已提交
198 199 200
            ConverterCheck(
                task="TensorFlow", time_info=time_info,
                lite_state="Start").start()
W
WJJ1995 已提交
201
        convert2lite(save_dir, lite_valid_places, lite_model_type)
W
WJJ1995 已提交
202
        logging.info("Successfully exported Paddle Lite support model!")
W
wjj19950828 已提交
203
        if not disable_feedback:
W
wjj19950828 已提交
204 205 206
            ConverterCheck(
                task="TensorFlow", time_info=time_info,
                lite_state="Success").start()
207 208 209 210
    # for convert survey
    logging.info("================================================")
    logging.info("")
    logging.info(
W
WJJ1995 已提交
211
        "Model Converted! Fill this survey to help X2Paddle better, https://iwenjuan.baidu.com/?code=npyd51 "
212 213 214
    )
    logging.info("")
    logging.info("================================================")
215 216


W
WJJ1995 已提交
217 218 219 220 221 222
def caffe2paddle(proto_file,
                 weight_file,
                 save_dir,
                 caffe_proto,
                 convert_to_lite=False,
                 lite_valid_places="arm",
W
wjj19950828 已提交
223 224
                 lite_model_type="naive_buffer",
                 disable_feedback=False):
W
wjj19950828 已提交
225 226
    # for convert_id
    time_info = int(time.time())
W
wjj19950828 已提交
227
    if not disable_feedback:
W
wjj19950828 已提交
228 229
        ConverterCheck(
            task="Caffe", time_info=time_info, convert_state="Start").start()
J
jiangjiajun 已提交
230
    from x2paddle.decoder.caffe_decoder import CaffeDecoder
S
SunAhong1993 已提交
231
    from x2paddle.op_mapper.caffe2paddle.caffe_op_mapper import CaffeOpMapper
S
SunAhong1993 已提交
232
    import google.protobuf as gpb
S
SunAhong1993 已提交
233 234 235
    ver_part = gpb.__version__.split('.')
    version_satisfy = False
    if (int(ver_part[0]) == 3 and int(ver_part[1]) >= 6) \
W
WJJ1995 已提交
236
            or (int(ver_part[0]) > 3):
S
SunAhong1993 已提交
237
        version_satisfy = True
J
jiangjiajun@baidu.com 已提交
238
    assert version_satisfy, '[ERROR] google.protobuf >= 3.6.0 is required'
W
WJJ1995 已提交
239
    logging.info("Now translating model from caffe to paddle.")
W
WJJ1995 已提交
240
    model = CaffeDecoder(proto_file, weight_file, caffe_proto)
J
jiangjiajun 已提交
241
    mapper = CaffeOpMapper(model)
S
SunAhong1993 已提交
242
    mapper.paddle_graph.build()
W
WJJ1995 已提交
243
    logging.info("Model optimizing ...")
S
SunAhong1993 已提交
244
    from x2paddle.optimizer.optimizer import GraphOptimizer
S
SunAhong1993 已提交
245
    graph_opt = GraphOptimizer(source_frame="caffe")
S
SunAhong1993 已提交
246
    graph_opt.optimize(mapper.paddle_graph)
W
WJJ1995 已提交
247
    logging.info("Model optimized!")
S
SunAhong1993 已提交
248
    mapper.paddle_graph.gen_model(save_dir)
W
WJJ1995 已提交
249
    logging.info("Successfully exported Paddle static graph model!")
W
wjj19950828 已提交
250
    if not disable_feedback:
W
wjj19950828 已提交
251 252
        ConverterCheck(
            task="Caffe", time_info=time_info, convert_state="Success").start()
W
WJJ1995 已提交
253
    if convert_to_lite:
W
WJJ1995 已提交
254
        logging.info("Now translating model from Paddle to Paddle Lite ...")
W
wjj19950828 已提交
255
        if not disable_feedback:
W
wjj19950828 已提交
256 257
            ConverterCheck(
                task="Caffe", time_info=time_info, lite_state="Start").start()
W
WJJ1995 已提交
258
        convert2lite(save_dir, lite_valid_places, lite_model_type)
W
WJJ1995 已提交
259
        logging.info("Successfully exported Paddle Lite support model!")
W
wjj19950828 已提交
260
        if not disable_feedback:
W
wjj19950828 已提交
261 262
            ConverterCheck(
                task="Caffe", time_info=time_info, lite_state="Success").start()
263 264 265 266
    # for convert survey
    logging.info("================================================")
    logging.info("")
    logging.info(
W
WJJ1995 已提交
267
        "Model Converted! Fill this survey to help X2Paddle better, https://iwenjuan.baidu.com/?code=npyd51 "
268 269 270
    )
    logging.info("")
    logging.info("================================================")
271 272


W
WJJ1995 已提交
273 274
def onnx2paddle(model_path,
                save_dir,
275
                input_shape_dict=None,
W
WJJ1995 已提交
276 277
                convert_to_lite=False,
                lite_valid_places="arm",
W
wjj19950828 已提交
278
                lite_model_type="naive_buffer",
W
wjj19950828 已提交
279 280
                disable_feedback=False,
                enable_onnx_checker=True):
W
wjj19950828 已提交
281 282
    # for convert_id
    time_info = int(time.time())
W
wjj19950828 已提交
283
    if not disable_feedback:
W
wjj19950828 已提交
284 285
        ConverterCheck(
            task="ONNX", time_info=time_info, convert_state="Start").start()
C
update  
channingss 已提交
286 287 288 289
    # check onnx installation and version
    try:
        import onnx
        version = onnx.version.version
W
WJJ1995 已提交
290 291 292
        v0, v1, v2 = version.split('.')
        version_sum = int(v0) * 100 + int(v1) * 10 + int(v2)
        if version_sum < 160:
W
WJJ1995 已提交
293
            logging.info("[ERROR] onnx>=1.6.0 is required")
C
update  
channingss 已提交
294 295
            return
    except:
W
WJJ1995 已提交
296 297
        logging.info(
            "[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\".")
C
update  
channingss 已提交
298
        return
W
WJJ1995 已提交
299
    logging.info("Now translating model from onnx to paddle.")
C
update  
channingss 已提交
300 301

    from x2paddle.decoder.onnx_decoder import ONNXDecoder
S
SunAhong1993 已提交
302
    from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper
303
    model = ONNXDecoder(model_path, input_shape_dict, enable_onnx_checker)
C
Channingss 已提交
304
    mapper = ONNXOpMapper(model)
S
SunAhong1993 已提交
305
    mapper.paddle_graph.build()
W
wjj19950828 已提交
306 307 308 309 310
    logging.info("Model optimizing ...")
    from x2paddle.optimizer.optimizer import GraphOptimizer
    graph_opt = GraphOptimizer(source_frame="onnx")
    graph_opt.optimize(mapper.paddle_graph)
    logging.info("Model optimized.")
S
SunAhong1993 已提交
311
    mapper.paddle_graph.gen_model(save_dir)
W
WJJ1995 已提交
312
    logging.info("Successfully exported Paddle static graph model!")
W
wjj19950828 已提交
313
    if not disable_feedback:
W
wjj19950828 已提交
314 315
        ConverterCheck(
            task="ONNX", time_info=time_info, convert_state="Success").start()
W
WJJ1995 已提交
316
    if convert_to_lite:
W
WJJ1995 已提交
317
        logging.info("Now translating model from Paddle to Paddle Lite ...")
W
wjj19950828 已提交
318
        if not disable_feedback:
W
wjj19950828 已提交
319 320
            ConverterCheck(
                task="ONNX", time_info=time_info, lite_state="Start").start()
W
WJJ1995 已提交
321
        convert2lite(save_dir, lite_valid_places, lite_model_type)
W
WJJ1995 已提交
322
        logging.info("Successfully exported Paddle Lite support model!")
W
wjj19950828 已提交
323
        if not disable_feedback:
W
wjj19950828 已提交
324 325
            ConverterCheck(
                task="ONNX", time_info=time_info, lite_state="Success").start()
326 327 328 329
    # for convert survey
    logging.info("================================================")
    logging.info("")
    logging.info(
W
WJJ1995 已提交
330
        "Model Converted! Fill this survey to help X2Paddle better, https://iwenjuan.baidu.com/?code=npyd51 "
331 332 333
    )
    logging.info("")
    logging.info("================================================")
C
Channingss 已提交
334 335


W
WJJ1995 已提交
336 337 338 339
def pytorch2paddle(module,
                   save_dir,
                   jit_type="trace",
                   input_examples=None,
340
                   enable_code_optim=False,
W
WJJ1995 已提交
341 342
                   convert_to_lite=False,
                   lite_valid_places="arm",
W
wjj19950828 已提交
343 344
                   lite_model_type="naive_buffer",
                   disable_feedback=False):
W
wjj19950828 已提交
345 346
    # for convert_id
    time_info = int(time.time())
W
wjj19950828 已提交
347
    if not disable_feedback:
W
wjj19950828 已提交
348 349
        ConverterCheck(
            task="PyTorch", time_info=time_info, convert_state="Start").start()
S
SunAhong1993 已提交
350 351 352 353
    # check pytorch installation and version
    try:
        import torch
        version = torch.__version__
W
WJJ1995 已提交
354 355 356 357 358 359 360
        v0, v1, v2 = version.split('.')
        # Avoid the situation where the version is equal to 1.7.0+cu101
        if '+' in v2:
            v2 = v2.split('+')[0]
        version_sum = int(v0) * 100 + int(v1) * 10 + int(v2)
        if version_sum < 150:
            logging.info(
W
WJJ1995 已提交
361
                "[ERROR] PyTorch>=1.5.0 is required, 1.6.0 is the most recommended"
W
WJJ1995 已提交
362
            )
S
SunAhong1993 已提交
363
            return
W
WJJ1995 已提交
364
        if version_sum > 160:
W
WJJ1995 已提交
365
            logging.info("[WARNING] PyTorch==1.6.0 is recommended")
S
SunAhong1993 已提交
366
    except:
W
WJJ1995 已提交
367
        logging.info(
W
WJJ1995 已提交
368
            "[ERROR] PyTorch is not installed, use \"pip install torch==1.6.0 torchvision\"."
S
SunAhong1993 已提交
369 370
        )
        return
W
WJJ1995 已提交
371
    logging.info("Now translating model from PyTorch to Paddle.")
372

S
SunAhong1993 已提交
373
    from x2paddle.decoder.pytorch_decoder import ScriptDecoder, TraceDecoder
S
SunAhong1993 已提交
374
    from x2paddle.op_mapper.pytorch2paddle.pytorch_op_mapper import PyTorchOpMapper
S
SunAhong1993 已提交
375

S
SunAhong1993 已提交
376
    if jit_type == "trace":
S
SunAhong1993 已提交
377
        model = TraceDecoder(module, input_examples)
S
SunAhong1993 已提交
378
    else:
S
SunAhong1993 已提交
379
        model = ScriptDecoder(module, input_examples)
S
SunAhong1993 已提交
380 381
    mapper = PyTorchOpMapper(model)
    mapper.paddle_graph.build()
W
WJJ1995 已提交
382
    logging.info("Model optimizing ...")
S
SunAhong1993 已提交
383
    from x2paddle.optimizer.optimizer import GraphOptimizer
S
SunAhong1993 已提交
384
    graph_opt = GraphOptimizer(source_frame="pytorch", jit_type=jit_type)
S
SunAhong1993 已提交
385
    graph_opt.optimize(mapper.paddle_graph)
W
WJJ1995 已提交
386
    logging.info("Model optimized!")
387 388
    mapper.paddle_graph.gen_model(
        save_dir, jit_type=jit_type, enable_code_optim=enable_code_optim)
W
WJJ1995 已提交
389
    logging.info("Successfully exported Paddle static graph model!")
W
wjj19950828 已提交
390
    if not disable_feedback:
W
wjj19950828 已提交
391 392 393
        ConverterCheck(
            task="PyTorch", time_info=time_info,
            convert_state="Success").start()
W
WJJ1995 已提交
394
    if convert_to_lite:
W
WJJ1995 已提交
395
        logging.info("Now translating model from Paddle to Paddle Lite ...")
W
wjj19950828 已提交
396
        if not disable_feedback:
W
wjj19950828 已提交
397 398
            ConverterCheck(
                task="PyTorch", time_info=time_info, lite_state="Start").start()
W
WJJ1995 已提交
399
        convert2lite(save_dir, lite_valid_places, lite_model_type)
W
WJJ1995 已提交
400
        logging.info("Successfully exported Paddle Lite support model!")
W
wjj19950828 已提交
401
        if not disable_feedback:
W
wjj19950828 已提交
402 403 404
            ConverterCheck(
                task="PyTorch", time_info=time_info,
                lite_state="Success").start()
405 406 407 408
    # for convert survey
    logging.info("================================================")
    logging.info("")
    logging.info(
W
WJJ1995 已提交
409
        "Model Converted! Fill this survey to help X2Paddle better, https://iwenjuan.baidu.com/?code=npyd51 "
410 411 412
    )
    logging.info("")
    logging.info("================================================")
S
SunAhong1993 已提交
413 414


415
def main():
W
WJJ1995 已提交
416
    logging.basicConfig(level=logging.INFO)
J
jiangjiajun 已提交
417
    if len(sys.argv) < 2:
W
WJJ1995 已提交
418 419 420 421
        logging.info("Use \"x2paddle -h\" to print the help information")
        logging.info(
            "For more information, please follow our github repo below:)")
        logging.info("\nGithub: https://github.com/PaddlePaddle/X2Paddle.git\n")
J
jiangjiajun 已提交
422 423
        return

424 425 426
    parser = arg_parser()
    args = parser.parse_args()

J
jiangjiajun 已提交
427
    if args.version:
J
jiangjiajun 已提交
428
        import x2paddle
W
WJJ1995 已提交
429 430
        logging.info("x2paddle-{} with python>=3.5, paddlepaddle>=1.6.0\n".
                     format(x2paddle.__version__))
J
jiangjiajun 已提交
431 432
        return

433 434
    if not args.convert_torch_project:
        assert args.framework is not None, "--framework is not defined(support tensorflow/caffe/onnx)"
S
SunAhong1993 已提交
435 436
    assert args.save_dir is not None, "--save_dir is not defined"

M
mamingjie-China 已提交
437
    try:
S
add lrn  
SunAhong1993 已提交
438 439
        import platform
        v0, v1, v2 = platform.python_version().split('.')
440
        if not (int(v0) >= 3 and int(v1) >= 5):
W
WJJ1995 已提交
441
            logging.info("[ERROR] python>=3.5 is required")
S
add lrn  
SunAhong1993 已提交
442
            return
M
mamingjie-China 已提交
443 444
        import paddle
        v0, v1, v2 = paddle.__version__.split('.')
W
WJJ1995 已提交
445
        logging.info("paddle.__version__ = {}".format(paddle.__version__))
446
        if v0 == '0' and v1 == '0' and v2 == '0':
W
WJJ1995 已提交
447 448
            logging.info(
                "[WARNING] You are use develop version of paddlepaddle")
S
SunAhong1993 已提交
449
        elif int(v0) != 2 or int(v1) < 0:
W
WJJ1995 已提交
450
            logging.info("[ERROR] paddlepaddle>=2.0.0 is required")
M
mamingjie-China 已提交
451 452
            return
    except:
W
WJJ1995 已提交
453
        logging.info(
J
jiangjiajun@baidu.com 已提交
454 455
            "[ERROR] paddlepaddle not installed, use \"pip install paddlepaddle\""
        )
456

457 458 459 460 461 462 463 464 465 466
    if args.convert_torch_project:
        assert args.project_dir is not None, "--project_dir should be defined while translating pytorch project"
        from x2paddle.project_convertor.pytorch.convert import main as convert_torch
        convert_torch(args)
    else:
        if args.framework == "tensorflow":
            assert args.model is not None, "--model should be defined while translating tensorflow model"
            define_input_shape = False
            if args.define_input_shape:
                define_input_shape = True
W
WJJ1995 已提交
467 468 469 470 471 472
            tf2paddle(
                args.model,
                args.save_dir,
                define_input_shape,
                convert_to_lite=args.to_lite,
                lite_valid_places=args.lite_valid_places,
W
wjj19950828 已提交
473 474
                lite_model_type=args.lite_model_type,
                disable_feedback=args.disable_feedback)
475

476 477
        elif args.framework == "caffe":
            assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model"
W
WJJ1995 已提交
478 479 480 481 482 483 484
            caffe2paddle(
                args.prototxt,
                args.weight,
                args.save_dir,
                args.caffe_proto,
                convert_to_lite=args.to_lite,
                lite_valid_places=args.lite_valid_places,
W
wjj19950828 已提交
485 486
                lite_model_type=args.lite_model_type,
                disable_feedback=args.disable_feedback)
487 488
        elif args.framework == "onnx":
            assert args.model is not None, "--model should be defined while translating onnx model"
W
WJJ1995 已提交
489 490 491
            onnx2paddle(
                args.model,
                args.save_dir,
492
                input_shape_dict=args.input_shape_dict,
W
WJJ1995 已提交
493 494
                convert_to_lite=args.to_lite,
                lite_valid_places=args.lite_valid_places,
W
wjj19950828 已提交
495
                lite_model_type=args.lite_model_type,
W
wjj19950828 已提交
496 497
                disable_feedback=args.disable_feedback,
                enable_onnx_checker=args.enable_onnx_checker)
498
        elif args.framework == "paddle2onnx":
W
WJJ1995 已提交
499
            logging.info(
500 501
                "Paddle to ONNX tool has been migrated to the new github: https://github.com/PaddlePaddle/paddle2onnx"
            )
502

503 504 505
        else:
            raise Exception(
                "--framework only support tensorflow/caffe/onnx now")
506 507 508


if __name__ == "__main__":
S
SunAhong1993 已提交
509
    main()