Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
219bc184
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
219bc184
编写于
4月 25, 2020
作者:
J
jinyaohui
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
clean pylint
上级
96d39886
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
37 addition
and
33 deletion
+37
-33
example/yolov3_coco2017/dataset.py
example/yolov3_coco2017/dataset.py
+1
-1
mindspore/_akg/__init__.py
mindspore/_akg/__init__.py
+3
-3
mindspore/nn/optim/ftrl.py
mindspore/nn/optim/ftrl.py
+7
-3
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+3
-3
mindspore/train/amp.py
mindspore/train/amp.py
+3
-0
mindspore/train/callback.py
mindspore/train/callback.py
+2
-2
mindspore/train/model.py
mindspore/train/model.py
+14
-14
tests/mindspore_test_framework/apps/test_bert_parts.py
tests/mindspore_test_framework/apps/test_bert_parts.py
+3
-6
tests/mindspore_test_framework/components/executor/check_exceptions.py
...re_test_framework/components/executor/check_exceptions.py
+1
-1
未找到文件。
example/yolov3_coco2017/dataset.py
浏览文件 @
219bc184
...
@@ -18,8 +18,8 @@ from __future__ import division
...
@@ -18,8 +18,8 @@ from __future__ import division
import
os
import
os
import
numpy
as
np
import
numpy
as
np
from
PIL
import
Image
from
matplotlib.colors
import
rgb_to_hsv
,
hsv_to_rgb
from
matplotlib.colors
import
rgb_to_hsv
,
hsv_to_rgb
from
PIL
import
Image
import
mindspore.dataset
as
de
import
mindspore.dataset
as
de
from
mindspore.mindrecord
import
FileWriter
from
mindspore.mindrecord
import
FileWriter
import
mindspore.dataset.transforms.vision.c_transforms
as
C
import
mindspore.dataset.transforms.vision.c_transforms
as
C
...
...
mindspore/_akg/__init__.py
浏览文件 @
219bc184
...
@@ -16,6 +16,9 @@
...
@@ -16,6 +16,9 @@
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
import
sys
import
sys
import
os
import
os
from
.op_build
import
op_build
from
.message
import
compilewithjson
def
AKGAddPath
():
def
AKGAddPath
():
"""_akg add path."""
"""_akg add path."""
...
@@ -58,6 +61,3 @@ class AKGMetaPathLoader:
...
@@ -58,6 +61,3 @@ class AKGMetaPathLoader:
sys
.
meta_path
.
insert
(
0
,
AKGMetaPathFinder
())
sys
.
meta_path
.
insert
(
0
,
AKGMetaPathFinder
())
from
.op_build
import
op_build
from
.message
import
compilewithjson
mindspore/nn/optim/ftrl.py
浏览文件 @
219bc184
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
# ============================================================================
# ============================================================================
"""FTRL"""
"""FTRL"""
from
mindspore.ops
import
functional
as
F
,
composite
as
C
,
operations
as
P
from
mindspore.ops
import
functional
as
F
,
composite
as
C
,
operations
as
P
from
mindspore.common.initializer
import
initializer
from
mindspore.common.parameter
import
Parameter
from
mindspore.common.parameter
import
Parameter
from
mindspore.common
import
Tensor
from
mindspore.common
import
Tensor
import
mindspore.common.dtype
as
mstype
import
mindspore.common.dtype
as
mstype
...
@@ -23,6 +22,8 @@ from mindspore._checkparam import Rel
...
@@ -23,6 +22,8 @@ from mindspore._checkparam import Rel
from
.optimizer
import
Optimizer
,
apply_decay
,
grad_scale
from
.optimizer
import
Optimizer
,
apply_decay
,
grad_scale
ftrl_opt
=
C
.
MultitypeFuncGraph
(
"ftrl_opt"
)
ftrl_opt
=
C
.
MultitypeFuncGraph
(
"ftrl_opt"
)
@
ftrl_opt
.
register
(
"Function"
,
"Tensor"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
@
ftrl_opt
.
register
(
"Function"
,
"Tensor"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
def
_tensor_run_opt
(
opt
,
learning_rate
,
l1
,
l2
,
lr_power
,
linear
,
gradient
,
weight
,
moment
):
def
_tensor_run_opt
(
opt
,
learning_rate
,
l1
,
l2
,
lr_power
,
linear
,
gradient
,
weight
,
moment
):
"""Apply ftrl optimizer to the weight parameter."""
"""Apply ftrl optimizer to the weight parameter."""
...
@@ -30,8 +31,10 @@ def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weig
...
@@ -30,8 +31,10 @@ def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weig
success
=
F
.
depend
(
success
,
opt
(
weight
,
moment
,
linear
,
gradient
,
learning_rate
,
l1
,
l2
,
lr_power
))
success
=
F
.
depend
(
success
,
opt
(
weight
,
moment
,
linear
,
gradient
,
learning_rate
,
l1
,
l2
,
lr_power
))
return
success
return
success
def
_check_param
(
initial_accum
,
learning_rate
,
lr_power
,
l1
,
l2
,
use_locking
,
loss_scale
=
1.0
,
weight_decay
=
0.0
,
def
_check_param
(
initial_accum
,
learning_rate
,
lr_power
,
l1
,
l2
,
use_locking
,
loss_scale
=
1.0
,
weight_decay
=
0.0
,
prim_name
=
None
):
prim_name
=
None
):
"""Check param."""
validator
.
check_value_type
(
"initial_accum"
,
initial_accum
,
[
float
],
prim_name
)
validator
.
check_value_type
(
"initial_accum"
,
initial_accum
,
[
float
],
prim_name
)
validator
.
check_number
(
"initial_accum"
,
initial_accum
,
0.0
,
Rel
.
GE
,
prim_name
)
validator
.
check_number
(
"initial_accum"
,
initial_accum
,
0.0
,
Rel
.
GE
,
prim_name
)
...
@@ -104,7 +107,7 @@ class FTRL(Optimizer):
...
@@ -104,7 +107,7 @@ class FTRL(Optimizer):
self
.
lr_power
=
lr_power
self
.
lr_power
=
lr_power
self
.
reciprocal_scale
=
1.0
/
loss_scale
self
.
reciprocal_scale
=
1.0
/
loss_scale
self
.
weight_decay
=
weight_decay
self
.
weight_decay
=
weight_decay
self
.
decay_tf
=
tuple
((
lambda
:
True
)()
for
x
in
self
.
parameters
)
self
.
decay_tf
=
tuple
((
lambda
:
True
)()
for
x
in
self
.
parameters
)
self
.
hyper_map
=
C
.
HyperMap
()
self
.
hyper_map
=
C
.
HyperMap
()
self
.
opt
=
P
.
ApplyFtrl
(
use_locking
=
use_locking
)
self
.
opt
=
P
.
ApplyFtrl
(
use_locking
=
use_locking
)
self
.
one
=
Tensor
(
1
,
mstype
.
int32
)
self
.
one
=
Tensor
(
1
,
mstype
.
int32
)
...
@@ -118,5 +121,6 @@ class FTRL(Optimizer):
...
@@ -118,5 +121,6 @@ class FTRL(Optimizer):
if
self
.
reciprocal_scale
!=
1.0
:
if
self
.
reciprocal_scale
!=
1.0
:
grads
=
self
.
hyper_map
(
F
.
partial
(
grad_scale
,
self
.
reciprocal_scale
),
grads
)
grads
=
self
.
hyper_map
(
F
.
partial
(
grad_scale
,
self
.
reciprocal_scale
),
grads
)
lr
=
self
.
learning_rate
lr
=
self
.
learning_rate
success
=
self
.
hyper_map
(
F
.
partial
(
ftrl_opt
,
self
.
opt
,
lr
,
self
.
l1
,
self
.
l2
,
self
.
lr_power
),
linear
,
grads
,
params
,
moments
)
success
=
self
.
hyper_map
(
F
.
partial
(
ftrl_opt
,
self
.
opt
,
lr
,
self
.
l1
,
self
.
l2
,
self
.
lr_power
),
linear
,
grads
,
params
,
moments
)
return
success
return
success
mindspore/ops/operations/nn_ops.py
浏览文件 @
219bc184
...
@@ -2063,7 +2063,7 @@ class LSTM(PrimitiveWithInfer):
...
@@ -2063,7 +2063,7 @@ class LSTM(PrimitiveWithInfer):
return
(
y_shape
,
h_shape
,
c_shape
,
reserved_shape
,
state_shape
)
return
(
y_shape
,
h_shape
,
c_shape
,
reserved_shape
,
state_shape
)
def
infer_dtype
(
self
,
x_dtype
,
h_dtype
,
c_dtype
,
w_dtype
):
def
infer_dtype
(
self
,
x_dtype
,
h_dtype
,
c_dtype
,
w_dtype
):
args
=
{
'x'
:
x_dtype
,
'h'
:
h_dtype
,
'c'
:
c_dtype
,
'w'
:
w_dtype
}
args
=
{
'x'
:
x_dtype
,
'h'
:
h_dtype
,
'c'
:
c_dtype
,
'w'
:
w_dtype
}
validator
.
check_tensor_type_same
(
args
,
(
mstype
.
float32
,
mstype
.
float16
),
self
.
name
)
validator
.
check_tensor_type_same
(
args
,
(
mstype
.
float32
,
mstype
.
float16
),
self
.
name
)
return
(
x_dtype
,
x_dtype
,
x_dtype
,
x_dtype
,
x_dtype
)
return
(
x_dtype
,
x_dtype
,
x_dtype
,
x_dtype
,
x_dtype
)
...
@@ -2691,8 +2691,8 @@ class ConfusionMulGrad(PrimitiveWithInfer):
...
@@ -2691,8 +2691,8 @@ class ConfusionMulGrad(PrimitiveWithInfer):
"""
"""
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
axis
=
(),
keep_dims
=
False
):
def
__init__
(
self
,
axis
=
(),
keep_dims
=
False
):
self
.
init_prim_io_names
(
inputs
=
[
"input0"
,
"input1"
,
"input2"
],
outputs
=
[
"output0"
,
"output1"
])
self
.
init_prim_io_names
(
inputs
=
[
"input0"
,
"input1"
,
"input2"
],
outputs
=
[
"output0"
,
"output1"
])
self
.
axis_
=
validator
.
check_value_type
(
"axis"
,
axis
,
[
int
,
tuple
,
list
],
self
.
name
)
self
.
axis_
=
validator
.
check_value_type
(
"axis"
,
axis
,
[
int
,
tuple
,
list
],
self
.
name
)
self
.
keep_dims_
=
validator
.
check_value_type
(
"keep_dims"
,
keep_dims
,
[
bool
],
self
.
name
)
self
.
keep_dims_
=
validator
.
check_value_type
(
"keep_dims"
,
keep_dims
,
[
bool
],
self
.
name
)
...
...
mindspore/train/amp.py
浏览文件 @
219bc184
...
@@ -41,6 +41,7 @@ class OutputTo16(nn.Cell):
...
@@ -41,6 +41,7 @@ class OutputTo16(nn.Cell):
def
_do_keep_batchnorm_fp32
(
network
):
def
_do_keep_batchnorm_fp32
(
network
):
"""Do keep batchnorm fp32."""
cells
=
network
.
name_cells
()
cells
=
network
.
name_cells
()
change
=
False
change
=
False
for
name
in
cells
:
for
name
in
cells
:
...
@@ -68,6 +69,7 @@ _config_level = {
...
@@ -68,6 +69,7 @@ _config_level = {
def
_check_kwargs
(
key_words
):
def
_check_kwargs
(
key_words
):
"""Check kwargs."""
for
arg
in
key_words
:
for
arg
in
key_words
:
if
arg
not
in
[
'cast_model_type'
,
'keep_batchnorm_fp32'
,
'loss_scale_manager'
]:
if
arg
not
in
[
'cast_model_type'
,
'keep_batchnorm_fp32'
,
'loss_scale_manager'
]:
raise
ValueError
(
f
"Unsupported arg '
{
arg
}
'"
)
raise
ValueError
(
f
"Unsupported arg '
{
arg
}
'"
)
...
@@ -84,6 +86,7 @@ def _check_kwargs(key_words):
...
@@ -84,6 +86,7 @@ def _check_kwargs(key_words):
def
_add_loss_network
(
network
,
loss_fn
,
cast_model_type
):
def
_add_loss_network
(
network
,
loss_fn
,
cast_model_type
):
"""Add loss network."""
class
WithLossCell
(
nn
.
Cell
):
class
WithLossCell
(
nn
.
Cell
):
"Wrap loss for amp. Cast network output back to float32"
"Wrap loss for amp. Cast network output back to float32"
...
...
mindspore/train/callback.py
浏览文件 @
219bc184
...
@@ -683,13 +683,14 @@ class LossMonitor(Callback):
...
@@ -683,13 +683,14 @@ class LossMonitor(Callback):
class
TimeMonitor
(
Callback
):
class
TimeMonitor
(
Callback
):
"""Time Monitor."""
def
__init__
(
self
,
data_size
):
def
__init__
(
self
,
data_size
):
super
(
TimeMonitor
,
self
).
__init__
()
super
(
TimeMonitor
,
self
).
__init__
()
self
.
data_size
=
data_size
self
.
data_size
=
data_size
def
epoch_begin
(
self
,
run_context
):
def
epoch_begin
(
self
,
run_context
):
self
.
epoch_time
=
time
.
time
()
self
.
epoch_time
=
time
.
time
()
def
epoch_end
(
self
,
run_context
):
def
epoch_end
(
self
,
run_context
):
epoch_mseconds
=
(
time
.
time
()
-
self
.
epoch_time
)
*
1000
epoch_mseconds
=
(
time
.
time
()
-
self
.
epoch_time
)
*
1000
per_step_mseconds
=
epoch_mseconds
/
self
.
data_size
per_step_mseconds
=
epoch_mseconds
/
self
.
data_size
...
@@ -701,4 +702,3 @@ class TimeMonitor(Callback):
...
@@ -701,4 +702,3 @@ class TimeMonitor(Callback):
def
step_end
(
self
,
run_context
):
def
step_end
(
self
,
run_context
):
step_mseconds
=
(
time
.
time
()
-
self
.
step_time
)
*
1000
step_mseconds
=
(
time
.
time
()
-
self
.
step_time
)
*
1000
print
(
'step time'
,
step_mseconds
,
flush
=
True
)
print
(
'step time'
,
step_mseconds
,
flush
=
True
)
mindspore/train/model.py
浏览文件 @
219bc184
...
@@ -122,7 +122,7 @@ class Model:
...
@@ -122,7 +122,7 @@ class Model:
def
_check_kwargs
(
self
,
kwargs
):
def
_check_kwargs
(
self
,
kwargs
):
for
arg
in
kwargs
:
for
arg
in
kwargs
:
if
arg
not
in
[
'loss_scale_manager'
,
'keep_batchnorm_fp32'
]:
if
arg
not
in
[
'loss_scale_manager'
,
'keep_batchnorm_fp32'
]:
raise
ValueError
(
f
"Unsupport arg '
{
arg
}
'"
)
raise
ValueError
(
f
"Unsupport arg '
{
arg
}
'"
)
def
_build_train_network
(
self
):
def
_build_train_network
(
self
):
"""Build train network"""
"""Build train network"""
...
@@ -130,17 +130,17 @@ class Model:
...
@@ -130,17 +130,17 @@ class Model:
if
self
.
_optimizer
:
if
self
.
_optimizer
:
if
self
.
_loss_scale_manager_set
:
if
self
.
_loss_scale_manager_set
:
network
=
amp
.
build_train_network
(
network
,
network
=
amp
.
build_train_network
(
network
,
self
.
_optimizer
,
self
.
_optimizer
,
self
.
_loss_fn
,
self
.
_loss_fn
,
level
=
self
.
_amp_level
,
level
=
self
.
_amp_level
,
loss_scale_manager
=
self
.
_loss_scale_manager
,
loss_scale_manager
=
self
.
_loss_scale_manager
,
keep_batchnorm_fp32
=
self
.
_keep_bn_fp32
)
keep_batchnorm_fp32
=
self
.
_keep_bn_fp32
)
else
:
else
:
network
=
amp
.
build_train_network
(
network
,
network
=
amp
.
build_train_network
(
network
,
self
.
_optimizer
,
self
.
_optimizer
,
self
.
_loss_fn
,
self
.
_loss_fn
,
level
=
self
.
_amp_level
,
level
=
self
.
_amp_level
,
keep_batchnorm_fp32
=
self
.
_keep_bn_fp32
)
keep_batchnorm_fp32
=
self
.
_keep_bn_fp32
)
elif
self
.
_loss_fn
:
elif
self
.
_loss_fn
:
network
=
nn
.
WithLossCell
(
network
,
self
.
_loss_fn
)
network
=
nn
.
WithLossCell
(
network
,
self
.
_loss_fn
)
# If need to check if loss_fn is not None, but optimizer is None
# If need to check if loss_fn is not None, but optimizer is None
...
@@ -273,14 +273,14 @@ class Model:
...
@@ -273,14 +273,14 @@ class Model:
# remove later to deal with loop sink
# remove later to deal with loop sink
need_wrap
=
False
need_wrap
=
False
if
not
hasattr
(
train_dataset
,
'__ME_INITED__'
)
and
context
.
get_context
(
"enable_loop_sink"
)
\
if
not
hasattr
(
train_dataset
,
'__ME_INITED__'
)
and
context
.
get_context
(
"enable_loop_sink"
)
\
and
not
context
.
get_context
(
"enable_ge"
):
and
not
context
.
get_context
(
"enable_ge"
):
need_wrap
=
True
need_wrap
=
True
dataset_helper
=
DatasetHelper
(
train_dataset
)
dataset_helper
=
DatasetHelper
(
train_dataset
)
# remove later to deal with loop sink
# remove later to deal with loop sink
if
need_wrap
:
if
need_wrap
:
self
.
_train_network
=
nn
.
DataWrapper
(
self
.
_train_network
,
*
(
dataset_helper
.
types_shapes
()),
self
.
_train_network
=
nn
.
DataWrapper
(
self
.
_train_network
,
*
(
dataset_helper
.
types_shapes
()),
train_dataset
.
__ME_INITED__
)
train_dataset
.
__ME_INITED__
)
cb_params
.
train_network
=
self
.
_train_network
cb_params
.
train_network
=
self
.
_train_network
self
.
_train_network
.
set_train
()
self
.
_train_network
.
set_train
()
...
@@ -440,7 +440,7 @@ class Model:
...
@@ -440,7 +440,7 @@ class Model:
# remove later to deal with loop sink
# remove later to deal with loop sink
need_wrap
=
False
need_wrap
=
False
if
not
hasattr
(
valid_dataset
,
'__ME_INITED__'
)
and
context
.
get_context
(
"enable_loop_sink"
)
\
if
not
hasattr
(
valid_dataset
,
'__ME_INITED__'
)
and
context
.
get_context
(
"enable_loop_sink"
)
\
and
not
context
.
get_context
(
"enable_ge"
):
and
not
context
.
get_context
(
"enable_ge"
):
need_wrap
=
True
need_wrap
=
True
valid_dataset
.
__loop_size__
=
1
valid_dataset
.
__loop_size__
=
1
...
@@ -449,7 +449,7 @@ class Model:
...
@@ -449,7 +449,7 @@ class Model:
# remove later to deal with loop sink
# remove later to deal with loop sink
if
need_wrap
:
if
need_wrap
:
self
.
_eval_network
=
nn
.
DataWrapper
(
self
.
_eval_network
,
*
(
dataset_helper
.
types_shapes
()),
self
.
_eval_network
=
nn
.
DataWrapper
(
self
.
_eval_network
,
*
(
dataset_helper
.
types_shapes
()),
valid_dataset
.
__ME_INITED__
)
valid_dataset
.
__ME_INITED__
)
self
.
_eval_network
.
set_train
(
mode
=
False
)
self
.
_eval_network
.
set_train
(
mode
=
False
)
self
.
_eval_network
.
phase
=
'eval'
self
.
_eval_network
.
phase
=
'eval'
...
...
tests/mindspore_test_framework/apps/test_bert_parts.py
浏览文件 @
219bc184
...
@@ -174,8 +174,7 @@ test_sets = [
...
@@ -174,8 +174,7 @@ test_sets = [
embedding_shape
=
[
1
,
128
,
768
],
embedding_shape
=
[
1
,
128
,
768
],
use_one_hot_embeddings
=
True
,
use_one_hot_embeddings
=
True
,
initializer_range
=
0.02
),
1
,
1
),
{
initializer_range
=
0.02
),
1
,
1
),
{
'init_param_with'
:
lambda
shp
:
np
.
ones
(
shp
).
astype
(
np
.
float32
)
'init_param_with'
:
lambda
shp
:
np
.
ones
(
shp
).
astype
(
np
.
float32
)}),
}),
'desc_inputs'
:
[
input_ids
],
'desc_inputs'
:
[
input_ids
],
'desc_bprop'
:
[[
128
]]}),
'desc_bprop'
:
[[
128
]]}),
(
'EmbeddingLookup_multi_outputs_init_param'
,
{
(
'EmbeddingLookup_multi_outputs_init_param'
,
{
...
@@ -184,8 +183,7 @@ test_sets = [
...
@@ -184,8 +183,7 @@ test_sets = [
embedding_shape
=
[
1
,
128
,
768
],
embedding_shape
=
[
1
,
128
,
768
],
use_one_hot_embeddings
=
False
,
use_one_hot_embeddings
=
False
,
initializer_range
=
0.02
),
{
initializer_range
=
0.02
),
{
'init_param_with'
:
lambda
shp
:
np
.
ones
(
shp
).
astype
(
np
.
float32
)
'init_param_with'
:
lambda
shp
:
np
.
ones
(
shp
).
astype
(
np
.
float32
)}),
}),
'desc_inputs'
:
[
input_ids
],
'desc_inputs'
:
[
input_ids
],
'desc_bprop'
:
[[
1
,
128
,
768
],
[
128
]]}),
'desc_bprop'
:
[[
1
,
128
,
768
],
[
128
]]}),
(
'EmbeddingLookup_multi_outputs_grad_with_no_sens'
,
{
(
'EmbeddingLookup_multi_outputs_grad_with_no_sens'
,
{
...
@@ -194,8 +192,7 @@ test_sets = [
...
@@ -194,8 +192,7 @@ test_sets = [
embedding_shape
=
[
1
,
128
,
768
],
embedding_shape
=
[
1
,
128
,
768
],
use_one_hot_embeddings
=
False
,
use_one_hot_embeddings
=
False
,
initializer_range
=
0.02
),
{
initializer_range
=
0.02
),
{
'init_param_with'
:
lambda
shp
:
np
.
ones
(
shp
).
astype
(
np
.
float32
)
'init_param_with'
:
lambda
shp
:
np
.
ones
(
shp
).
astype
(
np
.
float32
)}),
}),
'desc_inputs'
:
[
input_ids
]}),
'desc_inputs'
:
[
input_ids
]}),
(
'GetMaskedLMOutput_grad_with_no_sens'
,
{
(
'GetMaskedLMOutput_grad_with_no_sens'
,
{
'block'
:
GetMaskedLMOutput
(
BertConfig
(
batch_size
=
1
)),
'block'
:
GetMaskedLMOutput
(
BertConfig
(
batch_size
=
1
)),
...
...
tests/mindspore_test_framework/components/executor/check_exceptions.py
浏览文件 @
219bc184
...
@@ -44,4 +44,4 @@ class CheckExceptionsEC(IExectorComponent):
...
@@ -44,4 +44,4 @@ class CheckExceptionsEC(IExectorComponent):
raise
Exception
(
f
"Expect
{
e
}
, but got
{
sys
.
exc_info
()[
0
]
}
"
)
raise
Exception
(
f
"Expect
{
e
}
, but got
{
sys
.
exc_info
()[
0
]
}
"
)
if
error_kws
and
any
(
keyword
not
in
str
(
exec_info
.
value
)
for
keyword
in
error_kws
):
if
error_kws
and
any
(
keyword
not
in
str
(
exec_info
.
value
)
for
keyword
in
error_kws
):
raise
ValueError
(
'Error message `{}` does not contain all keywords `{}`'
.
format
(
raise
ValueError
(
'Error message `{}` does not contain all keywords `{}`'
.
format
(
str
(
exec_info
.
value
),
error_kws
))
str
(
exec_info
.
value
),
error_kws
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录