network.py 4.0 KB
Newer Older
Y
Yu Yang 已提交
1 2
import paddle.v2.framework.core as core
from paddle.v2.framework.create_op_creation_methods import op_creations
Y
Yu Yang 已提交
3
from default_scope_funcs import new_var, find_var, get_cur_scope
Y
Yu Yang 已提交
4

Y
Yu Yang 已提交
5 6
__all__ = ['Network']  # Only expose Network

Y
Yu Yang 已提交
7 8

class NetworkFunctor(object):
Y
Yu Yang 已提交
9 10 11 12 13 14 15 16 17 18 19
    """
    Network Op Creation Function. Used internally in this module.
    It convert string input to Variable. If it is not created before, just 
    create in scope.
    
    It is a functor object. means the instances are callable.
    
    :param func: The op creation function which generated in Python.
    :param net: The Network instance.
    """

Y
Yu Yang 已提交
20 21 22 23
    def __init__(self, func, net):
        self.func = func
        self.net = net

Y
Yu Yang 已提交
24 25 26
    def __call__(self, *args, **kwargs):
        if len(args) != 0:
            raise ValueError("Paddle must use keyword argument")
Y
Yu Yang 已提交
27 28 29 30 31
        inputs = self.func.all_input_args
        for ipt in inputs:
            if ipt in kwargs:
                var = kwargs[ipt]
                if isinstance(var, basestring):
Y
Yu Yang 已提交
32 33 34 35
                    tmp = new_var(var)
                    self.net.var_names[tmp] = var
                    var = tmp

Y
Yu Yang 已提交
36 37 38 39
                if not isinstance(var, core.Variable):
                    raise TypeError(
                        "Input of op creation must be string or variable")

Y
Yu Yang 已提交
40
                kwargs[ipt] = self.net.var_names[var]
Y
Yu Yang 已提交
41 42 43 44 45 46

        notemp_outputs = self.func.all_not_temp_output_args

        for name in notemp_outputs:
            if name not in kwargs:
                kwargs[
47 48
                    name] = self.func.__name__ + "@OUT@%d" % core.unique_integer(
                    )
Y
Yu Yang 已提交
49 50 51 52 53 54

        outputs = self.func.all_output_args
        for opt in outputs:
            if opt in kwargs:
                var = kwargs[opt]
                if isinstance(var, basestring):
Y
Yu Yang 已提交
55 56 57 58
                    tmp = new_var(var)
                    self.net.var_names[tmp] = var
                    var = tmp

Y
Yu Yang 已提交
59 60 61
                if not isinstance(var, core.Variable):
                    raise TypeError(
                        "Output of op creation must be string or variable")
Y
Yu Yang 已提交
62
                kwargs[opt] = self.net.var_names[var]
Y
Yu Yang 已提交
63 64 65 66 67

        op = self.func(**kwargs)

        self.net.net.add_op(op)

Y
Yu Yang 已提交
68
        lst = [find_var(kwargs[opt]) for opt in notemp_outputs]
Y
Yu Yang 已提交
69 70 71 72 73 74 75 76 77
        if len(lst) == 1:
            return lst[0]
        elif len(lst) == 0:
            return None
        else:
            return lst


class Network(object):
Y
Yu Yang 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
    """
    The network concept. It avoid user to manually create operator, create 
    variable, and combine them into a Net. Just use Network.xxx can create the
    operator, create variables in default scope, and add them into `self.net`.
    
    For example:
    
    ..  code-block: python
    
        net = Network()
        out = net.add_two(X="a", Y="b")
        fc_out = net.fc(X="out", W="fc.w")
        
        net.run(...)
    """

Y
Yu Yang 已提交
94 95 96 97
    def __init__(self):
        self.net = core.Net.create()
        funcs = (func_name for func_name in dir(op_creations)
                 if not func_name.startswith("__"))
Y
Yu Yang 已提交
98
        self.var_names = dict()
Y
Yu Yang 已提交
99

Y
Yu Yang 已提交
100 101 102
        # TODO(yuyang18): This code can work, but do not generate a good
        # docstring, try to give a better way generate function in runtime
        # later.
Y
Yu Yang 已提交
103 104 105 106 107 108 109
        for func_name in funcs:
            func = getattr(op_creations, func_name)
            impl = NetworkFunctor(func, self)
            setattr(self, func_name, impl.__call__)
        self.__complete_add_op__ = False

    def infer_shape(self):
Y
Yu Yang 已提交
110
        self.complete_add_op()
Y
Yu Yang 已提交
111 112
        self.net.infer_shape(get_cur_scope())

Y
Yu Yang 已提交
113 114 115 116
    def run(self, device_context):
        self.complete_add_op()
        self.net.run(get_cur_scope(), device_context)

Y
Yu Yang 已提交
117 118 119
    def __str__(self):
        return str(self.net)

Y
Yu Yang 已提交
120 121 122 123 124
    def complete_add_op(self):
        if not self.__complete_add_op__:
            self.net.complete_add_op()
            self.__complete_add_op__ = True

Y
Yu Yang 已提交
125 126 127 128 129

if __name__ == '__main__':
    net = Network()
    out = net.add_two(X="a", Y="b")
    fc_out = net.fc(X=out, W="fc.w", b="fc.b", activation="softmax")
Y
Yu Yang 已提交
130 131
    net.complete_add_op()
    print net