network.py 3.9 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4
import paddle.v2.framework.core as core
from paddle.v2.framework.create_op_creation_methods import op_creations
from default_scope_funcs import create_var, get_var, get_cur_scope

Y
Yu Yang 已提交
5 6
__all__ = ['Network']  # Only expose Network

Y
Yu Yang 已提交
7 8

class NetworkFunctor(object):
Y
Yu Yang 已提交
9 10 11 12 13 14 15 16 17 18 19
    """
    Network Op Creation Function. Used internally in this module.
    It convert string input to Variable. If it is not created before, just 
    create in scope.
    
    It is a functor object. means the instances are callable.
    
    :param func: The op creation function which generated in Python.
    :param net: The Network instance.
    """

Y
Yu Yang 已提交
20 21 22 23
    def __init__(self, func, net):
        self.func = func
        self.net = net

Y
Yu Yang 已提交
24 25 26
    def __call__(self, *args, **kwargs):
        if len(args) != 0:
            raise ValueError("Paddle must use keyword argument")
Y
Yu Yang 已提交
27 28 29 30 31 32 33 34 35 36
        inputs = self.func.all_input_args
        for ipt in inputs:
            if ipt in kwargs:
                var = kwargs[ipt]
                if isinstance(var, basestring):
                    var = create_var(var)
                if not isinstance(var, core.Variable):
                    raise TypeError(
                        "Input of op creation must be string or variable")

37
                kwargs[ipt] = get_cur_scope().get_var_name(var)
Y
Yu Yang 已提交
38 39 40 41 42 43

        notemp_outputs = self.func.all_not_temp_output_args

        for name in notemp_outputs:
            if name not in kwargs:
                kwargs[
44 45
                    name] = self.func.__name__ + "@OUT@%d" % core.unique_integer(
                    )
Y
Yu Yang 已提交
46 47 48 49 50 51 52 53 54 55

        outputs = self.func.all_output_args
        for opt in outputs:
            if opt in kwargs:
                var = kwargs[opt]
                if isinstance(var, basestring):
                    var = create_var(var)
                if not isinstance(var, core.Variable):
                    raise TypeError(
                        "Output of op creation must be string or variable")
56
                kwargs[opt] = get_cur_scope().get_var_name(var)
Y
Yu Yang 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71

        op = self.func(**kwargs)

        self.net.net.add_op(op)

        lst = [get_var(kwargs[opt]) for opt in notemp_outputs]
        if len(lst) == 1:
            return lst[0]
        elif len(lst) == 0:
            return None
        else:
            return lst


class Network(object):
Y
Yu Yang 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
    """
    The network concept. It avoid user to manually create operator, create 
    variable, and combine them into a Net. Just use Network.xxx can create the
    operator, create variables in default scope, and add them into `self.net`.
    
    For example:
    
    ..  code-block: python
    
        net = Network()
        out = net.add_two(X="a", Y="b")
        fc_out = net.fc(X="out", W="fc.w")
        
        net.run(...)
    """

Y
Yu Yang 已提交
88 89 90 91 92
    def __init__(self):
        self.net = core.Net.create()
        funcs = (func_name for func_name in dir(op_creations)
                 if not func_name.startswith("__"))

Y
Yu Yang 已提交
93 94 95
        # TODO(yuyang18): This code can work, but do not generate a good
        # docstring, try to give a better way generate function in runtime
        # later.
Y
Yu Yang 已提交
96 97 98 99 100 101 102
        for func_name in funcs:
            func = getattr(op_creations, func_name)
            impl = NetworkFunctor(func, self)
            setattr(self, func_name, impl.__call__)
        self.__complete_add_op__ = False

    def infer_shape(self):
Y
Yu Yang 已提交
103
        self.complete_add_op()
Y
Yu Yang 已提交
104 105
        self.net.infer_shape(get_cur_scope())

Y
Yu Yang 已提交
106 107 108 109
    def run(self, device_context):
        self.complete_add_op()
        self.net.run(get_cur_scope(), device_context)

Y
Yu Yang 已提交
110 111 112
    def __str__(self):
        return str(self.net)

Y
Yu Yang 已提交
113 114 115 116 117
    def complete_add_op(self):
        if not self.__complete_add_op__:
            self.net.complete_add_op()
            self.__complete_add_op__ = True

Y
Yu Yang 已提交
118 119 120 121 122

if __name__ == '__main__':
    net = Network()
    out = net.add_two(X="a", Y="b")
    fc_out = net.fc(X=out, W="fc.w", b="fc.b", activation="softmax")
Y
Yu Yang 已提交
123 124
    net.complete_add_op()
    print net