network.py 3.0 KB
Newer Older
Y
Yu Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
import paddle.v2.framework.core as core
from paddle.v2.framework.create_op_creation_methods import op_creations
from default_scope_funcs import create_var, get_var, get_cur_scope


class NetworkFunctor(object):
    def __init__(self, func, net):
        self.func = func
        self.net = net

    def __call__(self, **kwargs):
        inputs = self.func.all_input_args
        for ipt in inputs:
            if ipt in kwargs:
                var = kwargs[ipt]
                if isinstance(var, basestring):
                    var_name = var
                    var = create_var(var)
                    self.net.var_name_map[var] = var_name
                if not isinstance(var, core.Variable):
                    raise TypeError(
                        "Input of op creation must be string or variable")

                kwargs[ipt] = self.net.var_name_map[var]

        notemp_outputs = self.func.all_not_temp_output_args

        for name in notemp_outputs:
            if name not in kwargs:
                kwargs[
                    name] = self.func.__name__ + "@OUT@%d" % self.net.generate_idx
                self.net.generate_idx += 1

        outputs = self.func.all_output_args
        for opt in outputs:
            if opt in kwargs:
                var = kwargs[opt]
                if isinstance(var, basestring):
                    var_name = var
                    var = create_var(var)
                    self.net.var_name_map[var] = var_name
                if not isinstance(var, core.Variable):
                    raise TypeError(
                        "Output of op creation must be string or variable")
                kwargs[opt] = self.net.var_name_map[var]

        op = self.func(**kwargs)

        self.net.net.add_op(op)

        lst = [get_var(kwargs[opt]) for opt in notemp_outputs]
        if len(lst) == 1:
            return lst[0]
        elif len(lst) == 0:
            return None
        else:
            return lst


class Network(object):
    def __init__(self):
        self.net = core.Net.create()
        funcs = (func_name for func_name in dir(op_creations)
                 if not func_name.startswith("__"))
        self.generate_idx = 0
        self.var_name_map = dict()

        for func_name in funcs:
            func = getattr(op_creations, func_name)
            impl = NetworkFunctor(func, self)
            setattr(self, func_name, impl.__call__)
        self.__complete_add_op__ = False

    def infer_shape(self):
Y
Yu Yang 已提交
75
        self.complete_add_op()
Y
Yu Yang 已提交
76 77
        self.net.infer_shape(get_cur_scope())

Y
Yu Yang 已提交
78 79 80 81
    def run(self, device_context):
        self.complete_add_op()
        self.net.run(get_cur_scope(), device_context)

Y
Yu Yang 已提交
82 83 84
    def __str__(self):
        return str(self.net)

Y
Yu Yang 已提交
85 86 87 88 89
    def complete_add_op(self):
        if not self.__complete_add_op__:
            self.net.complete_add_op()
            self.__complete_add_op__ = True

Y
Yu Yang 已提交
90 91 92 93 94 95 96

if __name__ == '__main__':
    net = Network()
    out = net.add_two(X="a", Y="b")
    fc_out = net.fc(X=out, W="fc.w", b="fc.b", activation="softmax")

    print str(net)