Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
a3d9c9a8
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a3d9c9a8
编写于
5月 25, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 25, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1409 clean pylint warnings of parallel test cases
Merge pull request !1409 from yihuaijie/master
上级
0f2fc082
8cfc05e4
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
53 addition
and
65 deletion
+53
-65
tests/st/auto_parallel/onehot_model_parallel.py
tests/st/auto_parallel/onehot_model_parallel.py
+1
-1
tests/st/auto_parallel/soft_entropy_loss_expand_parallel.py
tests/st/auto_parallel/soft_entropy_loss_expand_parallel.py
+1
-1
tests/st/auto_parallel/test_resnet50_expand_loss_2p.py
tests/st/auto_parallel/test_resnet50_expand_loss_2p.py
+4
-2
tests/ut/python/parallel/test_auto_parallel_parameter_cast.py
...s/ut/python/parallel/test_auto_parallel_parameter_cast.py
+0
-11
tests/ut/python/parallel/test_auto_parallel_reshape.py
tests/ut/python/parallel/test_auto_parallel_reshape.py
+10
-10
tests/ut/python/parallel/test_auto_parallel_resnet.py
tests/ut/python/parallel/test_auto_parallel_resnet.py
+18
-22
tests/ut/python/parallel/test_auto_parallel_zig_zag.py
tests/ut/python/parallel/test_auto_parallel_zig_zag.py
+1
-1
tests/ut/python/parallel/test_dataset_interface.py
tests/ut/python/parallel/test_dataset_interface.py
+2
-2
tests/ut/python/parallel/test_gather_v2.py
tests/ut/python/parallel/test_gather_v2.py
+3
-1
tests/ut/python/parallel/test_gather_v2_primitive.py
tests/ut/python/parallel/test_gather_v2_primitive.py
+2
-2
tests/ut/python/parallel/test_one_hot_net.py
tests/ut/python/parallel/test_one_hot_net.py
+0
-1
tests/ut/python/parallel/test_onehot.py
tests/ut/python/parallel/test_onehot.py
+2
-2
tests/ut/python/parallel/test_prelu.py
tests/ut/python/parallel/test_prelu.py
+5
-5
tests/ut/python/parallel/test_reshape.py
tests/ut/python/parallel/test_reshape.py
+4
-4
未找到文件。
tests/st/auto_parallel/onehot_model_parallel.py
浏览文件 @
a3d9c9a8
...
...
@@ -72,7 +72,7 @@ class DataGenerator():
i
=
0
for
stra
in
strategy
:
temp
=
[]
while
len
(
blocks
)
>
0
:
while
blocks
:
block
=
blocks
.
pop
(
0
)
temp
.
extend
(
np
.
split
(
block
,
stra
,
axis
=
i
))
blocks
.
extend
(
temp
)
...
...
tests/st/auto_parallel/soft_entropy_loss_expand_parallel.py
浏览文件 @
a3d9c9a8
...
...
@@ -63,7 +63,7 @@ class DataGenerator():
i
=
0
for
stra
in
strategy
:
temp
=
[]
while
len
(
blocks
)
>
0
:
while
blocks
:
block
=
blocks
.
pop
(
0
)
temp
.
extend
(
np
.
split
(
block
,
stra
,
axis
=
i
))
blocks
.
extend
(
temp
)
...
...
tests/st/auto_parallel/test_resnet50_expand_loss_2p.py
浏览文件 @
a3d9c9a8
...
...
@@ -172,10 +172,12 @@ class ResNet(nn.Cell):
layer_nums
,
in_channels
,
out_channels
,
strides
=
[
1
,
2
,
2
,
2
]
,
strides
=
None
,
num_classes
=
100
):
super
(
ResNet
,
self
).
__init__
()
if
strides
is
None
:
strides
=
[
1
,
2
,
2
,
2
]
if
not
len
(
layer_nums
)
==
len
(
in_channels
)
==
len
(
out_channels
)
==
4
:
raise
ValueError
(
"the length of "
"layer_num, inchannel, outchannel list must be 4!"
)
...
...
@@ -300,7 +302,7 @@ class DataGenerator():
i
=
0
for
stra
in
strategy
:
temp
=
[]
while
len
(
blocks
)
>
0
:
while
blocks
:
block
=
blocks
.
pop
(
0
)
temp
.
extend
(
np
.
split
(
block
,
stra
,
axis
=
i
))
blocks
.
extend
(
temp
)
...
...
tests/ut/python/parallel/test_auto_parallel_parameter_cast.py
浏览文件 @
a3d9c9a8
...
...
@@ -38,17 +38,6 @@ class NetWithLoss(nn.Cell):
return
self
.
loss
(
predict
)
class
GradWrap
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
GradWrap
,
self
).
__init__
()
self
.
network
=
network
def
construct
(
self
,
x
,
y
,
z
,
w
):
return
C
.
grad_all
(
self
.
network
)(
x
,
y
,
z
,
w
)
# model_parallel test
def
test_common_parameter
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
...
...
tests/ut/python/parallel/test_auto_parallel_reshape.py
浏览文件 @
a3d9c9a8
...
...
@@ -174,9 +174,9 @@ def test_reshape_auto_4():
def
test_reshape_auto_5
():
class
NetWithLoss
(
nn
.
Cell
):
class
NetWithLoss
5
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
NetWithLoss
,
self
).
__init__
()
super
(
NetWithLoss
5
,
self
).
__init__
()
self
.
loss
=
VirtualLoss
()
self
.
network
=
network
...
...
@@ -184,9 +184,9 @@ def test_reshape_auto_5():
predict
=
self
.
network
(
x
,
y
)
return
self
.
loss
(
predict
)
class
GradWrap
(
nn
.
Cell
):
class
GradWrap
5
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
GradWrap
,
self
).
__init__
()
super
(
GradWrap
5
,
self
).
__init__
()
self
.
network
=
network
def
construct
(
self
,
x
,
y
):
...
...
@@ -217,16 +217,16 @@ def test_reshape_auto_5():
x
=
Tensor
(
np
.
ones
([
4
,
1024
*
size
,
1
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
4
,
1024
*
size
,]),
dtype
=
ms
.
float32
)
net
=
GradWrap
(
NetWithLoss
(
Net
()))
net
=
GradWrap
5
(
NetWithLoss5
(
Net
()))
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
)
net
.
set_auto_parallel
()
_executor
.
compile
(
net
,
x
,
y
)
def
test_reshape_auto_6
():
class
NetWithLoss
(
nn
.
Cell
):
class
NetWithLoss
6
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
NetWithLoss
,
self
).
__init__
()
super
(
NetWithLoss
6
,
self
).
__init__
()
self
.
loss
=
VirtualLoss
()
self
.
network
=
network
...
...
@@ -234,9 +234,9 @@ def test_reshape_auto_6():
predict
=
self
.
network
(
x
,
y
)
return
self
.
loss
(
predict
)
class
GradWrap
(
nn
.
Cell
):
class
GradWrap
6
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
GradWrap
,
self
).
__init__
()
super
(
GradWrap
6
,
self
).
__init__
()
self
.
network
=
network
def
construct
(
self
,
x
,
y
):
...
...
@@ -265,7 +265,7 @@ def test_reshape_auto_6():
x
=
Tensor
(
np
.
ones
([
4
,
1024
,
1
]),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
ones
([
4
,
1024
,]),
dtype
=
ms
.
float32
)
net
=
GradWrap
(
NetWithLoss
(
Net
()))
net
=
GradWrap
6
(
NetWithLoss6
(
Net
()))
context
.
set_auto_parallel_context
(
parallel_mode
=
"auto_parallel"
)
net
.
set_auto_parallel
()
_executor
.
compile
(
net
,
x
,
y
)
tests/ut/python/parallel/test_auto_parallel_resnet.py
浏览文件 @
a3d9c9a8
...
...
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
re
import
numpy
as
np
import
mindspore.common.dtype
as
mstype
import
mindspore.nn
as
nn
...
...
@@ -36,35 +36,33 @@ context.set_context(device_id=0)
init
()
def
weight_variable
(
shape
,
factor
=
0.1
):
def
weight_variable
():
return
TruncatedNormal
(
0.02
)
def
_conv3x3
(
in_channels
,
out_channels
,
stride
=
1
,
padding
=
0
,
pad_mode
=
'same'
):
"""Get a conv2d layer with 3x3 kernel size."""
init_value
=
weight_variable
(
(
out_channels
,
in_channels
,
3
,
3
)
)
init_value
=
weight_variable
()
return
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
pad_mode
=
pad_mode
,
weight_init
=
init_value
)
def
_conv1x1
(
in_channels
,
out_channels
,
stride
=
1
,
padding
=
0
,
pad_mode
=
'same'
):
"""Get a conv2d layer with 1x1 kernel size."""
init_value
=
weight_variable
(
(
out_channels
,
in_channels
,
1
,
1
)
)
init_value
=
weight_variable
()
return
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding
,
pad_mode
=
pad_mode
,
weight_init
=
init_value
)
def
_conv7x7
(
in_channels
,
out_channels
,
stride
=
1
,
padding
=
0
,
pad_mode
=
'same'
):
"""Get a conv2d layer with 7x7 kernel size."""
init_value
=
weight_variable
(
(
out_channels
,
in_channels
,
7
,
7
)
)
init_value
=
weight_variable
()
return
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
7
,
stride
=
stride
,
padding
=
padding
,
pad_mode
=
pad_mode
,
weight_init
=
init_value
)
def
_fused_bn
(
channels
,
momentum
=
0.9
):
"""Get a fused batchnorm"""
init_weight
=
weight_variable
((
channels
,))
init_bias
=
weight_variable
((
channels
,))
return
nn
.
BatchNorm2d
(
channels
,
momentum
=
momentum
)
...
...
@@ -132,10 +130,11 @@ class ResNet(nn.Cell):
layer_nums
,
in_channels
,
out_channels
,
strides
=
[
1
,
2
,
2
,
2
]
,
strides
=
None
,
num_classes
=
100
):
super
(
ResNet
,
self
).
__init__
()
if
strides
is
None
:
strides
=
[
1
,
2
,
2
,
2
]
if
not
len
(
layer_nums
)
==
len
(
in_channels
)
==
len
(
out_channels
)
==
4
:
raise
ValueError
(
"the length of "
"layer_num, inchannel, outchannel list must be 4!"
)
...
...
@@ -168,16 +167,13 @@ class ResNet(nn.Cell):
self
.
mean
=
P
.
ReduceMean
(
keep_dims
=
True
)
self
.
end_point
=
nn
.
Dense
(
2048
,
num_classes
,
has_bias
=
True
,
weight_init
=
weight_variable
(
(
num_classes
,
2048
)
),
bias_init
=
weight_variable
(
(
num_classes
,)
)).
add_flags_recursive
(
fp16
=
True
)
weight_init
=
weight_variable
(),
bias_init
=
weight_variable
()).
add_flags_recursive
(
fp16
=
True
)
self
.
squeeze
=
P
.
Squeeze
()
self
.
cast
=
P
.
Cast
()
def
_make_layer
(
self
,
block
,
layer_num
,
in_channel
,
out_channel
,
stride
):
layers
=
[]
down_sample
=
False
if
stride
!=
1
or
in_channel
!=
out_channel
:
down_sample
=
True
resblk
=
block
(
in_channel
,
out_channel
,
stride
=
1
)
layers
.
append
(
resblk
)
...
...
@@ -279,7 +275,7 @@ class DatasetLenet():
return
1
def
test_train_32k_8p
(
epoch_size
=
3
,
batch_size
=
32
,
num_classes
=
32768
):
def
test_train_32k_8p
(
batch_size
=
32
,
num_classes
=
32768
):
dev_num
=
8
context
.
set_auto_parallel_context
(
parallel_mode
=
ParallelMode
.
AUTO_PARALLEL
,
device_num
=
dev_num
)
set_algo_parameters
(
elementwise_op_strategy_follow
=
True
)
...
...
@@ -309,12 +305,12 @@ def test_train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768):
return
allreduce_fusion_dict
def
train_32k_8p_fusion1
(
epoch_size
=
3
,
batch_size
=
32
,
num_classes
=
32768
):
# 1048576 #131072 #32768 #8192
def
train_32k_8p_fusion1
(
batch_size
=
32
,
num_classes
=
32768
):
# 1048576 #131072 #32768 #8192
cost_model_context
.
set_cost_model_context
(
costmodel_gamma
=
0.001
,
costmodel_beta
=
400.0
)
cost_model_context
.
set_cost_model_context
(
costmodel_allreduce_fusion_algorithm
=
1
)
cost_model_context
.
set_cost_model_context
(
costmodel_allreduce_fusion_times
=
2
)
cost_model_context
.
set_cost_model_context
(
costmodel_allreduce_fusion_tail_percent
=
0.5
)
allreduce_fusion_dict
=
test_train_32k_8p
(
epoch_size
,
batch_size
,
num_classes
)
allreduce_fusion_dict
=
test_train_32k_8p
(
batch_size
,
num_classes
)
expect_dict
=
{
'end_point.bias'
:
2
,
'end_point.weight'
:
2
,
'layer4.2.bn3.beta'
:
2
,
...
...
@@ -477,17 +473,17 @@ def train_32k_8p_fusion1(epoch_size=3, batch_size=32, num_classes=32768): # 104
'bn1.gamma'
:
1
,
'conv1.weight'
:
1
}
assert
(
allreduce_fusion_dict
==
expect_dict
)
assert
allreduce_fusion_dict
==
expect_dict
cost_model_context
.
reset_cost_model_context
()
def
train_32k_8p_fusion2
(
epoch_size
=
3
,
batch_size
=
32
,
num_classes
=
32768
):
# 1048576 #131072 #32768 #8192
def
train_32k_8p_fusion2
(
batch_size
=
32
,
num_classes
=
32768
):
# 1048576 #131072 #32768 #8192
cost_model_context
.
set_cost_model_context
(
costmodel_allreduce_fusion_algorithm
=
2
)
cost_model_context
.
set_cost_model_context
(
costmodel_allreduce_fusion_tail_time
=
0.1
)
cost_model_context
.
set_cost_model_context
(
costmodel_allreduce_fusion_allreduce_inherent_time
=
0.05
)
cost_model_context
.
set_cost_model_context
(
costmodel_allreduce_fusion_allreduce_bandwidth
=
0.000001
)
cost_model_context
.
set_cost_model_context
(
costmodel_allreduce_fusion_computation_time_parameter
=
0.0000015
)
allreduce_fusion_dict
=
test_train_32k_8p
(
epoch_size
,
batch_size
,
num_classes
)
allreduce_fusion_dict
=
test_train_32k_8p
(
batch_size
,
num_classes
)
expect_dict
=
{
'end_point.bias'
:
2
,
'end_point.weight'
:
2
,
'layer4.2.bn3.beta'
:
2
,
...
...
@@ -650,11 +646,11 @@ def train_32k_8p_fusion2(epoch_size=3, batch_size=32, num_classes=32768): # 104
'bn1.gamma'
:
1
,
'conv1.weight'
:
1
}
assert
(
allreduce_fusion_dict
==
expect_dict
)
assert
allreduce_fusion_dict
==
expect_dict
cost_model_context
.
reset_cost_model_context
()
def
test_train_64k_8p
(
epoch_size
=
3
,
batch_size
=
32
,
num_classes
=
65536
):
# 1048576 #131072 #32768 #8192
def
test_train_64k_8p
(
batch_size
=
32
,
num_classes
=
65536
):
# 1048576 #131072 #32768 #8192
dev_num
=
8
context
.
set_auto_parallel_context
(
parallel_mode
=
ParallelMode
.
AUTO_PARALLEL
,
device_num
=
dev_num
)
cost_model_context
.
set_cost_model_context
(
costmodel_gamma
=
0.001
,
costmodel_beta
=
400.0
)
...
...
tests/ut/python/parallel/test_auto_parallel_zig_zag.py
浏览文件 @
a3d9c9a8
...
...
@@ -58,7 +58,7 @@ def test_zig_zag_graph():
def
construct
(
self
,
x
,
y
,
z
,
w
,
a
):
m1_result
=
self
.
matmul1
(
x
,
y
)
m2_result
=
self
.
matmul2
(
z
,
w
)
m3_result
=
self
.
matmul3
(
m2_result
,
m1_result
)
_
=
self
.
matmul3
(
m2_result
,
m1_result
)
out
=
self
.
matmul4
(
m2_result
,
a
)
return
out
...
...
tests/ut/python/parallel/test_dataset_interface.py
浏览文件 @
a3d9c9a8
...
...
@@ -101,7 +101,7 @@ def fixme_test_dataset_interface_sens_scalar():
class
TrainOneStepCell
(
nn
.
Cell
):
def
__init__
(
self
,
network
,
optimizer
,
sens
=
1.0
):
def
__init__
(
self
,
network
,
optimizer
):
super
(
TrainOneStepCell
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
network
=
network
self
.
network
.
add_flags
(
defer_inline
=
True
)
...
...
@@ -135,7 +135,7 @@ def test_dataset_interface_sens_shape_not_equal_loss():
sens
=
Tensor
(
np
.
ones
([
256
,
1024
]),
dtype
=
ms
.
float32
)
try
:
loss_scale_manager_sens
(
strategy1
,
sens
)
except
:
except
BaseException
:
pass
...
...
tests/ut/python/parallel/test_gather_v2.py
浏览文件 @
a3d9c9a8
...
...
@@ -45,8 +45,10 @@ class GradWrap(nn.Cell):
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
axis
=
0
,
strategy1
=
None
,
strategy2
=
None
,
shape
=
[
64
,
64
]
):
def
__init__
(
self
,
axis
=
0
,
strategy1
=
None
,
strategy2
=
None
,
shape
=
None
):
super
().
__init__
()
if
shape
is
None
:
shape
=
[
64
,
64
]
self
.
gatherv2
=
P
.
GatherV2
().
set_strategy
(
strategy1
)
self
.
mul
=
P
.
Mul
().
set_strategy
(
strategy2
)
self
.
index
=
Tensor
(
np
.
ones
(
shape
),
dtype
=
ms
.
int32
)
...
...
tests/ut/python/parallel/test_gather_v2_primitive.py
浏览文件 @
a3d9c9a8
...
...
@@ -221,14 +221,14 @@ def test_axis1_auto_batch_parallel():
def
test_axis1_batch_parallel
():
gather_v2_strategy
=
((
device_number
,
1
),
(
1
,
))
gather_v2_strategy
=
((
device_number
,
1
),
(
1
,))
criterion
=
GatherV2Axis1
(
1
,
strategy
=
gather_v2_strategy
,
index_size
=
512
)
rank
=
2
net_trains
(
criterion
,
rank
)
def
test_axis1_strategy1
():
gather_v2_strategy
=
((
16
,
2
),
(
1
,
))
gather_v2_strategy
=
((
16
,
2
),
(
1
,))
rank
=
17
criterion
=
GatherV2Axis1
(
1
,
strategy
=
gather_v2_strategy
,
index_size
=
512
)
net_trains
(
criterion
,
rank
)
tests/ut/python/parallel/test_one_hot_net.py
浏览文件 @
a3d9c9a8
...
...
@@ -265,7 +265,6 @@ class BNReshapeDenseBNNet(nn.Cell):
def
test_bn_reshape_dense_bn_train_loss
():
batch_size
=
16
device_num
=
16
context
.
set_auto_parallel_context
(
device_num
=
device_num
,
global_rank
=
0
)
input_
=
Tensor
(
np
.
ones
([
batch_size
,
2
,
32
,
32
]).
astype
(
np
.
float32
)
*
0.01
)
label
=
Tensor
(
np
.
ones
([
batch_size
]),
dtype
=
ms
.
int32
)
...
...
tests/ut/python/parallel/test_onehot.py
浏览文件 @
a3d9c9a8
...
...
@@ -104,7 +104,7 @@ def test_onehot_batch_parallel_invalid_strategy():
strategy4
=
((
16
,
1
),
(
16
,
1
))
try
:
compile_graph
(
strategy1
,
strategy2
,
strategy3
,
strategy4
)
except
:
except
BaseException
:
pass
...
...
@@ -144,7 +144,7 @@ def test_onehot_batch_parallel_invalid_strategy_axis0():
strategy4
=
((
16
,
1
),
(
16
,
1
))
try
:
compile_graph
(
strategy1
,
strategy2
,
strategy3
,
strategy4
,
onthot_axis
=
0
)
except
:
except
BaseException
:
pass
...
...
tests/ut/python/parallel/test_prelu.py
浏览文件 @
a3d9c9a8
...
...
@@ -124,9 +124,9 @@ def test_prelu_parallel_success2():
def
test_prelu_parallel_success3
():
class
NetWithLoss
(
nn
.
Cell
):
class
NetWithLoss
3
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
NetWithLoss
,
self
).
__init__
()
super
(
NetWithLoss
3
,
self
).
__init__
()
self
.
loss
=
VirtualLoss
()
self
.
network
=
network
...
...
@@ -134,9 +134,9 @@ def test_prelu_parallel_success3():
predict
=
self
.
network
(
x
,
y
,
w
)
return
self
.
loss
(
predict
)
class
GradWrap
(
nn
.
Cell
):
class
GradWrap
3
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
GradWrap
,
self
).
__init__
()
super
(
GradWrap
3
,
self
).
__init__
()
self
.
network
=
network
def
construct
(
self
,
x
,
y
,
w
):
...
...
@@ -161,7 +161,7 @@ def test_prelu_parallel_success3():
x
=
Tensor
(
np
.
random
.
rand
(
128
,
64
),
dtype
=
ms
.
float32
)
y
=
Tensor
(
np
.
random
.
rand
(
64
,
16
),
dtype
=
ms
.
float32
)
w
=
Tensor
(
np
.
random
.
rand
(
16
),
dtype
=
ms
.
float32
)
net
=
GradWrap
(
NetWithLoss
(
Net
(
strategy1
,
strategy2
)))
net
=
GradWrap
3
(
NetWithLoss3
(
Net
(
strategy1
,
strategy2
)))
net
.
set_auto_parallel
()
_executor
.
compile
(
net
,
x
,
y
,
w
)
...
...
tests/ut/python/parallel/test_reshape.py
浏览文件 @
a3d9c9a8
...
...
@@ -114,7 +114,7 @@ def test_reshape1_strategy_1():
strategy_loss
=
((
8
,
1
),
(
8
,
1
))
try
:
reshape_common
(
ParallelMode
.
SEMI_AUTO_PARALLEL
,
strategy0
,
strategy1
,
strategy2
,
strategy_loss
)
except
:
except
BaseException
:
pass
...
...
@@ -125,7 +125,7 @@ def test_reshape1_strategy_2():
strategy_loss
=
((
8
,
1
),
(
8
,
1
))
try
:
reshape_common
(
ParallelMode
.
AUTO_PARALLEL
,
strategy0
,
strategy1
,
strategy2
,
strategy_loss
)
except
:
except
BaseException
:
pass
...
...
@@ -347,14 +347,14 @@ def test_reshape_net3_2():
def
test_reshape_net4_1
():
try
:
reshape_net2
(
ReshapeNet4
(((
1
,
8
),
(
8
,
1
))))
except
:
except
BaseException
:
pass
def
test_reshape_net4_2
():
try
:
reshape_net2
(
ReshapeNet4
(((
1
,
8
),
(
8
,
2
))))
except
:
except
BaseException
:
pass
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录