Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
03804075
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
03804075
编写于
6月 17, 2020
作者:
Y
yanghaitao1
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
store get dataset size
上级
1e90e7be
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
12 addition
and
1 deletion
+12
-1
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+12
-1
未找到文件。
mindspore/dataset/engine/datasets.py
浏览文件 @
03804075
...
@@ -2284,6 +2284,7 @@ class ImageFolderDatasetV2(MappableDataset):
...
@@ -2284,6 +2284,7 @@ class ImageFolderDatasetV2(MappableDataset):
self
.
decode
=
decode
self
.
decode
=
decode
self
.
num_shards
=
num_shards
self
.
num_shards
=
num_shards
self
.
shard_id
=
shard_id
self
.
shard_id
=
shard_id
self
.
cur_dataset_size
=
None
def
get_args
(
self
):
def
get_args
(
self
):
args
=
super
().
get_args
()
args
=
super
().
get_args
()
...
@@ -2305,6 +2306,9 @@ class ImageFolderDatasetV2(MappableDataset):
...
@@ -2305,6 +2306,9 @@ class ImageFolderDatasetV2(MappableDataset):
Return:
Return:
Number, number of batches.
Number, number of batches.
"""
"""
if
self
.
cur_dataset_size
is
not
None
:
return
self
.
cur_dataset_size
if
self
.
num_samples
is
None
:
if
self
.
num_samples
is
None
:
num_samples
=
0
num_samples
=
0
else
:
else
:
...
@@ -2314,9 +2318,11 @@ class ImageFolderDatasetV2(MappableDataset):
...
@@ -2314,9 +2318,11 @@ class ImageFolderDatasetV2(MappableDataset):
rows_from_sampler
=
self
.
_get_sampler_dataset_size
()
rows_from_sampler
=
self
.
_get_sampler_dataset_size
()
if
rows_from_sampler
is
None
:
if
rows_from_sampler
is
None
:
self
.
cur_dataset_size
=
rows_per_shard
return
rows_per_shard
return
rows_per_shard
return
min
(
rows_from_sampler
,
rows_per_shard
)
self
.
cur_dataset_size
=
min
(
rows_from_sampler
,
rows_per_shard
)
return
self
.
cur_dataset_size
def
num_classes
(
self
):
def
num_classes
(
self
):
"""
"""
...
@@ -2509,6 +2515,7 @@ class MindDataset(SourceDataset):
...
@@ -2509,6 +2515,7 @@ class MindDataset(SourceDataset):
self
.
shuffle_option
=
shuffle
self
.
shuffle_option
=
shuffle
self
.
distribution
=
""
self
.
distribution
=
""
self
.
sampler
=
sampler
self
.
sampler
=
sampler
self
.
cur_dataset_size
=
None
if
num_shards
is
None
or
shard_id
is
None
:
if
num_shards
is
None
or
shard_id
is
None
:
self
.
partitions
=
None
self
.
partitions
=
None
...
@@ -2578,6 +2585,9 @@ class MindDataset(SourceDataset):
...
@@ -2578,6 +2585,9 @@ class MindDataset(SourceDataset):
Number, number of batches.
Number, number of batches.
"""
"""
if
self
.
_dataset_size
is
None
:
if
self
.
_dataset_size
is
None
:
if
self
.
cur_dataset_size
is
not
None
:
return
self
.
cur_dataset_size
if
self
.
load_dataset
:
if
self
.
load_dataset
:
dataset_file
=
[
self
.
dataset_file
]
dataset_file
=
[
self
.
dataset_file
]
else
:
else
:
...
@@ -2591,6 +2601,7 @@ class MindDataset(SourceDataset):
...
@@ -2591,6 +2601,7 @@ class MindDataset(SourceDataset):
raise
RuntimeError
(
raise
RuntimeError
(
"Dataset size plus number of padded samples is not divisible by number of shards."
)
"Dataset size plus number of padded samples is not divisible by number of shards."
)
num_rows
=
num_rows
//
self
.
partitions
[
0
]
+
1
num_rows
=
num_rows
//
self
.
partitions
[
0
]
+
1
self
.
cur_dataset_size
=
num_rows
return
num_rows
return
num_rows
return
self
.
_dataset_size
return
self
.
_dataset_size
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录