Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ee519aa0
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ee519aa0
编写于
4月 08, 2020
作者:
Y
yoonlee666
提交者:
高东海
4月 10, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
use TFRecordDataset in bert ci script and add absolute position embedding code in bert model
上级
0ba72a68
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
16 addition
and
2 deletion
+16
-2
mindspore/model_zoo/Bert_NEZHA/bert_model.py
mindspore/model_zoo/Bert_NEZHA/bert_model.py
+14
-0
tests/st/networks/models/bert/bert_tdt_no_lossscale.py
tests/st/networks/models/bert/bert_tdt_no_lossscale.py
+2
-2
未找到文件。
mindspore/model_zoo/Bert_NEZHA/bert_model.py
浏览文件 @
ee519aa0
...
...
@@ -165,6 +165,7 @@ class EmbeddingPostprocessor(nn.Cell):
def
__init__
(
self
,
embedding_size
,
embedding_shape
,
use_relative_positions
=
False
,
use_token_type
=
False
,
token_type_vocab_size
=
16
,
use_one_hot_embeddings
=
False
,
...
...
@@ -192,6 +193,13 @@ class EmbeddingPostprocessor(nn.Cell):
self
.
layernorm
=
nn
.
LayerNorm
(
embedding_size
)
self
.
dropout
=
nn
.
Dropout
(
1
-
dropout_prob
)
self
.
gather
=
P
.
GatherV2
()
self
.
use_relative_positions
=
use_relative_positions
self
.
slice
=
P
.
Slice
()
self
.
full_position_embeddings
=
Parameter
(
initializer
(
TruncatedNormal
(
initializer_range
),
[
max_position_embeddings
,
embedding_size
]),
name
=
'full_position_embeddings'
)
def
construct
(
self
,
token_type_ids
,
word_embeddings
):
output
=
word_embeddings
...
...
@@ -206,6 +214,11 @@ class EmbeddingPostprocessor(nn.Cell):
token_type_embeddings
=
self
.
gather
(
self
.
embedding_table
,
flat_ids
,
0
)
token_type_embeddings
=
self
.
reshape
(
token_type_embeddings
,
self
.
shape
)
output
+=
token_type_embeddings
if
not
self
.
use_relative_positions
:
_
,
seq
,
width
=
self
.
shape
position_embeddings
=
self
.
slice
(
self
.
full_position_embeddings
,
[
0
,
0
],
[
seq
,
width
])
position_embeddings
=
self
.
reshape
(
position_embeddings
,
(
1
,
seq
,
width
))
output
+=
position_embeddings
output
=
self
.
layernorm
(
output
)
output
=
self
.
dropout
(
output
)
return
output
...
...
@@ -853,6 +866,7 @@ class BertModel(nn.Cell):
self
.
bert_embedding_postprocessor
=
EmbeddingPostprocessor
(
embedding_size
=
self
.
embedding_size
,
embedding_shape
=
output_embedding_shape
,
use_relative_positions
=
config
.
use_relative_positions
,
use_token_type
=
True
,
token_type_vocab_size
=
config
.
type_vocab_size
,
use_one_hot_embeddings
=
use_one_hot_embeddings
,
...
...
tests/st/networks/models/bert/bert_tdt_no_lossscale.py
浏览文件 @
ee519aa0
...
...
@@ -103,9 +103,9 @@ def me_de_train_dataset():
"""test me de train dataset"""
# apply repeat operations
repeat_count
=
1
ds
=
de
.
Storage
Dataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"input_ids"
,
"input_mask"
,
"segment_ids"
,
ds
=
de
.
TFRecord
Dataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"input_ids"
,
"input_mask"
,
"segment_ids"
,
"next_sentence_labels"
,
"masked_lm_positions"
,
"masked_lm_ids"
,
"masked_lm_weights"
])
"masked_lm_ids"
,
"masked_lm_weights"
]
,
shuffle
=
False
)
type_cast_op
=
C
.
TypeCast
(
mstype
.
int32
)
ds
=
ds
.
map
(
input_columns
=
"masked_lm_ids"
,
operations
=
type_cast_op
)
ds
=
ds
.
map
(
input_columns
=
"masked_lm_positions"
,
operations
=
type_cast_op
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录