convert.py 10.1 KB
Newer Older
1
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
J
jiangjiajun 已提交
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
S
SunAhong1993 已提交
14

15 16
from six import text_type as _text_type
import argparse
J
jiangjiajun 已提交
17
import sys
18

J
jiangjiajun 已提交
19

20 21
def arg_parser():
    parser = argparse.ArgumentParser()
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
    parser.add_argument(
        "--model",
        "-m",
        type=_text_type,
        default=None,
        help="define model file path for tensorflow or onnx")
    parser.add_argument(
        "--prototxt",
        "-p",
        type=_text_type,
        default=None,
        help="prototxt file of caffe model")
    parser.add_argument(
        "--weight",
        "-w",
        type=_text_type,
        default=None,
        help="weight file of caffe model")
    parser.add_argument(
        "--save_dir",
        "-s",
        type=_text_type,
        default=None,
        help="path to save translated model")
J
upgrade  
jiangjiajun 已提交
46 47 48 49 50
    parser.add_argument(
        "--framework",
        "-f",
        type=_text_type,
        default=None,
51 52
        help="define which deeplearning framework(tensorflow/caffe/onnx/paddle2onnx)"
    )
S
SunAhong1993 已提交
53 54 55 56 57
    parser.add_argument(
        "--caffe_proto",
        "-c",
        type=_text_type,
        default=None,
J
upgrade  
jiangjiajun 已提交
58 59
        help="optional: the .py file compiled by caffe proto file of caffe model"
    )
60 61 62 63 64 65
    parser.add_argument(
        "--version",
        "-v",
        action="store_true",
        default=False,
        help="get version of x2paddle")
66 67 68
    parser.add_argument(
        "--without_data_format_optimization",
        "-wo",
S
SunAhong1993 已提交
69 70
        type=_text_type,
        default="True",
71
        help="tf model conversion without data format optimization")
72 73 74 75 76 77 78 79 80 81 82 83
    parser.add_argument(
        "--define_input_shape",
        "-d",
        action="store_true",
        default=False,
        help="define input shape for tf model")
    parser.add_argument(
        "--params_merge",
        "-pm",
        action="store_true",
        default=False,
        help="define whether merge the params")
S
SunAhong1993 已提交
84 85 86 87 88 89
    parser.add_argument(
        "--input_shapes",
        "-is",
        action='append',
        default=None,
        help="define the inputs' shape")
90
    return parser
J
jiangjiajun 已提交
91

C
Channingss 已提交
92

93 94
def tf2paddle(model_path,
              save_dir,
J
jiangjiajun 已提交
95
              without_data_format_optimization=False,
M
mamingjie-China 已提交
96
              define_input_shape=False,
M
mamingjie-China 已提交
97
              params_merge=False):
J
jiangjiajun 已提交
98 99
    # check tensorflow installation and version
    try:
100 101
        import os
        os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3'
J
jiangjiajun 已提交
102 103 104 105
        import tensorflow as tf
        version = tf.__version__
        if version >= '2.0.0' or version < '1.0.0':
            print(
J
jiangjiajun@baidu.com 已提交
106
                "[ERROR] 1.0.0<=tensorflow<2.0.0 is required, and v1.14.0 is recommended"
J
jiangjiajun 已提交
107 108 109
            )
            return
    except:
J
jiangjiajun@baidu.com 已提交
110 111 112
        print(
            "[ERROR] Tensorflow is not installed, use \"pip install tensorflow\"."
        )
J
jiangjiajun 已提交
113
        return
J
jiangjiajun 已提交
114
    from x2paddle import program
J
jiangjiajun 已提交
115
    from x2paddle.decoder.tf_decoder import TFDecoder
J
jiangjiajun 已提交
116
    from x2paddle.op_mapper.tf_op_mapper import TFOpMapper
J
jiangjiajun 已提交
117 118 119
    from x2paddle.optimizer.tensorflow.bias import BiasOpt
    from x2paddle.optimizer.tensorflow.transpose import TransposeOpt
    from x2paddle.optimizer.tensorflow.batch_norm import BatchNormOpt
S
SunAhong1993 已提交
120
    from x2paddle.optimizer.tensorflow.prelu import PReLUOpt
J
jiangjiajun 已提交
121 122

    print("Now translating model from tensorflow to paddle.")
123
    model = TFDecoder(model_path, define_input_shape=define_input_shape)
J
jiangjiajun 已提交
124 125 126 127 128
    mapper = TFOpMapper(model)
    program.build()
    bias_opt = BiasOpt()
    transpose_opt = TransposeOpt()
    batch_norm_opt = BatchNormOpt()
S
SunAhong1993 已提交
129
    prelu_opt = PReLUOpt()
J
jiangjiajun 已提交
130 131
    bias_opt.run(program)
    batch_norm_opt.run(program)
S
SunAhong1993 已提交
132
    prelu_opt.run(program)
J
jiangjiajun 已提交
133 134
    transpose_opt.run(program)
    program.gen_model(save_dir)
135 136


M
mamingjie-China 已提交
137
def caffe2paddle(proto, weight, save_dir, caffe_proto, params_merge=False):
J
jiangjiajun 已提交
138 139
    from x2paddle.decoder.caffe_decoder import CaffeDecoder
    from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper
140
    from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer
S
SunAhong1993 已提交
141
    import google.protobuf as gpb
S
SunAhong1993 已提交
142 143 144 145 146
    ver_part = gpb.__version__.split('.')
    version_satisfy = False
    if (int(ver_part[0]) == 3 and int(ver_part[1]) >= 6) \
        or (int(ver_part[0]) > 3):
        version_satisfy = True
J
jiangjiajun@baidu.com 已提交
147
    assert version_satisfy, '[ERROR] google.protobuf >= 3.6.0 is required'
J
jiangjiajun 已提交
148
    print("Now translating model from caffe to paddle.")
S
SunAhong1993 已提交
149
    model = CaffeDecoder(proto, weight, caffe_proto)
J
jiangjiajun 已提交
150
    mapper = CaffeOpMapper(model)
151 152 153
    optimizer = CaffeOptimizer(mapper)
    optimizer.merge_bn_scale()
    optimizer.merge_op_activation()
M
mamingjie-China 已提交
154
    mapper.save_inference_model(save_dir, params_merge)
155 156


M
mamingjie-China 已提交
157
def onnx2paddle(model_path, save_dir, params_merge=False):
C
update  
channingss 已提交
158 159 160 161
    # check onnx installation and version
    try:
        import onnx
        version = onnx.version.version
S
SunAhong1993 已提交
162 163
        if version < '1.6.0':
            print("[ERROR] onnx>=1.6.0 is required")
C
update  
channingss 已提交
164 165
            return
    except:
J
jiangjiajun@baidu.com 已提交
166
        print("[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\".")
C
update  
channingss 已提交
167
        return
C
channingss 已提交
168
    print("Now translating model from onnx to paddle.")
C
update  
channingss 已提交
169

C
Channingss 已提交
170
    from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper
C
update  
channingss 已提交
171
    from x2paddle.decoder.onnx_decoder import ONNXDecoder
R
root 已提交
172 173
    from x2paddle.optimizer.onnx_optimizer import ONNXOptimizer
    model = ONNXDecoder(model_path)
C
Channingss 已提交
174
    mapper = ONNXOpMapper(model)
175
    print("Model optimizing ...")
C
update  
channingss 已提交
176
    optimizer = ONNXOptimizer(mapper)
C
Channingss 已提交
177
    optimizer.delete_redundance_code()
178
    print("Model optimized.")
C
channingss 已提交
179

180
    print("Paddle model and code generating ...")
M
mamingjie-China 已提交
181
    mapper.save_inference_model(save_dir, params_merge)
182
    print("Paddle model and code generated.")
C
Channingss 已提交
183 184


S
SunAhong1993 已提交
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
def pytorch2paddle(model_path, save_dir, input_shapes):
    # check pytorch installation and version
    try:
        import torch
        version = torch.__version__
        ver_part = version.split('.')
        print(ver_part)
        if int(ver_part[1]) < 5:
            print("[ERROR] pytorch>=1.5.0 is required")
            return
    except:
        print(
            "[ERROR] Pytorch is not installed, use \"pip install torch==1.5.0 torchvision\"."
        )
        return
    print("Now translating model from pytorch to paddle.")

    from x2paddle.decoder.pytorch_decoder import PyTorchDecoder
    from x2paddle.op_mapper.pytorch2paddle import pytorch_op_mapper
    model = PyTorchDecoder(model_path)
    mapper = pytorch_op_mapper.PyTorchOpMapper(model)
    mapper.graph.build()
    print("Model optimizing ...")
    from x2paddle.optimizer.pytorch_optimizer.optimizer import GraphOptimizer
    graph_opt = GraphOptimizer()
    graph_opt.optimize(mapper.graph)
    print("Model optimized.")
    if input_shapes is not None:
        real_input_shapes = list()
        for shape in input_shapes:
            sp = shape[1:-1].split(",")
            for i, s in enumerate(sp):
                sp[i] = int(s)
            real_input_shapes.append(sp)
    else:
        real_input_shapes = None
    mapper.graph.gen_model(save_dir, real_input_shapes)


224
def main():
J
jiangjiajun 已提交
225
    if len(sys.argv) < 2:
C
update  
channingss 已提交
226
        print("Use \"x2paddle -h\" to print the help information")
J
jiangjiajun 已提交
227 228
        print("For more information, please follow our github repo below:)")
        print("\nGithub: https://github.com/PaddlePaddle/X2Paddle.git\n")
J
jiangjiajun 已提交
229 230
        return

231 232 233
    parser = arg_parser()
    args = parser.parse_args()

J
jiangjiajun 已提交
234
    if args.version:
J
jiangjiajun 已提交
235
        import x2paddle
M
mamingjie-China 已提交
236
        print("x2paddle-{} with python>=3.5, paddlepaddle>=1.6.0\n".format(
J
jiangjiajun 已提交
237
            x2paddle.__version__))
J
jiangjiajun 已提交
238 239
        return

J
Jason 已提交
240
    assert args.framework is not None, "--framework is not defined(support tensorflow/caffe/onnx)"
241
    assert args.save_dir is not None, "--save_dir is not defined"
M
mamingjie-China 已提交
242

M
mamingjie-China 已提交
243 244 245
    try:
        import paddle
        v0, v1, v2 = paddle.__version__.split('.')
246 247 248 249
        print("paddle.__version__ = {}".format(paddle.__version__))
        if v0 == '0' and v1 == '0' and v2 == '0':
            print("[WARNING] You are use develop version of paddlepaddle")
        elif int(v0) != 1 or int(v1) < 6:
J
jiangjiajun@baidu.com 已提交
250
            print("[ERROR] paddlepaddle>=1.6.0 is required")
M
mamingjie-China 已提交
251 252
            return
    except:
J
jiangjiajun@baidu.com 已提交
253 254 255
        print(
            "[ERROR] paddlepaddle not installed, use \"pip install paddlepaddle\""
        )
256 257

    if args.framework == "tensorflow":
J
jiangjiajun 已提交
258
        assert args.model is not None, "--model should be defined while translating tensorflow model"
S
SunAhong1993 已提交
259 260 261
        assert args.without_data_format_optimization in [
            "True", "False"
        ], "--the param without_data_format_optimization should be defined True or False"
262
        define_input_shape = False
M
mamingjie-China 已提交
263
        params_merge = False
S
SunAhong1993 已提交
264
        without_data_format_optimization = True if args.without_data_format_optimization == "True" else False
265 266
        if args.define_input_shape:
            define_input_shape = True
M
mamingjie-China 已提交
267 268
        if args.params_merge:
            params_merge = True
269
        tf2paddle(args.model, args.save_dir, without_data_format_optimization,
M
mamingjie-China 已提交
270
                  define_input_shape, params_merge)
271 272

    elif args.framework == "caffe":
S
SunAhong1993 已提交
273
        assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model"
M
mamingjie-China 已提交
274 275 276
        params_merge = False
        if args.params_merge:
            params_merge = True
S
SunAhong1993 已提交
277
        caffe2paddle(args.prototxt, args.weight, args.save_dir,
M
mamingjie-China 已提交
278
                     args.caffe_proto, params_merge)
C
update  
channingss 已提交
279 280
    elif args.framework == "onnx":
        assert args.model is not None, "--model should be defined while translating onnx model"
M
mamingjie-China 已提交
281
        params_merge = False
282

M
mamingjie-China 已提交
283 284 285
        if args.params_merge:
            params_merge = True
        onnx2paddle(args.model, args.save_dir, params_merge)
S
SunAhong1993 已提交
286
        
287
    elif args.framework == "paddle2onnx":
C
Channingss 已提交
288
        print("Paddle to ONNX tool has been migrated to the new github: https://github.com/PaddlePaddle/paddle2onnx")
289

290
    else:
291
        raise Exception(
C
Channingss 已提交
292
            "--framework only support tensorflow/caffe/onnx/ now")
293 294 295


if __name__ == "__main__":
C
Channingss 已提交
296
    main()