“5738b633136a41923b4ba75e6b1a160d08539c99”上不存在“drivers/video/omap2/dss/manager-sysfs.c”
提交 53e7c177 编写于 作者: Y ying

refine comments.

上级 78e29521
# Transformer
Set the model and training configurations in `config.py`, and execute `python train.py` to train.
More details to be added.
# Represent the dict sizes of source and target language. The dict from the
# dataset here used includes the <bos>, <eos> and <unk> token but exlcudes
# the <pad> token. It should plus 1 to include the padding token when used as
# the size of lookup table.
src_vocab_size = 10000
trg_vocab_size = 10000
# Represent the id of <pad> token in source language.
src_pad_idx = src_vocab_size
# Represent the id of <pad> token in target language.
trg_pad_idx = trg_vocab_size
# Represent the position value corresponding to the <pad> token.
pos_pad_idx = 0
# Represent the max length of sequences. It should plus 1 to include position
# padding token for position encoding.
max_length = 50
# Represent the epoch number to train.
pass_num = 2
# Represent the number of sequences contained in a mini-batch.
batch_size = 64
# Reprent the params for Adam optimizer.
learning_rate = 0.001
beta1 = 0.9
beta2 = 0.98
eps = 1e-9
# Represent the dimension of embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model = 512
# Represent the size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 1024
# Represent the dimension keys are projected to for dot-product attention.
d_key = 64
# Represent the dimension values are projected to for dot-product attention.
d_value = 64
# Represent the number of head used in multi-head attention.
n_head = 8
# Represent the number of sub-layers to be stacked in the encoder and decoder.
n_layer = 6
# Represent the dropout rate used by all dropout layers.
dropout = 0.1
# Names of position encoding table which will be initialized in external.
pos_enc_param_names = ("src_pos_enc_table", "trg_pos_enc_table")
# Names of all data layers listed in order.
input_data_names = ("src_word", "src_pos", "trg_word", "trg_pos",
"src_slf_attn_bias", "trg_slf_attn_bias",
"trg_src_attn_bias", "lbl_word")
# Attention is All You Need: A Paddle Fluid implementation
This is a Paddle Fluid implementation of the Transformer model in [Attention is All You Need]() (Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin, arxiv, 2017).
If you use the dataset/code in your research, please cite the paper:
```text
@inproceedings{vaswani2017attention,
title={Attention is all you need},
author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and Uszkoreit, Jakob and Jones, Llion and Gomez, Aidan N and Kaiser, {\L}ukasz and Polosukhin, Illia},
booktitle={Advances in Neural Information Processing Systems},
pages={6000--6010},
year={2017}
}
```
### TODO
This project is still under active development.
class TrainTaskConfig(object):
use_gpu = False
# the epoch number to train.
pass_num = 2
# number of sequences contained in a mini-batch.
batch_size = 64
# the hyper params for Adam optimizer.
learning_rate = 0.001
beta1 = 0.9
beta2 = 0.98
eps = 1e-9
class ModelHyperParams(object):
# Dictionary size for source and target language. This model directly uses
# paddle.dataset.wmt16 in which <bos>, <eos> and <unk> token has
# alreay been added, but the <pad> token is not added. Transformer requires
# sequences in a mini-batch are padded to have the same length. A <pad> token is
# added into the original dictionary in paddle.dateset.wmt16.
# size of source word dictionary.
src_vocab_size = 10000
# index for <pad> token in source language.
src_pad_idx = src_vocab_size
# size of target word dictionay
trg_vocab_size = 10000
# index for <pad> token in target language.
trg_pad_idx = trg_vocab_size
# position value corresponding to the <pad> token.
pos_pad_idx = 0
# max length of sequences. It should plus 1 to include position
# padding token for position encoding.
max_length = 50
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model = 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 1024
# the dimension that keys are projected to for dot-product attention.
d_key = 64
# the dimension that values are projected to for dot-product attention.
d_value = 64
# number of head used in multi-head attention.
n_head = 8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer = 6
# dropout rate used by all dropout layers.
dropout = 0.1
# Names of position encoding table which will be initialized externally.
pos_enc_param_names = (
"src_pos_enc_table",
"trg_pos_enc_table", )
# Names of all data layers listed in order.
input_data_names = (
"src_word",
"src_pos",
"trg_word",
"trg_pos",
"src_slf_attn_bias",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"lbl_word", )
from functools import partial
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import paddle.v2.fluid.layers as layers
# TODO: Remove out the batch_size from the model.
from config import batch_size, input_data_names, pos_enc_param_names
from config import TrainTaskConfig, input_data_names, pos_enc_param_names
# FIXME(guosheng): Remove out the batch_size from the model.
batch_size = TrainTaskConfig.batch_size
def position_encoding_init(n_position, d_pos_vec):
"""
"""
Generate the initial values for the sinusoid position encoding table.
"""
position_enc = np.array([[
......@@ -30,15 +34,16 @@ def multi_head_attention(queries,
num_heads=1,
dropout_rate=0.):
"""
Multi-Head Attention. Note that attn_bias will be to add to the logit to
affect the attention weights.
Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that
they will not considered in attention weights.
"""
if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
raise ValueError(
"Inputs quries, keys and values should all be 3-D tensors.")
"Inputs: quries, keys and values should all be 3-D tensors.")
def __compute_qkv(queries, keys, values, num_heads, d_key, d_value):
"""
"""
Add linear projection to queries, keys, and values.
"""
q = layers.fc(input=queries,
......@@ -66,7 +71,7 @@ def multi_head_attention(queries,
return x
hidden_size = x.shape[-1]
# TODO: Decouple the program desc with batch_size.
# FIXME(guosheng): Decouple the program desc with batch_size.
reshaped = layers.reshape(
x=x, shape=[batch_size, -1, num_heads, hidden_size // num_heads])
......@@ -84,7 +89,7 @@ def multi_head_attention(queries,
raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
# TODO: Decouple the program desc with batch_size.
# FIXME(guosheng): Decouple the program desc with batch_size.
return layers.reshape(
x=trans_x,
shape=map(int,
......@@ -95,12 +100,15 @@ def multi_head_attention(queries,
Scaled Dot-Product Attention
"""
# TODO: Optimize the shape in reshape_op or softmax_op.
# The softmax_op only supports 2D tensor currently and cann't be used
# here. Additionally, the reshape_op cann't be used here, since the
# shape of product inferred in compile-time is not the actual shape in
# run-time and cann't be used to set the attribute of reshape_op. Thus,
# define the softmax temporarily.
# FIXME(guosheng): Optimize the shape in reshape_op or softmax_op.
# The current implementation of softmax_op only supports 2D tensor,
# consequently it cannot be directly used here.
# If to use the reshape_op, Besides, the shape of product inferred in
# compile-time is not the actual shape in run-time. It cann't be used
# to set the attribute of reshape_op.
# So, here define the softmax for temporary solution.
def __softmax(x, eps=1e-9):
exp_out = layers.exp(x=x)
sum_out = layers.reduce_sum(exp_out, dim=-1, keep_dim=False)
......@@ -136,9 +144,9 @@ def multi_head_attention(queries,
def positionwise_feed_forward(x, d_inner_hid, d_hid):
"""
Position-wise Feed-Forward Networks.
This consists of two linear transformations with a ReLU activation in
between, which is applied to each position separately and identically.
Position-wise Feed-Forward Networks.
This module consists of two linear transformations with a ReLU activation
in between, which is applied to each position separately and identically.
"""
hidden = layers.fc(input=x,
size=d_inner_hid,
......@@ -150,17 +158,18 @@ def positionwise_feed_forward(x, d_inner_hid, d_hid):
def pre_post_process_layer(prev_out, out, process_cmd, dropout=0.):
"""
Add residual connection, layer normalization and droput on the out tensor
Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd.
This will be used before or after multi-head attention and position-wise
feed-forward networks.
"""
for cmd in process_cmd:
if cmd == "a":
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
elif cmd == "n":
elif cmd == "n": # add layer normalization
out = layers.layer_norm(out, begin_norm_axis=len(out.shape) - 1)
elif cmd == "d":
elif cmd == "d": # add dropout
if dropout:
out = layers.dropout(out, dropout_prob=dropout, is_test=False)
return out
......@@ -179,10 +188,11 @@ def prepare_encoder(src_word,
dropout=0.,
pos_pad_idx=0,
pos_enc_param_name=None):
"""
Add word embeddings and position encodings and output a tensor with shape
"""Add word embeddings and position encodings.
The output tensor has a shape of:
[batch_size, max_src_length_in_batch, d_model].
This is used at the bottom of the encoder stacks.
This module is used at the bottom of the encoder stacks.
"""
src_word_emb = layers.embedding(
src_word, size=[src_vocab_size, src_emb_dim], padding_idx=src_pad_idx)
......@@ -193,7 +203,8 @@ def prepare_encoder(src_word,
param_attr=fluid.ParamAttr(
name=pos_enc_param_name, trainable=False))
enc_input = src_word_emb + src_pos_enc
# TODO: Decouple the program desc with batch_size
# FIXME(guosheng): Decouple the program desc with batch_size.
enc_input = layers.reshape(x=enc_input, shape=[batch_size, -1, src_emb_dim])
return layers.dropout(
enc_input, dropout_prob=dropout,
......@@ -206,142 +217,247 @@ prepare_decoder = partial(
prepare_encoder, pos_enc_param_name=pos_enc_param_names[1])
def encoder_layer(enc_input, attn_bias, n_head, d_key, d_value, d_model,
d_inner_hid, dropout):
"""
The layer to be stacked in the encoder.
This consits of multi-head (self) attention followed by position-wise
feed-forward networks and both the two components companied with the
post_process_layer to add residual connection, layer normalization and
droput.
def encoder_layer(enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.):
"""The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
attn_output = multi_head_attention(enc_input, enc_input, enc_input,
attn_bias, d_key, d_value, d_model,
n_head, dropout)
attn_output = post_process_layer(enc_input, attn_output, "dan", dropout)
n_head, dropout_rate)
attn_output = post_process_layer(enc_input, attn_output, "dan",
dropout_rate)
ffd_output = positionwise_feed_forward(attn_output, d_inner_hid, d_model)
output = post_process_layer(attn_output, ffd_output, "dan", dropout)
return output
def encoder(enc_input, attn_bias, n_layer, n_head, d_key, d_value, d_model,
d_inner_hid, dropout):
return post_process_layer(attn_output, ffd_output, "dan", dropout_rate)
def encoder(enc_input,
attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.):
"""
The encoder is composed of a stack of identical encoder_layer layers.
The encoder is composed of a stack of identical layers returned by calling
encoder_layer.
"""
for i in range(n_layer):
enc_output = encoder_layer(enc_input, attn_bias, n_head, d_key, d_value,
d_model, d_inner_hid, dropout)
d_model, d_inner_hid, dropout_rate)
enc_input = enc_output
return enc_output
def decoder_layer(dec_input, enc_output, slf_attn_bias, dec_enc_attn_bias,
n_head, d_key, d_value, d_model, d_inner_hid, dropout):
def decoder_layer(dec_input,
enc_output,
slf_attn_bias,
dec_enc_attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.):
""" The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except
a multi-head attention is added to implement encoder-decoder attention.
"""
The layer to be stacked in the decoder. The structure of this is similar to
the encoder_layer but another multi-head attention is added to implement
encoder-decoder attention.
"""
slf_attn_output = multi_head_attention(dec_input, dec_input, dec_input,
slf_attn_bias, d_key, d_value,
d_model, n_head, dropout)
slf_attn_output = post_process_layer(dec_input, slf_attn_output, "dan",
dropout)
enc_attn_output = multi_head_attention(slf_attn_output, enc_output,
enc_output, dec_enc_attn_bias, d_key,
d_value, d_model, n_head, dropout)
enc_attn_output = post_process_layer(slf_attn_output, enc_attn_output,
"dan", dropout)
ffd_output = positionwise_feed_forward(enc_attn_output, d_inner_hid,
d_model)
dec_output = post_process_layer(enc_attn_output, ffd_output, "dan", dropout)
slf_attn_output = multi_head_attention(
dec_input,
dec_input,
dec_input,
slf_attn_bias,
d_key,
d_value,
d_model,
n_head,
dropout_rate, )
slf_attn_output = post_process_layer(
dec_input,
slf_attn_output,
"dan", # residual connection + dropout + layer normalization
dropout_rate, )
enc_attn_output = multi_head_attention(
slf_attn_output,
enc_output,
enc_output,
dec_enc_attn_bias,
d_key,
d_value,
d_model,
n_head,
dropout_rate, )
enc_attn_output = post_process_layer(
slf_attn_output,
enc_attn_output,
"dan", # residual connection + dropout + layer normalization
dropout_rate, )
ffd_output = positionwise_feed_forward(
enc_attn_output,
d_inner_hid,
d_model, )
dec_output = post_process_layer(
enc_attn_output,
ffd_output,
"dan", # residual connection + dropout + layer normalization
dropout_rate, )
return dec_output
def decoder(dec_input, enc_output, dec_slf_attn_bias, dec_enc_attn_bias,
n_layer, n_head, d_key, d_value, d_model, d_inner_hid, dropout):
def decoder(dec_input,
enc_output,
dec_slf_attn_bias,
dec_enc_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate=0.):
"""
The decoder is composed of a stack of identical decoder_layer layers.
"""
for i in range(n_layer):
dec_output = decoder_layer(dec_input, enc_output, dec_slf_attn_bias,
dec_enc_attn_bias, n_head, d_key, d_value,
d_model, d_inner_hid, dropout)
dec_output = decoder_layer(
dec_input,
enc_output,
dec_slf_attn_bias,
dec_enc_attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate, )
dec_input = dec_output
return dec_output
def transformer(src_vocab_size, trg_vocab_size, max_length, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid, dropout, src_pad_idx,
trg_pad_idx, pos_pad_idx):
# The shapes here only act as placeholder and are set to guarantee the
# success of infer-shape in compile time.
# The actual shape of src_word is:
# [batch_size * max_src_length_in_batch, 1].
def transformer(
src_vocab_size,
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate,
src_pad_idx,
trg_pad_idx,
pos_pad_idx, ):
# The shapes here act as placeholder.
# The shapes set here is to pass the infer-shape in compile time. The actual
# shape of src_word in run time is:
# [batch_size * max_src_length_in_a_batch, 1].
src_word = layers.data(
name=input_data_names[0],
shape=[batch_size * max_length, 1],
dtype="int64",
append_batch_size=False)
# The actual shape of src_pos is:
# [batch_size * max_src_length_in_batch, 1].
# The actual shape of src_pos in runtime is:
# [batch_size * max_src_length_in_a_batch, 1].
src_pos = layers.data(
name=input_data_names[1],
shape=[batch_size * max_length, 1],
dtype="int64",
append_batch_size=False)
# The actual shape of trg_word is:
# [batch_size * max_trg_length_in_batch, 1].
# The actual shape of trg_word is in runtime is:
# [batch_size * max_trg_length_in_a_batch, 1].
trg_word = layers.data(
name=input_data_names[2],
shape=[batch_size * max_length, 1],
dtype="int64",
append_batch_size=False)
# The actual shape of trg_pos is:
# [batch_size * max_trg_length_in_batch, 1].
# The actual shape of trg_pos in runtime is:
# [batch_size * max_trg_length_in_a_batch, 1].
trg_pos = layers.data(
name=input_data_names[3],
shape=[batch_size * max_length, 1],
dtype="int64",
append_batch_size=False)
# The actual shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_length_in_batch, max_src_length_in_batch].
# This is used to avoid attention on paddings.
# The actual shape of src_slf_attn_bias in runtime is:
# [batch_size, n_head, max_src_length_in_a_batch, max_src_length_in_a_batch].
# This input is used to remove attention weights on paddings.
src_slf_attn_bias = layers.data(
name=input_data_names[4],
shape=[batch_size, n_head, max_length, max_length],
dtype="float32",
append_batch_size=False)
# The actual shape of trg_slf_attn_bias is:
# The actual shape of trg_slf_attn_bias in runtime is:
# [batch_size, n_head, max_trg_length_in_batch, max_trg_length_in_batch].
# This is used to avoid attention on paddings and subsequent words.
# This is used to remove attention weights on paddings and subsequent words.
trg_slf_attn_bias = layers.data(
name=input_data_names[5],
shape=[batch_size, n_head, max_length, max_length],
dtype="float32",
append_batch_size=False)
# The actual shape of trg_src_attn_bias is:
# The actual shape of trg_src_attn_bias in runtime is:
# [batch_size, n_head, max_trg_length_in_batch, max_src_length_in_batch].
# This is used to avoid attention on paddings.
# This is used to remove attention weights on paddings.
trg_src_attn_bias = layers.data(
name=input_data_names[6],
shape=[batch_size, n_head, max_length, max_length],
dtype="float32",
append_batch_size=False)
enc_input = prepare_encoder(src_word, src_pos, src_vocab_size, d_model,
src_pad_idx, max_length, dropout)
enc_output = encoder(enc_input, src_slf_attn_bias, n_layer, n_head, d_key,
d_value, d_model, d_inner_hid, dropout)
dec_input = prepare_decoder(trg_word, trg_pos, trg_vocab_size, d_model,
trg_pad_idx, max_length, dropout)
dec_output = decoder(dec_input, enc_output, trg_slf_attn_bias,
trg_src_attn_bias, n_layer, n_head, d_key, d_value,
d_model, d_inner_hid, dropout)
# TODO: Share the same weight matrix between the two embedding layers and
enc_input = prepare_encoder(
src_word,
src_pos,
src_vocab_size,
d_model,
src_pad_idx,
max_length,
dropout_rate, )
enc_output = encoder(
enc_input,
src_slf_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate, )
dec_input = prepare_decoder(
trg_word,
trg_pos,
trg_vocab_size,
d_model,
trg_pad_idx,
max_length,
dropout_rate, )
dec_output = decoder(
dec_input,
enc_output,
trg_slf_attn_bias,
trg_src_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
dropout_rate, )
# TODO(guosheng): Share the weight matrix between the embedding layers and
# the pre-softmax linear transformation.
predict = layers.reshape(
x=layers.fc(input=dec_output,
......@@ -350,13 +466,12 @@ def transformer(src_vocab_size, trg_vocab_size, max_length, n_layer, n_head,
num_flatten_dims=2),
shape=[-1, trg_vocab_size],
act="softmax")
# The actual shape of gold is:
# [batch_size * max_trg_length_in_batch, 1].
# The actual shape of gold in runtime is:
# [batch_size * max_trg_length_in_a_batch, 1].
gold = layers.data(
name=input_data_names[7],
shape=[batch_size * max_length, 1],
dtype="int64",
append_batch_size=False)
cost = layers.cross_entropy(input=predict, label=gold)
avg_cost = layers.mean(x=cost)
return avg_cost
return layers.mean(x=cost)
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
from model import transformer, position_encoding_init
from config import *
from config import TrainTaskConfig, ModelHyperParams, \
pos_enc_param_names, input_data_names
def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
......@@ -14,12 +17,12 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
"""
input_dict = {}
def pad_batch_data(insts,
pad_idx,
is_target=False,
return_pos=True,
return_attn_bias=True,
return_max_len=True):
def __pad_batch_data(insts,
pad_idx,
is_target=False,
return_pos=True,
return_attn_bias=True,
return_max_len=True):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
......@@ -66,14 +69,14 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
tensor.set(data_list[i], place)
input_dict[name_list[i]] = tensor
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
src_word, src_pos, src_slf_attn_bias, src_max_len = __pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, is_target=False)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = __pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, is_target=True)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32")
lbl_word = pad_batch_data([inst[2] for inst in insts], trg_pad_idx, False,
False, False, False)
lbl_word = __pad_batch_data([inst[2] for inst in insts], trg_pad_idx, False,
False, False, False)
data_to_tensor([
src_word, src_pos, trg_word, trg_pos, src_slf_attn_bias,
......@@ -84,22 +87,30 @@ def prepare_batch_input(insts, input_data_names, src_pad_idx, trg_pad_idx,
def main():
avg_cost = transformer(src_vocab_size + 1, trg_vocab_size + 1,
max_length + 1, n_layer, n_head, d_key, d_value,
d_model, d_inner_hid, dropout, src_pad_idx,
trg_pad_idx, pos_pad_idx)
avg_cost = transformer(
ModelHyperParams.src_vocab_size + 1,
ModelHyperParams.trg_vocab_size + 1, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.dropout, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx)
optimizer = fluid.optimizer.Adam(
learning_rate=learning_rate, beta1=beta1, beta2=beta2, epsilon=eps)
learning_rate=TrainTaskConfig.learning_rate,
beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps)
optimizer.minimize(avg_cost)
train_data = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt16.train(src_vocab_size, trg_vocab_size),
buf_size=1000),
batch_size=batch_size)
paddle.dataset.wmt16.train(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
buf_size=51200),
batch_size=TrainTaskConfig.batch_size)
place = fluid.CPUPlace()
place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# Initialize the parameters.
......@@ -108,21 +119,21 @@ def main():
pos_enc_param = fluid.global_scope().find_var(
pos_enc_param_name).get_tensor()
pos_enc_param.set(
position_encoding_init(max_length + 1, d_model), place)
batch_id = 0
for pass_id in xrange(pass_num):
for data in train_data():
data_input = prepare_batch_input(data, input_data_names,
src_pad_idx, trg_pad_idx,
max_length, n_head, place)
position_encoding_init(ModelHyperParams.max_length + 1,
ModelHyperParams.d_model), place)
for pass_id in xrange(TrainTaskConfig.pass_num):
for batch_id, data in enumerate(train_data()):
data_input = prepare_batch_input(
data, input_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length,
ModelHyperParams.n_head, place)
outs = exe.run(fluid.framework.default_main_program(),
feed=data_input,
fetch_list=[avg_cost])
avg_cost_val = np.array(outs[0])
print("pass_id=" + str(pass_id) + " batch=" + str(batch_id) +
" avg_cost=" + str(avg_cost_val))
batch_id += 1
print("pass_id = " + str(pass_id) + " batch = " + str(batch_id) +
" avg_cost = " + str(avg_cost_val))
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册