Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
46e9d79a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
46e9d79a
编写于
4年前
作者:
C
Cathy Wong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
C++ API: Lexicographical order support for CLUE, CSV & TextFile Datasets
上级
b8da525f
master
无相关合并请求
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
284 addition
and
25 deletion
+284
-25
mindspore/ccsrc/minddata/dataset/api/datasets.cc
mindspore/ccsrc/minddata/dataset/api/datasets.cc
+24
-10
tests/ut/cpp/dataset/c_api_dataset_clue_test.cc
tests/ut/cpp/dataset/c_api_dataset_clue_test.cc
+73
-6
tests/ut/cpp/dataset/c_api_dataset_csv_test.cc
tests/ut/cpp/dataset/c_api_dataset_csv_test.cc
+63
-3
tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc
tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc
+124
-6
未找到文件。
mindspore/ccsrc/minddata/dataset/api/datasets.cc
浏览文件 @
46e9d79a
...
...
@@ -1009,9 +1009,14 @@ std::vector<std::shared_ptr<DatasetOp>> CLUEDataset::Build() {
}
bool
shuffle_files
=
(
shuffle_
==
ShuffleMode
::
kGlobal
||
shuffle_
==
ShuffleMode
::
kFiles
);
// Sort the dataset files in a lexicographical order
std
::
vector
<
std
::
string
>
sorted_dataset_files
=
dataset_files_
;
std
::
sort
(
sorted_dataset_files
.
begin
(),
sorted_dataset_files
.
end
());
std
::
shared_ptr
<
ClueOp
>
clue_op
=
std
::
make_shared
<
ClueOp
>
(
num_workers_
,
rows_per_buffer_
,
num_samples_
,
worker_connector_size_
,
ck_map
,
dataset_files_
,
connector_que_size_
,
shuffle_files
,
num_shards_
,
shard_id_
);
sorted_dataset_files
,
connector_que_size_
,
shuffle_files
,
num_shards_
,
shard_id_
);
RETURN_EMPTY_IF_ERROR
(
clue_op
->
Init
());
if
(
shuffle_
==
ShuffleMode
::
kGlobal
)
{
// Inject ShuffleOp
...
...
@@ -1019,10 +1024,10 @@ std::vector<std::shared_ptr<DatasetOp>> CLUEDataset::Build() {
int64_t
num_rows
=
0
;
// First, get the number of rows in the dataset
RETURN_EMPTY_IF_ERROR
(
ClueOp
::
CountAllFileRows
(
dataset_files_
,
&
num_rows
));
RETURN_EMPTY_IF_ERROR
(
ClueOp
::
CountAllFileRows
(
sorted_dataset_files
,
&
num_rows
));
// Add the shuffle op after this op
RETURN_EMPTY_IF_ERROR
(
AddShuffleOp
(
dataset_files_
.
size
(),
num_shards_
,
num_rows
,
0
,
connector_que_size_
,
RETURN_EMPTY_IF_ERROR
(
AddShuffleOp
(
sorted_dataset_files
.
size
(),
num_shards_
,
num_rows
,
0
,
connector_que_size_
,
rows_per_buffer_
,
&
shuffle_op
));
node_ops
.
push_back
(
shuffle_op
);
}
...
...
@@ -1162,6 +1167,11 @@ std::vector<std::shared_ptr<DatasetOp>> CSVDataset::Build() {
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
node_ops
;
bool
shuffle_files
=
(
shuffle_
==
ShuffleMode
::
kGlobal
||
shuffle_
==
ShuffleMode
::
kFiles
);
// Sort the dataset files in a lexicographical order
std
::
vector
<
std
::
string
>
sorted_dataset_files
=
dataset_files_
;
std
::
sort
(
sorted_dataset_files
.
begin
(),
sorted_dataset_files
.
end
());
std
::
vector
<
std
::
shared_ptr
<
CsvOp
::
BaseRecord
>>
column_default_list
;
for
(
auto
v
:
column_defaults_
)
{
if
(
v
->
type
==
CsvType
::
INT
)
{
...
...
@@ -1177,8 +1187,8 @@ std::vector<std::shared_ptr<DatasetOp>> CSVDataset::Build() {
}
std
::
shared_ptr
<
CsvOp
>
csv_op
=
std
::
make_shared
<
CsvOp
>
(
dataset_files_
,
field_delim_
,
column_default_list
,
column_names_
,
num_workers_
,
rows_per_buffer_
,
num_samples
_
,
worker_connector_size_
,
connector_que_size_
,
shuffle_files
,
num_shards_
,
shard_id_
);
sorted_dataset_files
,
field_delim_
,
column_default_list
,
column_names_
,
num_workers_
,
rows_per_buffer
_
,
num_samples_
,
worker_connector_size_
,
connector_que_size_
,
shuffle_files
,
num_shards_
,
shard_id_
);
RETURN_EMPTY_IF_ERROR
(
csv_op
->
Init
());
if
(
shuffle_
==
ShuffleMode
::
kGlobal
)
{
// Inject ShuffleOp
...
...
@@ -1186,10 +1196,10 @@ std::vector<std::shared_ptr<DatasetOp>> CSVDataset::Build() {
int64_t
num_rows
=
0
;
// First, get the number of rows in the dataset
RETURN_EMPTY_IF_ERROR
(
CsvOp
::
CountAllFileRows
(
dataset_files_
,
column_names_
.
empty
(),
&
num_rows
));
RETURN_EMPTY_IF_ERROR
(
CsvOp
::
CountAllFileRows
(
sorted_dataset_files
,
column_names_
.
empty
(),
&
num_rows
));
// Add the shuffle op after this op
RETURN_EMPTY_IF_ERROR
(
AddShuffleOp
(
dataset_files_
.
size
(),
num_shards_
,
num_rows
,
0
,
connector_que_size_
,
RETURN_EMPTY_IF_ERROR
(
AddShuffleOp
(
sorted_dataset_files
.
size
(),
num_shards_
,
num_rows
,
0
,
connector_que_size_
,
rows_per_buffer_
,
&
shuffle_op
));
node_ops
.
push_back
(
shuffle_op
);
}
...
...
@@ -1398,6 +1408,10 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() {
bool
shuffle_files
=
(
shuffle_
==
ShuffleMode
::
kGlobal
||
shuffle_
==
ShuffleMode
::
kFiles
);
// Sort the dataset files in a lexicographical order
std
::
vector
<
std
::
string
>
sorted_dataset_files
=
dataset_files_
;
std
::
sort
(
sorted_dataset_files
.
begin
(),
sorted_dataset_files
.
end
());
// Do internal Schema generation.
auto
schema
=
std
::
make_unique
<
DataSchema
>
();
RETURN_EMPTY_IF_ERROR
(
...
...
@@ -1405,7 +1419,7 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() {
// Create and initalize TextFileOp
std
::
shared_ptr
<
TextFileOp
>
text_file_op
=
std
::
make_shared
<
TextFileOp
>
(
num_workers_
,
rows_per_buffer_
,
num_samples_
,
worker_connector_size_
,
std
::
move
(
schema
),
dataset_files_
,
num_workers_
,
rows_per_buffer_
,
num_samples_
,
worker_connector_size_
,
std
::
move
(
schema
),
sorted_dataset_files
,
connector_que_size_
,
shuffle_files
,
num_shards_
,
shard_id_
,
std
::
move
(
nullptr
));
RETURN_EMPTY_IF_ERROR
(
text_file_op
->
Init
());
...
...
@@ -1415,10 +1429,10 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() {
int64_t
num_rows
=
0
;
// First, get the number of rows in the dataset
RETURN_EMPTY_IF_ERROR
(
TextFileOp
::
CountAllFileRows
(
dataset_files_
,
&
num_rows
));
RETURN_EMPTY_IF_ERROR
(
TextFileOp
::
CountAllFileRows
(
sorted_dataset_files
,
&
num_rows
));
// Add the shuffle op after this op
RETURN_EMPTY_IF_ERROR
(
AddShuffleOp
(
dataset_files_
.
size
(),
num_shards_
,
num_rows
,
0
,
connector_que_size_
,
RETURN_EMPTY_IF_ERROR
(
AddShuffleOp
(
sorted_dataset_files
.
size
(),
num_shards_
,
num_rows
,
0
,
connector_que_size_
,
rows_per_buffer_
,
&
shuffle_op
));
node_ops
.
push_back
(
shuffle_op
);
}
...
...
This diff is collapsed.
Click to expand it.
tests/ut/cpp/dataset/c_api_dataset_clue_test.cc
浏览文件 @
46e9d79a
...
...
@@ -362,8 +362,8 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetIFLYTEK) {
iter
->
Stop
();
}
TEST_F
(
MindDataTestPipeline
,
TestCLUEDatasetShuffleFiles
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestCLUEDatasetShuffleFiles."
;
TEST_F
(
MindDataTestPipeline
,
TestCLUEDatasetShuffleFiles
A
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestCLUEDatasetShuffleFiles
A
."
;
// Test CLUE Dataset with files shuffle, num_parallel_workers=1
// Set configuration
...
...
@@ -373,7 +373,7 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetShuffleFiles) {
GlobalContext
::
config_manager
()
->
set_seed
(
135
);
GlobalContext
::
config_manager
()
->
set_num_parallel_workers
(
1
);
// Create a CLUE Dataset, with two text files
// Create a CLUE Dataset, with two text files
, dev.json and train.json, in lexicographical order
// Note: train.json has 3 rows
// Note: dev.json has 3 rows
// Use default of all samples
...
...
@@ -383,7 +383,7 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetShuffleFiles) {
std
::
string
clue_file2
=
datasets_root_path_
+
"/testCLUE/afqmc/dev.json"
;
std
::
string
task
=
"AFQMC"
;
std
::
string
usage
=
"train"
;
std
::
shared_ptr
<
Dataset
>
ds
=
CLUE
({
clue_file
1
,
clue_file2
},
task
,
usage
,
0
,
ShuffleMode
::
kFiles
);
std
::
shared_ptr
<
Dataset
>
ds
=
CLUE
({
clue_file
2
,
clue_file1
},
task
,
usage
,
0
,
ShuffleMode
::
kFiles
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset.
...
...
@@ -397,12 +397,79 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetShuffleFiles) {
EXPECT_NE
(
row
.
find
(
"sentence1"
),
row
.
end
());
std
::
vector
<
std
::
string
>
expected_result
=
{
"你有花呗吗"
,
"吃饭能用花呗吗"
,
"蚂蚁花呗支付金额有什么限制"
,
"蚂蚁借呗等额还款能否换成先息后本"
,
"蚂蚁花呗说我违约了"
,
"帮我看看本月花呗账单结清了没"
,
"帮我看看本月花呗账单结清了没"
};
uint64_t
i
=
0
;
while
(
row
.
size
()
!=
0
)
{
auto
text
=
row
[
"sentence1"
];
std
::
string_view
sv
;
text
->
GetItemAt
(
&
sv
,
{
0
});
std
::
string
ss
(
sv
);
MS_LOG
(
INFO
)
<<
"Text length: "
<<
ss
.
length
()
<<
", Text: "
<<
ss
.
substr
(
0
,
50
);
// Compare against expected result
EXPECT_STREQ
(
ss
.
c_str
(),
expected_result
[
i
].
c_str
());
i
++
;
iter
->
GetNextRow
(
&
row
);
}
// Expect 3 + 3 = 6 samples
EXPECT_EQ
(
i
,
6
);
// Manually terminate the pipeline
iter
->
Stop
();
// Restore configuration
GlobalContext
::
config_manager
()
->
set_seed
(
original_seed
);
GlobalContext
::
config_manager
()
->
set_num_parallel_workers
(
original_num_parallel_workers
);
}
TEST_F
(
MindDataTestPipeline
,
TestCLUEDatasetShuffleFilesB
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestCLUEDatasetShuffleFilesB."
;
// Test CLUE Dataset with files shuffle, num_parallel_workers=1
// Set configuration
uint32_t
original_seed
=
GlobalContext
::
config_manager
()
->
seed
();
uint32_t
original_num_parallel_workers
=
GlobalContext
::
config_manager
()
->
num_parallel_workers
();
MS_LOG
(
DEBUG
)
<<
"ORIGINAL seed: "
<<
original_seed
<<
", num_parallel_workers: "
<<
original_num_parallel_workers
;
GlobalContext
::
config_manager
()
->
set_seed
(
135
);
GlobalContext
::
config_manager
()
->
set_num_parallel_workers
(
1
);
// Create a CLUE Dataset, with two text files, train.json and dev.json, in non-lexicographical order
// Note: train.json has 3 rows
// Note: dev.json has 3 rows
// Use default of all samples
// They have the same keywords
// Set shuffle to files shuffle
std
::
string
clue_file1
=
datasets_root_path_
+
"/testCLUE/afqmc/train.json"
;
std
::
string
clue_file2
=
datasets_root_path_
+
"/testCLUE/afqmc/dev.json"
;
std
::
string
task
=
"AFQMC"
;
std
::
string
usage
=
"train"
;
std
::
shared_ptr
<
Dataset
>
ds
=
CLUE
({
clue_file1
,
clue_file2
},
task
,
usage
,
0
,
ShuffleMode
::
kFiles
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
();
EXPECT_NE
(
iter
,
nullptr
);
// Iterate the dataset and get each row
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
EXPECT_NE
(
row
.
find
(
"sentence1"
),
row
.
end
());
std
::
vector
<
std
::
string
>
expected_result
=
{
"你有花呗吗"
,
"吃饭能用花呗吗"
,
"蚂蚁花呗支付金额有什么限制"
"蚂蚁花呗支付金额有什么限制"
,
"蚂蚁借呗等额还款能否换成先息后本"
,
"蚂蚁花呗说我违约了"
,
"帮我看看本月花呗账单结清了没"
};
uint64_t
i
=
0
;
...
...
This diff is collapsed.
Click to expand it.
tests/ut/cpp/dataset/c_api_dataset_csv_test.cc
浏览文件 @
46e9d79a
...
...
@@ -359,8 +359,8 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetException) {
EXPECT_EQ
(
ds5
,
nullptr
);
}
TEST_F
(
MindDataTestPipeline
,
TestCSVDatasetShuffleFiles
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestCSVDatasetShuffleFiles."
;
TEST_F
(
MindDataTestPipeline
,
TestCSVDatasetShuffleFiles
A
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestCSVDatasetShuffleFiles
A
."
;
// Set configuration
uint32_t
original_seed
=
GlobalContext
::
config_manager
()
->
seed
();
...
...
@@ -369,7 +369,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFiles) {
GlobalContext
::
config_manager
()
->
set_seed
(
130
);
GlobalContext
::
config_manager
()
->
set_num_parallel_workers
(
4
);
// Create a CSVDataset, with
single CSV file
// Create a CSVDataset, with
2 CSV files, 1.csv and append.csv in lexicographical order
std
::
string
file1
=
datasets_root_path_
+
"/testCSV/1.csv"
;
std
::
string
file2
=
datasets_root_path_
+
"/testCSV/append.csv"
;
std
::
vector
<
std
::
string
>
column_names
=
{
"col1"
,
"col2"
,
"col3"
,
"col4"
};
...
...
@@ -418,6 +418,66 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFiles) {
GlobalContext
::
config_manager
()
->
set_num_parallel_workers
(
original_num_parallel_workers
);
}
TEST_F
(
MindDataTestPipeline
,
TestCSVDatasetShuffleFilesB
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestCSVDatasetShuffleFilesB."
;
// Set configuration
uint32_t
original_seed
=
GlobalContext
::
config_manager
()
->
seed
();
uint32_t
original_num_parallel_workers
=
GlobalContext
::
config_manager
()
->
num_parallel_workers
();
MS_LOG
(
DEBUG
)
<<
"ORIGINAL seed: "
<<
original_seed
<<
", num_parallel_workers: "
<<
original_num_parallel_workers
;
GlobalContext
::
config_manager
()
->
set_seed
(
130
);
GlobalContext
::
config_manager
()
->
set_num_parallel_workers
(
4
);
// Create a CSVDataset, with 2 CSV files, append.csv and 1.csv in non-lexicographical order
std
::
string
file1
=
datasets_root_path_
+
"/testCSV/1.csv"
;
std
::
string
file2
=
datasets_root_path_
+
"/testCSV/append.csv"
;
std
::
vector
<
std
::
string
>
column_names
=
{
"col1"
,
"col2"
,
"col3"
,
"col4"
};
std
::
shared_ptr
<
Dataset
>
ds
=
CSV
({
file2
,
file1
},
','
,
{},
column_names
,
-
1
,
ShuffleMode
::
kFiles
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
();
EXPECT_NE
(
iter
,
nullptr
);
// Iterate the dataset and get each row
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
EXPECT_NE
(
row
.
find
(
"col1"
),
row
.
end
());
std
::
vector
<
std
::
vector
<
std
::
string
>>
expected_result
=
{
{
"13"
,
"14"
,
"15"
,
"16"
},
{
"1"
,
"2"
,
"3"
,
"4"
},
{
"17"
,
"18"
,
"19"
,
"20"
},
{
"5"
,
"6"
,
"7"
,
"8"
},
{
"21"
,
"22"
,
"23"
,
"24"
},
{
"9"
,
"10"
,
"11"
,
"12"
},
};
uint64_t
i
=
0
;
while
(
row
.
size
()
!=
0
)
{
for
(
int
j
=
0
;
j
<
column_names
.
size
();
j
++
)
{
auto
text
=
row
[
column_names
[
j
]];
std
::
string_view
sv
;
text
->
GetItemAt
(
&
sv
,
{
0
});
std
::
string
ss
(
sv
);
MS_LOG
(
INFO
)
<<
"Text length: "
<<
ss
.
length
()
<<
", Text: "
<<
ss
.
substr
(
0
,
50
);
EXPECT_STREQ
(
ss
.
c_str
(),
expected_result
[
i
][
j
].
c_str
());
}
iter
->
GetNextRow
(
&
row
);
i
++
;
}
// Expect 6 samples
EXPECT_EQ
(
i
,
6
);
// Manually terminate the pipeline
iter
->
Stop
();
// Restore configuration
GlobalContext
::
config_manager
()
->
set_seed
(
original_seed
);
GlobalContext
::
config_manager
()
->
set_num_parallel_workers
(
original_num_parallel_workers
);
}
TEST_F
(
MindDataTestPipeline
,
TestCSVDatasetShuffleGlobal
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestCSVDatasetShuffleGlobal."
;
// Test CSV Dataset with GLOBLE shuffle
...
...
This diff is collapsed.
Click to expand it.
tests/ut/cpp/dataset/c_api_dataset_textfile_test.cc
浏览文件 @
46e9d79a
...
...
@@ -165,8 +165,8 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetFail7) {
EXPECT_EQ
(
ds
,
nullptr
);
}
TEST_F
(
MindDataTestPipeline
,
TestTextFileDatasetShuffleFalse1
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestTextFileDatasetShuffleFalse1."
;
TEST_F
(
MindDataTestPipeline
,
TestTextFileDatasetShuffleFalse1
A
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestTextFileDatasetShuffleFalse1
A
."
;
// Test TextFile Dataset with two text files and no shuffle, num_parallel_workers=1
// Set configuration
...
...
@@ -176,7 +176,7 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFalse1) {
GlobalContext
::
config_manager
()
->
set_seed
(
654
);
GlobalContext
::
config_manager
()
->
set_num_parallel_workers
(
1
);
// Create a TextFile Dataset, with two text files
// Create a TextFile Dataset, with two text files
, 1.txt then 2.txt, in lexicographical order.
// Note: 1.txt has 3 rows
// Note: 2.txt has 2 rows
// Use default of all samples
...
...
@@ -223,6 +223,64 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFalse1) {
GlobalContext
::
config_manager
()
->
set_num_parallel_workers
(
original_num_parallel_workers
);
}
TEST_F
(
MindDataTestPipeline
,
TestTextFileDatasetShuffleFalse1B
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestTextFileDatasetShuffleFalse1B."
;
// Test TextFile Dataset with two text files and no shuffle, num_parallel_workers=1
// Set configuration
uint32_t
original_seed
=
GlobalContext
::
config_manager
()
->
seed
();
uint32_t
original_num_parallel_workers
=
GlobalContext
::
config_manager
()
->
num_parallel_workers
();
MS_LOG
(
DEBUG
)
<<
"ORIGINAL seed: "
<<
original_seed
<<
", num_parallel_workers: "
<<
original_num_parallel_workers
;
GlobalContext
::
config_manager
()
->
set_seed
(
654
);
GlobalContext
::
config_manager
()
->
set_num_parallel_workers
(
1
);
// Create a TextFile Dataset, with two text files, 2.txt then 1.txt, in non-lexicographical order
// Note: 1.txt has 3 rows
// Note: 2.txt has 2 rows
// Use default of all samples
std
::
string
tf_file1
=
datasets_root_path_
+
"/testTextFileDataset/1.txt"
;
std
::
string
tf_file2
=
datasets_root_path_
+
"/testTextFileDataset/2.txt"
;
std
::
shared_ptr
<
Dataset
>
ds
=
TextFile
({
tf_file2
,
tf_file1
},
0
,
ShuffleMode
::
kFalse
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
();
EXPECT_NE
(
iter
,
nullptr
);
// Iterate the dataset and get each row
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
EXPECT_NE
(
row
.
find
(
"text"
),
row
.
end
());
std
::
vector
<
std
::
string
>
expected_result
=
{
"This is a text file."
,
"Be happy every day."
,
"Good luck to everyone."
,
"Another file."
,
"End of file."
};
uint64_t
i
=
0
;
while
(
row
.
size
()
!=
0
)
{
auto
text
=
row
[
"text"
];
MS_LOG
(
INFO
)
<<
"Tensor text shape: "
<<
text
->
shape
();
std
::
string_view
sv
;
text
->
GetItemAt
(
&
sv
,
{
0
});
std
::
string
ss
(
sv
);
MS_LOG
(
INFO
)
<<
"Text length: "
<<
ss
.
length
()
<<
", Text: "
<<
ss
.
substr
(
0
,
50
);
// Compare against expected result
EXPECT_STREQ
(
ss
.
c_str
(),
expected_result
[
i
].
c_str
());
i
++
;
iter
->
GetNextRow
(
&
row
);
}
// Expect 2 + 3 = 5 samples
EXPECT_EQ
(
i
,
5
);
// Manually terminate the pipeline
iter
->
Stop
();
// Restore configuration
GlobalContext
::
config_manager
()
->
set_seed
(
original_seed
);
GlobalContext
::
config_manager
()
->
set_num_parallel_workers
(
original_num_parallel_workers
);
}
TEST_F
(
MindDataTestPipeline
,
TestTextFileDatasetShuffleFalse4Shard
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestTextFileDatasetShuffleFalse4Shard."
;
// Test TextFile Dataset with two text files and no shuffle, num_parallel_workers=4, shard coverage
...
...
@@ -280,8 +338,8 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFalse4Shard) {
GlobalContext
::
config_manager
()
->
set_num_parallel_workers
(
original_num_parallel_workers
);
}
TEST_F
(
MindDataTestPipeline
,
TestTextFileDatasetShuffleFiles1
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestTextFileDatasetShuffleFiles1."
;
TEST_F
(
MindDataTestPipeline
,
TestTextFileDatasetShuffleFiles1
A
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestTextFileDatasetShuffleFiles1
A
."
;
// Test TextFile Dataset with files shuffle, num_parallel_workers=1
// Set configuration
...
...
@@ -291,7 +349,7 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFiles1) {
GlobalContext
::
config_manager
()
->
set_seed
(
135
);
GlobalContext
::
config_manager
()
->
set_num_parallel_workers
(
1
);
// Create a TextFile Dataset, with two text files
// Create a TextFile Dataset, with two text files
, 1.txt then 2.txt, in lexicographical order.
// Note: 1.txt has 3 rows
// Note: 2.txt has 2 rows
// Use default of all samples
...
...
@@ -340,6 +398,66 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFiles1) {
GlobalContext
::
config_manager
()
->
set_num_parallel_workers
(
original_num_parallel_workers
);
}
TEST_F
(
MindDataTestPipeline
,
TestTextFileDatasetShuffleFiles1B
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestTextFileDatasetShuffleFiles1B."
;
// Test TextFile Dataset with files shuffle, num_parallel_workers=1
// Set configuration
uint32_t
original_seed
=
GlobalContext
::
config_manager
()
->
seed
();
uint32_t
original_num_parallel_workers
=
GlobalContext
::
config_manager
()
->
num_parallel_workers
();
MS_LOG
(
DEBUG
)
<<
"ORIGINAL seed: "
<<
original_seed
<<
", num_parallel_workers: "
<<
original_num_parallel_workers
;
GlobalContext
::
config_manager
()
->
set_seed
(
135
);
GlobalContext
::
config_manager
()
->
set_num_parallel_workers
(
1
);
// Create a TextFile Dataset, with two text files, 2.txt then 1.txt, in non-lexicographical order.
// Note: 1.txt has 3 rows
// Note: 2.txt has 2 rows
// Use default of all samples
// Set shuffle to files shuffle
std
::
string
tf_file1
=
datasets_root_path_
+
"/testTextFileDataset/1.txt"
;
std
::
string
tf_file2
=
datasets_root_path_
+
"/testTextFileDataset/2.txt"
;
std
::
shared_ptr
<
Dataset
>
ds
=
TextFile
({
tf_file2
,
tf_file1
},
0
,
ShuffleMode
::
kFiles
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
();
EXPECT_NE
(
iter
,
nullptr
);
// Iterate the dataset and get each row
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
EXPECT_NE
(
row
.
find
(
"text"
),
row
.
end
());
std
::
vector
<
std
::
string
>
expected_result
=
{
"This is a text file."
,
"Be happy every day."
,
"Good luck to everyone."
,
"Another file."
,
"End of file."
,
};
uint64_t
i
=
0
;
while
(
row
.
size
()
!=
0
)
{
auto
text
=
row
[
"text"
];
MS_LOG
(
INFO
)
<<
"Tensor text shape: "
<<
text
->
shape
();
std
::
string_view
sv
;
text
->
GetItemAt
(
&
sv
,
{
0
});
std
::
string
ss
(
sv
);
MS_LOG
(
INFO
)
<<
"Text length: "
<<
ss
.
length
()
<<
", Text: "
<<
ss
.
substr
(
0
,
50
);
// Compare against expected result
EXPECT_STREQ
(
ss
.
c_str
(),
expected_result
[
i
].
c_str
());
i
++
;
iter
->
GetNextRow
(
&
row
);
}
// Expect 2 + 3 = 5 samples
EXPECT_EQ
(
i
,
5
);
// Manually terminate the pipeline
iter
->
Stop
();
// Restore configuration
GlobalContext
::
config_manager
()
->
set_seed
(
original_seed
);
GlobalContext
::
config_manager
()
->
set_num_parallel_workers
(
original_num_parallel_workers
);
}
TEST_F
(
MindDataTestPipeline
,
TestTextFileDatasetShuffleFiles4
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestTextFileDatasetShuffleFiles4."
;
// Test TextFile Dataset with files shuffle, num_parallel_workers=4
...
...
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
反馈
建议
客服
返回
顶部