test_train_w2v.py 3.6 KB
Newer Older
Z
Zeyu Chen 已提交
1 2 3 4 5 6 7
# coding=utf-8
from __future__ import print_function
from __future__ import division
from __future__ import print_function

import paddle
import paddle.fluid as fluid
W
wuzewu 已提交
8
import paddlehub as hub
Z
Zeyu Chen 已提交
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
import unittest
import os

EMBED_SIZE = 64
HIDDEN_SIZE = 256
N = 5
BATCH_SIZE = 1
PASS_NUM = 100

word_dict = paddle.dataset.imikolov.build_dict()
dict_size = len(word_dict)

_MOCK_DATA = [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]


def mock_data():
    for d in _MOCK_DATA:
        yield d


batch_reader = paddle.batch(mock_data, BATCH_SIZE)


def word2vec(words, is_sparse):
    embed_first = fluid.layers.embedding(
        input=words[0],
        size=[dict_size, EMBED_SIZE],
        dtype='float32',
        is_sparse=is_sparse,
        param_attr='embedding')
    embed_second = fluid.layers.embedding(
        input=words[1],
        size=[dict_size, EMBED_SIZE],
        dtype='float32',
        is_sparse=is_sparse,
        param_attr='embedding')
    embed_third = fluid.layers.embedding(
        input=words[2],
        size=[dict_size, EMBED_SIZE],
        dtype='float32',
        is_sparse=is_sparse,
        param_attr='embedding')
    embed_fourth = fluid.layers.embedding(
        input=words[3],
        size=[dict_size, EMBED_SIZE],
        dtype='float32',
        is_sparse=is_sparse,
        param_attr='embedding')

    concat_emb = fluid.layers.concat(
        input=[embed_first, embed_second, embed_third, embed_fourth], axis=1)
    hidden1 = fluid.layers.fc(input=concat_emb, size=HIDDEN_SIZE, act='sigmoid')
    predict_word = fluid.layers.fc(input=hidden1, size=dict_size, act='softmax')

    # declare later than predict word
    next_word = fluid.layers.data(name='nextw', shape=[1], dtype='int64')

    cost = fluid.layers.cross_entropy(input=predict_word, label=next_word)
    avg_cost = fluid.layers.mean(cost)

    return avg_cost


def train():
    place = fluid.CPUPlace()

    first_word = fluid.layers.data(name='firstw', shape=[1], dtype='int64')
    second_word = fluid.layers.data(name='secondw', shape=[1], dtype='int64')
    third_word = fluid.layers.data(name='thirdw', shape=[1], dtype='int64')
    forth_word = fluid.layers.data(name='fourthw', shape=[1], dtype='int64')
    next_word = fluid.layers.data(name='nextw', shape=[1], dtype='int64')

    word_list = [first_word, second_word, third_word, forth_word, next_word]
    avg_cost = word2vec(word_list, is_sparse=True)

    main_program = fluid.default_main_program()
    startup_program = fluid.default_startup_program()

    sgd_optimizer = fluid.optimizer.SGDOptimizer(learning_rate=1e-3)

    sgd_optimizer.minimize(avg_cost)
    exe = fluid.Executor(place)
    exe.run(startup_program)  # initialization

    for epoch in range(0, PASS_NUM):
        for mini_batch in batch_reader():
            feed_var_list = [
                main_program.global_block().var("firstw"),
                main_program.global_block().var("secondw"),
                main_program.global_block().var("thirdw"),
                main_program.global_block().var("fourthw"),
                main_program.global_block().var("nextw")
            ]
            feeder = fluid.DataFeeder(feed_list=feed_var_list, place=place)
            cost = exe.run(
                main_program,
                feed=feeder.feed(mini_batch),
                fetch_list=[avg_cost])
107
        print("Epoch {} Cost = {}".format(epoch, cost[0]))
Z
Zeyu Chen 已提交
108

109
    model_dir = "./tmp/w2v_model"
Z
Zeyu Chen 已提交
110 111 112 113 114 115 116 117
    var_list_to_saved = [main_program.global_block().var("embedding")]
    print("saving model to %s" % model_dir)
    fluid.io.save_vars(
        executor=exe, dirname="./w2v_model/", vars=var_list_to_saved)


if __name__ == "__main__":
    train()