convert.py 7.7 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
S
SunAhong1993 已提交
14

15 16
from six import text_type as _text_type
import argparse
J
jiangjiajun 已提交
17
import sys
18

J
jiangjiajun 已提交
19

20 21 22 23 24 25
def arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model",
                        "-m",
                        type=_text_type,
                        default=None,
J
upgrade  
jiangjiajun 已提交
26
                        help="define model file path for tensorflow or onnx")
S
SunAhong1993 已提交
27
    parser.add_argument("--prototxt",
28 29 30
                        "-p",
                        type=_text_type,
                        default=None,
S
SunAhong1993 已提交
31
                        help="prototxt file of caffe model")
32 33 34 35 36 37 38 39 40 41
    parser.add_argument("--weight",
                        "-w",
                        type=_text_type,
                        default=None,
                        help="weight file of caffe model")
    parser.add_argument("--save_dir",
                        "-s",
                        type=_text_type,
                        default=None,
                        help="path to save translated model")
J
upgrade  
jiangjiajun 已提交
42 43 44 45 46 47
    parser.add_argument(
        "--framework",
        "-f",
        type=_text_type,
        default=None,
        help="define which deeplearning framework(tensorflow/caffe/onnx)")
S
SunAhong1993 已提交
48 49 50 51 52
    parser.add_argument(
        "--caffe_proto",
        "-c",
        type=_text_type,
        default=None,
J
upgrade  
jiangjiajun 已提交
53 54
        help="optional: the .py file compiled by caffe proto file of caffe model"
    )
J
jiangjiajun 已提交
55 56 57 58 59
    parser.add_argument("--version",
                        "-v",
                        action="store_true",
                        default=False,
                        help="get version of x2paddle")
60 61 62 63 64 65 66 67 68 69 70
    parser.add_argument(
        "--without_data_format_optimization",
        "-wo",
        action="store_true",
        default=False,
        help="tf model conversion without data format optimization")
    parser.add_argument("--define_input_shape",
                        "-d",
                        action="store_true",
                        default=False,
                        help="define input shape for tf model")
J
jiangjiajun 已提交
71

72
    return parser
J
jiangjiajun 已提交
73

74

75 76 77 78
def tf2paddle(model_path,
              save_dir,
              without_data_format_optimization=False,
              define_input_shape=False):
J
jiangjiajun 已提交
79 80
    # check tensorflow installation and version
    try:
81 82
        import os
        os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3'
J
jiangjiajun 已提交
83 84 85 86 87 88 89 90
        import tensorflow as tf
        version = tf.__version__
        if version >= '2.0.0' or version < '1.0.0':
            print(
                "1.0.0<=tensorflow<2.0.0 is required, and v1.14.0 is recommended"
            )
            return
    except:
S
SunAhong1993 已提交
91
        print("Tensorflow is not installed, use \"pip install tensorflow\".")
J
jiangjiajun 已提交
92 93
        return

J
jiangjiajun 已提交
94
    from x2paddle.decoder.tf_decoder import TFDecoder
J
jiangjiajun 已提交
95
    from x2paddle.op_mapper.tf_op_mapper import TFOpMapper
96
    from x2paddle.op_mapper.tf_op_mapper_nhwc import TFOpMapperNHWC
J
jiangjiajun 已提交
97
    from x2paddle.optimizer.tf_optimizer import TFOptimizer
J
jiangjiajun 已提交
98 99

    print("Now translating model from tensorflow to paddle.")
100
    model = TFDecoder(model_path, define_input_shape=define_input_shape)
J
jiangjiajun 已提交
101 102 103 104 105 106
    mapper = TFOpMapperNHWC(model)
    optimizer = TFOptimizer(mapper)
    optimizer.delete_redundance_code()
    optimizer.strip_graph()
    #        optimizer.merge_activation()
    #        optimizer.merge_bias()
J
jiangjiajun 已提交
107
    mapper.save_inference_model(save_dir)
108 109


S
SunAhong1993 已提交
110
def caffe2paddle(proto, weight, save_dir, caffe_proto):
J
jiangjiajun 已提交
111 112
    from x2paddle.decoder.caffe_decoder import CaffeDecoder
    from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper
113
    from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer
S
SunAhong1993 已提交
114
    import google.protobuf as gpb
S
SunAhong1993 已提交
115 116 117 118 119
    ver_part = gpb.__version__.split('.')
    version_satisfy = False
    if (int(ver_part[0]) == 3 and int(ver_part[1]) >= 6) \
        or (int(ver_part[0]) > 3):
        version_satisfy = True
S
SunAhong1993 已提交
120
    assert version_satisfy, 'google.protobuf >= 3.6.0 is required'
J
jiangjiajun 已提交
121
    print("Now translating model from caffe to paddle.")
S
SunAhong1993 已提交
122
    model = CaffeDecoder(proto, weight, caffe_proto)
J
jiangjiajun 已提交
123
    mapper = CaffeOpMapper(model)
124 125 126
    optimizer = CaffeOptimizer(mapper)
    optimizer.merge_bn_scale()
    optimizer.merge_op_activation()
J
jiangjiajun 已提交
127
    mapper.save_inference_model(save_dir)
128 129


C
update  
channingss 已提交
130 131 132 133 134 135 136 137 138 139 140
def onnx2paddle(model_path, save_dir):
    # check onnx installation and version
    try:
        import onnx
        version = onnx.version.version
        if version != '1.5.0':
            print("onnx==1.5.0 is required")
            return
    except:
        print("onnx is not installed, use \"pip install onnx==1.5.0\".")
        return
C
channingss 已提交
141
    print("Now translating model from onnx to paddle.")
C
update  
channingss 已提交
142 143

    from x2paddle.decoder.onnx_decoder import ONNXDecoder
C
channingss 已提交
144
    model = ONNXDecoder(model_path)
C
channingss 已提交
145 146

    from x2paddle.op_mapper.onnx_op_mapper import ONNXOpMapper
C
channingss 已提交
147
    mapper = ONNXOpMapper(model, save_dir)
C
channingss 已提交
148 149

    from x2paddle.optimizer.onnx_optimizer import ONNXOptimizer
C
update  
channingss 已提交
150
    optimizer = ONNXOptimizer(mapper)
C
channingss 已提交
151

C
update  
channingss 已提交
152 153 154 155
    optimizer.delete_redundance_code()
    mapper.save_inference_model(save_dir)


156
def main():
J
jiangjiajun 已提交
157
    if len(sys.argv) < 2:
C
update  
channingss 已提交
158
        print("Use \"x2paddle -h\" to print the help information")
J
jiangjiajun 已提交
159 160
        print("For more information, please follow our github repo below:)")
        print("\nGithub: https://github.com/PaddlePaddle/X2Paddle.git\n")
J
jiangjiajun 已提交
161 162
        return

163 164 165
    parser = arg_parser()
    args = parser.parse_args()

J
jiangjiajun 已提交
166
    if args.version:
J
jiangjiajun 已提交
167
        import x2paddle
J
jiangjiajun 已提交
168
        print("x2paddle-{} with python>=3.5, paddlepaddle>=1.6.1\n".format(
J
jiangjiajun 已提交
169
            x2paddle.__version__))
J
jiangjiajun 已提交
170 171
        return

J
jiangjiajun 已提交
172 173 174 175 176 177 178 179 180 181 182 183 184 185
    try:
        import paddle
        v0, v1, v2 = paddle.__version__.split('.')
        if int(v0) == 0 and int(v1) == 0 and int(v2) == 0:
            print(
                "You have installed paddlepaddle-dev? We're not sure it's working for x2paddle!"
            )
            print(
                "==================paddlepaddle>=1.6.1 is strongly recommended================="
            )
        elif int(v0) != 1 or int(v1) < 6:
            print("paddlepaddle>=1.6.1 is required")
            return
    except:
J
jiangjiajun 已提交
186
        print("paddlepaddle not installed, use \"pip install paddlepaddle\"")
J
jiangjiajun 已提交
187
        return
J
jiangjiajun 已提交
188

J
Jason 已提交
189
    assert args.framework is not None, "--framework is not defined(support tensorflow/caffe/onnx)"
190 191 192
    assert args.save_dir is not None, "--save_dir is not defined"

    if args.framework == "tensorflow":
J
jiangjiajun 已提交
193
        assert args.model is not None, "--model should be defined while translating tensorflow model"
194 195 196 197 198 199 200 201
        without_data_format_optimization = False
        define_input_shape = False
        if args.without_data_format_optimization:
            without_data_format_optimization = True
        if args.define_input_shape:
            define_input_shape = True
        tf2paddle(args.model, args.save_dir, without_data_format_optimization,
                  define_input_shape)
202 203

    elif args.framework == "caffe":
S
SunAhong1993 已提交
204 205 206
        assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model"
        caffe2paddle(args.prototxt, args.weight, args.save_dir,
                     args.caffe_proto)
C
update  
channingss 已提交
207 208 209
    elif args.framework == "onnx":
        assert args.model is not None, "--model should be defined while translating onnx model"
        onnx2paddle(args.model, args.save_dir)
210
    else:
C
update  
channingss 已提交
211
        raise Exception("--framework only support tensorflow/caffe/onnx now")
212 213 214 215


if __name__ == "__main__":
    main()