From 870650d8c171bbcd1e6e0c1da5b1057cf066d32b Mon Sep 17 00:00:00 2001 From: "Yang Yang(Tony)" Date: Wed, 8 Nov 2017 00:50:15 -0800 Subject: [PATCH] Static lstm sanity check (#5365) * add fill_constant_batch_size_like_op to rnn h_boot * first commit * merge develop; fix conflict * update to main_program --- .../fill_constant_batch_size_like_op.cc | 4 +- paddle/operators/lstm_unit_op.cc | 8 +- python/paddle/v2/framework/layers.py | 72 +++++++++++- .../tests/test_understand_sentiment_lstm.py | 107 ++++++++++++++++++ 4 files changed, 182 insertions(+), 9 deletions(-) create mode 100644 python/paddle/v2/framework/tests/test_understand_sentiment_lstm.py diff --git a/paddle/operators/fill_constant_batch_size_like_op.cc b/paddle/operators/fill_constant_batch_size_like_op.cc index f86ee3c3d88..85871ebbfcd 100644 --- a/paddle/operators/fill_constant_batch_size_like_op.cc +++ b/paddle/operators/fill_constant_batch_size_like_op.cc @@ -75,10 +75,10 @@ class FillConstantBatchSizeLikeOpMaker "with the specified value"); AddAttr>("shape", "(vector) The shape of the output"); AddAttr("input_dim_idx", - "(int, default 0) the index of input's batch size dimension") + "(int, default 0) The index of input's batch size dimension") .SetDefault(0); AddAttr("output_dim_idx", - "(int, default 0) the index of output's batch size dimension") + "(int, default 0) The index of output's batch size dimension") .SetDefault(0); AddAttr("value", "(float, default 0) The value to be filled") .SetDefault(0.0f); diff --git a/paddle/operators/lstm_unit_op.cc b/paddle/operators/lstm_unit_op.cc index f4519ec16f3..18b9cdf2a39 100644 --- a/paddle/operators/lstm_unit_op.cc +++ b/paddle/operators/lstm_unit_op.cc @@ -34,10 +34,10 @@ class LstmUnitOp : public framework::OperatorWithKernel { auto c_prev_dims = ctx->GetInputDim("C_prev"); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); - PADDLE_ENFORCE(x_dims[0] == c_prev_dims[0], - "Batch size of inputs and states must be equal"); - PADDLE_ENFORCE(x_dims[1] == c_prev_dims[1] * 4, - "Dimension of FC should equal to prev state * 4"); + PADDLE_ENFORCE_EQ(x_dims[0], c_prev_dims[0], + "Batch size of inputs and states must be equal"); + PADDLE_ENFORCE_EQ(x_dims[1], c_prev_dims[1] * 4, + "Dimension of FC should equal to prev state * 4"); int b_size = c_prev_dims[0]; // batch size int s_dim = c_prev_dims[1]; // state dim diff --git a/python/paddle/v2/framework/layers.py b/python/paddle/v2/framework/layers.py index d42af89eaea..f1c09af8edf 100644 --- a/python/paddle/v2/framework/layers.py +++ b/python/paddle/v2/framework/layers.py @@ -134,9 +134,7 @@ def _create_op_func_(op_type): o_name = not_intermediate_outputs[0].name intermediate_output_names = [output.name for output in intermediate_outputs] - def func(**kwargs): - helper = LayerHelper(op_type, **kwargs) - inputs = dict() + def infer_and_check_data_type(op_proto, **kwargs): dtype = None for ipt in op_proto.inputs: name = _convert_(ipt.name) @@ -153,6 +151,20 @@ def _create_op_func_(op_type): elif dtype != each.data_type: raise ValueError( "operator {0} must input same dtype".format(op_type)) + + return dtype + + def func(**kwargs): + helper = LayerHelper(op_type, **kwargs) + + dtype = infer_and_check_data_type(op_proto, **kwargs) + + inputs = dict() + for ipt in op_proto.inputs: + name = _convert_(ipt.name) + val = kwargs.pop(name, []) + if not isinstance(val, list) and not isinstance(val, tuple): + val = [val] inputs[ipt.name] = val outputs = dict() @@ -178,6 +190,20 @@ _create_op_func_('reshape') _create_op_func_('elementwise_add') _create_op_func_('sigmoid') _create_op_func_('scale') +_create_op_func_('reshape') +_create_op_func_('transpose') + + +def fill_constant(data_type, shape, value=None, program=None): + helper = LayerHelper('fill_constant', **locals()) + out = helper.create_tmp_variable(dtype=data_type) + helper.append_op( + type='fill_constant', + outputs={'Out': [out]}, + attrs={'data_type': data_type, + 'shape': shape, + 'value': value}) + return out def cast(x, data_type, main_program=None): @@ -762,6 +788,46 @@ class StaticRNN(object): }) +def lstm(x, + c_pre_init, + hidden_dim, + forget_bias=None, + main_program=None, + startup_program=None): + helper = LayerHelper('lstm_unit', **locals()) + rnn = StaticRNN() + with rnn.step(): + c_pre = rnn.memory(init=c_pre_init) + x_t = rnn.step_input(x) + + before_fc = concat( + input=[x_t, c_pre], + axis=1, + main_program=main_program, + startup_program=startup_program) + after_fc = fc(input=before_fc, + size=hidden_dim * 4, + main_program=main_program, + startup_program=startup_program) + + data_type = x.data_type + c = helper.create_tmp_variable(data_type) + h = helper.create_tmp_variable(data_type) + + helper.append_op( + type='lstm_unit', + inputs={"X": after_fc, + "C_prev": c_pre}, + outputs={"C": c, + "H": h}, + attrs={"forget_bias": forget_bias}) + + rnn.update_memory(c_pre, c) + rnn.output(h) + + return rnn() + + def lod_rank_table(x, level=0, main_program=None): helper = LayerHelper("lod_rank_table", **locals()) table = helper.create_variable( diff --git a/python/paddle/v2/framework/tests/test_understand_sentiment_lstm.py b/python/paddle/v2/framework/tests/test_understand_sentiment_lstm.py new file mode 100644 index 00000000000..26cbd01bc04 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_understand_sentiment_lstm.py @@ -0,0 +1,107 @@ +import paddle.v2 as paddle +import paddle.v2.framework.layers as layers +import paddle.v2.framework.core as core +import paddle.v2.framework.optimizer as optimizer + +from paddle.v2.framework.framework import g_main_program, g_startup_program +from paddle.v2.framework.executor import Executor + +import numpy as np + + +def lstm_net(dict_dim, class_dim=2, emb_dim=32, seq_len=80, batch_size=50): + data = layers.data( + name="words", + shape=[seq_len * batch_size, 1], + append_batch_size=False, + data_type="int64") + label = layers.data( + name="label", + shape=[batch_size, 1], + append_batch_size=False, + data_type="int64") + + emb = layers.embedding(input=data, size=[dict_dim, emb_dim]) + emb = layers.reshape(x=emb, shape=[batch_size, seq_len, emb_dim]) + emb = layers.transpose(x=emb, axis=[1, 0, 2]) + + c_pre_init = layers.fill_constant( + dtype=emb.data_type, shape=[batch_size, emb_dim], value=0.0) + layer_1_out = layers.lstm(emb, c_pre_init=c_pre_init, hidden_dim=emb_dim) + layer_1_out = layers.transpose(x=layer_1_out, axis=[1, 0, 2]) + + prediction = layers.fc(input=layer_1_out, size=class_dim, act="softmax") + cost = layers.cross_entropy(input=prediction, label=label) + + avg_cost = layers.mean(x=cost) + adam_optimizer = optimizer.AdamOptimizer(learning_rate=0.002) + opts = adam_optimizer.minimize(avg_cost) + acc = layers.accuracy(input=prediction, label=label) + + return avg_cost, acc + + +def to_lodtensor(data, place): + seq_lens = [len(seq) for seq in data] + cur_len = 0 + lod = [cur_len] + for l in seq_lens: + cur_len += l + lod.append(cur_len) + flattened_data = np.concatenate(data, axis=0).astype("int64") + flattened_data = flattened_data.reshape([len(flattened_data), 1]) + res = core.LoDTensor() + res.set(flattened_data, place) + res.set_lod([lod]) + return res + + +def chop_data(data, chop_len=80, batch_len=50): + data = [(x[0][:chop_len], x[1]) for x in data if len(x[0]) >= chop_len] + + return data[:batch_len] + + +def prepare_feed_data(data, place): + tensor_words = to_lodtensor(map(lambda x: x[0], data), place) + + label = np.array(map(lambda x: x[1], data)).astype("int64") + label = label.reshape([50, 1]) + tensor_label = core.LoDTensor() + tensor_label.set(label, place) + + return tensor_words, tensor_label + + +def main(): + word_dict = paddle.dataset.imdb.word_dict() + cost, acc = lstm_net(dict_dim=len(word_dict), class_dim=2) + + batch_size = 100 + train_data = paddle.batch( + paddle.reader.buffered( + paddle.dataset.imdb.train(word_dict), size=batch_size * 10), + batch_size=batch_size) + + data = chop_data(next(train_data())) + + place = core.CPUPlace() + tensor_words, tensor_label = prepare_feed_data(data, place) + exe = Executor(place) + exe.run(g_startup_program) + + while True: + outs = exe.run(g_main_program, + feed={"words": tensor_words, + "label": tensor_label}, + fetch_list=[cost, acc]) + cost_val = np.array(outs[0]) + acc_val = np.array(outs[1]) + + print("cost=" + str(cost_val) + " acc=" + str(acc_val)) + if acc_val > 0.9: + break + + +if __name__ == '__main__': + main() -- GitLab