Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b091f74c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b091f74c
编写于
7月 20, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 20, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3016 Add CSV dataset loader
Merge pull request !3016 from jiangzhiwen/dataset/csv
上级
57252dee
2f506b79
变更
24
展开全部
隐藏空白更改
内联
并排
Showing
24 changed file
with
1858 addition
and
11 deletion
+1858
-11
mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc
mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc
+82
-0
mindspore/ccsrc/minddata/dataset/api/de_pipeline.h
mindspore/ccsrc/minddata/dataset/api/de_pipeline.h
+3
-0
mindspore/ccsrc/minddata/dataset/api/python_bindings.cc
mindspore/ccsrc/minddata/dataset/api/python_bindings.cc
+15
-2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt
.../minddata/dataset/engine/datasetops/source/CMakeLists.txt
+2
-1
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc
...ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc
+757
-0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h
.../ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h
+451
-0
mindspore/dataset/__init__.py
mindspore/dataset/__init__.py
+2
-2
mindspore/dataset/engine/__init__.py
mindspore/dataset/engine/__init__.py
+1
-1
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+107
-5
mindspore/dataset/engine/iterators.py
mindspore/dataset/engine/iterators.py
+2
-0
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+43
-0
tests/ut/cpp/dataset/CMakeLists.txt
tests/ut/cpp/dataset/CMakeLists.txt
+1
-0
tests/ut/cpp/dataset/csv_op_test.cc
tests/ut/cpp/dataset/csv_op_test.cc
+122
-0
tests/ut/data/dataset/testCSV/1.csv
tests/ut/data/dataset/testCSV/1.csv
+3
-0
tests/ut/data/dataset/testCSV/2.csv
tests/ut/data/dataset/testCSV/2.csv
+8
-0
tests/ut/data/dataset/testCSV/chinese.csv
tests/ut/data/dataset/testCSV/chinese.csv
+1
-0
tests/ut/data/dataset/testCSV/embedded.csv
tests/ut/data/dataset/testCSV/embedded.csv
+2
-0
tests/ut/data/dataset/testCSV/exception.csv
tests/ut/data/dataset/testCSV/exception.csv
+3
-0
tests/ut/data/dataset/testCSV/header.csv
tests/ut/data/dataset/testCSV/header.csv
+2
-0
tests/ut/data/dataset/testCSV/number.csv
tests/ut/data/dataset/testCSV/number.csv
+1
-0
tests/ut/data/dataset/testCSV/quoted.csv
tests/ut/data/dataset/testCSV/quoted.csv
+1
-0
tests/ut/data/dataset/testCSV/separated.csv
tests/ut/data/dataset/testCSV/separated.csv
+1
-0
tests/ut/data/dataset/testCSV/size.csv
tests/ut/data/dataset/testCSV/size.csv
+10
-0
tests/ut/python/dataset/test_datasets_csv.py
tests/ut/python/dataset/test_datasets_csv.py
+238
-0
未找到文件。
mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc
浏览文件 @
b091f74c
...
...
@@ -31,6 +31,7 @@
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
...
...
@@ -88,6 +89,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {
{
kBuildVocab
,
&
DEPipeline
::
ParseBuildVocabOp
},
{
kClue
,
&
DEPipeline
::
ParseClueOp
},
{
kEpochCtrl
,
&
DEPipeline
::
ParseEpochCtrlOp
},
{
kCsv
,
&
DEPipeline
::
ParseCsvOp
},
{
kSentencePieceVocab
,
&
DEPipeline
::
ParseBuildSentencePieceVocabOp
}};
DEPipeline
::
DEPipeline
()
:
iterator_
(
nullptr
)
{
...
...
@@ -1848,6 +1850,86 @@ Status DEPipeline::AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num
return
Status
::
OK
();
}
Status
DEPipeline
::
ParseCsvOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
)
{
std
::
vector
<
std
::
string
>
files_list
;
std
::
shared_ptr
<
CsvOp
::
Builder
>
builder
=
std
::
make_shared
<
CsvOp
::
Builder
>
();
if
(
!
args
[
"dataset_files"
].
is_none
())
{
files_list
=
ToStringVector
(
args
[
"dataset_files"
]);
(
void
)
builder
->
SetCsvFilesList
(
files_list
);
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Error: dataset_files is missing"
);
}
// Optional arguments
bool
shuffle_required
=
false
;
int64_t
num_devices
=
0
;
std
::
vector
<
std
::
string
>
col_names
;
for
(
auto
arg
:
args
)
{
std
::
string
key
=
py
::
str
(
arg
.
first
);
py
::
handle
value
=
arg
.
second
;
if
(
!
value
.
is_none
())
{
if
(
key
==
"num_parallel_workers"
)
{
(
void
)
builder
->
SetNumWorkers
(
ToInt
(
value
));
}
else
if
(
key
==
"shuffle_files"
)
{
(
void
)
builder
->
SetShuffleFiles
(
ToBool
(
value
));
}
else
if
(
key
==
"shuffle_global"
)
{
shuffle_required
=
ToBool
(
value
);
}
else
if
(
key
==
"num_samples"
)
{
(
void
)
builder
->
SetNumSamples
(
ToInt
(
value
));
}
else
if
(
key
==
"num_shards"
)
{
num_devices
=
ToInt
(
value
);
(
void
)
builder
->
SetNumDevices
(
num_devices
);
}
else
if
(
key
==
"shard_id"
)
{
(
void
)
builder
->
SetDeviceId
(
ToInt
(
value
));
}
else
if
(
key
==
"field_delim"
)
{
(
void
)
builder
->
SetFieldDelim
(
ToString
(
value
)[
0
]);
}
else
if
(
key
==
"column_defaults"
)
{
py
::
list
py_object_list
=
py
::
reinterpret_borrow
<
py
::
list
>
(
value
);
std
::
vector
<
std
::
shared_ptr
<
CsvOp
::
BaseRecord
>>
column_default_list
;
for
(
auto
l
:
py_object_list
)
{
std
::
string
type_s
=
(
std
::
string
)
py
::
str
(
l
.
get_type
().
attr
(
"__name__"
));
if
(
type_s
==
"int"
)
{
column_default_list
.
push_back
(
std
::
make_shared
<
CsvOp
::
Record
<
int
>>
(
CsvOp
::
INT
,
ToInt
(
l
)));
}
else
if
(
type_s
==
"float"
)
{
column_default_list
.
push_back
(
std
::
make_shared
<
CsvOp
::
Record
<
float
>>
(
CsvOp
::
FLOAT
,
ToFloat
(
l
)));
}
else
if
(
type_s
==
"str"
)
{
column_default_list
.
push_back
(
std
::
make_shared
<
CsvOp
::
Record
<
std
::
string
>>
(
CsvOp
::
STRING
,
ToString
(
l
)));
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Record type is not allowed"
);
}
}
(
void
)
builder
->
SetColumDefault
(
column_default_list
);
}
else
if
(
key
==
"column_names"
)
{
col_names
=
ToStringVector
(
value
);
(
void
)
builder
->
SetColumName
(
col_names
);
}
}
}
std
::
shared_ptr
<
CsvOp
>
csv_op
;
RETURN_IF_NOT_OK
(
builder
->
Build
(
&
csv_op
));
RETURN_IF_NOT_OK
(
tree_
->
AssociateNode
(
csv_op
));
*
top
=
csv_op
;
if
(
shuffle_required
)
{
std
::
shared_ptr
<
DatasetOp
>
shuffle_op
=
nullptr
;
int64_t
shuffle_size
=
0
;
int64_t
num_rows
=
0
;
// First, get the number of rows in the dataset and then compute the shuffle size
RETURN_IF_NOT_OK
(
CsvOp
::
CountAllFileRows
(
files_list
,
col_names
.
empty
(),
&
num_rows
));
RETURN_IF_NOT_OK
(
ComputeShuffleSize
(
files_list
.
size
(),
num_devices
,
num_rows
,
0
,
&
shuffle_size
));
// Add the shuffle op over top of this op and return the subtree (top/bottom) to caller
RETURN_IF_NOT_OK
(
AddShuffleOp
(
shuffle_size
,
csv_op
,
&
shuffle_op
));
*
top
=
shuffle_op
;
*
bottom
=
csv_op
;
}
return
Status
::
OK
();
}
// Helper function to inject a shuffle operator over top of the current operation being built.
Status
DEPipeline
::
AddShuffleOp
(
int64_t
shuffle_size
,
std
::
shared_ptr
<
DatasetOp
>
input_op
,
std
::
shared_ptr
<
DatasetOp
>
*
shuffle_op
)
{
...
...
mindspore/ccsrc/minddata/dataset/api/de_pipeline.h
浏览文件 @
b091f74c
...
...
@@ -73,6 +73,7 @@ enum OpName {
kClue
,
kEpochCtrl
,
kSentencePieceVocab
,
kCsv
};
// The C++ binder class that we expose to the python script.
...
...
@@ -201,6 +202,8 @@ class DEPipeline {
Status
ParseClueOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseCsvOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
private:
// Execution tree that links the dataset operators.
std
::
shared_ptr
<
ExecutionTree
>
tree_
;
...
...
mindspore/ccsrc/minddata/dataset/api/python_bindings.cc
浏览文件 @
b091f74c
...
...
@@ -19,6 +19,7 @@
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
...
...
@@ -277,6 +278,17 @@ void bindDatasetOps(py::module *m) {
return
count
;
});
(
void
)
py
::
class_
<
CsvOp
,
DatasetOp
,
std
::
shared_ptr
<
CsvOp
>>
(
*
m
,
"CsvOp"
)
.
def_static
(
"get_num_rows"
,
[](
const
py
::
list
&
files
,
bool
csv_header
)
{
int64_t
count
=
0
;
std
::
vector
<
std
::
string
>
filenames
;
for
(
auto
file
:
files
)
{
file
.
is_none
()
?
(
void
)
filenames
.
emplace_back
(
""
)
:
filenames
.
push_back
(
py
::
str
(
file
));
}
THROW_IF_ERROR
(
CsvOp
::
CountAllFileRows
(
filenames
,
csv_header
,
&
count
));
return
count
;
});
(
void
)
py
::
class_
<
VOCOp
,
DatasetOp
,
std
::
shared_ptr
<
VOCOp
>>
(
*
m
,
"VOCOp"
)
.
def_static
(
"get_num_rows"
,
[](
const
std
::
string
&
dir
,
const
std
::
string
&
task_type
,
const
std
::
string
&
task_mode
,
...
...
@@ -1039,8 +1051,9 @@ PYBIND11_MODULE(_c_dataengine, m) {
.
value
(
"SENTENCEPIECEVOCAB"
,
OpName
::
kSentencePieceVocab
)
.
value
(
"CELEBA"
,
OpName
::
kCelebA
)
.
value
(
"TEXTFILE"
,
OpName
::
kTextFile
)
.
value
(
"CLUE"
,
OpName
::
kClue
)
.
value
(
"EPOCHCTRL"
,
OpName
::
kEpochCtrl
);
.
value
(
"EPOCHCTRL"
,
OpName
::
kEpochCtrl
)
.
value
(
"CSV"
,
OpName
::
kCsv
)
.
value
(
"CLUE"
,
OpName
::
kClue
);
(
void
)
py
::
enum_
<
JiebaMode
>
(
m
,
"JiebaMode"
,
py
::
arithmetic
())
.
value
(
"DE_JIEBA_MIX"
,
JiebaMode
::
kMix
)
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt
浏览文件 @
b091f74c
...
...
@@ -12,6 +12,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
celeba_op.cc
text_file_op.cc
clue_op.cc
csv_op.cc
)
set
(
DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
...
...
@@ -29,4 +30,4 @@ if (ENABLE_PYTHON)
)
endif
()
add_library
(
engine-datasetops-source OBJECT
${
DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
}
)
\ No newline at end of file
add_library
(
engine-datasetops-source OBJECT
${
DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
}
)
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc
0 → 100644
浏览文件 @
b091f74c
此差异已折叠。
点击以展开。
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.h
0 → 100644
浏览文件 @
b091f74c
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CSV_OP_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_CSV_OP_H_
#include <string>
#include <vector>
#include <memory>
#include <map>
#include <utility>
#include <limits>
#include "minddata/dataset/util/auto_index.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
namespace
mindspore
{
namespace
dataset
{
const
size_t
CSV_BUFFER_SIZE
=
4096
;
using
StringIndex
=
AutoIndexObj
<
std
::
string
>
;
class
JaggedConnector
;
class
CsvOp
:
public
ParallelOp
{
public:
enum
RecordType
:
uint8_t
{
INT
=
0
,
FLOAT
,
STRING
};
struct
BaseRecord
{
public:
BaseRecord
()
=
default
;
explicit
BaseRecord
(
RecordType
t
)
:
type
(
t
)
{}
virtual
~
BaseRecord
()
{}
RecordType
type
;
};
template
<
typename
T
>
class
Record
:
public
BaseRecord
{
public:
Record
()
=
default
;
Record
(
RecordType
t
,
T
v
)
:
BaseRecord
(
t
),
value
(
v
)
{}
~
Record
()
{}
T
value
;
};
// CsvParser is a class that parsing CSV file.
// We design a state machine to implement CSV syntactic analysis. It contains two state diagram,'sd' and 'sdl'.
// The 'sd' is used for parsing CSV syntactic, it's complete and complicate.
// The 'sdl' is used for counting the record rows, it's concise and it runs fast.
struct
CsvParser
{
public:
CsvParser
()
=
delete
;
CsvParser
(
int32_t
worker_id
,
std
::
shared_ptr
<
JaggedConnector
>
connector
,
int64_t
rows_per_buffer
,
char
field_delim
,
std
::
vector
<
std
::
shared_ptr
<
CsvOp
::
BaseRecord
>>
column_default
)
:
worker_id_
(
worker_id
),
buffer_connector_
(
connector
),
csv_rows_per_buffer_
(
rows_per_buffer
),
csv_field_delim_
(
field_delim
),
column_default_
(
column_default
),
cur_state_
(
START_OF_FILE
),
pos_
(
0
),
cur_row_
(
0
),
cur_col_
(
0
),
total_rows_
(
0
),
start_offset_
(
0
),
end_offset_
(
std
::
numeric_limits
<
int64_t
>::
max
())
{
cur_buffer_
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
BufferFlags
::
kDeBFlagNone
);
initCsvParser
();
}
~
CsvParser
()
=
default
;
void
Reset
()
{
cur_state_
=
START_OF_FILE
;
pos_
=
0
;
cur_row_
=
0
;
cur_col_
=
0
;
}
void
setStartOffset
(
int64_t
start_offset
)
{
start_offset_
=
start_offset
;
}
void
setEndOffset
(
int64_t
end_offset
)
{
end_offset_
=
end_offset
;
}
int
processMessage
(
char
c
)
{
Message
m
=
getMessage
(
c
);
StateDiagram
::
iterator
it
=
sd
.
find
({
cur_state_
,
m
});
if
(
it
==
sd
.
end
())
{
return
-
1
;
}
cur_state_
=
it
->
second
.
first
;
return
it
->
second
.
second
(
*
this
,
c
);
}
int
countRows
(
char
c
);
Status
initCsvParser
();
enum
State
:
uint8_t
{
START_OF_FILE
=
0
,
UNQUOTE
,
DELIM
,
QUOTE
,
SECOND_QUOTE
,
END_OF_LINE
,
END_OF_FILE
,
EXCEPTION
};
enum
Message
:
uint8_t
{
MS_NORMAL
=
0
,
MS_DELIM
,
MS_QUOTE
,
MS_END_OF_LINE
,
MS_END_OF_FILE
,
};
typedef
std
::
pair
<
State
,
Message
>
StateMessagePair
;
typedef
std
::
pair
<
State
,
std
::
function
<
int
(
CsvParser
&
,
char
)
>>
StateActionPair
;
typedef
std
::
map
<
StateMessagePair
,
StateActionPair
>
StateDiagram
;
Message
getMessage
(
char
c
)
{
if
(
c
==
csv_field_delim_
)
{
return
Message
::
MS_DELIM
;
}
else
if
(
c
==
'"'
)
{
return
Message
::
MS_QUOTE
;
}
else
if
(
c
==
'\r'
||
c
==
'\n'
)
{
return
Message
::
MS_END_OF_LINE
;
}
else
if
(
c
==
std
::
char_traits
<
char
>::
eof
())
{
return
Message
::
MS_END_OF_FILE
;
}
else
{
return
Message
::
MS_NORMAL
;
}
}
int
null_func
(
char
c
)
{
return
0
;
}
int
put_char
(
char
c
)
{
if
(
pos_
>=
str_buf_
.
size
())
{
str_buf_
.
resize
(
str_buf_
.
size
()
*
2
);
}
str_buf_
[
pos_
]
=
c
;
pos_
++
;
return
0
;
}
int
put_record
(
char
c
);
int
put_row
(
char
c
);
int
end_file
(
char
c
);
int
add_row
(
char
c
)
{
total_rows_
++
;
return
0
;
}
int
catch_exception
(
char
c
)
{
MS_LOG
(
ERROR
)
<<
"Invalid syntax!"
;
return
-
1
;
}
int32_t
worker_id_
;
std
::
shared_ptr
<
JaggedConnector
>
buffer_connector_
;
int64_t
csv_rows_per_buffer_
;
const
char
csv_field_delim_
;
std
::
vector
<
std
::
shared_ptr
<
CsvOp
::
BaseRecord
>>
column_default_
;
State
cur_state_
;
size_t
pos_
;
int
cur_row_
;
int
cur_col_
;
int64_t
total_rows_
;
int64_t
start_offset_
;
int64_t
end_offset_
;
StateDiagram
sd
;
StateDiagram
sdl
;
std
::
vector
<
char
>
str_buf_
;
std
::
unique_ptr
<
TensorQTable
>
tensor_table_
;
std
::
unique_ptr
<
DataBuffer
>
cur_buffer_
;
};
class
Builder
{
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @return This is a constructor.
Builder
();
// Default destructor
~
Builder
()
=
default
;
// Checks if the inputs of the builder is valid.
// @return Status - the error code returned.
Status
ValidateInputs
()
const
;
// Create the final object.
// @param op - dataset op.
// @return - the error code return.
Status
Build
(
std
::
shared_ptr
<
CsvOp
>
*
op
);
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetNumWorkers
(
int32_t
num_workers
)
{
builder_num_workers_
=
num_workers
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetOpConnectorSize
(
int32_t
op_connector_size
)
{
builder_op_connector_size_
=
op_connector_size
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetRowsPerBuffer
(
int64_t
rows_per_buffer
)
{
builder_rows_per_buffer_
=
rows_per_buffer
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetNumDevices
(
int64_t
num_dev
)
{
builder_num_devices_
=
num_dev
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetDeviceId
(
int64_t
dev_id
)
{
builder_device_id_
=
dev_id
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetCsvFilesList
(
const
std
::
vector
<
std
::
string
>
&
files_list
)
{
builder_csv_files_list_
=
files_list
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetShuffleFiles
(
bool
shuffle_files
)
{
builder_shuffle_files_
=
shuffle_files
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetNumSamples
(
int64_t
num_samples
)
{
builder_num_samples_
=
num_samples
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetFieldDelim
(
char
field_delim
)
{
builder_field_delim_
=
field_delim
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetColumDefault
(
std
::
vector
<
std
::
shared_ptr
<
CsvOp
::
BaseRecord
>>
record_list
)
{
builder_column_default_list_
=
record_list
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetColumName
(
std
::
vector
<
std
::
string
>
col_name_list
)
{
builder_column_name_list_
=
col_name_list
;
return
*
this
;
}
private:
int32_t
builder_device_id_
;
int32_t
builder_num_devices_
;
int32_t
builder_num_workers_
;
int32_t
builder_op_connector_size_
;
int64_t
builder_rows_per_buffer_
;
int64_t
builder_num_samples_
;
int32_t
builder_worker_connector_size_
;
std
::
vector
<
std
::
string
>
builder_csv_files_list_
;
bool
builder_shuffle_files_
;
char
builder_field_delim_
;
std
::
vector
<
std
::
shared_ptr
<
CsvOp
::
BaseRecord
>>
builder_column_default_list_
;
std
::
vector
<
std
::
string
>
builder_column_name_list_
;
};
// Constructor of CsvOp
CsvOp
()
=
delete
;
CsvOp
(
const
std
::
vector
<
std
::
string
>
&
csv_files_list
,
char
field_delim
,
const
std
::
vector
<
std
::
shared_ptr
<
BaseRecord
>>
&
column_default
,
const
std
::
vector
<
std
::
string
>
&
column_name
,
int32_t
num_workers
,
int64_t
rows_per_buffer
,
int64_t
num_samples
,
int32_t
worker_connector_size
,
int32_t
op_connector_size
,
bool
shuffle_files
,
int32_t
num_devices
,
int32_t
device_id
);
// Default destructor
~
CsvOp
()
=
default
;
// A print method typically used for debugging
// @param out - The output stream to write output to
// @param show_all - A bool to control if you want to show all info or just a summary
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
// Instantiates the internal queues and connectors
// @return Status - the error code returned
Status
Init
();
// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status
operator
()()
override
;
// Overrides base class reset method. Cleans up any state info from it's previous execution
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status
Reset
()
override
;
// Get total rows in files.
// @param files - all csv files.
// @param csv_header - a bool that indicates csv file include header line
// @param count - number of rows.
// @return Status - the error coed returned.
static
Status
CountAllFileRows
(
const
std
::
vector
<
std
::
string
>
&
files
,
bool
csv_header
,
int64_t
*
count
);
// File names getter
// @return Vector of the input file names
std
::
vector
<
std
::
string
>
FileNames
()
{
return
csv_files_list_
;
}
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status
WorkerEntry
(
int32_t
worker_id
)
override
;
// Parses a single row and puts the data into a tensor table.
// @param line - the content of the row.
// @param tensor_table - the tensor table to put the parsed data in.
// @param row - the id of the row filled in the tensor table.
// @return Status - the error code returned.
Status
LoadTensor
(
const
std
::
string
&
line
,
std
::
unique_ptr
<
TensorQTable
>
*
tensor_table
,
int64_t
row
);
// Reads a csv file and loads the data into multiple buffers.
// @param file - the file to read.
// @param start_offset - the start offset of file.
// @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status
LoadFile
(
const
std
::
string
&
file
,
const
int64_t
start_offset
,
const
int64_t
end_offset
,
const
int32_t
worker_id
);
// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status
PopIoBlockQueue
(
int32_t
index
,
std
::
unique_ptr
<
FilenameBlock
>
*
out_block
);
// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status
PushIoBlockQueue
(
int32_t
index
,
std
::
unique_ptr
<
FilenameBlock
>
&&
io_block
);
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status
WaitToFillIOBlockQueue
();
// Fill the IOBlockQueue.
// @para i_keys - keys of file to fill to the IOBlockQueue
// @return Status - the error code returned.
Status
FillIOBlockQueue
(
const
std
::
vector
<
int64_t
>
&
i_keys
);
// Notifies the thread which called FillIoBlockQueue to resume execution
void
NotifyToFillIOBlockQueue
();
// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_file - If file contains the first sample of data.
// @param end_file - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool
NeedPushFileToBlockQueue
(
const
std
::
string
&
file_name
,
int64_t
*
start_offset
,
int64_t
*
end_offset
,
const
int64_t
&
pre_count
);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status
PostEndOfEpoch
(
int32_t
queue_index
);
// Calculate number of rows in each shard.
// @return Status - the error code returned.
Status
CalculateNumRowsPerShard
();
// Count number of rows in each file.
// @param filename - csv file name.
// @return int64_t - the total number of rows in file.
int64_t
CountTotalRows
(
const
std
::
string
&
file
);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status
PostEndOfData
();
// Private function for computing the assignment of the column name map.
// @return - Status
Status
ComputeColMap
()
override
;
// Split string based on a character delimiter
// @return - the a string vector
std
::
vector
<
std
::
string
>
split
(
const
std
::
string
&
s
,
char
delim
);
int32_t
device_id_
;
bool
shuffle_files_
;
bool
finished_reading_dataset_
;
int32_t
num_devices_
;
int64_t
rows_per_buffer_
;
bool
load_io_block_queue_
;
int64_t
num_rows_per_shard_
;
int64_t
all_num_rows_
;
int64_t
num_samples_
;
std
::
map
<
std
::
string
,
int64_t
>
filename_numrows_
;
std
::
unique_ptr
<
StringIndex
>
filename_index_
;
std
::
vector
<
std
::
string
>
csv_files_list_
;
WaitPost
io_block_queue_wait_post_
;
std
::
shared_ptr
<
JaggedConnector
>
jagged_buffer_connector_
;
QueueList
<
std
::
unique_ptr
<
FilenameBlock
>>
io_block_queues_
;
bool
load_jagged_connector_
;
char
field_delim_
;
std
::
vector
<
std
::
shared_ptr
<
CsvOp
::
BaseRecord
>>
column_default_list_
;
std
::
vector
<
std
::
string
>
column_name_list_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CSV_OP_H_
mindspore/dataset/__init__.py
浏览文件 @
b091f74c
...
...
@@ -21,7 +21,7 @@ can also create samplers with this module to sample data.
from
.core
import
config
from
.engine.datasets
import
TFRecordDataset
,
ImageFolderDatasetV2
,
MnistDataset
,
MindDataset
,
NumpySlicesDataset
,
\
GeneratorDataset
,
ManifestDataset
,
Cifar10Dataset
,
Cifar100Dataset
,
VOCDataset
,
CocoDataset
,
CelebADataset
,
\
TextFileDataset
,
CLUEDataset
,
Schema
,
Shuffle
,
zip
,
RandomDataset
TextFileDataset
,
CLUEDataset
,
CSVDataset
,
Schema
,
Shuffle
,
zip
,
RandomDataset
from
.engine.samplers
import
DistributedSampler
,
PKSampler
,
RandomSampler
,
SequentialSampler
,
SubsetRandomSampler
,
\
WeightedRandomSampler
,
Sampler
from
.engine.cache_client
import
DatasetCache
...
...
@@ -31,5 +31,5 @@ from .engine.graphdata import GraphData
__all__
=
[
"config"
,
"ImageFolderDatasetV2"
,
"MnistDataset"
,
"MindDataset"
,
"GeneratorDataset"
,
"TFRecordDataset"
,
"ManifestDataset"
,
"Cifar10Dataset"
,
"Cifar100Dataset"
,
"CelebADataset"
,
"NumpySlicesDataset"
,
"VOCDataset"
,
"CocoDataset"
,
"TextFileDataset"
,
"CLUEDataset"
,
"Schema"
,
"DistributedSampler"
,
"PKSampler"
,
"CocoDataset"
,
"TextFileDataset"
,
"CLUEDataset"
,
"
CSVDataset"
,
"
Schema"
,
"DistributedSampler"
,
"PKSampler"
,
"RandomSampler"
,
"SequentialSampler"
,
"SubsetRandomSampler"
,
"WeightedRandomSampler"
,
"zip"
,
"GraphData"
]
mindspore/dataset/engine/__init__.py
浏览文件 @
b091f74c
...
...
@@ -29,7 +29,7 @@ from .samplers import *
from
..core
import
config
__all__
=
[
"config"
,
"zip"
,
"ImageFolderDatasetV2"
,
"MnistDataset"
,
"MindDataset"
,
"GeneratorDataset"
,
"TFRecordDataset"
,
"CLUEDataset"
,
"MindDataset"
,
"GeneratorDataset"
,
"TFRecordDataset"
,
"CLUEDataset"
,
"CSVDataset"
,
"ManifestDataset"
,
"Cifar10Dataset"
,
"Cifar100Dataset"
,
"CelebADataset"
,
"VOCDataset"
,
"CocoDataset"
,
"TextFileDataset"
,
"Schema"
,
"DistributedSampler"
,
"PKSampler"
,
"RandomSampler"
,
"SequentialSampler"
,
"SubsetRandomSampler"
,
"WeightedRandomSampler"
]
mindspore/dataset/engine/datasets.py
浏览文件 @
b091f74c
...
...
@@ -33,7 +33,7 @@ import copy
import
numpy
as
np
from
mindspore._c_dataengine
import
DataType
,
TFReaderOp
,
ImageFolderOp
,
CifarOp
,
MnistOp
,
ManifestOp
,
\
MindRecordOp
,
TextFileOp
,
ClueOp
,
VOCOp
,
CocoOp
,
CBatchInfo
MindRecordOp
,
TextFileOp
,
ClueOp
,
CsvOp
,
VOCOp
,
CocoOp
,
CBatchInfo
from
mindspore._c_expression
import
typing
from
mindspore
import
log
as
logger
...
...
@@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_take
,
check_project
,
check_imagefolderdatasetv2
,
check_mnist_cifar_dataset
,
check_manifestdataset
,
\
check_tfrecorddataset
,
check_vocdataset
,
check_cocodataset
,
check_celebadataset
,
check_minddataset
,
\
check_generatordataset
,
check_sync_wait
,
check_zip_dataset
,
check_add_column
,
check_textfiledataset
,
check_concat
,
\
check_random_dataset
,
check_split
,
check_bucket_batch_by_length
,
check_cluedataset
,
check_save
check_random_dataset
,
check_split
,
check_bucket_batch_by_length
,
check_cluedataset
,
check_save
,
check_csvdataset
from
..core.datatypes
import
mstype_to_detype
,
mstypelist_to_detypelist
from
..text.utils
import
DE_C_INTER_SENTENCEPIECE_MODE
...
...
@@ -1012,7 +1012,7 @@ class Dataset:
if
isinstance
(
sampler
,
samplers
.
DistributedSampler
):
dev_id
=
sampler
.
shard_id
return
""
,
dev_id
if
isinstance
(
output_dataset
,
(
TFRecordDataset
,
TextFileDataset
,
CLUEDataset
)):
if
isinstance
(
output_dataset
,
(
TFRecordDataset
,
TextFileDataset
,
CLUEDataset
,
CSVDataset
)):
if
output_dataset
.
shard_id
is
not
None
:
dev_id
=
output_dataset
.
shard_id
return
""
,
dev_id
...
...
@@ -4652,8 +4652,8 @@ class CLUEDataset(SourceDataset):
}
Args:
dataset_files (str or
list[str]): String or list of files to be read or glob strings to search for a pattern of
files. The list will be sorted in a lexicographical order.
dataset_files (str or
a list of strings): String or list of files to be read or glob strings to search for
a pattern of
files. The list will be sorted in a lexicographical order.
task (str, optional): The kind of task, one of 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' and 'CSL'.
(default=AFQMC).
usage (str, optional): Need train, test or eval data (default="train").
...
...
@@ -4860,6 +4860,108 @@ class CLUEDataset(SourceDataset):
return
False
class
CSVDataset
(
SourceDataset
):
"""
A source dataset that reads and parses CSV datasets.
Args:
dataset_files (str or a list of strings): String or list of files to be read or glob strings to search
for a pattern of files. The list will be sorted in a lexicographical order.
field_delim (str, optional): A string that indicates the char delimiter to separate fields (default=',').
column_defaults (list, optional): List of default values for the CSV field (default=None). Each item
in the list is either a valid type (float, int, or string). If this is not provided, treats all
columns as string type.
column_names (list of string, optional): List of column names of the dataset (default=None). If this
is not provided, infers the column_names from the first row of CSV file.
num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
num_parallel_workers (int, optional): number of workers to read the data
(default=None, number set in the config).
shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
If shuffle is False, no shuffling will be performed;
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
Otherwise, there are two levels of shuffling:
- Shuffle.GLOBAL: Shuffle both the files and samples.
- Shuffle.FILES: Shuffle files only.
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument should be specified only when num_shards is also specified.
Examples:
>>> import mindspore.dataset as ds
>>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files
>>> dataset = ds.CSVDataset(dataset_files=dataset_files, column_names=['col1', 'col2', 'col3', 'col4'])
"""
@
check_csvdataset
def
__init__
(
self
,
dataset_files
,
field_delim
=
','
,
column_defaults
=
None
,
column_names
=
None
,
num_samples
=
None
,
num_parallel_workers
=
None
,
shuffle
=
Shuffle
.
GLOBAL
,
num_shards
=
None
,
shard_id
=
None
):
super
().
__init__
(
num_parallel_workers
)
self
.
dataset_files
=
self
.
_find_files
(
dataset_files
)
self
.
dataset_files
.
sort
()
self
.
field_delim
=
field_delim
self
.
column_defaults
=
column_defaults
self
.
column_names
=
column_names
self
.
num_samples
=
num_samples
if
not
isinstance
(
shuffle
,
(
bool
,
Shuffle
)):
raise
TypeError
(
"shuffle should be of boolean or enum 'Shuffle'."
)
if
not
isinstance
(
shuffle
,
Shuffle
):
if
shuffle
:
self
.
shuffle_level
=
Shuffle
.
GLOBAL
self
.
shuffle_files
=
True
else
:
self
.
shuffle_level
=
None
self
.
shuffle_files
=
False
else
:
self
.
shuffle_level
=
shuffle
self
.
shuffle_files
=
True
self
.
num_shards
=
num_shards
self
.
shard_id
=
shard_id
def
get_args
(
self
):
args
=
super
().
get_args
()
args
[
"dataset_files"
]
=
self
.
dataset_files
args
[
'field_delim'
]
=
self
.
field_delim
args
[
'column_defaults'
]
=
self
.
column_defaults
args
[
'column_names'
]
=
self
.
column_names
args
[
"num_samples"
]
=
self
.
num_samples
if
self
.
shuffle_files
is
not
None
:
args
[
"shuffle_files"
]
=
self
.
shuffle_files
args
[
"shuffle_global"
]
=
(
self
.
shuffle_level
==
Shuffle
.
GLOBAL
)
args
[
"shuffle"
]
=
self
.
shuffle_level
args
[
"num_shards"
]
=
self
.
num_shards
args
[
"shard_id"
]
=
self
.
shard_id
return
args
def
get_dataset_size
(
self
):
"""
Get the number of batches in an epoch.
Return:
Number, number of batches.
"""
if
self
.
_dataset_size
is
None
:
num_rows
=
CsvOp
.
get_num_rows
(
self
.
dataset_files
,
self
.
column_names
is
None
)
num_rows
=
get_num_rows
(
num_rows
,
self
.
num_shards
)
if
self
.
num_samples
is
None
:
return
num_rows
return
min
(
self
.
num_samples
,
num_rows
)
return
self
.
_dataset_size
def
is_shuffled
(
self
):
return
self
.
shuffle_files
def
is_sharded
(
self
):
if
self
.
num_shards
is
not
None
:
return
self
.
num_shards
>
1
return
False
class
TextFileDataset
(
SourceDataset
):
"""
A source dataset that reads and parses datasets stored on disk in text format.
...
...
mindspore/dataset/engine/iterators.py
浏览文件 @
b091f74c
...
...
@@ -185,6 +185,8 @@ class Iterator:
op_type
=
OpName
.
SENTENCEPIECEVOCAB
elif
isinstance
(
dataset
,
de
.
CLUEDataset
):
op_type
=
OpName
.
CLUE
elif
isinstance
(
dataset
,
de
.
CSVDataset
):
op_type
=
OpName
.
CSV
else
:
raise
ValueError
(
"Unsupported DatasetOp"
)
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
b091f74c
...
...
@@ -787,6 +787,49 @@ def check_cluedataset(method):
return
new_method
def
check_csvdataset
(
method
):
"""A wrapper that wrap a parameter checker to the original Dataset(CSVDataset)."""
@
wraps
(
method
)
def
new_method
(
self
,
*
args
,
**
kwargs
):
_
,
param_dict
=
parse_user_args
(
method
,
*
args
,
**
kwargs
)
nreq_param_int
=
[
'num_samples'
,
'num_parallel_workers'
,
'num_shards'
,
'shard_id'
]
# check dataset_files; required argument
dataset_files
=
param_dict
.
get
(
'dataset_files'
)
type_check
(
dataset_files
,
(
str
,
list
),
"dataset files"
)
# check field_delim
field_delim
=
param_dict
.
get
(
'field_delim'
)
type_check
(
field_delim
,
(
str
,),
'field delim'
)
if
field_delim
in
[
'"'
,
'
\r
'
,
'
\n
'
]
or
len
(
field_delim
)
>
1
:
raise
ValueError
(
"field_delim is not legal."
)
# check column_defaults
column_defaults
=
param_dict
.
get
(
'column_defaults'
)
if
column_defaults
is
not
None
:
if
not
isinstance
(
column_defaults
,
list
):
raise
TypeError
(
"column_defaults should be type of list."
)
for
item
in
column_defaults
:
if
not
isinstance
(
item
,
(
str
,
int
,
float
)):
raise
TypeError
(
"column type is not legal."
)
# check column_names: must be list of string.
column_names
=
param_dict
.
get
(
"column_names"
)
if
column_names
is
not
None
:
all_string
=
all
(
isinstance
(
item
,
str
)
for
item
in
column_names
)
if
not
all_string
:
raise
TypeError
(
"column_names should be a list of str."
)
validate_dataset_param_value
(
nreq_param_int
,
param_dict
,
int
)
check_sampler_shuffle_shard_options
(
param_dict
)
return
method
(
self
,
*
args
,
**
kwargs
)
return
new_method
def
check_textfiledataset
(
method
):
"""A wrapper that wraps a parameter checker to the original Dataset(TextFileDataset)."""
...
...
tests/ut/cpp/dataset/CMakeLists.txt
浏览文件 @
b091f74c
...
...
@@ -77,6 +77,7 @@ SET(DE_UT_SRCS
celeba_op_test.cc
take_op_test.cc
clue_op_test.cc
csv_op_test.cc
text_file_op_test.cc
filter_op_test.cc
concat_op_test.cc
...
...
tests/ut/cpp/dataset/csv_op_test.cc
0 → 100644
浏览文件 @
b091f74c
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include <vector>
#include "minddata/dataset/core/client.h"
#include "common/common.h"
#include "common/utils.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/util/status.h"
namespace
common
=
mindspore
::
common
;
using
namespace
mindspore
::
dataset
;
using
mindspore
::
MsLogLevel
::
INFO
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
using
mindspore
::
LogStream
;
class
MindDataTestCSVOp
:
public
UT
::
DatasetOpTesting
{
};
TEST_F
(
MindDataTestCSVOp
,
TestCSVBasic
)
{
// Start with an empty execution tree
auto
tree
=
std
::
make_shared
<
ExecutionTree
>
();
std
::
string
dataset_path
;
dataset_path
=
datasets_root_path_
+
"/testCSV/1.csv"
;
std
::
vector
<
std
::
shared_ptr
<
CsvOp
::
BaseRecord
>>
column_default_list
;
column_default_list
.
push_back
(
std
::
make_shared
<
CsvOp
::
Record
<
int
>>
(
CsvOp
::
INT
,
0
));
column_default_list
.
push_back
(
std
::
make_shared
<
CsvOp
::
Record
<
int
>>
(
CsvOp
::
INT
,
0
));
column_default_list
.
push_back
(
std
::
make_shared
<
CsvOp
::
Record
<
int
>>
(
CsvOp
::
INT
,
0
));
column_default_list
.
push_back
(
std
::
make_shared
<
CsvOp
::
Record
<
int
>>
(
CsvOp
::
INT
,
0
));
std
::
shared_ptr
<
CsvOp
>
op
;
CsvOp
::
Builder
builder
;
builder
.
SetCsvFilesList
({
dataset_path
})
.
SetRowsPerBuffer
(
16
)
.
SetNumWorkers
(
16
)
.
SetShuffleFiles
(
false
)
.
SetOpConnectorSize
(
2
)
.
SetFieldDelim
(
','
)
.
SetColumDefault
(
column_default_list
)
.
SetColumName
({
"col1"
,
"col2"
,
"col3"
,
"col4"
});
Status
rc
=
builder
.
Build
(
&
op
);
ASSERT_TRUE
(
rc
.
IsOk
());
rc
=
tree
->
AssociateNode
(
op
);
ASSERT_TRUE
(
rc
.
IsOk
());
rc
=
tree
->
AssignRoot
(
op
);
ASSERT_TRUE
(
rc
.
IsOk
());
MS_LOG
(
INFO
)
<<
"Launching tree and begin iteration."
;
rc
=
tree
->
Prepare
();
ASSERT_TRUE
(
rc
.
IsOk
());
rc
=
tree
->
Launch
();
ASSERT_TRUE
(
rc
.
IsOk
());
// Start the loop of reading tensors from our pipeline
DatasetIterator
di
(
tree
);
TensorRow
tensor_list
;
rc
=
di
.
FetchNextTensorRow
(
&
tensor_list
);
ASSERT_TRUE
(
rc
.
IsOk
());
int
row_count
=
0
;
while
(
!
tensor_list
.
empty
())
{
// Display the tensor by calling the printer on it
for
(
int
i
=
0
;
i
<
tensor_list
.
size
();
i
++
)
{
std
::
ostringstream
ss
;
ss
<<
"("
<<
tensor_list
[
i
]
<<
"): "
<<
*
tensor_list
[
i
]
<<
std
::
endl
;
MS_LOG
(
INFO
)
<<
"Tensor print: "
<<
ss
.
str
()
<<
"."
;
}
rc
=
di
.
FetchNextTensorRow
(
&
tensor_list
);
ASSERT_TRUE
(
rc
.
IsOk
());
row_count
++
;
}
ASSERT_EQ
(
row_count
,
3
);
}
TEST_F
(
MindDataTestCSVOp
,
TestTotalRows
)
{
std
::
string
csv_file1
=
datasets_root_path_
+
"/testCSV/1.csv"
;
std
::
string
csv_file2
=
datasets_root_path_
+
"/testCSV/size.csv"
;
std
::
vector
<
std
::
string
>
files
;
files
.
push_back
(
csv_file1
);
int64_t
total_rows
=
0
;
CsvOp
::
CountAllFileRows
(
files
,
false
,
&
total_rows
);
ASSERT_EQ
(
total_rows
,
3
);
files
.
clear
();
files
.
push_back
(
csv_file2
);
CsvOp
::
CountAllFileRows
(
files
,
false
,
&
total_rows
);
ASSERT_EQ
(
total_rows
,
5
);
files
.
clear
();
files
.
push_back
(
csv_file1
);
files
.
push_back
(
csv_file2
);
CsvOp
::
CountAllFileRows
(
files
,
false
,
&
total_rows
);
ASSERT_EQ
(
total_rows
,
8
);
files
.
clear
();
}
tests/ut/data/dataset/testCSV/1.csv
0 → 100644
浏览文件 @
b091f74c
1,2,3,4
5,6,7,8
9,10,11,12
tests/ut/data/dataset/testCSV/2.csv
0 → 100644
浏览文件 @
b091f74c
,"222",3,"4"""
"5",6,,"8"
9,10,"1""1",12
,,"",
,,,
a,b,c,""
a,b,c,d
tests/ut/data/dataset/testCSV/chinese.csv
0 → 100644
浏览文件 @
b091f74c
大家,早上好,中午好,下午好,晚上好
tests/ut/data/dataset/testCSV/embedded.csv
0 → 100644
浏览文件 @
b091f74c
"a,b","c""d","e
f"," g "
tests/ut/data/dataset/testCSV/exception.csv
0 → 100644
浏览文件 @
b091f74c
1,2,3,4
5,6,7,8
a,"c",d,"e
tests/ut/data/dataset/testCSV/header.csv
0 → 100644
浏览文件 @
b091f74c
col1,col2,col3,col4
a,b,c,d
\ No newline at end of file
tests/ut/data/dataset/testCSV/number.csv
0 → 100644
浏览文件 @
b091f74c
3,0.3,4,55.5
tests/ut/data/dataset/testCSV/quoted.csv
0 → 100644
浏览文件 @
b091f74c
"a","b","c","d"
tests/ut/data/dataset/testCSV/separated.csv
0 → 100644
浏览文件 @
b091f74c
a|b|c|d
tests/ut/data/dataset/testCSV/size.csv
0 → 100644
浏览文件 @
b091f74c
1,2,3,4
"a","b","c
","d
e"
5,6,7,8
9,10,11,12
a,"b
",c,"d
e"
tests/ut/python/dataset/test_datasets_csv.py
0 → 100644
浏览文件 @
b091f74c
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
mindspore.dataset
as
ds
import
numpy
as
np
import
pytest
DATA_FILE
=
'../data/dataset/testCSV/1.csv'
def
test_csv_dataset_basic
():
"""
Test CSV with repeat, skip and so on
"""
TRAIN_FILE
=
'../data/dataset/testCSV/1.csv'
buffer
=
[]
data
=
ds
.
CSVDataset
(
TRAIN_FILE
,
column_defaults
=
[
"0"
,
0
,
0.0
,
"0"
],
column_names
=
[
'1'
,
'2'
,
'3'
,
'4'
],
shuffle
=
False
)
data
=
data
.
repeat
(
2
)
data
=
data
.
skip
(
2
)
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
(
d
)
assert
len
(
buffer
)
==
4
def
test_csv_dataset_one_file
():
data
=
ds
.
CSVDataset
(
DATA_FILE
,
column_defaults
=
[
"1"
,
"2"
,
"3"
,
"4"
],
column_names
=
[
'col1'
,
'col2'
,
'col3'
,
'col4'
],
shuffle
=
False
)
buffer
=
[]
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
(
d
)
assert
len
(
buffer
)
==
3
def
test_csv_dataset_all_file
():
APPEND_FILE
=
'../data/dataset/testCSV/2.csv'
data
=
ds
.
CSVDataset
(
[
DATA_FILE
,
APPEND_FILE
],
column_defaults
=
[
"1"
,
"2"
,
"3"
,
"4"
],
column_names
=
[
'col1'
,
'col2'
,
'col3'
,
'col4'
],
shuffle
=
False
)
buffer
=
[]
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
(
d
)
assert
len
(
buffer
)
==
10
def
test_csv_dataset_num_samples
():
data
=
ds
.
CSVDataset
(
DATA_FILE
,
column_defaults
=
[
"1"
,
"2"
,
"3"
,
"4"
],
column_names
=
[
'col1'
,
'col2'
,
'col3'
,
'col4'
],
shuffle
=
False
,
num_samples
=
2
)
count
=
0
for
_
in
data
.
create_dict_iterator
():
count
+=
1
assert
count
==
2
def
test_csv_dataset_distribution
():
TEST_FILE
=
'../data/dataset/testCSV/1.csv'
data
=
ds
.
CSVDataset
(
TEST_FILE
,
column_defaults
=
[
"1"
,
"2"
,
"3"
,
"4"
],
column_names
=
[
'col1'
,
'col2'
,
'col3'
,
'col4'
],
shuffle
=
False
,
num_shards
=
2
,
shard_id
=
0
)
count
=
0
for
_
in
data
.
create_dict_iterator
():
count
+=
1
assert
count
==
2
def
test_csv_dataset_quoted
():
TEST_FILE
=
'../data/dataset/testCSV/quoted.csv'
data
=
ds
.
CSVDataset
(
TEST_FILE
,
column_defaults
=
[
""
,
""
,
""
,
""
],
column_names
=
[
'col1'
,
'col2'
,
'col3'
,
'col4'
],
shuffle
=
False
)
buffer
=
[]
for
d
in
data
.
create_dict_iterator
():
buffer
.
extend
([
d
[
'col1'
].
item
().
decode
(
"utf8"
),
d
[
'col2'
].
item
().
decode
(
"utf8"
),
d
[
'col3'
].
item
().
decode
(
"utf8"
),
d
[
'col4'
].
item
().
decode
(
"utf8"
)])
assert
buffer
==
[
'a'
,
'b'
,
'c'
,
'd'
]
def
test_csv_dataset_separated
():
TEST_FILE
=
'../data/dataset/testCSV/separated.csv'
data
=
ds
.
CSVDataset
(
TEST_FILE
,
field_delim
=
'|'
,
column_defaults
=
[
""
,
""
,
""
,
""
],
column_names
=
[
'col1'
,
'col2'
,
'col3'
,
'col4'
],
shuffle
=
False
)
buffer
=
[]
for
d
in
data
.
create_dict_iterator
():
buffer
.
extend
([
d
[
'col1'
].
item
().
decode
(
"utf8"
),
d
[
'col2'
].
item
().
decode
(
"utf8"
),
d
[
'col3'
].
item
().
decode
(
"utf8"
),
d
[
'col4'
].
item
().
decode
(
"utf8"
)])
assert
buffer
==
[
'a'
,
'b'
,
'c'
,
'd'
]
def
test_csv_dataset_embedded
():
TEST_FILE
=
'../data/dataset/testCSV/embedded.csv'
data
=
ds
.
CSVDataset
(
TEST_FILE
,
column_defaults
=
[
""
,
""
,
""
,
""
],
column_names
=
[
'col1'
,
'col2'
,
'col3'
,
'col4'
],
shuffle
=
False
)
buffer
=
[]
for
d
in
data
.
create_dict_iterator
():
buffer
.
extend
([
d
[
'col1'
].
item
().
decode
(
"utf8"
),
d
[
'col2'
].
item
().
decode
(
"utf8"
),
d
[
'col3'
].
item
().
decode
(
"utf8"
),
d
[
'col4'
].
item
().
decode
(
"utf8"
)])
assert
buffer
==
[
'a,b'
,
'c"d'
,
'e
\n
f'
,
' g '
]
def
test_csv_dataset_chinese
():
TEST_FILE
=
'../data/dataset/testCSV/chinese.csv'
data
=
ds
.
CSVDataset
(
TEST_FILE
,
column_defaults
=
[
""
,
""
,
""
,
""
,
""
],
column_names
=
[
'col1'
,
'col2'
,
'col3'
,
'col4'
,
'col5'
],
shuffle
=
False
)
buffer
=
[]
for
d
in
data
.
create_dict_iterator
():
buffer
.
extend
([
d
[
'col1'
].
item
().
decode
(
"utf8"
),
d
[
'col2'
].
item
().
decode
(
"utf8"
),
d
[
'col3'
].
item
().
decode
(
"utf8"
),
d
[
'col4'
].
item
().
decode
(
"utf8"
),
d
[
'col5'
].
item
().
decode
(
"utf8"
)])
assert
buffer
==
[
'大家'
,
'早上好'
,
'中午好'
,
'下午好'
,
'晚上好'
]
def
test_csv_dataset_header
():
TEST_FILE
=
'../data/dataset/testCSV/header.csv'
data
=
ds
.
CSVDataset
(
TEST_FILE
,
column_defaults
=
[
""
,
""
,
""
,
""
],
shuffle
=
False
)
buffer
=
[]
for
d
in
data
.
create_dict_iterator
():
buffer
.
extend
([
d
[
'col1'
].
item
().
decode
(
"utf8"
),
d
[
'col2'
].
item
().
decode
(
"utf8"
),
d
[
'col3'
].
item
().
decode
(
"utf8"
),
d
[
'col4'
].
item
().
decode
(
"utf8"
)])
assert
buffer
==
[
'a'
,
'b'
,
'c'
,
'd'
]
def
test_csv_dataset_number
():
TEST_FILE
=
'../data/dataset/testCSV/number.csv'
data
=
ds
.
CSVDataset
(
TEST_FILE
,
column_defaults
=
[
0.0
,
0.0
,
0
,
0.0
],
column_names
=
[
'col1'
,
'col2'
,
'col3'
,
'col4'
],
shuffle
=
False
)
buffer
=
[]
for
d
in
data
.
create_dict_iterator
():
buffer
.
extend
([
d
[
'col1'
].
item
(),
d
[
'col2'
].
item
(),
d
[
'col3'
].
item
(),
d
[
'col4'
].
item
()])
assert
np
.
allclose
(
buffer
,
[
3.0
,
0.3
,
4
,
55.5
])
def
test_csv_dataset_size
():
TEST_FILE
=
'../data/dataset/testCSV/size.csv'
data
=
ds
.
CSVDataset
(
TEST_FILE
,
column_defaults
=
[
0.0
,
0.0
,
0
,
0.0
],
column_names
=
[
'col1'
,
'col2'
,
'col3'
,
'col4'
],
shuffle
=
False
)
assert
data
.
get_dataset_size
()
==
5
def
test_csv_dataset_exception
():
TEST_FILE
=
'../data/dataset/testCSV/exception.csv'
data
=
ds
.
CSVDataset
(
TEST_FILE
,
column_defaults
=
[
""
,
""
,
""
,
""
],
column_names
=
[
'col1'
,
'col2'
,
'col3'
,
'col4'
],
shuffle
=
False
)
with
pytest
.
raises
(
Exception
)
as
err
:
for
_
in
data
.
create_dict_iterator
():
pass
assert
"Failed to parse CSV file"
in
str
(
err
.
value
)
def
test_csv_dataset_type_error
():
TEST_FILE
=
'../data/dataset/testCSV/exception.csv'
data
=
ds
.
CSVDataset
(
TEST_FILE
,
column_defaults
=
[
""
,
0
,
""
,
""
],
column_names
=
[
'col1'
,
'col2'
,
'col3'
,
'col4'
],
shuffle
=
False
)
with
pytest
.
raises
(
Exception
)
as
err
:
for
_
in
data
.
create_dict_iterator
():
pass
assert
"invalid argument of stoi"
in
str
(
err
.
value
)
if
__name__
==
"__main__"
:
test_csv_dataset_basic
()
test_csv_dataset_one_file
()
test_csv_dataset_all_file
()
test_csv_dataset_num_samples
()
test_csv_dataset_distribution
()
test_csv_dataset_quoted
()
test_csv_dataset_separated
()
test_csv_dataset_embedded
()
test_csv_dataset_chinese
()
test_csv_dataset_header
()
test_csv_dataset_number
()
test_csv_dataset_size
()
test_csv_dataset_exception
()
test_csv_dataset_type_error
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录