提交 3f513630 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!149 Use TFRecordDataset instead of StorageDataset in Bert model integration...

!149 Use TFRecordDataset instead of StorageDataset in Bert model integration test and add absolute position embedding code in bert model
Merge pull request !149 from yoonlee666/master
......@@ -165,6 +165,7 @@ class EmbeddingPostprocessor(nn.Cell):
def __init__(self,
embedding_size,
embedding_shape,
use_relative_positions=False,
use_token_type=False,
token_type_vocab_size=16,
use_one_hot_embeddings=False,
......@@ -192,6 +193,13 @@ class EmbeddingPostprocessor(nn.Cell):
self.layernorm = nn.LayerNorm(embedding_size)
self.dropout = nn.Dropout(1 - dropout_prob)
self.gather = P.GatherV2()
self.use_relative_positions = use_relative_positions
self.slice = P.Slice()
self.full_position_embeddings = Parameter(initializer
(TruncatedNormal(initializer_range),
[max_position_embeddings,
embedding_size]),
name='full_position_embeddings')
def construct(self, token_type_ids, word_embeddings):
output = word_embeddings
......@@ -206,6 +214,11 @@ class EmbeddingPostprocessor(nn.Cell):
token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0)
token_type_embeddings = self.reshape(token_type_embeddings, self.shape)
output += token_type_embeddings
if not self.use_relative_positions:
_, seq, width = self.shape
position_embeddings = self.slice(self.full_position_embeddings, [0, 0], [seq, width])
position_embeddings = self.reshape(position_embeddings, (1, seq, width))
output += position_embeddings
output = self.layernorm(output)
output = self.dropout(output)
return output
......@@ -853,6 +866,7 @@ class BertModel(nn.Cell):
self.bert_embedding_postprocessor = EmbeddingPostprocessor(
embedding_size=self.embedding_size,
embedding_shape=output_embedding_shape,
use_relative_positions=config.use_relative_positions,
use_token_type=True,
token_type_vocab_size=config.type_vocab_size,
use_one_hot_embeddings=use_one_hot_embeddings,
......
......@@ -103,9 +103,9 @@ def me_de_train_dataset():
"""test me de train dataset"""
# apply repeat operations
repeat_count = 1
ds = de.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids",
ds = de.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids",
"next_sentence_labels", "masked_lm_positions",
"masked_lm_ids", "masked_lm_weights"])
"masked_lm_ids", "masked_lm_weights"], shuffle=False)
type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册