test_rnn_layer.py 5.6 KB
Newer Older
Q
qiaolongfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
# Copyright PaddlePaddle contributors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import difflib
import unittest

import paddle.trainer_config_helpers as conf_helps
import paddle.v2.activation as activation
import paddle.v2.data_type as data_type
import paddle.v2.layer as layer
from paddle.trainer_config_helpers.config_parser_utils import \
    parse_network_config as parse_network


class RNNTest(unittest.TestCase):
    def test_simple_rnn(self):
        dict_dim = 10
        word_dim = 8
        hidden_dim = 8

        def parse_old_rnn():
            def step(y):
                mem = conf_helps.memory(name="rnn_state", size=hidden_dim)
                out = conf_helps.fc_layer(
                    input=[y, mem],
                    size=hidden_dim,
                    act=activation.Tanh(),
                    bias_attr=True,
                    name="rnn_state")
                return out

            def test():
                data = conf_helps.data_layer(name="word", size=dict_dim)
                embd = conf_helps.embedding_layer(input=data, size=word_dim)
                conf_helps.recurrent_group(name="rnn", step=step, input=embd)

            return str(parse_network(test))

        def parse_new_rnn():
            def new_step(y):
                mem = layer.memory(name="rnn_state", size=hidden_dim)
                out = layer.fc(input=[y, mem],
                               size=hidden_dim,
                               act=activation.Tanh(),
                               bias_attr=True,
                               name="rnn_state")
                return out

            data = layer.data(
                name="word", type=data_type.integer_value(dict_dim))
            embd = layer.embedding(input=data, size=word_dim)
            rnn_layer = layer.recurrent_group(
                name="rnn", step=new_step, input=embd)
            return str(layer.parse_network(rnn_layer))

        diff = difflib.unified_diff(parse_old_rnn().splitlines(1),
                                    parse_new_rnn().splitlines(1))
        print ''.join(diff)

    def test_sequence_rnn_multi_input(self):
        dict_dim = 10
        word_dim = 8
        hidden_dim = 8
        label_dim = 3

        def parse_old_rnn():
            def step(y, wid):
                z = conf_helps.embedding_layer(input=wid, size=word_dim)
                mem = conf_helps.memory(name="rnn_state", size=hidden_dim)
                out = conf_helps.fc_layer(
                    input=[y, z, mem],
                    size=hidden_dim,
                    act=conf_helps.TanhActivation(),
                    bias_attr=True,
                    name="rnn_state")
                return out

            def test():
                data = conf_helps.data_layer(name="word", size=dict_dim)
                label = conf_helps.data_layer(name="label", size=label_dim)
                emb = conf_helps.embedding_layer(input=data, size=word_dim)
                out = conf_helps.recurrent_group(
                    name="rnn", step=step, input=[emb, data])

                rep = conf_helps.last_seq(input=out)
                prob = conf_helps.fc_layer(
                    size=label_dim,
                    input=rep,
                    act=conf_helps.SoftmaxActivation(),
                    bias_attr=True)

                conf_helps.outputs(
                    conf_helps.classification_cost(
                        input=prob, label=label))

            return str(parse_network(test))

        def parse_new_rnn():
Y
Yu Yang 已提交
109 110 111 112 113 114 115 116 117 118 119
            data = layer.data(
                name="word", type=data_type.dense_vector(dict_dim))
            label = layer.data(
                name="label", type=data_type.dense_vector(label_dim))
            emb = layer.embedding(input=data, size=word_dim)

            boot_layer = layer.data(
                name="boot", type=data_type.dense_vector(10))

            boot_layer = layer.fc(name='wtf', input=boot_layer, size=10)

Q
qiaolongfei 已提交
120 121
            def step(y, wid):
                z = layer.embedding(input=wid, size=word_dim)
Y
Yu Yang 已提交
122 123
                mem = layer.memory(
                    name="rnn_state", size=hidden_dim, boot_layer=boot_layer)
Q
qiaolongfei 已提交
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
                out = layer.fc(input=[y, z, mem],
                               size=hidden_dim,
                               act=activation.Tanh(),
                               bias_attr=True,
                               name="rnn_state")
                return out

            out = layer.recurrent_group(
                name="rnn", step=step, input=[emb, data])

            rep = layer.last_seq(input=out)
            prob = layer.fc(size=label_dim,
                            input=rep,
                            act=activation.Softmax(),
                            bias_attr=True)

            cost = layer.classification_cost(input=prob, label=label)

            return str(layer.parse_network(cost))

Y
Yu Yang 已提交
144 145 146 147 148
        with open("/Users/baidu/old.out", 'w') as f:
            print >> f, parse_old_rnn()
        with open("/Users/baidu/new.out", "w") as f:
            print >> f, parse_new_rnn()
        # print ''.join(diff)
Q
qiaolongfei 已提交
149 150 151 152


if __name__ == '__main__':
    unittest.main()