未验证 提交 65036c6c 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #650 from junjun315/04-stuff

update to low level api--04 word2vec,test=develop
......@@ -202,40 +202,32 @@ dream that one day <e>
首先,加载所需要的包:
```python
import paddle
import paddle as paddle
import paddle.fluid as fluid
import six
import numpy
from functools import partial
import math
import os
import six
import sys
from __future__ import print_function
try:
from paddle.fluid.contrib.trainer import *
from paddle.fluid.contrib.inferencer import *
except ImportError:
print(
"In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib",
file=sys.stderr)
from paddle.fluid.trainer import *
from paddle.fluid.inferencer import *
```
然后,定义参数:
```python
EMBED_SIZE = 32 # word vector dimension
HIDDEN_SIZE = 256 # hidden layer dimension
N = 5 # train 5-gram
BATCH_SIZE = 32 # batch size
EMBED_SIZE = 32
HIDDEN_SIZE = 256
N = 5
BATCH_SIZE = 100
PASS_NUM = 100
# can use CPU or GPU
use_cuda = os.getenv('WITH_GPU', '0') != '0'
use_cuda = False # set to True if training with GPU
word_dict = paddle.dataset.imikolov.build_dict()
dict_size = len(word_dict)
```
更大的`BATCH_SIZE`将使得训练更快收敛,但也会消耗更多内存。由于词向量计算规模较大,如果环境允许,请开启使用GPU进行训练,能更快得到结果。
不同于之前的PaddlePaddle v2版本,在新的Fluid版本里,我们不必再手动计算词向量。PaddlePaddle提供了一个内置的方法`fluid.layers.embedding`,我们就可以直接用它来构造 N-gram 神经网络。
- 我们来定义我们的 N-gram 神经网络结构。这个结构在训练和预测中都会使用到。因为词向量比较稀疏,我们传入参数 `is_sparse == True`, 可以加速稀疏矩阵的更新。
......
......@@ -244,40 +244,32 @@ dream that one day <e>
首先,加载所需要的包:
```python
import paddle
import paddle as paddle
import paddle.fluid as fluid
import six
import numpy
from functools import partial
import math
import os
import six
import sys
from __future__ import print_function
try:
from paddle.fluid.contrib.trainer import *
from paddle.fluid.contrib.inferencer import *
except ImportError:
print(
"In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib",
file=sys.stderr)
from paddle.fluid.trainer import *
from paddle.fluid.inferencer import *
```
然后,定义参数:
```python
EMBED_SIZE = 32 # word vector dimension
HIDDEN_SIZE = 256 # hidden layer dimension
N = 5 # train 5-gram
BATCH_SIZE = 32 # batch size
EMBED_SIZE = 32
HIDDEN_SIZE = 256
N = 5
BATCH_SIZE = 100
PASS_NUM = 100
# can use CPU or GPU
use_cuda = os.getenv('WITH_GPU', '0') != '0'
use_cuda = False # set to True if training with GPU
word_dict = paddle.dataset.imikolov.build_dict()
dict_size = len(word_dict)
```
更大的`BATCH_SIZE`将使得训练更快收敛,但也会消耗更多内存。由于词向量计算规模较大,如果环境允许,请开启使用GPU进行训练,能更快得到结果。
不同于之前的PaddlePaddle v2版本,在新的Fluid版本里,我们不必再手动计算词向量。PaddlePaddle提供了一个内置的方法`fluid.layers.embedding`,我们就可以直接用它来构造 N-gram 神经网络。
- 我们来定义我们的 N-gram 神经网络结构。这个结构在训练和预测中都会使用到。因为词向量比较稀疏,我们传入参数 `is_sparse == True`, 可以加速稀疏矩阵的更新。
......
......@@ -15,28 +15,15 @@ from __future__ import print_function
import paddle as paddle
import paddle.fluid as fluid
import six
import sys
try:
from paddle.fluid.contrib.trainer import *
from paddle.fluid.contrib.inferencer import *
except ImportError:
print(
"In the fluid 1.0, the trainer and inferencer are moving to paddle.fluid.contrib",
file=sys.stderr)
from paddle.fluid.trainer import *
from paddle.fluid.inferencer import *
import numpy
import sys
from functools import partial
import math
import os
EMBED_SIZE = 32
HIDDEN_SIZE = 256
N = 5
BATCH_SIZE = 100
BATCH_SIZE = 32
PASS_NUM = 100
use_cuda = False # set to True if training with GPU
......@@ -44,32 +31,28 @@ word_dict = paddle.dataset.imikolov.build_dict()
dict_size = len(word_dict)
def inference_program(is_sparse):
first_word = fluid.layers.data(name='firstw', shape=[1], dtype='int64')
second_word = fluid.layers.data(name='secondw', shape=[1], dtype='int64')
third_word = fluid.layers.data(name='thirdw', shape=[1], dtype='int64')
fourth_word = fluid.layers.data(name='fourthw', shape=[1], dtype='int64')
def inference_program(words, is_sparse):
embed_first = fluid.layers.embedding(
input=first_word,
input=words[0],
size=[dict_size, EMBED_SIZE],
dtype='float32',
is_sparse=is_sparse,
param_attr='shared_w')
embed_second = fluid.layers.embedding(
input=second_word,
input=words[1],
size=[dict_size, EMBED_SIZE],
dtype='float32',
is_sparse=is_sparse,
param_attr='shared_w')
embed_third = fluid.layers.embedding(
input=third_word,
input=words[2],
size=[dict_size, EMBED_SIZE],
dtype='float32',
is_sparse=is_sparse,
param_attr='shared_w')
embed_fourth = fluid.layers.embedding(
input=fourth_word,
input=words[3],
size=[dict_size, EMBED_SIZE],
dtype='float32',
is_sparse=is_sparse,
......@@ -83,11 +66,10 @@ def inference_program(is_sparse):
return predict_word
def train_program(is_sparse):
def train_program(predict_word):
# The declaration of 'next_word' must be after the invoking of inference_program,
# or the data input order of train program would be [next_word, firstw, secondw,
# thirdw, fourthw], which is not correct.
predict_word = inference_program(is_sparse)
next_word = fluid.layers.data(name='nextw', shape=[1], dtype='int64')
cost = fluid.layers.cross_entropy(input=predict_word, label=next_word)
avg_cost = fluid.layers.mean(cost)
......@@ -100,86 +82,152 @@ def optimizer_func():
regularization=fluid.regularizer.L2DecayRegularizer(8e-4))
def train(use_cuda, train_program, params_dirname):
def train(if_use_cuda, params_dirname, is_sparse=True):
place = fluid.CUDAPlace(0) if if_use_cuda else fluid.CPUPlace()
train_reader = paddle.batch(
paddle.dataset.imikolov.train(word_dict, N), BATCH_SIZE)
test_reader = paddle.batch(
paddle.dataset.imikolov.test(word_dict, N), BATCH_SIZE)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
def event_handler(event):
if isinstance(event, EndStepEvent):
outs = trainer.test(
reader=test_reader,
feed_order=['firstw', 'secondw', 'thirdw', 'fourthw', 'nextw'])
avg_cost = outs[0]
if event.step % 10 == 0:
print("Step %d: Average Cost %f" % (event.step, avg_cost))
# If average cost is lower than 5.8, we consider the model good enough to stop.
# Note 5.8 is a relatively high value. In order to get a better model, one should
# aim for avg_cost lower than 3.5. But the training could take longer time.
if avg_cost < 5.8:
trainer.save_params(params_dirname)
trainer.stop()
if math.isnan(avg_cost):
sys.exit("got NaN loss, training failed.")
trainer = Trainer(
train_func=train_program,
# optimizer=fluid.optimizer.SGD(learning_rate=0.001),
optimizer_func=optimizer_func,
place=place)
trainer.train(
reader=train_reader,
num_epochs=1,
event_handler=event_handler,
feed_order=['firstw', 'secondw', 'thirdw', 'fourthw', 'nextw'])
first_word = fluid.layers.data(name='firstw', shape=[1], dtype='int64')
second_word = fluid.layers.data(name='secondw', shape=[1], dtype='int64')
third_word = fluid.layers.data(name='thirdw', shape=[1], dtype='int64')
forth_word = fluid.layers.data(name='fourthw', shape=[1], dtype='int64')
next_word = fluid.layers.data(name='nextw', shape=[1], dtype='int64')
def infer(use_cuda, inference_program, params_dirname=None):
word_list = [first_word, second_word, third_word, forth_word, next_word]
feed_order = ['firstw', 'secondw', 'thirdw', 'fourthw', 'nextw']
main_program = fluid.default_main_program()
star_program = fluid.default_startup_program()
predict_word = inference_program(word_list, is_sparse)
avg_cost = train_program(predict_word)
test_program = main_program.clone(for_test=True)
sgd_optimizer = optimizer_func()
sgd_optimizer.minimize(avg_cost)
exe = fluid.Executor(place)
def train_test(program, reader):
count = 0
feed_var_list = [
program.global_block().var(var_name) for var_name in feed_order
]
feeder_test = fluid.DataFeeder(feed_list=feed_var_list, place=place)
test_exe = fluid.Executor(place)
accumulated = len([avg_cost]) * [0]
for test_data in reader():
avg_cost_np = test_exe.run(
program=program,
feed=feeder_test.feed(test_data),
fetch_list=[avg_cost])
accumulated = [
x[0] + x[1][0] for x in zip(accumulated, avg_cost_np)
]
count += 1
return [x / count for x in accumulated]
def train_loop():
step = 0
feed_var_list_loop = [
main_program.global_block().var(var_name) for var_name in feed_order
]
feeder = fluid.DataFeeder(feed_list=feed_var_list_loop, place=place)
exe.run(star_program)
for pass_id in range(PASS_NUM):
for data in train_reader():
avg_cost_np = exe.run(
main_program, feed=feeder.feed(data), fetch_list=[avg_cost])
if step % 10 == 0:
#outs = train_test(test_program, test_reader)
# print("Step %d: Average Cost %f" % (step, avg_cost_np[0]))
print("Step %d: Average Cost %f" % (step, avg_cost_np[0]))
# it will take a few hours.
# If average cost is lower than 5.8, we consider the model good enough to stop.
# Note 5.8 is a relatively high value. In order to get a better model, one should
# aim for avg_cost lower than 3.5. But the training could take longer time.
if avg_cost_np[0] < 5.8:
if params_dirname is not None:
fluid.io.save_inference_model(params_dirname, [
'firstw', 'secondw', 'thirdw', 'fourthw'
], [predict_word], exe)
return
step += 1
if math.isnan(float(avg_cost_np[0])):
sys.exit("got NaN loss, training failed.")
raise AssertionError("Cost is too large {0:2.2}".format(avg_cost_np[0]))
train_loop()
def infer(use_cuda, params_dirname=None):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
inferencer = Inferencer(
infer_func=inference_program, param_path=params_dirname, place=place)
# Setup inputs by creating 4 LoDTensors representing 4 words. Here each word
# is simply an index to look up for the corresponding word vector and hence
# the shape of word (base_shape) should be [1]. The length-based level of
# detail (lod) info of each LoDtensor should be [[1]] meaning there is only
# one lod_level and there is only one sequence of one word on this level.
# Note that lod info should be a list of lists.
data1 = [[211]] # 'among'
data2 = [[6]] # 'a'
data3 = [[96]] # 'group'
data4 = [[4]] # 'of'
lod = [[1]]
first_word = fluid.create_lod_tensor(data1, lod, place)
second_word = fluid.create_lod_tensor(data2, lod, place)
third_word = fluid.create_lod_tensor(data3, lod, place)
fourth_word = fluid.create_lod_tensor(data4, lod, place)
result = inferencer.infer(
{
'firstw': first_word,
'secondw': second_word,
'thirdw': third_word,
'fourthw': fourth_word
},
return_numpy=False)
print(numpy.array(result[0]))
most_possible_word_index = numpy.argmax(result[0])
print(most_possible_word_index)
print([
key for key, value in six.iteritems(word_dict)
if value == most_possible_word_index
][0])
exe = fluid.Executor(place)
inference_scope = fluid.core.Scope()
with fluid.scope_guard(inference_scope):
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[inferencer, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(params_dirname, exe)
# Setup inputs by creating 4 LoDTensors representing 4 words. Here each word
# is simply an index to look up for the corresponding word vector and hence
# the shape of word (base_shape) should be [1]. The recursive_sequence_lengths,
# which is length-based level of detail (lod) of each LoDTensor, should be [[1]]
# meaning there is only one level of detail and there is only one sequence of
# one word on this level.
# Note that recursive_sequence_lengths should be a list of lists.
data1 = [[211]] # 'among'
data2 = [[6]] # 'a'
data3 = [[96]] # 'group'
data4 = [[4]] # 'of'
lod = [[1]]
first_word = fluid.create_lod_tensor(data1, lod, place)
second_word = fluid.create_lod_tensor(data2, lod, place)
third_word = fluid.create_lod_tensor(data3, lod, place)
fourth_word = fluid.create_lod_tensor(data4, lod, place)
assert feed_target_names[0] == 'firstw'
assert feed_target_names[1] == 'secondw'
assert feed_target_names[2] == 'thirdw'
assert feed_target_names[3] == 'fourthw'
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
results = exe.run(
inferencer,
feed={
feed_target_names[0]: first_word,
feed_target_names[1]: second_word,
feed_target_names[2]: third_word,
feed_target_names[3]: fourth_word
},
fetch_list=fetch_targets,
return_numpy=False)
print(numpy.array(results[0]))
most_possible_word_index = numpy.argmax(results[0])
print(most_possible_word_index)
print([
key for key, value in six.iteritems(word_dict)
if value == most_possible_word_index
][0])
print(results[0].recursive_sequence_lengths())
np_data = numpy.array(results[0])
print("Inference Shape: ", np_data.shape)
def main(use_cuda, is_sparse):
......@@ -189,14 +237,11 @@ def main(use_cuda, is_sparse):
params_dirname = "word2vec.inference.model"
train(
use_cuda=use_cuda,
train_program=partial(train_program, is_sparse),
params_dirname=params_dirname)
infer(
use_cuda=use_cuda,
inference_program=partial(inference_program, is_sparse),
params_dirname=params_dirname)
if_use_cuda=use_cuda,
params_dirname=params_dirname,
is_sparse=is_sparse)
infer(use_cuda=use_cuda, params_dirname=params_dirname)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册