create_programs.py 2.9 KB
Newer Older
R
eol  
rensilin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#!/usr/bin/env python
#-*- coding:utf-8 -*-

from __future__ import print_function, division
import os
import sys
import paddle
from paddle import fluid
import yaml

def print_help(this_name):
    """Print help
    """
    dirname = os.path.dirname(this_name)
    print("Usage: {} <network building filename> [model_dir]\n".format(this_name))
    print("    example: {} {}".format(this_name, os.path.join(dirname, 'example.py')))

R
rensilin 已提交
18 19 20


def inference_warpper(filename):
R
eol  
rensilin 已提交
21 22 23 24 25 26
    """Build inference network(without loss and optimizer)
    Args:
        filename: path of file which defined real inference function
    Returns:
        list<Variable>: inputs
        and
R
rensilin 已提交
27
        list<Variable>: outputs
R
eol  
rensilin 已提交
28
    """
R
rensilin 已提交
29
    
R
eol  
rensilin 已提交
30 31 32
    with open(filename, 'r') as f:
        code = f.read()
    compiled = compile(code, filename, 'exec')
R
rensilin 已提交
33 34 35 36
    
    scope = dict()
    exec(compiled, scope)
    return scope['inference']()
R
eol  
rensilin 已提交
37 38 39 40 41 42 43 44 45 46 47

def main(argv):
    """Create programs
    Args:
        argv: arg list, length should be 2
    """
    if len(argv) < 2 or not os.path.exists(argv[1]):
        print_help(argv[0])
        exit(1)
    network_build_file = argv[1]

R
rensilin 已提交
48
    if len(argv) > 2:
R
eol  
rensilin 已提交
49 50 51 52 53 54 55
        model_dir = argv[2]
    else:
        model_dir = './model'

    main_program = fluid.Program()
    startup_program = fluid.Program()
    with fluid.program_guard(main_program, startup_program):
R
rensilin 已提交
56
        inputs, outputs = inference_warpper(network_build_file)
R
eol  
rensilin 已提交
57 58 59

        test_program = main_program.clone(for_test=True)

R
rensilin 已提交
60 61 62 63 64 65
        labels = list()
        losses = list()
        for output in outputs:
            label = fluid.layers.data(name='label_' + output.name, shape=output.shape, dtype='float32')
            loss = fluid.layers.square_error_cost(input=output, label=label)
            loss = fluid.layers.mean(loss, name='loss_' + output.name)
R
eol  
rensilin 已提交
66

R
rensilin 已提交
67 68 69 70
            labels.append(label)
            losses.append(loss)

        loss_all = fluid.layers.sum(losses)
R
eol  
rensilin 已提交
71
        optimizer = fluid.optimizer.SGD(learning_rate=1.0)
R
rensilin 已提交
72
        params_grads = optimizer.backward(loss_all)
R
eol  
rensilin 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86

    if not os.path.exists(model_dir):
        os.mkdir(model_dir)

    programs = {
        'startup_program': startup_program,
        'main_program': main_program,
        'test_program': test_program,
    }
    for save_path, program in programs.items():
        with open(os.path.join(model_dir, save_path), 'w') as f:
            f.write(program.desc.serialize_to_string())

    model_desc_path = os.path.join(model_dir, 'model.yaml')
R
rensilin 已提交
87 88 89 90 91
    model_desc = {
        'inputs': [{"name": var.name, "shape": var.shape} for var in inputs],
        'outputs': [{"name": var.name, "shape": var.shape, "label_name": label.name, "loss_name": loss.name} for var, label, loss in zip(outputs, labels, losses)],
        'loss_all': loss_all.name,
    }
R
eol  
rensilin 已提交
92 93 94 95 96 97 98
    
    with open(model_desc_path, 'w') as f:
        yaml.safe_dump(model_desc, f, encoding='utf-8', allow_unicode=True)


if __name__ == "__main__":
    main(sys.argv)