Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindspore
提交
01da7d2c
M
mindspore
项目概览
MindSpore
/
mindspore
通知
35
Star
15
Fork
15
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
01da7d2c
编写于
9月 03, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
9月 03, 2020
浏览文件
操作
浏览文件
下载
差异文件
!5713 Align num_samples of CSV with other dataset
Merge pull request !5713 from jiangzhiwen/fix/csv_num_samples
上级
c2ff5e3f
1a1a8893
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
20 addition
and
24 deletion
+20
-24
mindspore/ccsrc/minddata/dataset/api/datasets.cc
mindspore/ccsrc/minddata/dataset/api/datasets.cc
+1
-1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc
...ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc
+2
-2
mindspore/ccsrc/minddata/dataset/include/datasets.h
mindspore/ccsrc/minddata/dataset/include/datasets.h
+2
-2
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+3
-3
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+1
-5
tests/ut/cpp/dataset/c_api_dataset_csv_test.cc
tests/ut/cpp/dataset/c_api_dataset_csv_test.cc
+11
-11
未找到文件。
mindspore/ccsrc/minddata/dataset/api/datasets.cc
浏览文件 @
01da7d2c
...
...
@@ -1200,7 +1200,7 @@ bool CSVDataset::ValidateParams() {
return
false
;
}
if
(
num_samples_
<
-
1
)
{
if
(
num_samples_
<
0
)
{
MS_LOG
(
ERROR
)
<<
"CSVDataset: Invalid number of samples: "
<<
num_samples_
;
return
false
;
}
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc
浏览文件 @
01da7d2c
...
...
@@ -27,7 +27,7 @@
namespace
mindspore
{
namespace
dataset
{
CsvOp
::
Builder
::
Builder
()
:
builder_device_id_
(
0
),
builder_num_devices_
(
1
),
builder_num_samples_
(
-
1
),
builder_shuffle_files_
(
false
)
{
:
builder_device_id_
(
0
),
builder_num_devices_
(
1
),
builder_num_samples_
(
0
),
builder_shuffle_files_
(
false
)
{
std
::
shared_ptr
<
ConfigManager
>
config_manager
=
GlobalContext
::
config_manager
();
builder_num_workers_
=
config_manager
->
num_parallel_workers
();
builder_op_connector_size_
=
config_manager
->
op_connector_size
();
...
...
@@ -539,7 +539,7 @@ Status CsvOp::operator()() {
RETURN_IF_NOT_OK
(
jagged_buffer_connector_
->
Pop
(
0
,
&
buffer
));
if
(
buffer
->
eoe
())
{
workers_done
++
;
}
else
if
(
num_samples_
==
-
1
||
rows_read
<
num_samples_
)
{
}
else
if
(
num_samples_
==
0
||
rows_read
<
num_samples_
)
{
if
((
num_samples_
>
0
)
&&
(
rows_read
+
buffer
->
NumRows
()
>
num_samples_
))
{
int64_t
rowsToRemove
=
buffer
->
NumRows
()
-
(
num_samples_
-
rows_read
);
RETURN_IF_NOT_OK
(
buffer
->
SliceOff
(
rowsToRemove
));
...
...
mindspore/ccsrc/minddata/dataset/include/datasets.h
浏览文件 @
01da7d2c
...
...
@@ -191,7 +191,7 @@ std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::str
/// \param[in] column_names List of column names of the dataset (default={}). If this is not provided, infers the
/// column_names from the first row of CSV file.
/// \param[in] num_samples The number of samples to be included in the dataset.
/// (Default =
-1
means all samples.)
/// (Default =
0
means all samples.)
/// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode::kGlobal)
/// Can be any of:
/// ShuffleMode::kFalse - No shuffling is performed.
...
...
@@ -203,7 +203,7 @@ std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::str
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
CSVDataset
>
CSV
(
const
std
::
vector
<
std
::
string
>
&
dataset_files
,
char
field_delim
=
','
,
const
std
::
vector
<
std
::
shared_ptr
<
CsvBase
>>
&
column_defaults
=
{},
const
std
::
vector
<
std
::
string
>
&
column_names
=
{},
int64_t
num_samples
=
-
1
,
const
std
::
vector
<
std
::
string
>
&
column_names
=
{},
int64_t
num_samples
=
0
,
ShuffleMode
shuffle
=
ShuffleMode
::
kGlobal
,
int32_t
num_shards
=
1
,
int32_t
shard_id
=
0
);
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
01da7d2c
...
...
@@ -5140,7 +5140,7 @@ class CSVDataset(SourceDataset):
columns as string type.
column_names (list[str], optional): List of column names of the dataset (default=None). If this
is not provided, infers the column_names from the first row of CSV file.
num_samples (int, optional): number of samples(rows) to read (default=
-1
, reads the full dataset).
num_samples (int, optional): number of samples(rows) to read (default=
None
, reads the full dataset).
num_parallel_workers (int, optional): number of workers to read the data
(default=None, number set in the config).
shuffle (Union[bool, Shuffle level], optional): perform reshuffling of the data every epoch
...
...
@@ -5164,7 +5164,7 @@ class CSVDataset(SourceDataset):
"""
@
check_csvdataset
def
__init__
(
self
,
dataset_files
,
field_delim
=
','
,
column_defaults
=
None
,
column_names
=
None
,
num_samples
=
-
1
,
def
__init__
(
self
,
dataset_files
,
field_delim
=
','
,
column_defaults
=
None
,
column_names
=
None
,
num_samples
=
None
,
num_parallel_workers
=
None
,
shuffle
=
Shuffle
.
GLOBAL
,
num_shards
=
None
,
shard_id
=
None
):
super
().
__init__
(
num_parallel_workers
)
self
.
dataset_files
=
self
.
_find_files
(
dataset_files
)
...
...
@@ -5215,7 +5215,7 @@ class CSVDataset(SourceDataset):
if
self
.
dataset_size
is
None
:
num_rows
=
CsvOp
.
get_num_rows
(
self
.
dataset_files
,
self
.
column_names
is
None
)
self
.
dataset_size
=
get_num_rows
(
num_rows
,
self
.
num_shards
)
if
self
.
num_samples
!=
-
1
and
self
.
num_samples
<
self
.
dataset_size
:
if
self
.
num_samples
is
not
None
and
self
.
num_samples
<
self
.
dataset_size
:
self
.
dataset_size
=
num_rows
return
self
.
dataset_size
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
01da7d2c
...
...
@@ -830,16 +830,12 @@ def check_csvdataset(method):
def
new_method
(
self
,
*
args
,
**
kwargs
):
_
,
param_dict
=
parse_user_args
(
method
,
*
args
,
**
kwargs
)
nreq_param_int
=
[
'num_parallel_workers'
,
'num_shards'
,
'shard_id'
]
nreq_param_int
=
[
'num_
samples'
,
'num_
parallel_workers'
,
'num_shards'
,
'shard_id'
]
# check dataset_files; required argument
dataset_files
=
param_dict
.
get
(
'dataset_files'
)
type_check
(
dataset_files
,
(
str
,
list
),
"dataset files"
)
# check num_samples
num_samples
=
param_dict
.
get
(
'num_samples'
)
check_value
(
num_samples
,
[
-
1
,
INT32_MAX
],
"num_samples"
)
# check field_delim
field_delim
=
param_dict
.
get
(
'field_delim'
)
type_check
(
field_delim
,
(
str
,),
'field delim'
)
...
...
tests/ut/cpp/dataset/c_api_dataset_csv_test.cc
浏览文件 @
01da7d2c
...
...
@@ -33,7 +33,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetBasic) {
// Create a CSVDataset, with single CSV file
std
::
string
train_file
=
datasets_root_path_
+
"/testCSV/1.csv"
;
std
::
vector
<
std
::
string
>
column_names
=
{
"col1"
,
"col2"
,
"col3"
,
"col4"
};
std
::
shared_ptr
<
Dataset
>
ds
=
CSV
({
train_file
},
','
,
{},
column_names
,
-
1
,
ShuffleMode
::
kFalse
);
std
::
shared_ptr
<
Dataset
>
ds
=
CSV
({
train_file
},
','
,
{},
column_names
,
0
,
ShuffleMode
::
kFalse
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
...
...
@@ -85,7 +85,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetMultiFiles) {
std
::
string
file1
=
datasets_root_path_
+
"/testCSV/1.csv"
;
std
::
string
file2
=
datasets_root_path_
+
"/testCSV/append.csv"
;
std
::
vector
<
std
::
string
>
column_names
=
{
"col1"
,
"col2"
,
"col3"
,
"col4"
};
std
::
shared_ptr
<
Dataset
>
ds
=
CSV
({
file1
,
file2
},
','
,
{},
column_names
,
-
1
,
ShuffleMode
::
kGlobal
);
std
::
shared_ptr
<
Dataset
>
ds
=
CSV
({
file1
,
file2
},
','
,
{},
column_names
,
0
,
ShuffleMode
::
kGlobal
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
...
...
@@ -179,7 +179,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetDistribution) {
// Create a CSVDataset, with single CSV file
std
::
string
file
=
datasets_root_path_
+
"/testCSV/1.csv"
;
std
::
vector
<
std
::
string
>
column_names
=
{
"col1"
,
"col2"
,
"col3"
,
"col4"
};
std
::
shared_ptr
<
Dataset
>
ds
=
CSV
({
file
},
','
,
{},
column_names
,
-
1
,
ShuffleMode
::
kFalse
,
2
,
0
);
std
::
shared_ptr
<
Dataset
>
ds
=
CSV
({
file
},
','
,
{},
column_names
,
0
,
ShuffleMode
::
kFalse
,
2
,
0
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
...
...
@@ -228,7 +228,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetType) {
std
::
make_shared
<
CsvRecord
<
std
::
string
>>
(
CsvType
::
STRING
,
""
),
};
std
::
vector
<
std
::
string
>
column_names
=
{
"col1"
,
"col2"
,
"col3"
,
"col4"
};
std
::
shared_ptr
<
Dataset
>
ds
=
CSV
({
file
},
','
,
colum_type
,
column_names
,
-
1
,
ShuffleMode
::
kFalse
);
std
::
shared_ptr
<
Dataset
>
ds
=
CSV
({
file
},
','
,
colum_type
,
column_names
,
0
,
ShuffleMode
::
kFalse
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
...
...
@@ -343,15 +343,15 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetException) {
EXPECT_EQ
(
ds1
,
nullptr
);
// Test invalid num_samples < -1
std
::
shared_ptr
<
Dataset
>
ds2
=
CSV
({
file
},
','
,
{},
column_names
,
-
2
);
std
::
shared_ptr
<
Dataset
>
ds2
=
CSV
({
file
},
','
,
{},
column_names
,
-
1
);
EXPECT_EQ
(
ds2
,
nullptr
);
// Test invalid num_shards < 1
std
::
shared_ptr
<
Dataset
>
ds3
=
CSV
({
file
},
','
,
{},
column_names
,
-
1
,
ShuffleMode
::
kFalse
,
0
);
std
::
shared_ptr
<
Dataset
>
ds3
=
CSV
({
file
},
','
,
{},
column_names
,
0
,
ShuffleMode
::
kFalse
,
0
);
EXPECT_EQ
(
ds3
,
nullptr
);
// Test invalid shard_id >= num_shards
std
::
shared_ptr
<
Dataset
>
ds4
=
CSV
({
file
},
','
,
{},
column_names
,
-
1
,
ShuffleMode
::
kFalse
,
2
,
2
);
std
::
shared_ptr
<
Dataset
>
ds4
=
CSV
({
file
},
','
,
{},
column_names
,
0
,
ShuffleMode
::
kFalse
,
2
,
2
);
EXPECT_EQ
(
ds4
,
nullptr
);
// Test invalid field_delim
...
...
@@ -373,7 +373,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFilesA) {
std
::
string
file1
=
datasets_root_path_
+
"/testCSV/1.csv"
;
std
::
string
file2
=
datasets_root_path_
+
"/testCSV/append.csv"
;
std
::
vector
<
std
::
string
>
column_names
=
{
"col1"
,
"col2"
,
"col3"
,
"col4"
};
std
::
shared_ptr
<
Dataset
>
ds
=
CSV
({
file1
,
file2
},
','
,
{},
column_names
,
-
1
,
ShuffleMode
::
kFiles
);
std
::
shared_ptr
<
Dataset
>
ds
=
CSV
({
file1
,
file2
},
','
,
{},
column_names
,
0
,
ShuffleMode
::
kFiles
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
...
...
@@ -432,7 +432,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFilesB) {
std
::
string
file1
=
datasets_root_path_
+
"/testCSV/1.csv"
;
std
::
string
file2
=
datasets_root_path_
+
"/testCSV/append.csv"
;
std
::
vector
<
std
::
string
>
column_names
=
{
"col1"
,
"col2"
,
"col3"
,
"col4"
};
std
::
shared_ptr
<
Dataset
>
ds
=
CSV
({
file2
,
file1
},
','
,
{},
column_names
,
-
1
,
ShuffleMode
::
kFiles
);
std
::
shared_ptr
<
Dataset
>
ds
=
CSV
({
file2
,
file1
},
','
,
{},
column_names
,
0
,
ShuffleMode
::
kFiles
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
...
...
@@ -492,7 +492,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleGlobal) {
// Create a CSVFile Dataset, with single CSV file
std
::
string
train_file
=
datasets_root_path_
+
"/testCSV/1.csv"
;
std
::
vector
<
std
::
string
>
column_names
=
{
"col1"
,
"col2"
,
"col3"
,
"col4"
};
std
::
shared_ptr
<
Dataset
>
ds
=
CSV
({
train_file
},
','
,
{},
column_names
,
-
1
,
ShuffleMode
::
kGlobal
);
std
::
shared_ptr
<
Dataset
>
ds
=
CSV
({
train_file
},
','
,
{},
column_names
,
0
,
ShuffleMode
::
kGlobal
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
...
...
@@ -540,7 +540,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetDuplicateColumnName) {
// Create a CSVDataset, with single CSV file
std
::
string
train_file
=
datasets_root_path_
+
"/testCSV/1.csv"
;
std
::
vector
<
std
::
string
>
column_names
=
{
"col1"
,
"col1"
,
"col3"
,
"col4"
};
std
::
shared_ptr
<
Dataset
>
ds
=
CSV
({
train_file
},
','
,
{},
column_names
,
-
1
,
ShuffleMode
::
kFalse
);
std
::
shared_ptr
<
Dataset
>
ds
=
CSV
({
train_file
},
','
,
{},
column_names
,
0
,
ShuffleMode
::
kFalse
);
// Expect failure: duplicate column names
EXPECT_EQ
(
ds
,
nullptr
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录