diff --git a/python/paddle/v2/tests/test_layer.py b/python/paddle/v2/tests/test_layer.py index b600e8cf765122ab6cfe8530465391c92be0590f..73d769a3582441a8e9d831820c2c36804ce453c9 100644 --- a/python/paddle/v2/tests/test_layer.py +++ b/python/paddle/v2/tests/test_layer.py @@ -51,12 +51,57 @@ class CostLayerTest(unittest.TestCase): cost10 = layer.sum_cost(input=inference) cost11 = layer.huber_cost(input=score, label=label) - print dir(layer) - layer.parse_network(cost1, cost2) - print dir(layer) - #print layer.parse_network(cost3, cost4) - #print layer.parse_network(cost5, cost6) - #print layer.parse_network(cost7, cost8, cost9, cost10, cost11) + print layer.parse_network(cost1, cost2) + print layer.parse_network(cost3, cost4) + print layer.parse_network(cost5, cost6) + print layer.parse_network(cost7, cost8, cost9, cost10, cost11) + + +class RNNTest(unittest.TestCase): + def test_simple_rnn(self): + dict_dim = 10 + word_dim = 8 + hidden_dim = 8 + + def test_old_rnn(): + def step(y): + mem = conf_helps.memory(name="rnn_state", size=hidden_dim) + out = conf_helps.fc_layer( + input=[y, mem], + size=hidden_dim, + act=activation.Tanh(), + bias_attr=True, + name="rnn_state") + return out + + def test(): + data1 = conf_helps.data_layer(name="word", size=dict_dim) + embd = conf_helps.embedding_layer(input=data1, size=word_dim) + conf_helps.recurrent_group(name="rnn", step=step, input=embd) + + return str(parse_network(test)) + + def test_new_rnn(): + def new_step(y): + mem = layer.memory(name="rnn_state", size=hidden_dim) + out = layer.fc(input=[mem], + step_input=y, + size=hidden_dim, + act=activation.Tanh(), + bias_attr=True, + name="rnn_state") + return out.to_proto(dict()) + + data1 = layer.data( + name="word", type=data_type.integer_value(dict_dim)) + embd = layer.embedding(input=data1, size=word_dim) + rnn_layer = layer.recurrent_group( + name="rnn", step=new_step, input=embd) + return str(layer.parse_network(rnn_layer)) + + diff = difflib.unified_diff(test_old_rnn().splitlines(1), + test_new_rnn().splitlines(1)) + print ''.join(diff) if __name__ == '__main__':