convert.py 9.4 KB
Newer Older
S
SunAhong1993 已提交
1
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
J
jiangjiajun 已提交
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
S
SunAhong1993 已提交
14

15
from six import text_type as _text_type
S
SunAhong1993 已提交
16
from x2paddle import program
17
import argparse
J
jiangjiajun 已提交
18
import sys
19

J
jiangjiajun 已提交
20

21 22
def arg_parser():
    parser = argparse.ArgumentParser()
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
    parser.add_argument(
        "--model",
        "-m",
        type=_text_type,
        default=None,
        help="define model file path for tensorflow or onnx")
    parser.add_argument(
        "--prototxt",
        "-p",
        type=_text_type,
        default=None,
        help="prototxt file of caffe model")
    parser.add_argument(
        "--weight",
        "-w",
        type=_text_type,
        default=None,
        help="weight file of caffe model")
    parser.add_argument(
        "--save_dir",
        "-s",
        type=_text_type,
        default=None,
        help="path to save translated model")
J
upgrade  
jiangjiajun 已提交
47 48 49
    parser.add_argument(
        "--framework",
        "-f",
50 51
        choices=['tensorflow', 'caffe', 'onnx'],
        help="define which deeplearning framework(tensorflow/caffe/onnx)")
S
SunAhong1993 已提交
52 53 54 55 56
    parser.add_argument(
        "--caffe_proto",
        "-c",
        type=_text_type,
        default=None,
J
upgrade  
jiangjiajun 已提交
57 58
        help="optional: the .py file compiled by caffe proto file of caffe model"
    )
59 60 61 62 63 64 65 66 67 68 69 70
    parser.add_argument(
        "--version",
        "-v",
        action="store_true",
        default=False,
        help="get version of x2paddle")
    parser.add_argument(
        "--define_input_shape",
        "-d",
        action="store_true",
        default=False,
        help="define input shape for tf model")
S
SunAhong1993 已提交
71
    parser.add_argument(
S
SunAhong1993 已提交
72 73
        "--paddle_type",
        "-pt",
74
        choices=['dygraph', 'static'],
S
SunAhong1993 已提交
75
        default="dygraph",
76 77
        help="define the paddle model type after converting(dygraph/static)")

78
    return parser
J
jiangjiajun 已提交
79

C
Channingss 已提交
80

81 82
def tf2paddle(model_path,
              save_dir,
M
mamingjie-China 已提交
83
              define_input_shape=False,
S
SunAhong1993 已提交
84
              paddle_type="dygraph"):
J
jiangjiajun 已提交
85 86
    # check tensorflow installation and version
    try:
87 88
        import os
        os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3'
J
jiangjiajun 已提交
89 90 91 92
        import tensorflow as tf
        version = tf.__version__
        if version >= '2.0.0' or version < '1.0.0':
            print(
J
jiangjiajun@baidu.com 已提交
93
                "[ERROR] 1.0.0<=tensorflow<2.0.0 is required, and v1.14.0 is recommended"
J
jiangjiajun 已提交
94 95 96
            )
            return
    except:
J
jiangjiajun@baidu.com 已提交
97 98 99
        print(
            "[ERROR] Tensorflow is not installed, use \"pip install tensorflow\"."
        )
J
jiangjiajun 已提交
100
        return
101

J
jiangjiajun 已提交
102
    from x2paddle.decoder.tf_decoder import TFDecoder
S
SunAhong1993 已提交
103 104 105 106
    if paddle_type == "dygraph":
        from x2paddle.op_mapper.dygraph.tf2paddle.tf_op_mapper import TFOpMapper
    else:
        from x2paddle.op_mapper.static.tf2paddle.tf_op_mapper import TFOpMapper
107

J
jiangjiajun 已提交
108
    print("Now translating model from tensorflow to paddle.")
109
    model = TFDecoder(model_path, define_input_shape=define_input_shape)
S
SunAhong1993 已提交
110
    mapper = TFOpMapper(model)
S
SunAhong1993 已提交
111
    mapper.paddle_graph.build()
S
SunAhong1993 已提交
112 113 114 115 116
    if paddle_type == "dygraph":
        from x2paddle.optimizer.optimizer import GraphOptimizer
        graph_opt = GraphOptimizer(source_frame="tf", paddle_type=paddle_type)
        graph_opt.optimize(mapper.paddle_graph)
    else:
S
SunAhong1993 已提交
117 118 119
        from x2paddle.optimizer.optimizer import GraphOptimizer
        graph_opt = GraphOptimizer(source_frame="tf", paddle_type=paddle_type)
        graph_opt.optimize(mapper.paddle_graph)
S
SunAhong1993 已提交
120
    mapper.paddle_graph.gen_model(save_dir)
121 122


S
SunAhong1993 已提交
123
def caffe2paddle(proto, weight, save_dir, caffe_proto, paddle_type):
J
jiangjiajun 已提交
124
    from x2paddle.decoder.caffe_decoder import CaffeDecoder
S
SunAhong1993 已提交
125
    if paddle_type == "dygraph":
S
SunAhong1993 已提交
126
        from x2paddle.op_mapper.dygraph.caffe2paddle.caffe_op_mapper import CaffeOpMapper
S
SunAhong1993 已提交
127
    else:
S
SunAhong1993 已提交
128
        from x2paddle.op_mapper.static.caffe2paddle.caffe_op_mapper import CaffeOpMapper
S
SunAhong1993 已提交
129
    import google.protobuf as gpb
S
SunAhong1993 已提交
130 131 132 133 134
    ver_part = gpb.__version__.split('.')
    version_satisfy = False
    if (int(ver_part[0]) == 3 and int(ver_part[1]) >= 6) \
        or (int(ver_part[0]) > 3):
        version_satisfy = True
J
jiangjiajun@baidu.com 已提交
135
    assert version_satisfy, '[ERROR] google.protobuf >= 3.6.0 is required'
J
jiangjiajun 已提交
136
    print("Now translating model from caffe to paddle.")
S
SunAhong1993 已提交
137
    model = CaffeDecoder(proto, weight, caffe_proto)
J
jiangjiajun 已提交
138
    mapper = CaffeOpMapper(model)
S
SunAhong1993 已提交
139 140 141 142 143 144 145
    mapper.paddle_graph.build()
    print("Model optimizing ...")
    from x2paddle.optimizer.optimizer import GraphOptimizer
    graph_opt = GraphOptimizer(source_frame="caffe", paddle_type=paddle_type)
    graph_opt.optimize(mapper.paddle_graph)
    print("Model optimized.")
    mapper.paddle_graph.gen_model(save_dir)
146 147


S
SunAhong1993 已提交
148
def onnx2paddle(model_path, save_dir, paddle_type):
C
update  
channingss 已提交
149 150 151 152
    # check onnx installation and version
    try:
        import onnx
        version = onnx.version.version
S
SunAhong1993 已提交
153 154
        if version < '1.6.0':
            print("[ERROR] onnx>=1.6.0 is required")
C
update  
channingss 已提交
155 156
            return
    except:
J
jiangjiajun@baidu.com 已提交
157
        print("[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\".")
C
update  
channingss 已提交
158
        return
C
channingss 已提交
159
    print("Now translating model from onnx to paddle.")
C
update  
channingss 已提交
160 161

    from x2paddle.decoder.onnx_decoder import ONNXDecoder
S
SunAhong1993 已提交
162 163 164 165
    if paddle_type == "dygraph":
        from x2paddle.op_mapper.dygraph.onnx2paddle.onnx_op_mapper import ONNXOpMapper
    else:
        from x2paddle.op_mapper.static.onnx2paddle.onnx_op_mapper import ONNXOpMapper
R
root 已提交
166
    model = ONNXDecoder(model_path)
C
Channingss 已提交
167
    mapper = ONNXOpMapper(model)
S
SunAhong1993 已提交
168 169
    mapper.paddle_graph.build()
    mapper.paddle_graph.gen_model(save_dir)
C
Channingss 已提交
170 171


S
SunAhong1993 已提交
172
def pytorch2paddle(module, save_dir, jit_type="trace", input_examples=None):
S
SunAhong1993 已提交
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
    # check pytorch installation and version
    try:
        import torch
        version = torch.__version__
        ver_part = version.split('.')
        print(ver_part)
        if int(ver_part[1]) < 5:
            print("[ERROR] pytorch>=1.5.0 is required")
            return
    except:
        print(
            "[ERROR] Pytorch is not installed, use \"pip install torch==1.5.0 torchvision\"."
        )
        return
    print("Now translating model from pytorch to paddle.")
188

S
SunAhong1993 已提交
189
    from x2paddle.decoder.pytorch_decoder import ScriptDecoder, TraceDecoder
S
SunAhong1993 已提交
190
    from x2paddle.op_mapper.dygraph.pytorch2paddle.pytorch_op_mapper import PyTorchOpMapper
S
SunAhong1993 已提交
191

S
SunAhong1993 已提交
192
    if jit_type == "trace":
S
SunAhong1993 已提交
193
        model = TraceDecoder(module, input_examples)
S
SunAhong1993 已提交
194
    else:
S
SunAhong1993 已提交
195
        model = ScriptDecoder(module, input_examples)
S
SunAhong1993 已提交
196 197
    mapper = PyTorchOpMapper(model)
    mapper.paddle_graph.build()
S
SunAhong1993 已提交
198
    print("Model optimizing ...")
S
SunAhong1993 已提交
199
    from x2paddle.optimizer.optimizer import GraphOptimizer
200 201
    graph_opt = GraphOptimizer(
        source_frame="pytorch", paddle_type="dygraph", jit_type=jit_type)
S
SunAhong1993 已提交
202
    graph_opt.optimize(mapper.paddle_graph)
S
SunAhong1993 已提交
203
    print("Model optimized.")
S
SunAhong1993 已提交
204
    mapper.paddle_graph.gen_model(save_dir, jit_type=jit_type)
S
SunAhong1993 已提交
205 206


207
def main():
J
jiangjiajun 已提交
208
    if len(sys.argv) < 2:
C
update  
channingss 已提交
209
        print("Use \"x2paddle -h\" to print the help information")
J
jiangjiajun 已提交
210 211
        print("For more information, please follow our github repo below:)")
        print("\nGithub: https://github.com/PaddlePaddle/X2Paddle.git\n")
J
jiangjiajun 已提交
212 213
        return

214 215 216
    parser = arg_parser()
    args = parser.parse_args()

J
jiangjiajun 已提交
217
    if args.version:
J
jiangjiajun 已提交
218
        import x2paddle
M
mamingjie-China 已提交
219
        print("x2paddle-{} with python>=3.5, paddlepaddle>=1.6.0\n".format(
J
jiangjiajun 已提交
220
            x2paddle.__version__))
J
jiangjiajun 已提交
221 222
        return

S
SunAhong1993 已提交
223 224
    assert args.save_dir is not None, "--save_dir is not defined"

M
mamingjie-China 已提交
225
    try:
S
add lrn  
SunAhong1993 已提交
226 227
        import platform
        v0, v1, v2 = platform.python_version().split('.')
228
        if not (int(v0) >= 3 and int(v1) >= 5):
S
add lrn  
SunAhong1993 已提交
229 230
            print("[ERROR] python>=3.5 is required")
            return
M
mamingjie-China 已提交
231 232
        import paddle
        v0, v1, v2 = paddle.__version__.split('.')
233 234 235
        print("paddle.__version__ = {}".format(paddle.__version__))
        if v0 == '0' and v1 == '0' and v2 == '0':
            print("[WARNING] You are use develop version of paddlepaddle")
S
SunAhong1993 已提交
236 237
        elif int(v0) != 2 or int(v1) < 0:
            print("[ERROR] paddlepaddle>=2.0.0 is required")
M
mamingjie-China 已提交
238 239
            return
    except:
J
jiangjiajun@baidu.com 已提交
240 241 242
        print(
            "[ERROR] paddlepaddle not installed, use \"pip install paddlepaddle\""
        )
243 244

    if args.framework == "tensorflow":
J
jiangjiajun 已提交
245
        assert args.model is not None, "--model should be defined while translating tensorflow model"
246 247 248
        define_input_shape = False
        if args.define_input_shape:
            define_input_shape = True
249 250
        tf2paddle(args.model, args.save_dir, define_input_shape,
                  args.paddle_type)
251 252

    elif args.framework == "caffe":
S
SunAhong1993 已提交
253 254
        assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model"
        caffe2paddle(args.prototxt, args.weight, args.save_dir,
S
SunAhong1993 已提交
255
                     args.caffe_proto, args.paddle_type)
C
update  
channingss 已提交
256 257
    elif args.framework == "onnx":
        assert args.model is not None, "--model should be defined while translating onnx model"
S
SunAhong1993 已提交
258
        onnx2paddle(args.model, args.save_dir, args.paddle_type)
259
    elif args.framework == "paddle2onnx":
260 261 262
        print(
            "Paddle to ONNX tool has been migrated to the new github: https://github.com/PaddlePaddle/paddle2onnx"
        )
263

264
    else:
265
        raise Exception("--framework only support tensorflow/caffe/onnx now")
266 267 268


if __name__ == "__main__":
S
SunAhong1993 已提交
269
    main()