torch2paddle.py 2.9 KB
Newer Older
Z
zhangjinchao01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert torch parameter file to paddle model files.

Note: must have torchfile installed in order to use this tool.

Usage: python torch2paddle.py -i torchfile.t7 -l layers.txt -o path/to/paddle_model
"""

import os
import sys
import struct
import numpy as np
import torchfile
import cPickle as pickle
import argparse

Q
qijun 已提交
30

Z
zhangjinchao01 已提交
31 32 33
# save parameters
def save_layer_parameters(outfile, feats):
    version = 0
Q
qijun 已提交
34
    value_size = 4
Z
zhangjinchao01 已提交
35 36 37 38 39 40 41 42 43
    ret = ""
    for feat in feats:
        ret += feat.tostring()
    size = len(ret) / 4
    fo = open(outfile, 'wb')
    fo.write(struct.pack('iIQ', version, value_size, size))
    fo.write(ret)
    fo.close()

Q
qijun 已提交
44

Z
zhangjinchao01 已提交
45 46
def save_net_parameters(layers, params, output_path):
    for i in range(len(layers)):
Q
qijun 已提交
47 48
        weight = params[i * 2]
        biases = params[i * 2 + 1]
Z
zhangjinchao01 已提交
49 50 51 52 53 54
        weight_file = os.path.join(output_path, '_%s.w0' % layers[i])
        biases_file = os.path.join(output_path, '_%s.wbias' % layers[i])
        print "Saving for layer %s." % layers[i]
        save_layer_parameters(weight_file, [weight])
        save_layer_parameters(biases_file, biases)

Q
qijun 已提交
55

Z
zhangjinchao01 已提交
56 57 58 59 60 61 62 63 64
def load_layer_parameters(filename):
    fn = open(filename, 'rb')
    version, = struct.unpack('i', fn.read(4))
    value_length, = struct.unpack("I", fn.read(4))
    dtype = 'float32' if value_length == 4 else 'float64'
    param_size, = struct.unpack("L", fn.read(8))
    value = np.fromfile(fn, dtype)
    return value

Q
qijun 已提交
65

Z
zhangjinchao01 已提交
66 67 68 69 70 71
def main(argv):
    """
    main method of converting torch to paddle files.
    :param argv:
    :return:
    """
Q
qijun 已提交
72 73 74 75
    cmdparser = argparse.ArgumentParser(
        "Convert torch parameter file to paddle model files.")
    cmdparser.add_argument(
        '-i', '--input', help='input filename of torch parameters')
Z
zhangjinchao01 已提交
76
    cmdparser.add_argument('-l', '--layers', help='list of layer names')
Q
qijun 已提交
77 78
    cmdparser.add_argument(
        '-o', '--output', help='output file path of paddle model')
Z
zhangjinchao01 已提交
79 80 81 82 83 84 85

    args = cmdparser.parse_args(argv)
    if args.input and args.layers and args.output:
        params = torchfile.load(args.input)
        layers = [line.strip() for line in open(args.layers, 'r')]
        save_net_parameters(layers, params, args.output)
    else:
Q
qijun 已提交
86 87 88 89
        print(
            'Usage: python torch2paddle.py -i torchfile.t7 -l layers.txt -o path/to/paddle_model'
        )

Z
zhangjinchao01 已提交
90 91 92

if __name__ == "__main__":
    main(sys.argv[1:])