convert.py 10.4 KB
Newer Older
1
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
J
jiangjiajun 已提交
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
S
SunAhong1993 已提交
14

15
from six import text_type as _text_type
S
SunAhong1993 已提交
16
from x2paddle import program
17
import argparse
J
jiangjiajun 已提交
18
import sys
19

J
jiangjiajun 已提交
20

21 22
def arg_parser():
    parser = argparse.ArgumentParser()
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
    parser.add_argument(
        "--model",
        "-m",
        type=_text_type,
        default=None,
        help="define model file path for tensorflow or onnx")
    parser.add_argument(
        "--prototxt",
        "-p",
        type=_text_type,
        default=None,
        help="prototxt file of caffe model")
    parser.add_argument(
        "--weight",
        "-w",
        type=_text_type,
        default=None,
        help="weight file of caffe model")
    parser.add_argument(
        "--save_dir",
        "-s",
        type=_text_type,
        default=None,
        help="path to save translated model")
J
upgrade  
jiangjiajun 已提交
47 48 49 50 51
    parser.add_argument(
        "--framework",
        "-f",
        type=_text_type,
        default=None,
52 53
        help="define which deeplearning framework(tensorflow/caffe/onnx/paddle2onnx)"
    )
S
SunAhong1993 已提交
54 55 56 57 58
    parser.add_argument(
        "--caffe_proto",
        "-c",
        type=_text_type,
        default=None,
J
upgrade  
jiangjiajun 已提交
59 60
        help="optional: the .py file compiled by caffe proto file of caffe model"
    )
61 62 63 64 65 66
    parser.add_argument(
        "--version",
        "-v",
        action="store_true",
        default=False,
        help="get version of x2paddle")
67 68 69
    parser.add_argument(
        "--without_data_format_optimization",
        "-wo",
S
SunAhong1993 已提交
70 71
        action="store_true",
        default=False,
72
        help="tf model conversion without data format optimization")
73 74 75 76 77 78
    parser.add_argument(
        "--define_input_shape",
        "-d",
        action="store_true",
        default=False,
        help="define input shape for tf model")
C
Channingss 已提交
79 80 81 82 83 84
    parser.add_argument(
        "--onnx_opset",
        "-oo",
        type=int,
        default=10,
        help="when paddle2onnx set onnx opset version to export")
85 86 87 88 89 90
    parser.add_argument(
        "--params_merge",
        "-pm",
        action="store_true",
        default=False,
        help="define whether merge the params")
S
SunAhong1993 已提交
91 92 93 94 95 96
    parser.add_argument(
        "--input_shapes",
        "-is",
        action='append',
        default=None,
        help="define the inputs' shape")
J
jiangjiajun 已提交
97

98
    return parser
J
jiangjiajun 已提交
99

100

101 102
def tf2paddle(model_path,
              save_dir,
S
SunAhong1993 已提交
103
              without_data_format_optimization=False,
M
mamingjie-China 已提交
104
              define_input_shape=False,
M
mamingjie-China 已提交
105
              params_merge=False):
J
jiangjiajun 已提交
106 107
    # check tensorflow installation and version
    try:
108 109
        import os
        os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3'
J
jiangjiajun 已提交
110 111 112 113
        import tensorflow as tf
        version = tf.__version__
        if version >= '2.0.0' or version < '1.0.0':
            print(
J
jiangjiajun@baidu.com 已提交
114
                "[ERROR] 1.0.0<=tensorflow<2.0.0 is required, and v1.14.0 is recommended"
J
jiangjiajun 已提交
115 116 117
            )
            return
    except:
J
jiangjiajun@baidu.com 已提交
118 119 120
        print(
            "[ERROR] Tensorflow is not installed, use \"pip install tensorflow\"."
        )
J
jiangjiajun 已提交
121 122
        return

J
jiangjiajun 已提交
123
    from x2paddle.decoder.tf_decoder import TFDecoder
J
jiangjiajun 已提交
124
    from x2paddle.op_mapper.tf_op_mapper import TFOpMapper
125
    from x2paddle.op_mapper.tf_op_mapper_nhwc import TFOpMapperNHWC
J
jiangjiajun 已提交
126
    from x2paddle.optimizer.tf_optimizer import TFOptimizer
J
jiangjiajun 已提交
127 128

    print("Now translating model from tensorflow to paddle.")
129
    model = TFDecoder(model_path, define_input_shape=define_input_shape)
S
SunAhong1993 已提交
130 131 132 133

    mapper = TFOpMapperNHWC(model)
    program.build()
    program.gen_model(save_dir)
134 135


M
mamingjie-China 已提交
136
def caffe2paddle(proto, weight, save_dir, caffe_proto, params_merge=False):
J
jiangjiajun 已提交
137 138
    from x2paddle.decoder.caffe_decoder import CaffeDecoder
    from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper
139
    from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer
S
SunAhong1993 已提交
140
    import google.protobuf as gpb
S
SunAhong1993 已提交
141 142 143 144 145
    ver_part = gpb.__version__.split('.')
    version_satisfy = False
    if (int(ver_part[0]) == 3 and int(ver_part[1]) >= 6) \
        or (int(ver_part[0]) > 3):
        version_satisfy = True
J
jiangjiajun@baidu.com 已提交
146
    assert version_satisfy, '[ERROR] google.protobuf >= 3.6.0 is required'
J
jiangjiajun 已提交
147
    print("Now translating model from caffe to paddle.")
S
SunAhong1993 已提交
148
    model = CaffeDecoder(proto, weight, caffe_proto)
J
jiangjiajun 已提交
149
    mapper = CaffeOpMapper(model)
150 151 152
    optimizer = CaffeOptimizer(mapper)
    optimizer.merge_bn_scale()
    optimizer.merge_op_activation()
M
mamingjie-China 已提交
153
    mapper.save_inference_model(save_dir, params_merge)
154 155


M
mamingjie-China 已提交
156
def onnx2paddle(model_path, save_dir, params_merge=False):
C
update  
channingss 已提交
157 158 159 160
    # check onnx installation and version
    try:
        import onnx
        version = onnx.version.version
S
SunAhong1993 已提交
161 162
        if version != '1.6.0':
            print("[ERROR] onnx==1.6.0 is required")
C
update  
channingss 已提交
163 164
            return
    except:
J
jiangjiajun@baidu.com 已提交
165
        print("[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\".")
C
update  
channingss 已提交
166
        return
C
channingss 已提交
167
    print("Now translating model from onnx to paddle.")
C
update  
channingss 已提交
168

C
Channingss 已提交
169
    from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper
C
update  
channingss 已提交
170
    from x2paddle.decoder.onnx_decoder import ONNXDecoder
R
root 已提交
171 172
    from x2paddle.optimizer.onnx_optimizer import ONNXOptimizer
    model = ONNXDecoder(model_path)
C
Channingss 已提交
173
    mapper = ONNXOpMapper(model)
174
    print("Model optimizing ...")
C
update  
channingss 已提交
175
    optimizer = ONNXOptimizer(mapper)
176
    print("Model optimized.")
C
channingss 已提交
177

178
    print("Paddle model and code generating ...")
M
mamingjie-China 已提交
179
    mapper.save_inference_model(save_dir, params_merge)
180 181 182
    print("Paddle model and code generated.")


S
SunAhong1993 已提交
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
def pytorch2paddle(model_path, save_dir, input_shapes):
    # check pytorch installation and version
    try:
        import torch
        version = torch.__version__
        ver_part = version.split('.')
        print(ver_part)
        if int(ver_part[1]) < 5:
            print("[ERROR] pytorch>=1.5.0 is required")
            return
    except:
        print(
            "[ERROR] Pytorch is not installed, use \"pip install torch==1.5.0 torchvision\"."
        )
        return
    print("Now translating model from pytorch to paddle.")

    from x2paddle.decoder.pytorch_decoder import PyTorchDecoder
    from x2paddle.op_mapper.pytorch2paddle import pytorch_op_mapper
    model = PyTorchDecoder(model_path)
    mapper = pytorch_op_mapper.PyTorchOpMapper(model)
    mapper.graph.build()
    print("Model optimizing ...")
    from x2paddle.optimizer.pytorch_optimizer.optimizer import GraphOptimizer
    graph_opt = GraphOptimizer()
    graph_opt.optimize(mapper.graph)
    print("Model optimized.")
    if input_shapes is not None:
        real_input_shapes = list()
        for shape in input_shapes:
            sp = shape[1:-1].split(",")
            for i, s in enumerate(sp):
                sp[i] = int(s)
            real_input_shapes.append(sp)
    else:
        real_input_shapes = None
    mapper.graph.gen_model(save_dir, real_input_shapes)


J
Jason 已提交
222
def paddle2onnx(model_path, save_dir, opset_version=10):
223
    from x2paddle.decoder.paddle_decoder import PaddleDecoder
C
Channingss 已提交
224
    from x2paddle.op_mapper.paddle2onnx.paddle_op_mapper import PaddleOpMapper
225 226
    model = PaddleDecoder(model_path, '__model__', '__params__')
    mapper = PaddleOpMapper()
S
SunAhong1993 已提交
227
    mapper.convert(model.program, save_dir, opset_number=opset_version)
C
update  
channingss 已提交
228 229


230
def main():
J
jiangjiajun 已提交
231
    if len(sys.argv) < 2:
C
update  
channingss 已提交
232
        print("Use \"x2paddle -h\" to print the help information")
J
jiangjiajun 已提交
233 234
        print("For more information, please follow our github repo below:)")
        print("\nGithub: https://github.com/PaddlePaddle/X2Paddle.git\n")
J
jiangjiajun 已提交
235 236
        return

237 238 239
    parser = arg_parser()
    args = parser.parse_args()

J
jiangjiajun 已提交
240
    if args.version:
J
jiangjiajun 已提交
241
        import x2paddle
M
mamingjie-China 已提交
242
        print("x2paddle-{} with python>=3.5, paddlepaddle>=1.6.0\n".format(
J
jiangjiajun 已提交
243
            x2paddle.__version__))
J
jiangjiajun 已提交
244 245
        return

J
Jason 已提交
246
    assert args.framework is not None, "--framework is not defined(support tensorflow/caffe/onnx)"
247
    assert args.save_dir is not None, "--save_dir is not defined"
M
mamingjie-China 已提交
248

M
mamingjie-China 已提交
249 250 251
    try:
        import paddle
        v0, v1, v2 = paddle.__version__.split('.')
252 253 254 255
        print("paddle.__version__ = {}".format(paddle.__version__))
        if v0 == '0' and v1 == '0' and v2 == '0':
            print("[WARNING] You are use develop version of paddlepaddle")
        elif int(v0) != 1 or int(v1) < 6:
J
jiangjiajun@baidu.com 已提交
256
            print("[ERROR] paddlepaddle>=1.6.0 is required")
M
mamingjie-China 已提交
257 258
            return
    except:
J
jiangjiajun@baidu.com 已提交
259 260 261
        print(
            "[ERROR] paddlepaddle not installed, use \"pip install paddlepaddle\""
        )
262 263

    if args.framework == "tensorflow":
J
jiangjiajun 已提交
264
        assert args.model is not None, "--model should be defined while translating tensorflow model"
S
SunAhong1993 已提交
265
        without_data_format_optimization = False
266
        define_input_shape = False
M
mamingjie-China 已提交
267
        params_merge = False
S
SunAhong1993 已提交
268 269
        if args.without_data_format_optimization:
            without_data_format_optimization = True
270 271
        if args.define_input_shape:
            define_input_shape = True
M
mamingjie-China 已提交
272 273
        if args.params_merge:
            params_merge = True
274
        tf2paddle(args.model, args.save_dir, without_data_format_optimization,
M
mamingjie-China 已提交
275
                  define_input_shape, params_merge)
276 277

    elif args.framework == "caffe":
S
SunAhong1993 已提交
278
        assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model"
M
mamingjie-China 已提交
279 280 281
        params_merge = False
        if args.params_merge:
            params_merge = True
S
SunAhong1993 已提交
282
        caffe2paddle(args.prototxt, args.weight, args.save_dir,
M
mamingjie-China 已提交
283
                     args.caffe_proto, params_merge)
C
update  
channingss 已提交
284 285
    elif args.framework == "onnx":
        assert args.model is not None, "--model should be defined while translating onnx model"
M
mamingjie-China 已提交
286
        params_merge = False
287

M
mamingjie-China 已提交
288 289 290
        if args.params_merge:
            params_merge = True
        onnx2paddle(args.model, args.save_dir, params_merge)
S
SunAhong1993 已提交
291 292 293
    elif args.framework == "pytorch":
        assert args.model is not None, "--model should be defined while translating pytorch model"
        pytorch2paddle(args.model, args.save_dir, args.input_shapes)
294 295 296

    elif args.framework == "paddle2onnx":
        assert args.model is not None, "--model should be defined while translating paddle model to onnx"
S
SunAhong1993 已提交
297
        paddle2onnx(args.model, args.save_dir, args.onnx_opset)
298

299
    else:
300 301
        raise Exception(
            "--framework only support tensorflow/caffe/onnx/paddle2onnx now")
302 303 304 305


if __name__ == "__main__":
    main()