提交 96db04b4 编写于 作者: Y yuchaojie

fix decoder loop for Transformer model

上级 6e7a38ac
......@@ -781,95 +781,22 @@ class TransformerDecoder(nn.Cell):
super(TransformerDecoder, self).__init__()
self.num_hidden_layers = num_hidden_layers
# wait to be supported
# layers = []
# for _ in range(num_hidden_layers):
# layer = DecoderCell(batch_size=batch_size,
# hidden_size=hidden_size,
# seq_length=seq_length,
# enc_seq_length=enc_seq_length,
# num_attention_heads=num_attention_heads,
# intermediate_size=intermediate_size,
# attention_probs_dropout_prob=attention_probs_dropout_prob,
# use_one_hot_embeddings=use_one_hot_embeddings,
# initializer_range=initializer_range,
# hidden_dropout_prob=hidden_dropout_prob,
# hidden_act=hidden_act,
# compute_type=compute_type)
# layers.append(layer)
# self.layers = nn.CellList(layers)
self.layer0 = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
enc_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
self.layer1 = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
enc_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
self.layer2 = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
enc_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
self.layer3 = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
enc_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
self.layer4 = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
enc_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
self.layer5 = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
enc_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
layers = []
for _ in range(num_hidden_layers):
layer = DecoderCell(batch_size=batch_size,
hidden_size=hidden_size,
seq_length=seq_length,
enc_seq_length=enc_seq_length,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
attention_probs_dropout_prob=attention_probs_dropout_prob,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=initializer_range,
hidden_dropout_prob=hidden_dropout_prob,
hidden_act=hidden_act,
compute_type=compute_type)
layers.append(layer)
self.layers = nn.CellList(layers)
self.layer_preprocess = LayerPreprocess(in_channels=hidden_size)
......@@ -880,16 +807,9 @@ class TransformerDecoder(nn.Cell):
def construct(self, input_tensor, attention_mask, enc_states, enc_attention_mask):
prev_output = self.reshape(input_tensor, self.shape)
# wait to be supported
# for layer_module in self.layers:
# layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask)
# prev_output = layer_output
prev_output = self.layer0(prev_output, attention_mask, enc_states, enc_attention_mask)
prev_output = self.layer1(prev_output, attention_mask, enc_states, enc_attention_mask)
prev_output = self.layer2(prev_output, attention_mask, enc_states, enc_attention_mask)
prev_output = self.layer3(prev_output, attention_mask, enc_states, enc_attention_mask)
prev_output = self.layer4(prev_output, attention_mask, enc_states, enc_attention_mask)
prev_output = self.layer5(prev_output, attention_mask, enc_states, enc_attention_mask)
for layer_module in self.layers:
layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask)
prev_output = layer_output
prev_output = self.layer_preprocess(prev_output)
output = self.reshape(prev_output, self.out_shape)
......
......@@ -16,6 +16,7 @@
import time
import argparse
import random
import numpy as np
import mindspore.common.dtype as mstype
......@@ -26,6 +27,7 @@ from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore.train.callback import Callback, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset.engine as de
import mindspore.communication.management as D
from mindspore.train.parallel_utils import ParallelMode
from mindspore import context
......@@ -36,6 +38,10 @@ from src.config import cfg, transformer_net_cfg
from src.dataset import create_transformer_dataset
from src.lr_schedule import create_dynamic_lr
random_seed = 1
random.seed(random_seed)
np.random.seed(random_seed)
de.config.set_seed(random_seed)
def get_ms_timestamp():
t = time.time()
......@@ -161,7 +167,4 @@ def run_transformer_train():
model.train(repeat_count, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"))
if __name__ == '__main__':
random_seed = 1
np.random.seed(random_seed)
run_transformer_train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册