test_dynamic_recurrent_op.py 5.4 KB
Newer Older
1
import logging
Q
Qiao Longfei 已提交
2
import paddle.v2.fluid.core as core
3
import unittest
Q
Qiao Longfei 已提交
4
from paddle.v2.fluid.op import Operator, DynamicRecurrentOp
5 6
import numpy as np

7 8 9 10 11 12
# for siplicity, just one level LoD
lod_py = [[0, 4, 7, 9, 10]]
input_dim = 30
num_sents = len(lod_py[0]) - 1
weight_dim = 15

13 14

def create_tensor(scope, name, shape, np_data):
D
Dong Zhihong 已提交
15
    tensor = scope.var(name).get_tensor()
16 17 18 19 20
    tensor.set_dims(shape)
    tensor.set(np_data, core.CPUPlace())
    return tensor


21 22 23 24 25 26 27 28 29 30 31
class PyRNNStep(object):
    def __init__(self):

        self.x = np.random.normal(size=(lod_py[0][-1],
                                        input_dim)).astype("float32")
        self.W = np.random.normal(size=(input_dim, input_dim)).astype("float32")
        self.U = np.random.normal(size=(input_dim, input_dim)).astype("float32")
        self.h_boot = np.random.normal(size=(num_sents,
                                             input_dim)).astype("float32")


32 33 34 35 36 37 38 39 40 41 42
class DynamicRecurrentOpTest(unittest.TestCase):
    '''
    Test RNNOp

    equation:
        h_t = \sigma (W x_t + U h_{t-1})
    weights:
        - W
        - U
    vars:
        - x
43
    states:
44 45 46 47 48
        - h
    outputs:
       - h
    '''

49
    py = PyRNNStep()
50 51 52 53 54 55 56 57

    def forward(self):
        self.scope = core.Scope()
        self.create_global_variables()
        self.create_rnn_op()
        self.create_step_net()
        ctx = core.DeviceContext.create(core.CPUPlace())
        self.rnnop.run(self.scope, ctx)
58
        state = self.rnnop.get_state("h@state")
59 60 61 62 63 64
        print 'state size: ', state.size()

        step_inputs = self.rnnop.get_step_input("x")
        print "x size ", step_inputs.size()
        for i in range(step_inputs.size()):
            print "x %d" % i, np.array(step_inputs.read(i).get_dims())
65
        step_outputs = self.rnnop.get_step_output('h@state')
66
        print 'step_outputs.size ', step_outputs.size()
67
        output = self.scope.find_var("h@state").get_tensor()
68 69 70 71
        print 'output', np.array(output).shape

    def create_global_variables(self):
        # create inlink
72 73 74 75 76 77 78
        x_tensor = create_tensor(self.scope, "x", [num_sents, input_dim],
                                 self.py.x)
        x_tensor.set_lod(lod_py)
        create_tensor(self.scope, "W", [input_dim, input_dim], self.py.W)
        create_tensor(self.scope, "U", [input_dim, input_dim], self.py.U)
        create_tensor(self.scope, "h_boot", [num_sents, input_dim],
                      self.py.h_boot)
D
Dong Zhihong 已提交
79
        self.scope.var("step_scopes")
80
        self.scope.var("h@state")
81 82 83 84 85

    def create_rnn_op(self):
        # create RNNOp
        self.rnnop = DynamicRecurrentOp(
            # inputs
86 87 88
            inputs=["x"],
            initial_states=["h_boot"],
            step_net="step_unit",
89
            # outputs
90
            outputs=["h@state"],
91 92
            step_scopes="step_scopes",
            # attributes
93 94
            ex_states=["h@pre"],
            states=["h@state"])
95 96

    def create_step_net(self):
97
        step_unit = core.Net.create()
98 99 100
        x_fc_op = Operator("mul", X="x", Y="W", Out="Wx")
        h_fc_op = Operator("mul", X="h@pre", Y="U", Out="Uh")
        sum_op = Operator("sum", X=["Wx", "Uh"], Out="sum")
101
        sig_op = Operator("sigmoid", X="sum", Y="h@state")
102 103

        for op in [x_fc_op, h_fc_op, sum_op, sig_op]:
104 105 106
            step_unit.append_op(op)
        step_unit.complete_add_op(True)
        self.rnnop.set_step_unit(step_unit)
107 108 109 110 111 112 113

    def test_forward(self):
        print 'test recurrent op forward'
        pd_output = self.forward()
        print 'pd_output', pd_output


114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
class RecurrentGradientOpTest(unittest.TestCase):
    py = PyRNNStep()

    def create_forward_op(self):
        # create RNNOp
        self.forward_op = DynamicRecurrentOp(
            # inputs
            inputs=["x"],
            initial_states=["h_boot"],
            step_net="step_unit",
            # outputs
            outputs=["h@state"],
            step_scopes="step_scopes",
            # attributes
            ex_states=["h@pre"],
            states=["h@state"])

    def create_gradient_op(self):
        a = set()
        backward_op = core.DynamicRecurrentOp.backward(self.forward_op, a)

    def create_step_net(self):
        step_unit = core.Net.create()
        x_fc_op = Operator("mul", X="x", Y="W", Out="Wx")
        h_fc_op = Operator("mul", X="h@pre", Y="U", Out="Uh")
        sum_op = Operator("sum", X=["Wx", "Uh"], Out="sum")
        sig_op = Operator("sigmoid", X="sum", Y="h@state")

        for op in [x_fc_op, h_fc_op, sum_op, sig_op]:
            step_unit.append_op(op)
        step_unit.complete_add_op(True)
        self.forward_op.set_step_unit(step_unit)

    def create_global_variables(self):
        # create inlink
        x_tensor = create_tensor(self.scope, "x", [num_sents, input_dim],
                                 self.py.x)
        x_tensor.set_lod(lod_py)
        create_tensor(self.scope, "W", [input_dim, input_dim], self.py.W)
        create_tensor(self.scope, "U", [input_dim, input_dim], self.py.U)
        create_tensor(self.scope, "h_boot", [num_sents, input_dim],
                      self.py.h_boot)
        self.scope.var("step_scopes")
        self.scope.var("h@state")

    def test_grad(self):
        self.scope = core.Scope()
        self.create_forward_op()
        self.create_global_variables()
        self.create_step_net()
        self.create_gradient_op()


167
if __name__ == '__main__':
Q
QI JUN 已提交
168 169 170
    exit(
        0
    )  # FIXME(qijun): https://github.com/PaddlePaddle/Paddle/issues/5101#issuecomment-339814957
171
    unittest.main()