Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
415afe09
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
415afe09
编写于
5月 28, 2020
作者:
Y
yanghaitao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add get_dataset_size to celebadataset
上级
0f221403
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
29 addition
and
0 deletion
+29
-0
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+24
-0
tests/ut/python/dataset/test_datasets_celeba.py
tests/ut/python/dataset/test_datasets_celeba.py
+5
-0
未找到文件。
mindspore/dataset/engine/datasets.py
浏览文件 @
415afe09
...
...
@@ -4021,6 +4021,30 @@ class CelebADataset(MappableDataset):
return
self
.
sampler
.
is_sharded
()
def
get_dataset_size
(
self
):
"""
Get the number of batches in an epoch.
Return:
Number, number of batches.
"""
if
self
.
_dataset_size
is
None
:
dir
=
os
.
path
.
realpath
(
self
.
dataset_dir
)
attr_file
=
os
.
path
.
join
(
dir
,
"list_attr_celeba.txt"
)
num_rows
=
''
try
:
with
open
(
attr_file
,
'r'
)
as
f
:
num_rows
=
int
(
f
.
readline
())
except
Exception
:
raise
RuntimeError
(
"Get dataset size failed from attribution file."
)
rows_per_shard
=
get_num_rows
(
num_rows
,
self
.
num_shards
)
if
self
.
num_samples
is
not
None
:
rows_per_shard
=
min
(
self
.
num_samples
,
rows_per_shard
)
rows_from_sampler
=
self
.
_get_sampler_dataset_size
()
if
rows_from_sampler
is
None
:
return
rows_per_shard
return
min
(
rows_from_sampler
,
rows_per_shard
)
return
self
.
_dataset_size
class
TextFileDataset
(
SourceDataset
):
"""
...
...
tests/ut/python/dataset/test_datasets_celeba.py
浏览文件 @
415afe09
...
...
@@ -85,9 +85,14 @@ def test_celeba_dataset_distribute():
count
=
count
+
1
assert
(
count
==
1
)
def
test_celeba_get_dataset_size
():
data
=
ds
.
CelebADataset
(
DATA_DIR
,
decode
=
True
,
shuffle
=
False
)
size
=
data
.
get_dataset_size
()
assert
size
==
2
if
__name__
==
'__main__'
:
test_celeba_dataset_label
()
test_celeba_dataset_op
()
test_celeba_dataset_ext
()
test_celeba_dataset_distribute
()
test_celeba_get_dataset_size
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录