Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
a9e92661
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a9e92661
编写于
4月 18, 2020
作者:
H
hesham
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Deepcopy problem when pyfunc cannot be pickled
上级
9cab093c
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
67 addition
and
1 deletion
+67
-1
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+36
-0
tests/ut/python/dataset/test_iterator.py
tests/ut/python/dataset/test_iterator.py
+31
-1
未找到文件。
mindspore/dataset/engine/datasets.py
浏览文件 @
a9e92661
...
...
@@ -30,7 +30,9 @@ from enum import Enum
from
importlib
import
import_module
import
threading
import
copy
import
numpy
as
np
from
mindspore._c_dataengine
import
DataType
,
TFReaderOp
,
ImageFolderOp
,
CifarOp
,
MnistOp
,
ManifestOp
,
\
MindRecordOp
,
TextFileOp
,
CBatchInfo
from
mindspore._c_expression
import
typing
...
...
@@ -1376,6 +1378,23 @@ class MapDataset(DatasetOp):
"""
return
self
.
input
[
0
].
get_dataset_size
()
def
__deepcopy__
(
self
,
memodict
):
if
id
(
self
)
in
memodict
:
return
memodict
[
id
(
self
)]
cls
=
self
.
__class__
new_op
=
cls
.
__new__
(
cls
)
memodict
[
id
(
self
)]
=
new_op
new_op
.
input
=
copy
.
deepcopy
(
self
.
input
,
memodict
)
new_op
.
input_columns
=
copy
.
deepcopy
(
self
.
input_columns
,
memodict
)
new_op
.
output_columns
=
copy
.
deepcopy
(
self
.
output_columns
,
memodict
)
new_op
.
columns_order
=
copy
.
deepcopy
(
self
.
columns_order
,
memodict
)
new_op
.
num_parallel_workers
=
copy
.
deepcopy
(
self
.
num_parallel_workers
,
memodict
)
new_op
.
output
=
copy
.
deepcopy
(
self
.
output
,
memodict
)
new_op
.
input_indexs
=
copy
.
deepcopy
(
self
.
_input_indexs
,
memodict
)
new_op
.
python_multiprocessing
=
copy
.
deepcopy
(
self
.
python_multiprocessing
,
memodict
)
new_op
.
operations
=
self
.
operations
return
new_op
# Iterator bootstrap will be called on iterator construction.
# A deep copy of Dataset object is created prior of iterator_bootstrap.
# This method will create per iterator process pool and bind pyfunc execution to the pool.
...
...
@@ -2599,6 +2618,23 @@ class GeneratorDataset(SourceDataset):
else
:
raise
ValueError
(
'set dataset_size with negative value {}'
.
format
(
value
))
def
__deepcopy__
(
self
,
memodict
):
if
id
(
self
)
in
memodict
:
return
memodict
[
id
(
self
)]
cls
=
self
.
__class__
new_op
=
cls
.
__new__
(
cls
)
memodict
[
id
(
self
)]
=
new_op
new_op
.
input
=
copy
.
deepcopy
(
self
.
input
,
memodict
)
new_op
.
output
=
copy
.
deepcopy
(
self
.
output
,
memodict
)
new_op
.
num_parallel_workers
=
copy
.
deepcopy
(
self
.
num_parallel_workers
,
memodict
)
new_op
.
column_types
=
copy
.
deepcopy
(
self
.
column_types
,
memodict
)
new_op
.
column_names
=
copy
.
deepcopy
(
self
.
column_names
,
memodict
)
new_op
.
source
=
self
.
source
new_op
.
sampler
=
self
.
sampler
return
new_op
class
TFRecordDataset
(
SourceDataset
):
"""
...
...
tests/ut/python/dataset/test_iterator.py
浏览文件 @
a9e92661
...
...
@@ -14,7 +14,7 @@
# ==============================================================================
import
numpy
as
np
import
pytest
import
copy
import
mindspore.dataset
as
ds
from
mindspore.dataset.engine.iterators
import
ITERATORS_LIST
,
_cleanup
...
...
@@ -81,3 +81,33 @@ def test_iterator_weak_ref():
assert
sum
(
itr
()
is
not
None
for
itr
in
ITERATORS_LIST
)
==
2
_cleanup
()
class
MyDict
(
dict
):
def
__getattr__
(
self
,
key
):
return
self
[
key
]
def
__setattr__
(
self
,
key
,
value
):
self
[
key
]
=
value
def
__call__
(
self
,
t
):
return
t
def
test_tree_copy
():
# Testing copying the tree with a pyfunc that cannot be pickled
data
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
COLUMNS
)
data1
=
data
.
map
(
operations
=
[
MyDict
()])
itr
=
data1
.
create_tuple_iterator
()
assert
id
(
data1
)
!=
id
(
itr
.
dataset
)
assert
id
(
data
)
!=
id
(
itr
.
dataset
.
input
[
0
])
assert
id
(
data1
.
operations
[
0
])
==
id
(
itr
.
dataset
.
operations
[
0
])
itr
.
release
()
if
__name__
==
'__main__'
:
test_tree_copy
()
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录