network.py 4.1 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4
import paddle.v2.framework.core as core
from paddle.v2.framework.create_op_creation_methods import op_creations
from default_scope_funcs import create_var, get_var, get_cur_scope

Y
Yu Yang 已提交
5 6
__all__ = ['Network']  # Only expose Network

Y
Yu Yang 已提交
7 8

class NetworkFunctor(object):
Y
Yu Yang 已提交
9 10 11 12 13 14 15 16 17 18 19
    """
    Network Op Creation Function. Used internally in this module.
    It convert string input to Variable. If it is not created before, just 
    create in scope.
    
    It is a functor object. means the instances are callable.
    
    :param func: The op creation function which generated in Python.
    :param net: The Network instance.
    """

Y
Yu Yang 已提交
20 21 22 23
    def __init__(self, func, net):
        self.func = func
        self.net = net

Y
Yu Yang 已提交
24 25 26
    def __call__(self, *args, **kwargs):
        if len(args) != 0:
            raise ValueError("Paddle must use keyword argument")
Y
Yu Yang 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
        inputs = self.func.all_input_args
        for ipt in inputs:
            if ipt in kwargs:
                var = kwargs[ipt]
                if isinstance(var, basestring):
                    var_name = var
                    var = create_var(var)
                    self.net.var_name_map[var] = var_name
                if not isinstance(var, core.Variable):
                    raise TypeError(
                        "Input of op creation must be string or variable")

                kwargs[ipt] = self.net.var_name_map[var]

        notemp_outputs = self.func.all_not_temp_output_args

        for name in notemp_outputs:
            if name not in kwargs:
                kwargs[
                    name] = self.func.__name__ + "@OUT@%d" % self.net.generate_idx
                self.net.generate_idx += 1

        outputs = self.func.all_output_args
        for opt in outputs:
            if opt in kwargs:
                var = kwargs[opt]
                if isinstance(var, basestring):
                    var_name = var
                    var = create_var(var)
                    self.net.var_name_map[var] = var_name
                if not isinstance(var, core.Variable):
                    raise TypeError(
                        "Output of op creation must be string or variable")
                kwargs[opt] = self.net.var_name_map[var]

        op = self.func(**kwargs)

        self.net.net.add_op(op)

        lst = [get_var(kwargs[opt]) for opt in notemp_outputs]
        if len(lst) == 1:
            return lst[0]
        elif len(lst) == 0:
            return None
        else:
            return lst


class Network(object):
Y
Yu Yang 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
    """
    The network concept. It avoid user to manually create operator, create 
    variable, and combine them into a Net. Just use Network.xxx can create the
    operator, create variables in default scope, and add them into `self.net`.
    
    For example:
    
    ..  code-block: python
    
        net = Network()
        out = net.add_two(X="a", Y="b")
        fc_out = net.fc(X="out", W="fc.w")
        
        net.run(...)
    """

Y
Yu Yang 已提交
92 93 94 95 96 97 98
    def __init__(self):
        self.net = core.Net.create()
        funcs = (func_name for func_name in dir(op_creations)
                 if not func_name.startswith("__"))
        self.generate_idx = 0
        self.var_name_map = dict()

Y
Yu Yang 已提交
99 100 101
        # TODO(yuyang18): This code can work, but do not generate a good
        # docstring, try to give a better way generate function in runtime
        # later.
Y
Yu Yang 已提交
102 103 104 105 106 107 108
        for func_name in funcs:
            func = getattr(op_creations, func_name)
            impl = NetworkFunctor(func, self)
            setattr(self, func_name, impl.__call__)
        self.__complete_add_op__ = False

    def infer_shape(self):
Y
Yu Yang 已提交
109
        self.complete_add_op()
Y
Yu Yang 已提交
110 111
        self.net.infer_shape(get_cur_scope())

Y
Yu Yang 已提交
112 113 114 115
    def run(self, device_context):
        self.complete_add_op()
        self.net.run(get_cur_scope(), device_context)

Y
Yu Yang 已提交
116 117 118
    def __str__(self):
        return str(self.net)

Y
Yu Yang 已提交
119 120 121 122 123
    def complete_add_op(self):
        if not self.__complete_add_op__:
            self.net.complete_add_op()
            self.__complete_add_op__ = True

Y
Yu Yang 已提交
124 125 126 127 128

if __name__ == '__main__':
    net = Network()
    out = net.add_two(X="a", Y="b")
    fc_out = net.fc(X=out, W="fc.w", b="fc.b", activation="softmax")
Y
Yu Yang 已提交
129 130
    net.complete_add_op()
    print net