diff --git a/word_embedding/network_conf.py b/word_embedding/network_conf.py index e8c7b5cb153467e3de14c48e74c29246211833a8..6443ec62abc5ecac99eed3172871ca86a87d9350 100644 --- a/word_embedding/network_conf.py +++ b/word_embedding/network_conf.py @@ -5,7 +5,7 @@ import math import paddle.v2 as paddle -def network_conf(hidden_size, embed_size, dict_size): +def network_conf(is_train, hidden_size, embed_size, dict_size): def word_embed(in_layer): ''' word embedding layer ''' word_embed = paddle.layer.table_projection( @@ -44,20 +44,20 @@ def network_conf(hidden_size, embed_size, dict_size): param_attr=paddle.attr.Param( initial_std=1. / math.sqrt(embed_size * 8), learning_rate=1)) - cost = paddle.layer.hsigmoid( - input=hidden_layer, - label=target_word, - num_classes=dict_size, - param_attr=paddle.attr.Param(name='sigmoid_w'), - bias_attr=paddle.attr.Param(name='sigmoid_b')) - - with paddle.layer.mixed( - size=dict_size - 1, - act=paddle.activation.Sigmoid(), - bias_attr=paddle.attr.Param(name='sigmoid_b')) as prediction: - prediction += paddle.layer.trans_full_matrix_projection( - input=hidden_layer, param_attr=paddle.attr.Param(name='sigmoid_w')) - - input_data_lst = ['firstw', 'secondw', 'thirdw', 'fourthw', 'fifthw'] - - return input_data_lst, cost, prediction + if is_train == True: + cost = paddle.layer.hsigmoid( + input=hidden_layer, + label=target_word, + num_classes=dict_size, + param_attr=paddle.attr.Param(name='sigmoid_w'), + bias_attr=paddle.attr.Param(name='sigmoid_b')) + return cost + else: + with paddle.layer.mixed( + size=dict_size - 1, + act=paddle.activation.Sigmoid(), + bias_attr=paddle.attr.Param(name='sigmoid_b')) as prediction: + prediction += paddle.layer.trans_full_matrix_projection( + input=hidden_layer, + param_attr=paddle.attr.Param(name='sigmoid_w')) + return prediction diff --git a/word_embedding/predict_v2.py b/word_embedding/predict_v2.py index 5dfa928c9f9b0014790a60b687c25c3a39fa2dc4..d523b89a4bb60b9a7dc481e5275d28abbfe84d69 100644 --- a/word_embedding/predict_v2.py +++ b/word_embedding/predict_v2.py @@ -7,6 +7,16 @@ import gzip def decode_res(infer_res, dict_size): + """ + Inferring probabilities are orginized as a complete binary tree. + The actual labels are leaves (indices are counted from class number). + This function travels paths decoded from inferring results. + If the probability >0.5 then go to right child, otherwise go to left child. + + param infer_res: inferring result + param dict_size: class number + return predict_lbls: actual class + """ predict_lbls = [] infer_res = infer_res > 0.5 for i, probs in enumerate(infer_res): @@ -20,47 +30,47 @@ def decode_res(infer_res, dict_size): idx = idx * 2 + 2 # right child else: idx = idx * 2 + 1 # left child + predict_lbl = result - dict_size predict_lbls.append(predict_lbl) return predict_lbls def main(): - paddle.init(use_gpu=False, trainer_count=4) - word_dict = paddle.dataset.imikolov.build_dict() + paddle.init(use_gpu=False, trainer_count=1) + word_dict = paddle.dataset.imikolov.build_dict(typo_freq=2) dict_size = len(word_dict) - _, _, prediction = network_conf( - hidden_size=256, embed_size=32, dict_size=dict_size) + prediction = network_conf( + is_train=False, hidden_size=256, embed_size=32, dict_size=dict_size) print('Load model ....') with gzip.open('./models/model_pass_00000.tar.gz') as f: parameters = paddle.parameters.Parameters.from_tar(f) - ins_num = 10 - ins_lst = [] - ins_lbls = [] + ins_num = 10 # total 10 instance for prediction + ins_lst = [] # input data - ins_buffer = paddle.reader.shuffle( - lambda: paddle.dataset.imikolov.train(word_dict, 5)(), - buf_size=1000) + ins_iter = paddle.dataset.imikolov.test(word_dict, 5) - for ins in ins_buffer(): - ins_lst.append(ins[:-1]) - ins_lbls.append(ins[-1]) - if len(ins_lst) >= ins_num: break + for ins in ins_iter(): + ins_lst.append(ins[:-1]) + if len(ins_lst) >= ins_num: break - infer_res = paddle.infer( - output_layer=prediction, parameters=parameters, input=ins_lst) + infer_res = paddle.infer( + output_layer=prediction, parameters=parameters, input=ins_lst) - idx_word_dict = dict((v, k) for k, v in word_dict.items()) + idx_word_dict = dict((v, k) for k, v in word_dict.items()) - predict_lbls = decode_res(infer_res, dict_size) - predict_words = [idx_word_dict[lbl] for lbl in predict_lbls] - gt_words = [idx_word_dict[lbl] for lbl in ins_lbls] + predict_lbls = decode_res(infer_res, dict_size) + predict_words = [idx_word_dict[lbl] for lbl in predict_lbls] # map to word - for i, ins in enumerate(ins_lst): - print idx_word_dict[ins[0]] + ' ' + idx_word_dict[ins[1]] + \ - ' -> ' + predict_words[i] + ' ( ' + gt_words[i] + ' )' + # Ouput format: word1 word2 word3 word4 -> predict label + for i, ins in enumerate(ins_lst): + print idx_word_dict[ins[0]] + ' ' + \ + idx_word_dict[ins[1]] + ' ' + \ + idx_word_dict[ins[2]] + ' ' + \ + idx_word_dict[ins[3]] + ' ' + \ + ' -> ' + predict_words[i] if __name__ == '__main__': diff --git a/word_embedding/train_v2.py b/word_embedding/train_v2.py index 4cb028b22aabc540304b584b0f9a6541ad2bc9ec..529aa965da4cf4d30cb3cab1d3e5a27297164051 100644 --- a/word_embedding/train_v2.py +++ b/word_embedding/train_v2.py @@ -8,10 +8,10 @@ import gzip def main(): paddle.init(use_gpu=False, trainer_count=1) - word_dict = paddle.dataset.imikolov.build_dict() + word_dict = paddle.dataset.imikolov.build_dict(typo_freq=2) dict_size = len(word_dict) - input_data_lst, cost, prediction = network_conf( - hidden_size=256, embed_size=32, dict_size=dict_size) + cost = network_conf( + is_train=True, hidden_size=256, embed_size=32, dict_size=dict_size) def event_handler(event): if isinstance(event, paddle.event.EndPass): @@ -28,8 +28,15 @@ def main(): print "Pass %d, Batch %d, Cost %f" % ( event.pass_id, event.batch_id, event.cost) - feeding = dict(zip(input_data_lst, xrange(len(input_data_lst)))) - parameters = paddle.parameters.create([cost, prediction]) + feeding = { + 'firstw': 0, + 'secondw': 1, + 'thirdw': 2, + 'fourthw': 3, + 'fifthw': 4 + } + + parameters = paddle.parameters.create(cost) adam_optimizer = paddle.optimizer.Adam( learning_rate=3e-3, regularization=paddle.optimizer.L2Regularization(8e-4))