Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
edc92ccf
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
edc92ccf
编写于
8月 22, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(imperative/data): improve dataloader preformance
GitOrigin-RevId: 7d8d52aaeb47e7ec6c3efa282ff9014a4b7d1f01
上级
896b0193
变更
4
展开全部
隐藏空白更改
内联
并排
Showing
4 changed file
with
524 addition
and
732 deletion
+524
-732
imperative/python/megengine/data/dataloader.py
imperative/python/megengine/data/dataloader.py
+363
-563
imperative/python/megengine/data/sampler.py
imperative/python/megengine/data/sampler.py
+20
-9
imperative/python/test/unit/data/test_dataloader.py
imperative/python/test/unit/data/test_dataloader.py
+106
-73
imperative/python/test/unit/data/test_pre_dataloader.py
imperative/python/test/unit/data/test_pre_dataloader.py
+35
-87
未找到文件。
imperative/python/megengine/data/dataloader.py
浏览文件 @
edc92ccf
此差异已折叠。
点击以展开。
imperative/python/megengine/data/sampler.py
浏览文件 @
edc92ccf
...
...
@@ -2,6 +2,7 @@
import
collections.abc
import
math
from
abc
import
ABC
,
abstractmethod
from
itertools
import
count
from
typing
import
Any
,
Generator
,
Iterator
,
List
,
Union
import
numpy
as
np
...
...
@@ -126,13 +127,15 @@ class MapSampler(Sampler):
if
self
.
world_size
>
1
:
indices
=
self
.
scatter
(
indices
)
step
,
length
=
self
.
batch_size
,
len
(
indices
)
batch_index
=
[
indices
[
i
:
i
+
step
]
for
i
in
range
(
0
,
length
,
step
)]
batch
=
[]
for
idx
in
indices
:
batch
.
append
(
idx
)
if
len
(
batch
)
==
self
.
batch_size
:
yield
batch
batch
=
[]
if
self
.
drop_last
and
len
(
batch_index
[
-
1
])
<
self
.
batch_size
:
batch_index
.
pop
()
return
iter
(
batch_index
)
if
len
(
batch
)
>
0
and
not
self
.
drop_last
:
yield
batch
class
StreamSampler
(
Sampler
):
...
...
@@ -151,10 +154,18 @@ class StreamSampler(Sampler):
self
.
batch_size
=
batch_size
def
__iter__
(
self
):
return
self
return
self
.
batch
()
def
__next__
(
self
):
return
iter
(
range
(
self
.
batch_size
))
def
batch
(
self
):
batch
=
[]
for
idx
in
self
.
sample
():
batch
.
append
(
idx
)
if
len
(
batch
)
==
self
.
batch_size
:
yield
batch
batch
=
[]
def
sample
(
self
):
return
count
(
start
=
0
)
class
SequentialSampler
(
MapSampler
):
...
...
imperative/python/test/unit/data/test_dataloader.py
浏览文件 @
edc92ccf
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
math
import
os
import
platform
import
time
...
...
@@ -7,7 +15,7 @@ import numpy as np
import
pytest
from
megengine.data.collator
import
Collator
from
megengine.data.dataloader
import
DataLoader
from
megengine.data.dataloader
import
DataLoader
,
get_worker_info
from
megengine.data.dataset
import
ArrayDataset
,
StreamDataset
from
megengine.data.sampler
import
RandomSampler
,
SequentialSampler
,
StreamSampler
from
megengine.data.transform
import
(
...
...
@@ -29,14 +37,10 @@ def init_dataset():
def
test_dataloader_init
():
dataset
=
init_dataset
()
with
pytest
.
raises
(
ValueError
):
dataloader
=
DataLoader
(
dataset
,
num_workers
=
2
,
divide
=
True
)
with
pytest
.
raises
(
ValueError
):
dataloader
=
DataLoader
(
dataset
,
num_workers
=-
1
)
with
pytest
.
raises
(
ValueError
):
dataloader
=
DataLoader
(
dataset
,
timeout
=-
1
)
with
pytest
.
raises
(
ValueError
):
dataloader
=
DataLoader
(
dataset
,
num_workers
=
0
,
divide
=
True
)
dataloader
=
DataLoader
(
dataset
)
assert
isinstance
(
dataloader
.
sampler
,
SequentialSampler
)
...
...
@@ -54,10 +58,8 @@ def test_dataloader_init():
class
MyStream
(
StreamDataset
):
def
__init__
(
self
,
number
,
b
atch
=
False
,
error_foramt
=
False
,
b
lock
=
False
):
def
__init__
(
self
,
number
,
block
=
False
):
self
.
number
=
number
self
.
batch
=
batch
self
.
error_format
=
error_foramt
self
.
block
=
block
def
__iter__
(
self
):
...
...
@@ -65,22 +67,14 @@ class MyStream(StreamDataset):
if
self
.
block
:
for
_
in
range
(
10
):
time
.
sleep
(
1
)
if
self
.
batch
:
data
=
np
.
random
.
randint
(
0
,
256
,
(
2
,
2
,
2
,
3
),
dtype
=
"uint8"
)
yield
(
True
,
(
data
,
[
cnt
,
cnt
-
self
.
number
]))
else
:
data
=
np
.
random
.
randint
(
0
,
256
,
(
2
,
2
,
3
),
dtype
=
"uint8"
)
if
self
.
error_format
:
yield
(
data
,
cnt
)
else
:
yield
(
False
,
(
data
,
cnt
))
data
=
np
.
random
.
randint
(
0
,
256
,
(
2
,
2
,
3
),
dtype
=
"uint8"
)
yield
(
data
,
cnt
)
raise
StopIteration
@
pytest
.
mark
.
parametrize
(
"batch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_stream_dataloader
(
batch
,
num_workers
):
dataset
=
MyStream
(
100
,
batch
=
batch
)
def
test_stream_dataloader
(
num_workers
):
dataset
=
MyStream
(
100
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
dataset
,
...
...
@@ -90,7 +84,6 @@ def test_stream_dataloader(batch, num_workers):
)
check_set
=
set
()
for
step
,
data
in
enumerate
(
dataloader
):
if
step
==
10
:
break
...
...
@@ -101,18 +94,9 @@ def test_stream_dataloader(batch, num_workers):
check_set
.
add
(
i
)
def
test_stream_dataloader_error
():
dataset
=
MyStream
(
100
,
error_foramt
=
True
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
dataset
,
sampler
)
with
pytest
.
raises
(
AssertionError
,
match
=
r
".*tuple.*"
):
data_iter
=
iter
(
dataloader
)
next
(
data_iter
)
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_stream_dataloader_timeout
(
num_workers
):
dataset
=
MyStream
(
100
,
False
,
block
=
True
)
dataset
=
MyStream
(
100
,
block
=
True
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
dataset
,
sampler
,
num_workers
=
num_workers
,
timeout
=
2
)
...
...
@@ -140,17 +124,6 @@ def test_dataloader_parallel():
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
4
,
drop_last
=
False
),
num_workers
=
2
,
divide
=
False
,
)
for
(
data
,
label
)
in
dataloader
:
assert
data
.
shape
==
(
4
,
1
,
32
,
32
)
assert
label
.
shape
==
(
4
,)
dataloader
=
DataLoader
(
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
4
,
drop_last
=
False
),
num_workers
=
2
,
divide
=
True
,
)
for
(
data
,
label
)
in
dataloader
:
assert
data
.
shape
==
(
4
,
1
,
32
,
32
)
...
...
@@ -205,7 +178,7 @@ def test_dataloader_parallel_worker_exception():
transform
=
FakeErrorTransform
(),
num_workers
=
2
,
)
with
pytest
.
raises
(
RuntimeError
,
match
=
r
"
worker.*died
"
):
with
pytest
.
raises
(
RuntimeError
,
match
=
r
"
exited unexpectedly
"
):
data_iter
=
iter
(
dataloader
)
batch_data
=
next
(
data_iter
)
...
...
@@ -213,26 +186,23 @@ def test_dataloader_parallel_worker_exception():
def
_multi_instances_parallel_dataloader_worker
():
dataset
=
init_dataset
()
for
divide_flag
in
[
True
,
False
]:
train_dataloader
=
DataLoader
(
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
4
,
drop_last
=
False
),
num_workers
=
2
,
divide
=
divide_flag
,
)
val_dataloader
=
DataLoader
(
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
10
,
drop_last
=
False
),
num_workers
=
2
,
divide
=
divide_flag
,
)
for
idx
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
assert
data
.
shape
==
(
4
,
1
,
32
,
32
)
assert
label
.
shape
==
(
4
,)
if
idx
%
5
==
0
:
for
val_data
,
val_label
in
val_dataloader
:
assert
val_data
.
shape
==
(
10
,
1
,
32
,
32
)
assert
val_label
.
shape
==
(
10
,)
train_dataloader
=
DataLoader
(
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
4
,
drop_last
=
False
),
num_workers
=
2
,
)
val_dataloader
=
DataLoader
(
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
10
,
drop_last
=
False
),
num_workers
=
2
,
)
for
idx
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
assert
data
.
shape
==
(
4
,
1
,
32
,
32
)
assert
label
.
shape
==
(
4
,)
if
idx
%
5
==
0
:
for
val_data
,
val_label
in
val_dataloader
:
assert
val_data
.
shape
==
(
10
,
1
,
32
,
32
)
assert
val_label
.
shape
==
(
10
,)
def
test_dataloader_parallel_multi_instances
():
...
...
@@ -261,18 +231,81 @@ def test_dataloader_parallel_multi_instances_multiprocessing():
assert
p
.
exitcode
==
0
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_timeout_event
(
num_workers
):
def
cb
():
return
(
True
,
(
np
.
zeros
(
shape
=
(
2
,
2
,
2
,
3
)),
np
.
ones
(
shape
=
(
2
,))))
def
partition
(
ls
,
size
):
return
[
ls
[
i
:
i
+
size
]
for
i
in
range
(
0
,
len
(
ls
),
size
)]
dataset
=
MyStream
(
100
,
block
=
True
)
class
MyPreStream
(
StreamDataset
):
def
__init__
(
self
,
number
,
block
=
False
):
self
.
number
=
[
i
for
i
in
range
(
number
)]
self
.
block
=
block
self
.
data
=
[]
for
i
in
range
(
100
):
self
.
data
.
append
(
np
.
random
.
randint
(
0
,
256
,
(
2
,
2
,
3
),
dtype
=
"uint8"
))
def
__iter__
(
self
):
worker_info
=
get_worker_info
()
per_worker
=
int
(
math
.
ceil
((
len
(
self
.
data
))
/
float
(
worker_info
.
worker
)))
pre_data
=
iter
(
partition
(
self
.
data
,
per_worker
)[
worker_info
.
idx
])
pre_cnt
=
partition
(
self
.
number
,
per_worker
)[
worker_info
.
idx
]
for
cnt
in
pre_cnt
:
if
self
.
block
:
for
_
in
range
(
10
):
time
.
sleep
(
1
)
yield
(
next
(
pre_data
),
cnt
)
raise
StopIteration
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Windows"
,
reason
=
"dataloader do not support parallel on windows"
,
)
def
test_prestream_dataloader_multiprocessing
():
dataset
=
MyPreStream
(
100
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
dataset
,
sampler
,
Compose
([
Normalize
(
mean
=
(
103
,
116
,
123
),
std
=
(
57
,
57
,
58
)),
ToMode
(
"CHW"
)]),
num_workers
=
2
,
parallel_stream
=
True
,
)
check_set
=
set
()
for
step
,
data
in
enumerate
(
dataloader
):
if
step
==
10
:
break
assert
data
[
0
].
shape
==
(
4
,
3
,
2
,
2
)
assert
data
[
1
].
shape
==
(
4
,)
for
i
in
data
[
1
]:
assert
i
not
in
check_set
check_set
.
add
(
i
)
@
pytest
.
mark
.
skipif
(
platform
.
system
()
==
"Windows"
,
reason
=
"dataloader do not support parallel on windows"
,
)
def
test_predataloader_parallel_worker_exception
():
dataset
=
MyPreStream
(
100
)
class
FakeErrorTransform
(
Transform
):
def
__init__
(
self
):
pass
def
apply
(
self
,
input
):
raise
RuntimeError
(
"test raise error"
)
return
input
dataloader
=
DataLoader
(
dataset
,
sampler
,
num_workers
=
num_workers
,
timeout
=
2
,
timeout_event
=
cb
dataset
,
sampler
=
StreamSampler
(
batch_size
=
4
),
transform
=
FakeErrorTransform
(),
num_workers
=
2
,
parallel_stream
=
True
,
)
for
_
,
data
in
enumerate
(
dataloader
):
np
.
testing
.
assert_equal
(
data
[
0
],
np
.
zeros
(
shape
=
(
4
,
2
,
2
,
3
))
)
np
.
testing
.
assert_equal
(
data
[
1
],
np
.
ones
(
shape
=
(
4
,))
)
break
with
pytest
.
raises
(
RuntimeError
,
match
=
r
"exited unexpectedly"
):
data_iter
=
iter
(
dataloader
)
batch_data
=
next
(
data_iter
)
print
(
batch_data
.
shape
)
imperative/python/test/unit/data/test_pre_dataloader.py
浏览文件 @
edc92ccf
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
gc
import
math
import
os
import
platform
import
time
...
...
@@ -8,7 +16,7 @@ import numpy as np
import
pytest
from
megengine.data.collator
import
Collator
from
megengine.data.dataloader
import
DataLoader
from
megengine.data.dataloader
import
DataLoader
,
get_worker_info
from
megengine.data.dataset
import
ArrayDataset
,
StreamDataset
from
megengine.data.sampler
import
RandomSampler
,
SequentialSampler
,
StreamSampler
from
megengine.data.transform
import
(
...
...
@@ -30,14 +38,10 @@ def init_dataset():
def
test_dataloader_init
():
dataset
=
init_dataset
()
with
pytest
.
raises
(
ValueError
):
dataloader
=
DataLoader
(
dataset
,
num_workers
=
2
,
divide
=
True
)
with
pytest
.
raises
(
ValueError
):
dataloader
=
DataLoader
(
dataset
,
num_workers
=-
1
)
with
pytest
.
raises
(
ValueError
):
dataloader
=
DataLoader
(
dataset
,
timeout
=-
1
)
with
pytest
.
raises
(
ValueError
):
dataloader
=
DataLoader
(
dataset
,
num_workers
=
0
,
divide
=
True
)
dataloader
=
DataLoader
(
dataset
,
preload
=
True
)
assert
isinstance
(
dataloader
.
sampler
,
SequentialSampler
)
...
...
@@ -59,10 +63,8 @@ def test_dataloader_init():
class
MyStream
(
StreamDataset
):
def
__init__
(
self
,
number
,
b
atch
=
False
,
error_foramt
=
False
,
b
lock
=
False
):
def
__init__
(
self
,
number
,
block
=
False
):
self
.
number
=
number
self
.
batch
=
batch
self
.
error_format
=
error_foramt
self
.
block
=
block
def
__iter__
(
self
):
...
...
@@ -70,22 +72,14 @@ class MyStream(StreamDataset):
if
self
.
block
:
for
_
in
range
(
10
):
time
.
sleep
(
1
)
if
self
.
batch
:
data
=
np
.
random
.
randint
(
0
,
256
,
(
2
,
2
,
2
,
3
),
dtype
=
"uint8"
)
yield
(
True
,
(
data
,
[
cnt
,
cnt
-
self
.
number
]))
else
:
data
=
np
.
random
.
randint
(
0
,
256
,
(
2
,
2
,
3
),
dtype
=
"uint8"
)
if
self
.
error_format
:
yield
(
data
,
cnt
)
else
:
yield
(
False
,
(
data
,
cnt
))
data
=
np
.
random
.
randint
(
0
,
256
,
(
2
,
2
,
3
),
dtype
=
"uint8"
)
yield
(
data
,
cnt
)
raise
StopIteration
@
pytest
.
mark
.
parametrize
(
"batch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_stream_dataloader
(
batch
,
num_workers
):
dataset
=
MyStream
(
100
,
batch
=
batch
)
def
test_stream_dataloader
(
num_workers
):
dataset
=
MyStream
(
100
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
dataset
,
...
...
@@ -107,18 +101,9 @@ def test_stream_dataloader(batch, num_workers):
check_set
.
add
(
i
)
def
test_stream_dataloader_error
():
dataset
=
MyStream
(
100
,
error_foramt
=
True
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
dataset
,
sampler
,
preload
=
True
)
with
pytest
.
raises
(
AssertionError
,
match
=
r
".*tuple.*"
):
data_iter
=
iter
(
dataloader
)
next
(
data_iter
)
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_stream_dataloader_timeout
(
num_workers
):
dataset
=
MyStream
(
100
,
False
,
block
=
True
)
dataset
=
MyStream
(
100
,
block
=
True
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
...
...
@@ -150,18 +135,6 @@ def test_dataloader_parallel():
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
4
,
drop_last
=
False
),
num_workers
=
2
,
divide
=
False
,
preload
=
True
,
)
for
(
data
,
label
)
in
dataloader
:
assert
data
.
_tuple_shape
==
(
4
,
1
,
32
,
32
)
assert
label
.
_tuple_shape
==
(
4
,)
dataloader
=
DataLoader
(
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
4
,
drop_last
=
False
),
num_workers
=
2
,
divide
=
True
,
preload
=
True
,
)
for
(
data
,
label
)
in
dataloader
:
...
...
@@ -219,7 +192,7 @@ def test_dataloader_parallel_worker_exception():
num_workers
=
2
,
preload
=
True
,
)
with
pytest
.
raises
(
RuntimeError
,
match
=
r
"
worker.*died
"
):
with
pytest
.
raises
(
RuntimeError
,
match
=
r
"
exited unexpectedly
"
):
data_iter
=
iter
(
dataloader
)
batch_data
=
next
(
data_iter
)
...
...
@@ -227,28 +200,25 @@ def test_dataloader_parallel_worker_exception():
def
_multi_instances_parallel_dataloader_worker
():
dataset
=
init_dataset
()
for
divide_flag
in
[
True
,
False
]:
train_dataloader
=
DataLoader
(
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
4
,
drop_last
=
False
),
num_workers
=
2
,
divide
=
divide_flag
,
preload
=
True
,
)
val_dataloader
=
DataLoader
(
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
10
,
drop_last
=
False
),
num_workers
=
2
,
divide
=
divide_flag
,
preload
=
True
,
)
for
idx
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
assert
data
.
_tuple_shape
==
(
4
,
1
,
32
,
32
)
assert
label
.
_tuple_shape
==
(
4
,)
if
idx
%
5
==
0
:
for
val_data
,
val_label
in
val_dataloader
:
assert
val_data
.
_tuple_shape
==
(
10
,
1
,
32
,
32
)
assert
val_label
.
_tuple_shape
==
(
10
,)
train_dataloader
=
DataLoader
(
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
4
,
drop_last
=
False
),
num_workers
=
2
,
preload
=
True
,
)
val_dataloader
=
DataLoader
(
dataset
,
sampler
=
RandomSampler
(
dataset
,
batch_size
=
10
,
drop_last
=
False
),
num_workers
=
2
,
preload
=
True
,
)
for
idx
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
assert
data
.
_tuple_shape
==
(
4
,
1
,
32
,
32
)
assert
label
.
_tuple_shape
==
(
4
,)
if
idx
%
5
==
0
:
for
val_data
,
val_label
in
val_dataloader
:
assert
val_data
.
_tuple_shape
==
(
10
,
1
,
32
,
32
)
assert
val_label
.
_tuple_shape
==
(
10
,)
def
test_dataloader_parallel_multi_instances
():
...
...
@@ -276,25 +246,3 @@ def test_dataloader_parallel_multi_instances_multiprocessing():
for
p
in
processes
:
p
.
join
()
assert
p
.
exitcode
==
0
@
pytest
.
mark
.
parametrize
(
"num_workers"
,
[
0
,
2
])
def
test_timeout_event
(
num_workers
):
def
cb
():
return
(
True
,
(
np
.
zeros
(
shape
=
(
2
,
2
,
2
,
3
)),
np
.
ones
(
shape
=
(
2
,))))
dataset
=
MyStream
(
100
,
block
=
True
)
sampler
=
StreamSampler
(
batch_size
=
4
)
dataloader
=
DataLoader
(
dataset
,
sampler
,
num_workers
=
num_workers
,
timeout
=
2
,
timeout_event
=
cb
,
preload
=
True
,
)
for
_
,
data
in
enumerate
(
dataloader
):
np
.
testing
.
assert_equal
(
data
[
0
],
np
.
zeros
(
shape
=
(
4
,
2
,
2
,
3
)))
np
.
testing
.
assert_equal
(
data
[
1
],
np
.
ones
(
shape
=
(
4
,)))
break
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录