convert.py 3.4 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
S
SunAhong1993 已提交
14

15 16
from six import text_type as _text_type
import argparse
17

J
jiangjiajun 已提交
18

19 20 21 22 23 24 25
def arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model",
                        "-m",
                        type=_text_type,
                        default=None,
                        help="model file path")
S
SunAhong1993 已提交
26
    parser.add_argument("--prototxt",
27 28 29
                        "-p",
                        type=_text_type,
                        default=None,
S
SunAhong1993 已提交
30
                        help="prototxt file of caffe model")
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
    parser.add_argument("--weight",
                        "-w",
                        type=_text_type,
                        default=None,
                        help="weight file of caffe model")
    parser.add_argument("--save_dir",
                        "-s",
                        type=_text_type,
                        default=None,
                        help="path to save translated model")
    parser.add_argument("--framework",
                        "-f",
                        type=_text_type,
                        default=None,
                        help="define which deeplearning framework")
S
SunAhong1993 已提交
46 47 48 49 50
    parser.add_argument("--caffe_proto",
                        "-c",
                        type=_text_type,
                        default=None,
                        help="caffe proto file of caffe model")
51
    return parser
J
jiangjiajun 已提交
52

53

J
jiangjiajun 已提交
54 55
def tf2paddle(model_path, save_dir):
    from x2paddle.decoder.tf_decoder import TFDecoder
J
jiangjiajun 已提交
56
    from x2paddle.op_mapper.tf_op_mapper import TFOpMapper
J
jiangjiajun 已提交
57 58

    print("Now translating model from tensorflow to paddle.")
J
jiangjiajun 已提交
59 60 61 62
    model = TFDecoder(model_path)
    mapper = TFOpMapper(model)
    mapper.run()
    mapper.save_python_model(save_dir)
63 64


S
SunAhong1993 已提交
65
def caffe2paddle(proto, weight, save_dir, caffe_proto):
J
jiangjiajun 已提交
66 67
    from x2paddle.decoder.caffe_decoder import CaffeDecoder
    from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper
J
jiangjiajun 已提交
68 69

    print("Now translating model from caffe to paddle.")
S
SunAhong1993 已提交
70
    model = CaffeDecoder(proto, weight, caffe_proto)
J
jiangjiajun 已提交
71 72 73
    mapper = CaffeOpMapper(model)
    mapper.run()
    mapper.save_python_model(save_dir)
74 75 76 77 78 79 80 81 82 83


def main():
    parser = arg_parser()
    args = parser.parse_args()

    assert args.framework is not None, "--from is not defined(tensorflow/caffe)"
    assert args.save_dir is not None, "--save_dir is not defined"

    if args.framework == "tensorflow":
J
jiangjiajun 已提交
84
        assert args.model is not None, "--model should be defined while translating tensorflow model"
85 86 87
        tf2paddle(args.model, args.save_dir)

    elif args.framework == "caffe":
S
SunAhong1993 已提交
88 89 90
        assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model"
        caffe2paddle(args.prototxt, args.weight, args.save_dir,
                     args.caffe_proto)
91 92 93 94 95 96 97

    else:
        raise Exception("--framework only support tensorflow/caffe now")


if __name__ == "__main__":
    main()