convert.py 2.8 KB
Newer Older
R
Renwb1991 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
#!/usr/bin/env python

import os
import sys
import numpy as np
import argparse

from kaffe import KaffeError, print_stderr
from kaffe.paddle import Transformer


def fatal_error(msg):
    """ fatal error encounted
    """
    print_stderr(msg)
    exit(-1)


def validate_arguments(args):
    """ validate args
    """
    if (args.data_output_path is not None) and (args.caffemodel is None):
        fatal_error('No input data path provided.')
    if (args.caffemodel is not None) and (args.data_output_path is None):
        fatal_error('No output data path provided.')
    if (args.code_output_path is None) and (args.data_output_path is None):
        fatal_error('No output path specified.')


def convert(def_path, caffemodel_path, data_output_path, code_output_path,
            phase):
    """ convert caffe model to tf/paddle models
    """
    try:
        transformer = Transformer(def_path, caffemodel_path, phase=phase)
        print_stderr('Converting data...')
        if caffemodel_path is not None:
            data = transformer.transform_data()
            print_stderr('Saving data...')
            with open(data_output_path, 'wb') as data_out:
                np.save(data_out, data)
        if code_output_path:
            print_stderr('Saving source...')
S
SunAhong1993 已提交
44
            s = sys.version
R
Renwb1991 已提交
45
            with open(code_output_path, 'wb') as src_out:
S
SunAhong1993 已提交
46 47 48 49
                if s.startswith('2'):
                    src_out.write(transformer.transform_source())
                else:
                    src_out.write(str.encode(transformer.transform_source()))
R
Renwb1991 已提交
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
        print_stderr('set env variable before using converted model '\
                'if used custom_layers:')
        custom_pk_path = os.path.dirname(os.path.abspath(__file__))
        custom_pk_path = os.path.join(custom_pk_path, 'kaffe')
        print_stderr('export CAFFE2FLUID_CUSTOM_LAYERS=%s' % (custom_pk_path))
        print_stderr('Done.')
        return 0
    except KaffeError as err:
        fatal_error('Error encountered: {}'.format(err))

    return 1


def main():
    """ main
    """
    parser = argparse.ArgumentParser()
67
    parser.add_argument('--def_path', help='Model definition (.prototxt) path')
R
Renwb1991 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
    parser.add_argument('--caffemodel', help='Model data (.caffemodel) path')
    parser.add_argument('--data-output-path', help='Converted data output path')
    parser.add_argument(
        '--code-output-path', help='Save generated source to this path')
    parser.add_argument(
        '-p',
        '--phase',
        default='test',
        help='The phase to convert: test (default) or train')
    args = parser.parse_args()
    validate_arguments(args)
    return convert(args.def_path, args.caffemodel, args.data_output_path,
                   args.code_output_path, args.phase)


if __name__ == '__main__':
    ret = main()
    sys.exit(ret)