Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
dfd85caa
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dfd85caa
编写于
9月 03, 2020
作者:
Y
yoonlee666
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
delete enable_fused_layernorm
上级
4ec34396
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
29 addition
and
315 deletion
+29
-315
model_zoo/official/nlp/bert/README.md
model_zoo/official/nlp/bert/README.md
+0
-1
model_zoo/official/nlp/bert/src/bert_model.py
model_zoo/official/nlp/bert/src/bert_model.py
+11
-27
model_zoo/official/nlp/bert/src/fused_layer_norm.py
model_zoo/official/nlp/bert/src/fused_layer_norm.py
+0
-122
model_zoo/official/nlp/tinybert/README.md
model_zoo/official/nlp/tinybert/README.md
+0
-2
model_zoo/official/nlp/tinybert/src/fused_layer_norm.py
model_zoo/official/nlp/tinybert/src/fused_layer_norm.py
+0
-122
model_zoo/official/nlp/tinybert/src/gd_config.py
model_zoo/official/nlp/tinybert/src/gd_config.py
+2
-4
model_zoo/official/nlp/tinybert/src/td_config.py
model_zoo/official/nlp/tinybert/src/td_config.py
+2
-4
model_zoo/official/nlp/tinybert/src/tinybert_model.py
model_zoo/official/nlp/tinybert/src/tinybert_model.py
+12
-29
tests/st/networks/models/bert/bert_performance/test_bert_tdt_lossscale.py
...s/models/bert/bert_performance/test_bert_tdt_lossscale.py
+1
-2
tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py
...rks/models/bert/bert_precision/test_bert_tdt_lossscale.py
+1
-2
未找到文件。
model_zoo/official/nlp/bert/README.md
浏览文件 @
dfd85caa
...
...
@@ -161,7 +161,6 @@ For example, the schema file of cn-wiki-128 dataset for pretraining shows as fol
├─dataset.py
# data preprocessing
├─finetune_eval_config.py
# parameter configuration for finetuning
├─finetune_eval_model.py
# backbone code of network
├─fused_layer_norm.py
# Layernormal is optimized for Ascend
├─sample_process.py
# sample processing
├─utils.py
# util function
├─pretrain_eval.py
# train and eval net
...
...
model_zoo/official/nlp/bert/src/bert_model.py
浏览文件 @
dfd85caa
...
...
@@ -25,7 +25,6 @@ from mindspore.ops import operations as P
from
mindspore.ops
import
composite
as
C
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.parameter
import
Parameter
from
.fused_layer_norm
import
FusedLayerNorm
class
BertConfig
:
...
...
@@ -78,8 +77,7 @@ class BertConfig:
input_mask_from_dataset
=
True
,
token_type_ids_from_dataset
=
True
,
dtype
=
mstype
.
float32
,
compute_type
=
mstype
.
float32
,
enable_fused_layernorm
=
False
):
compute_type
=
mstype
.
float32
):
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
vocab_size
=
vocab_size
...
...
@@ -98,7 +96,6 @@ class BertConfig:
self
.
use_relative_positions
=
use_relative_positions
self
.
dtype
=
dtype
self
.
compute_type
=
compute_type
self
.
enable_fused_layernorm
=
enable_fused_layernorm
class
EmbeddingLookup
(
nn
.
Cell
):
...
...
@@ -245,19 +242,14 @@ class BertOutput(nn.Cell):
out_channels
,
initializer_range
=
0.02
,
dropout_prob
=
0.1
,
compute_type
=
mstype
.
float32
,
enable_fused_layernorm
=
False
):
compute_type
=
mstype
.
float32
):
super
(
BertOutput
,
self
).
__init__
()
self
.
dense
=
nn
.
Dense
(
in_channels
,
out_channels
,
weight_init
=
TruncatedNormal
(
initializer_range
)).
to_float
(
compute_type
)
self
.
dropout
=
nn
.
Dropout
(
1
-
dropout_prob
)
self
.
dropout_prob
=
dropout_prob
self
.
add
=
P
.
TensorAdd
()
if
compute_type
==
mstype
.
float16
:
self
.
layernorm
=
FusedLayerNorm
((
out_channels
,),
use_batch_norm
=
enable_fused_layernorm
).
to_float
(
compute_type
)
else
:
self
.
layernorm
=
nn
.
LayerNorm
((
out_channels
,)).
to_float
(
compute_type
)
self
.
layernorm
=
nn
.
LayerNorm
((
out_channels
,)).
to_float
(
compute_type
)
self
.
cast
=
P
.
Cast
()
def
construct
(
self
,
hidden_status
,
input_tensor
):
...
...
@@ -615,8 +607,7 @@ class BertSelfAttention(nn.Cell):
initializer_range
=
0.02
,
hidden_dropout_prob
=
0.1
,
use_relative_positions
=
False
,
compute_type
=
mstype
.
float32
,
enable_fused_layernorm
=
False
):
compute_type
=
mstype
.
float32
):
super
(
BertSelfAttention
,
self
).
__init__
()
if
hidden_size
%
num_attention_heads
!=
0
:
raise
ValueError
(
"The hidden size (%d) is not a multiple of the number "
...
...
@@ -644,8 +635,7 @@ class BertSelfAttention(nn.Cell):
out_channels
=
hidden_size
,
initializer_range
=
initializer_range
,
dropout_prob
=
hidden_dropout_prob
,
compute_type
=
compute_type
,
enable_fused_layernorm
=
enable_fused_layernorm
)
compute_type
=
compute_type
)
self
.
reshape
=
P
.
Reshape
()
self
.
shape
=
(
-
1
,
hidden_size
)
...
...
@@ -687,8 +677,7 @@ class BertEncoderCell(nn.Cell):
hidden_dropout_prob
=
0.1
,
use_relative_positions
=
False
,
hidden_act
=
"gelu"
,
compute_type
=
mstype
.
float32
,
enable_fused_layernorm
=
False
):
compute_type
=
mstype
.
float32
):
super
(
BertEncoderCell
,
self
).
__init__
()
self
.
attention
=
BertSelfAttention
(
batch_size
=
batch_size
,
...
...
@@ -700,8 +689,7 @@ class BertEncoderCell(nn.Cell):
initializer_range
=
initializer_range
,
hidden_dropout_prob
=
hidden_dropout_prob
,
use_relative_positions
=
use_relative_positions
,
compute_type
=
compute_type
,
enable_fused_layernorm
=
enable_fused_layernorm
)
compute_type
=
compute_type
)
self
.
intermediate
=
nn
.
Dense
(
in_channels
=
hidden_size
,
out_channels
=
intermediate_size
,
activation
=
hidden_act
,
...
...
@@ -710,8 +698,7 @@ class BertEncoderCell(nn.Cell):
out_channels
=
hidden_size
,
initializer_range
=
initializer_range
,
dropout_prob
=
hidden_dropout_prob
,
compute_type
=
compute_type
,
enable_fused_layernorm
=
enable_fused_layernorm
)
compute_type
=
compute_type
)
def
construct
(
self
,
hidden_states
,
attention_mask
):
# self-attention
...
...
@@ -758,8 +745,7 @@ class BertTransformer(nn.Cell):
use_relative_positions
=
False
,
hidden_act
=
"gelu"
,
compute_type
=
mstype
.
float32
,
return_all_encoders
=
False
,
enable_fused_layernorm
=
False
):
return_all_encoders
=
False
):
super
(
BertTransformer
,
self
).
__init__
()
self
.
return_all_encoders
=
return_all_encoders
...
...
@@ -776,8 +762,7 @@ class BertTransformer(nn.Cell):
hidden_dropout_prob
=
hidden_dropout_prob
,
use_relative_positions
=
use_relative_positions
,
hidden_act
=
hidden_act
,
compute_type
=
compute_type
,
enable_fused_layernorm
=
enable_fused_layernorm
)
compute_type
=
compute_type
)
layers
.
append
(
layer
)
self
.
layers
=
nn
.
CellList
(
layers
)
...
...
@@ -904,8 +889,7 @@ class BertModel(nn.Cell):
use_relative_positions
=
config
.
use_relative_positions
,
hidden_act
=
config
.
hidden_act
,
compute_type
=
config
.
compute_type
,
return_all_encoders
=
True
,
enable_fused_layernorm
=
config
.
enable_fused_layernorm
)
return_all_encoders
=
True
)
self
.
cast
=
P
.
Cast
()
self
.
dtype
=
config
.
dtype
...
...
model_zoo/official/nlp/bert/src/fused_layer_norm.py
已删除
100644 → 0
浏览文件 @
4ec34396
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""fused layernorm"""
import
numpy
as
np
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
functional
as
F
from
mindspore.common.parameter
import
Parameter
from
mindspore.common.initializer
import
initializer
from
mindspore.ops.primitive
import
constexpr
import
mindspore.common.dtype
as
mstype
from
mindspore.nn.cell
import
Cell
__all__
=
[
'FusedLayerNorm'
]
@
constexpr
def
get_shape_for_norm
(
x_shape
,
begin_norm_axis
):
print
(
"input_shape: "
,
x_shape
)
norm_shape
=
x_shape
[
begin_norm_axis
:]
output_shape
=
(
1
,
-
1
,
1
,
int
(
np
.
prod
(
norm_shape
)))
print
(
"output_shape: "
,
output_shape
)
return
output_shape
class
FusedLayerNorm
(
Cell
):
r
"""
Applies Layer Normalization over a mini-batch of inputs.
Layer normalization is widely used in recurrent neural networks. It applies
normalization over a mini-batch of inputs for each single training case as described
in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch
normalization, layer normalization performs exactly the same computation at training and
testing times. It can be described using the following formula. It is applied across all channels
and pixel but only one batch size.
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Args:
normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
`begin_norm_axis ... R - 1`.
begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions
`begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'zeros'.
use_batch_nrom (bool): Whether use batchnorm to preocess.
Inputs:
- **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`,
and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.
Outputs:
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
Examples:
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
>>> shape1 = x.shape[1:]
>>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
>>> m(x)
"""
def
__init__
(
self
,
normalized_shape
,
begin_norm_axis
=-
1
,
begin_params_axis
=-
1
,
gamma_init
=
'ones'
,
beta_init
=
'zeros'
,
use_batch_norm
=
False
):
super
(
FusedLayerNorm
,
self
).
__init__
()
if
not
isinstance
(
normalized_shape
,
(
tuple
,
list
)):
raise
TypeError
(
"The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}."
.
format
(
normalized_shape
,
type
(
normalized_shape
)))
self
.
normalized_shape
=
normalized_shape
self
.
begin_norm_axis
=
begin_norm_axis
self
.
begin_params_axis
=
begin_params_axis
self
.
gamma
=
Parameter
(
initializer
(
gamma_init
,
normalized_shape
),
name
=
"gamma"
)
self
.
beta
=
Parameter
(
initializer
(
beta_init
,
normalized_shape
),
name
=
"beta"
)
self
.
layer_norm
=
P
.
LayerNorm
(
begin_norm_axis
=
self
.
begin_norm_axis
,
begin_params_axis
=
self
.
begin_params_axis
)
self
.
batch_norm
=
P
.
BatchNorm
(
is_training
=
True
,
epsilon
=
1e-5
)
self
.
use_batch_norm
=
use_batch_norm
def
construct
(
self
,
input_x
):
"""Applies Layer Normalization over a mini-batch of inputs"""
if
self
.
use_batch_norm
and
self
.
training
:
ones
=
P
.
Fill
()(
mstype
.
float32
,
F
.
shape
(
input_x
)[:
self
.
begin_norm_axis
],
1.0
)
zeros
=
P
.
Fill
()(
mstype
.
float32
,
F
.
shape
(
input_x
)[:
self
.
begin_norm_axis
],
0.0
)
shape_x
=
F
.
shape
(
input_x
)
norm_shape
=
get_shape_for_norm
(
shape_x
,
self
.
begin_norm_axis
)
input_x
=
F
.
reshape
(
input_x
,
norm_shape
)
output
,
_
,
_
,
_
,
_
,
_
=
self
.
batch_norm
(
input_x
,
ones
,
zeros
,
None
,
None
)
output
=
F
.
reshape
(
output
,
shape_x
)
y
=
output
*
self
.
gamma
+
self
.
beta
else
:
y
,
_
,
_
=
self
.
layer_norm
(
input_x
,
self
.
gamma
,
self
.
beta
)
return
y
def
extend_repr
(
self
):
"""Display instance object as string."""
s
=
'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'
.
format
(
self
.
normalized_shape
,
self
.
begin_norm_axis
,
self
.
begin_params_axis
,
self
.
gamma
,
self
.
beta
)
return
s
model_zoo/official/nlp/tinybert/README.md
浏览文件 @
dfd85caa
...
...
@@ -113,7 +113,6 @@ For example, the dataset is cn-wiki-128, the schema file for general distill pha
├─__init__.py
├─assessment_method.py
# assessment method for evaluation
├─dataset.py
# data processing
├─fused_layer_norm.py
# Layernormal is optimized for Ascend
├─gd_config.py
# parameter configuration for general distill phase
├─td_config.py
# parameter configuration for task distill phase
├─tinybert_for_gd_td.py
# backbone code of network
...
...
@@ -229,7 +228,6 @@ Parameters for bert network:
token_type_ids_from_dataset use the token type ids loaded from dataset or not: True | False, default is True
dtype data type of input: mstype.float16 | mstype.float32, default is mstype.float32
compute_type compute type in BertTransformer: mstype.float16 | mstype.float32, default is mstype.float16
enable_fused_layernorm use batchnorm instead of layernorm to improve performance, default is False
```
## [Training Process](#contents)
### Training
...
...
model_zoo/official/nlp/tinybert/src/fused_layer_norm.py
已删除
100644 → 0
浏览文件 @
4ec34396
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""fused layernorm"""
import
numpy
as
np
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
functional
as
F
from
mindspore.common.parameter
import
Parameter
from
mindspore.common.initializer
import
initializer
from
mindspore.ops.primitive
import
constexpr
import
mindspore.common.dtype
as
mstype
from
mindspore.nn.cell
import
Cell
__all__
=
[
'FusedLayerNorm'
]
@
constexpr
def
get_shape_for_norm
(
x_shape
,
begin_norm_axis
):
print
(
"input_shape: "
,
x_shape
)
norm_shape
=
x_shape
[
begin_norm_axis
:]
output_shape
=
(
1
,
-
1
,
1
,
int
(
np
.
prod
(
norm_shape
)))
print
(
"output_shape: "
,
output_shape
)
return
output_shape
class
FusedLayerNorm
(
Cell
):
r
"""
Applies Layer Normalization over a mini-batch of inputs.
Layer normalization is widely used in recurrent neural networks. It applies
normalization over a mini-batch of inputs for each single training case as described
in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike batch
normalization, layer normalization performs exactly the same computation at training and
testing times. It can be described using the following formula. It is applied across all channels
and pixel but only one batch size.
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Args:
normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
`begin_norm_axis ... R - 1`.
begin_norm_axis (int): It first normalization dimension: normalization will be performed along dimensions
`begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'zeros'.
use_batch_nrom (bool): Whether use batchnorm to preocess.
Inputs:
- **input_x** (Tensor) - The shape of 'input_x' is :math:`(x_1, x_2, ..., x_R)`,
and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.
Outputs:
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input_x`.
Examples:
>>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
>>> shape1 = x.shape[1:]
>>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
>>> m(x)
"""
def
__init__
(
self
,
normalized_shape
,
begin_norm_axis
=-
1
,
begin_params_axis
=-
1
,
gamma_init
=
'ones'
,
beta_init
=
'zeros'
,
use_batch_norm
=
False
):
super
(
FusedLayerNorm
,
self
).
__init__
()
if
not
isinstance
(
normalized_shape
,
(
tuple
,
list
)):
raise
TypeError
(
"The type of 'normalized_shape' should be tuple[int] or list[int], but '{}' type is {}."
.
format
(
normalized_shape
,
type
(
normalized_shape
)))
self
.
normalized_shape
=
normalized_shape
self
.
begin_norm_axis
=
begin_norm_axis
self
.
begin_params_axis
=
begin_params_axis
self
.
gamma
=
Parameter
(
initializer
(
gamma_init
,
normalized_shape
),
name
=
"gamma"
)
self
.
beta
=
Parameter
(
initializer
(
beta_init
,
normalized_shape
),
name
=
"beta"
)
self
.
layer_norm
=
P
.
LayerNorm
(
begin_norm_axis
=
self
.
begin_norm_axis
,
begin_params_axis
=
self
.
begin_params_axis
)
self
.
batch_norm
=
P
.
BatchNorm
(
is_training
=
True
,
epsilon
=
1e-5
)
self
.
use_batch_norm
=
use_batch_norm
def
construct
(
self
,
input_x
):
"""fusedlayernorm"""
if
self
.
use_batch_norm
and
self
.
training
:
ones
=
P
.
Fill
()(
mstype
.
float32
,
F
.
shape
(
input_x
)[:
self
.
begin_norm_axis
],
1.0
)
zeros
=
P
.
Fill
()(
mstype
.
float32
,
F
.
shape
(
input_x
)[:
self
.
begin_norm_axis
],
0.0
)
shape_x
=
F
.
shape
(
input_x
)
norm_shape
=
get_shape_for_norm
(
shape_x
,
self
.
begin_norm_axis
)
input_x
=
F
.
reshape
(
input_x
,
norm_shape
)
output
,
_
,
_
,
_
,
_
,
_
=
self
.
batch_norm
(
input_x
,
ones
,
zeros
,
None
,
None
)
output
=
F
.
reshape
(
output
,
shape_x
)
y
=
output
*
self
.
gamma
+
self
.
beta
else
:
y
,
_
,
_
=
self
.
layer_norm
(
input_x
,
self
.
gamma
,
self
.
beta
)
return
y
def
extend_repr
(
self
):
"""Display instance object as string."""
s
=
'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'
.
format
(
self
.
normalized_shape
,
self
.
begin_norm_axis
,
self
.
begin_params_axis
,
self
.
gamma
,
self
.
beta
)
return
s
model_zoo/official/nlp/tinybert/src/gd_config.py
浏览文件 @
dfd85caa
...
...
@@ -55,8 +55,7 @@ bert_teacher_net_cfg = BertConfig(
input_mask_from_dataset
=
True
,
token_type_ids_from_dataset
=
True
,
dtype
=
mstype
.
float32
,
compute_type
=
mstype
.
float16
,
enable_fused_layernorm
=
False
compute_type
=
mstype
.
float16
)
bert_student_net_cfg
=
BertConfig
(
batch_size
=
32
,
...
...
@@ -76,6 +75,5 @@ bert_student_net_cfg = BertConfig(
input_mask_from_dataset
=
True
,
token_type_ids_from_dataset
=
True
,
dtype
=
mstype
.
float32
,
compute_type
=
mstype
.
float16
,
enable_fused_layernorm
=
False
compute_type
=
mstype
.
float16
)
model_zoo/official/nlp/tinybert/src/td_config.py
浏览文件 @
dfd85caa
...
...
@@ -74,8 +74,7 @@ td_teacher_net_cfg = BertConfig(
input_mask_from_dataset
=
True
,
token_type_ids_from_dataset
=
True
,
dtype
=
mstype
.
float32
,
compute_type
=
mstype
.
float16
,
enable_fused_layernorm
=
False
compute_type
=
mstype
.
float16
)
td_student_net_cfg
=
BertConfig
(
batch_size
=
32
,
...
...
@@ -95,6 +94,5 @@ td_student_net_cfg = BertConfig(
input_mask_from_dataset
=
True
,
token_type_ids_from_dataset
=
True
,
dtype
=
mstype
.
float32
,
compute_type
=
mstype
.
float16
,
enable_fused_layernorm
=
False
compute_type
=
mstype
.
float16
)
model_zoo/official/nlp/tinybert/src/tinybert_model.py
浏览文件 @
dfd85caa
...
...
@@ -25,7 +25,6 @@ from mindspore.ops import composite as C
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.parameter
import
Parameter
from
mindspore
import
context
from
.fused_layer_norm
import
FusedLayerNorm
class
BertConfig
:
...
...
@@ -78,8 +77,7 @@ class BertConfig:
input_mask_from_dataset
=
True
,
token_type_ids_from_dataset
=
True
,
dtype
=
mstype
.
float32
,
compute_type
=
mstype
.
float32
,
enable_fused_layernorm
=
False
):
compute_type
=
mstype
.
float32
):
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
vocab_size
=
vocab_size
...
...
@@ -98,7 +96,6 @@ class BertConfig:
self
.
use_relative_positions
=
use_relative_positions
self
.
dtype
=
dtype
self
.
compute_type
=
compute_type
self
.
enable_fused_layernorm
=
enable_fused_layernorm
class
EmbeddingLookup
(
nn
.
Cell
):
...
...
@@ -244,8 +241,7 @@ class BertOutput(nn.Cell):
out_channels
,
initializer_range
=
0.02
,
dropout_prob
=
0.1
,
compute_type
=
mstype
.
float32
,
enable_fused_layernorm
=
False
):
compute_type
=
mstype
.
float32
):
super
(
BertOutput
,
self
).
__init__
()
self
.
dense
=
nn
.
Dense
(
in_channels
,
out_channels
,
weight_init
=
TruncatedNormal
(
initializer_range
)).
to_float
(
compute_type
)
...
...
@@ -256,11 +252,7 @@ class BertOutput(nn.Cell):
self
.
layernorm
=
nn
.
LayerNorm
((
out_channels
,)).
to_float
(
mstype
.
float32
)
self
.
compute_type
=
compute_type
else
:
if
compute_type
==
mstype
.
float16
:
self
.
layernorm
=
FusedLayerNorm
((
out_channels
,),
use_batch_norm
=
enable_fused_layernorm
).
to_float
(
compute_type
)
else
:
self
.
layernorm
=
nn
.
LayerNorm
((
out_channels
,)).
to_float
(
compute_type
)
self
.
layernorm
=
nn
.
LayerNorm
((
out_channels
,)).
to_float
(
compute_type
)
self
.
cast
=
P
.
Cast
()
...
...
@@ -602,8 +594,7 @@ class BertSelfAttention(nn.Cell):
initializer_range
=
0.02
,
hidden_dropout_prob
=
0.1
,
use_relative_positions
=
False
,
compute_type
=
mstype
.
float32
,
enable_fused_layernorm
=
False
):
compute_type
=
mstype
.
float32
):
super
(
BertSelfAttention
,
self
).
__init__
()
if
hidden_size
%
num_attention_heads
!=
0
:
raise
ValueError
(
"The hidden size (%d) is not a multiple of the number "
...
...
@@ -628,8 +619,7 @@ class BertSelfAttention(nn.Cell):
out_channels
=
hidden_size
,
initializer_range
=
initializer_range
,
dropout_prob
=
hidden_dropout_prob
,
compute_type
=
compute_type
,
enable_fused_layernorm
=
enable_fused_layernorm
)
compute_type
=
compute_type
)
self
.
reshape
=
P
.
Reshape
()
self
.
shape
=
(
-
1
,
hidden_size
)
...
...
@@ -672,8 +662,7 @@ class BertEncoderCell(nn.Cell):
hidden_dropout_prob
=
0.1
,
use_relative_positions
=
False
,
hidden_act
=
"gelu"
,
compute_type
=
mstype
.
float32
,
enable_fused_layernorm
=
False
):
compute_type
=
mstype
.
float32
):
super
(
BertEncoderCell
,
self
).
__init__
()
self
.
attention
=
BertSelfAttention
(
batch_size
=
batch_size
,
...
...
@@ -685,8 +674,7 @@ class BertEncoderCell(nn.Cell):
initializer_range
=
initializer_range
,
hidden_dropout_prob
=
hidden_dropout_prob
,
use_relative_positions
=
use_relative_positions
,
compute_type
=
compute_type
,
enable_fused_layernorm
=
enable_fused_layernorm
)
compute_type
=
compute_type
)
self
.
intermediate
=
nn
.
Dense
(
in_channels
=
hidden_size
,
out_channels
=
intermediate_size
,
activation
=
hidden_act
,
...
...
@@ -695,8 +683,7 @@ class BertEncoderCell(nn.Cell):
out_channels
=
hidden_size
,
initializer_range
=
initializer_range
,
dropout_prob
=
hidden_dropout_prob
,
compute_type
=
compute_type
,
enable_fused_layernorm
=
enable_fused_layernorm
)
compute_type
=
compute_type
)
def
construct
(
self
,
hidden_states
,
attention_mask
):
"""bert encoder cell"""
# self-attention
...
...
@@ -743,8 +730,7 @@ class BertTransformer(nn.Cell):
use_relative_positions
=
False
,
hidden_act
=
"gelu"
,
compute_type
=
mstype
.
float32
,
return_all_encoders
=
False
,
enable_fused_layernorm
=
False
):
return_all_encoders
=
False
):
super
(
BertTransformer
,
self
).
__init__
()
self
.
return_all_encoders
=
return_all_encoders
layers
=
[]
...
...
@@ -760,8 +746,7 @@ class BertTransformer(nn.Cell):
hidden_dropout_prob
=
hidden_dropout_prob
,
use_relative_positions
=
use_relative_positions
,
hidden_act
=
hidden_act
,
compute_type
=
compute_type
,
enable_fused_layernorm
=
enable_fused_layernorm
)
compute_type
=
compute_type
)
layers
.
append
(
layer
)
self
.
layers
=
nn
.
CellList
(
layers
)
self
.
reshape
=
P
.
Reshape
()
...
...
@@ -877,8 +862,7 @@ class BertModel(nn.Cell):
use_relative_positions
=
config
.
use_relative_positions
,
hidden_act
=
config
.
hidden_act
,
compute_type
=
config
.
compute_type
,
return_all_encoders
=
True
,
enable_fused_layernorm
=
config
.
enable_fused_layernorm
)
return_all_encoders
=
True
)
self
.
cast
=
P
.
Cast
()
self
.
dtype
=
config
.
dtype
self
.
cast_compute_type
=
SaturateCast
(
dst_type
=
config
.
compute_type
)
...
...
@@ -981,8 +965,7 @@ class TinyBertModel(nn.Cell):
use_relative_positions
=
config
.
use_relative_positions
,
hidden_act
=
config
.
hidden_act
,
compute_type
=
config
.
compute_type
,
return_all_encoders
=
True
,
enable_fused_layernorm
=
config
.
enable_fused_layernorm
)
return_all_encoders
=
True
)
self
.
cast
=
P
.
Cast
()
self
.
dtype
=
config
.
dtype
self
.
cast_compute_type
=
SaturateCast
(
dst_type
=
config
.
compute_type
)
...
...
tests/st/networks/models/bert/bert_performance/test_bert_tdt_lossscale.py
浏览文件 @
dfd85caa
...
...
@@ -82,8 +82,7 @@ def get_config(version='base', batch_size=1):
input_mask_from_dataset
=
True
,
token_type_ids_from_dataset
=
True
,
dtype
=
mstype
.
float32
,
compute_type
=
mstype
.
float16
,
enable_fused_layernorm
=
False
)
compute_type
=
mstype
.
float16
)
else
:
bert_config
=
BertConfig
(
batch_size
=
batch_size
)
return
bert_config
...
...
tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py
浏览文件 @
dfd85caa
...
...
@@ -82,8 +82,7 @@ def get_config(version='base', batch_size=1):
input_mask_from_dataset
=
True
,
token_type_ids_from_dataset
=
True
,
dtype
=
mstype
.
float32
,
compute_type
=
mstype
.
float16
,
enable_fused_layernorm
=
False
)
compute_type
=
mstype
.
float16
)
else
:
bert_config
=
BertConfig
(
batch_size
=
batch_size
)
return
bert_config
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录