提交 e60d94b3 编写于 作者: Q qiaolongfei

correct data_type

上级 495861f5
...@@ -4,7 +4,8 @@ import paddle.v2 as paddle ...@@ -4,7 +4,8 @@ import paddle.v2 as paddle
from seqToseq_net_v2 import seqToseq_net_v2 from seqToseq_net_v2 import seqToseq_net_v2
### Data Definiation # Data Definiation.
# TODO:This code should be merged to dataset package.
data_dir = "./data/pre-wmt14" data_dir = "./data/pre-wmt14"
src_lang_dict = os.path.join(data_dir, 'src.dict') src_lang_dict = os.path.join(data_dir, 'src.dict')
trg_lang_dict = os.path.join(data_dir, 'trg.dict') trg_lang_dict = os.path.join(data_dir, 'trg.dict')
...@@ -68,15 +69,14 @@ def train_reader(file_name): ...@@ -68,15 +69,14 @@ def train_reader(file_name):
def main(): def main():
paddle.init(use_gpu=False, trainer_count=1) paddle.init(use_gpu=False, trainer_count=1)
# reader = train_reader("data/pre-wmt14/train/train")
# define network topology # define network topology
cost = seqToseq_net_v2(source_dict_dim, target_dict_dim) cost = seqToseq_net_v2(source_dict_dim, target_dict_dim)
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
optimizer = paddle.optimizer.Adam(batch_size=50, learning_rate=5e-4) optimizer = paddle.optimizer.Adam(learning_rate=1e-4)
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0: if event.batch_id % 10 == 0:
print "Pass %d, Batch %d, Cost %f, %s" % ( print "Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics) event.pass_id, event.batch_id, event.cost, event.metrics)
...@@ -93,12 +93,12 @@ def main(): ...@@ -93,12 +93,12 @@ def main():
trn_reader = paddle.reader.batched( trn_reader = paddle.reader.batched(
paddle.reader.shuffle( paddle.reader.shuffle(
train_reader("data/pre-wmt14/train/train"), buf_size=8192), train_reader("data/pre-wmt14/train/train"), buf_size=8192),
batch_size=10) batch_size=10000)
trainer.train( trainer.train(
reader=trn_reader, reader=trn_reader,
event_handler=event_handler, event_handler=event_handler,
num_passes=10000, num_passes=10,
reader_dict=reader_dict) reader_dict=reader_dict)
......
...@@ -14,7 +14,7 @@ def seqToseq_net_v2(source_dict_dim, target_dict_dim): ...@@ -14,7 +14,7 @@ def seqToseq_net_v2(source_dict_dim, target_dict_dim):
#### Encoder #### Encoder
src_word_id = layer.data( src_word_id = layer.data(
name='source_language_word', name='source_language_word',
type=data_type.dense_vector(source_dict_dim)) type=data_type.integer_value_sequence(source_dict_dim))
src_embedding = layer.embedding( src_embedding = layer.embedding(
input=src_word_id, input=src_word_id,
size=word_vector_dim, size=word_vector_dim,
...@@ -67,7 +67,7 @@ def seqToseq_net_v2(source_dict_dim, target_dict_dim): ...@@ -67,7 +67,7 @@ def seqToseq_net_v2(source_dict_dim, target_dict_dim):
trg_embedding = layer.embedding( trg_embedding = layer.embedding(
input=layer.data( input=layer.data(
name='target_language_word', name='target_language_word',
type=data_type.dense_vector(target_dict_dim)), type=data_type.integer_value_sequence(target_dict_dim)),
size=word_vector_dim, size=word_vector_dim,
param_attr=attr.ParamAttr(name='_target_language_embedding')) param_attr=attr.ParamAttr(name='_target_language_embedding'))
group_inputs.append(trg_embedding) group_inputs.append(trg_embedding)
...@@ -84,7 +84,7 @@ def seqToseq_net_v2(source_dict_dim, target_dict_dim): ...@@ -84,7 +84,7 @@ def seqToseq_net_v2(source_dict_dim, target_dict_dim):
lbl = layer.data( lbl = layer.data(
name='target_language_next_word', name='target_language_next_word',
type=data_type.dense_vector(target_dict_dim)) type=data_type.integer_value_sequence(target_dict_dim))
cost = layer.classification_cost(input=decoder, label=lbl) cost = layer.classification_cost(input=decoder, label=lbl)
return cost return cost
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册