Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7f39b5cf
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7f39b5cf
编写于
8月 07, 2020
作者:
C
Cathy Wong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
C++ API Support for TextFile Dataset and Unit Tests
上级
4f75adb1
变更
8
展开全部
隐藏空白更改
内联
并排
Showing
8 changed file
with
932 addition
and
7 deletion
+932
-7
mindspore/ccsrc/minddata/dataset/api/datasets.cc
mindspore/ccsrc/minddata/dataset/api/datasets.cc
+122
-1
mindspore/ccsrc/minddata/dataset/core/constants.h
mindspore/ccsrc/minddata/dataset/core/constants.h
+3
-0
mindspore/ccsrc/minddata/dataset/include/datasets.h
mindspore/ccsrc/minddata/dataset/include/datasets.h
+58
-1
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+1
-1
tests/ut/cpp/dataset/CMakeLists.txt
tests/ut/cpp/dataset/CMakeLists.txt
+1
-0
tests/ut/cpp/dataset/c_api_dataset_filetext_test.cc
tests/ut/cpp/dataset/c_api_dataset_filetext_test.cc
+596
-0
tests/ut/cpp/dataset/text_file_op_test.cc
tests/ut/cpp/dataset/text_file_op_test.cc
+28
-0
tests/ut/python/dataset/test_datasets_textfileop.py
tests/ut/python/dataset/test_datasets_textfileop.py
+123
-4
未找到文件。
mindspore/ccsrc/minddata/dataset/api/datasets.cc
浏览文件 @
7f39b5cf
...
...
@@ -26,6 +26,7 @@
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
// Dataset operator headers (in alphabetical order)
#include "minddata/dataset/engine/datasetops/batch_op.h"
...
...
@@ -95,6 +96,7 @@ Dataset::Dataset() {
num_workers_
=
cfg
->
num_parallel_workers
();
rows_per_buffer_
=
cfg
->
rows_per_buffer
();
connector_que_size_
=
cfg
->
op_connector_size
();
worker_connector_size_
=
cfg
->
worker_connector_size
();
}
// FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS
...
...
@@ -140,7 +142,7 @@ std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::str
std
::
shared_ptr
<
ImageFolderDataset
>
ImageFolder
(
std
::
string
dataset_dir
,
bool
decode
,
std
::
shared_ptr
<
SamplerObj
>
sampler
,
std
::
set
<
std
::
string
>
extensions
,
std
::
map
<
std
::
string
,
int32_t
>
class_indexing
)
{
// This arg
is exist
in ImageFolderOp, but not externalized (in Python API). The default value is false.
// This arg
exists
in ImageFolderOp, but not externalized (in Python API). The default value is false.
bool
recursive
=
false
;
// Create logical representation of ImageFolderDataset.
...
...
@@ -163,6 +165,16 @@ std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &dataset
const
std
::
shared_ptr
<
Dataset
>
&
datasets2
)
{
std
::
shared_ptr
<
ConcatDataset
>
ds
=
std
::
make_shared
<
ConcatDataset
>
(
std
::
vector
({
datasets1
,
datasets2
}));
// Call derived class validation method.
return
ds
->
ValidateParams
()
?
ds
:
nullptr
;
}
// Function to create a TextFileDataset.
std
::
shared_ptr
<
TextFileDataset
>
TextFile
(
std
::
vector
<
std
::
string
>
dataset_files
,
int32_t
num_samples
,
ShuffleMode
shuffle
,
int32_t
num_shards
,
int32_t
shard_id
)
{
auto
ds
=
std
::
make_shared
<
TextFileDataset
>
(
dataset_files
,
num_samples
,
shuffle
,
num_shards
,
shard_id
);
// Call derived class validation method.
return
ds
->
ValidateParams
()
?
ds
:
nullptr
;
}
...
...
@@ -340,6 +352,34 @@ std::shared_ptr<SamplerObj> CreateDefaultSampler() {
return
std
::
make_shared
<
RandomSamplerObj
>
(
replacement
,
num_samples
);
}
// Helper function to compute a default shuffle size
int64_t
ComputeShuffleSize
(
int64_t
num_files
,
int64_t
num_devices
,
int64_t
num_rows
,
int64_t
total_rows
)
{
const
int64_t
average_files_multiplier
=
4
;
const
int64_t
shuffle_max
=
10000
;
int64_t
avg_rows_per_file
=
0
;
int64_t
shuffle_size
=
0
;
// Adjust the num rows per shard if sharding was given
if
(
num_devices
>
0
)
{
if
(
num_rows
%
num_devices
==
0
)
{
num_rows
=
num_rows
/
num_devices
;
}
else
{
num_rows
=
(
num_rows
/
num_devices
)
+
1
;
}
}
// Cap based on total rows directive. Some ops do not have this and give value of 0.
if
(
total_rows
>
0
)
{
num_rows
=
std
::
min
(
num_rows
,
total_rows
);
}
// get the average per file
avg_rows_per_file
=
num_rows
/
num_files
;
shuffle_size
=
std
::
max
(
avg_rows_per_file
*
average_files_multiplier
,
shuffle_max
);
return
shuffle_size
;
}
// Helper function to validate dataset params
bool
ValidateCommonDatasetParams
(
std
::
string
dataset_dir
)
{
if
(
dataset_dir
.
empty
())
{
...
...
@@ -613,6 +653,87 @@ std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
return
node_ops
;
}
// Constructor for TextFileDataset
TextFileDataset
::
TextFileDataset
(
std
::
vector
<
std
::
string
>
dataset_files
,
int32_t
num_samples
,
ShuffleMode
shuffle
,
int32_t
num_shards
,
int32_t
shard_id
)
:
dataset_files_
(
dataset_files
),
num_samples_
(
num_samples
),
shuffle_
(
shuffle
),
num_shards_
(
num_shards
),
shard_id_
(
shard_id
)
{}
bool
TextFileDataset
::
ValidateParams
()
{
if
(
dataset_files_
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"TextFileDataset: dataset_files is not specified."
;
return
false
;
}
for
(
auto
file
:
dataset_files_
)
{
std
::
ifstream
handle
(
file
);
if
(
!
handle
.
is_open
())
{
MS_LOG
(
ERROR
)
<<
"TextFileDataset: Failed to open file: "
<<
file
;
return
false
;
}
}
if
(
num_samples_
<
0
)
{
MS_LOG
(
ERROR
)
<<
"TextFileDataset: Invalid number of samples: "
<<
num_samples_
;
return
false
;
}
if
(
num_shards_
<=
0
)
{
MS_LOG
(
ERROR
)
<<
"TextFileDataset: Invalid num_shards: "
<<
num_shards_
;
return
false
;
}
if
(
shard_id_
<
0
||
shard_id_
>=
num_shards_
)
{
MS_LOG
(
ERROR
)
<<
"TextFileDataset: Invalid input, shard_id: "
<<
shard_id_
<<
", num_shards: "
<<
num_shards_
;
return
false
;
}
return
true
;
}
// Function to build TextFileDataset
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
TextFileDataset
::
Build
()
{
// A vector containing shared pointer to the Dataset Ops that this object will create
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
node_ops
;
bool
shuffle_files
=
(
shuffle_
==
ShuffleMode
::
kGlobal
||
shuffle_
==
ShuffleMode
::
kFiles
);
// Do internal Schema generation.
auto
schema
=
std
::
make_unique
<
DataSchema
>
();
RETURN_EMPTY_IF_ERROR
(
schema
->
AddColumn
(
ColDescriptor
(
"text"
,
DataType
(
DataType
::
DE_UINT8
),
TensorImpl
::
kFlexible
,
1
)));
// Create and initalize TextFileOp
std
::
shared_ptr
<
TextFileOp
>
text_file_op
=
std
::
make_shared
<
TextFileOp
>
(
num_workers_
,
rows_per_buffer_
,
num_samples_
,
worker_connector_size_
,
std
::
move
(
schema
),
dataset_files_
,
connector_que_size_
,
shuffle_files
,
num_shards_
,
shard_id_
,
std
::
move
(
nullptr
));
RETURN_EMPTY_IF_ERROR
(
text_file_op
->
Init
());
if
(
shuffle_
==
ShuffleMode
::
kGlobal
)
{
// Inject ShuffleOp
std
::
shared_ptr
<
DatasetOp
>
shuffle_op
=
nullptr
;
int64_t
shuffle_size
=
0
;
int64_t
num_rows
=
0
;
// First, get the number of rows in the dataset and then compute the shuffle size
RETURN_EMPTY_IF_ERROR
(
TextFileOp
::
CountAllFileRows
(
dataset_files_
,
&
num_rows
));
shuffle_size
=
ComputeShuffleSize
(
dataset_files_
.
size
(),
num_shards_
,
num_rows
,
0
);
MS_LOG
(
INFO
)
<<
"TextFileDataset::Build - num_rows: "
<<
num_rows
<<
", shuffle_size: "
<<
shuffle_size
;
// Add the shuffle op after this op
shuffle_op
=
std
::
make_shared
<
ShuffleOp
>
(
shuffle_size
,
GetSeed
(),
connector_que_size_
,
true
,
rows_per_buffer_
);
node_ops
.
push_back
(
shuffle_op
);
}
// Add TextFileOp
node_ops
.
push_back
(
text_file_op
);
return
node_ops
;
}
// Constructor for VOCDataset
VOCDataset
::
VOCDataset
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
task
,
const
std
::
string
&
mode
,
const
std
::
map
<
std
::
string
,
int32_t
>
&
class_index
,
bool
decode
,
...
...
mindspore/ccsrc/minddata/dataset/core/constants.h
浏览文件 @
7f39b5cf
...
...
@@ -35,6 +35,9 @@ enum class DatasetType { kUnknown, kArrow, kTf };
// Possible flavours of Tensor implementations
enum
class
TensorImpl
{
kNone
,
kFlexible
,
kCv
,
kNP
};
// Possible values for shuffle
enum
class
ShuffleMode
{
kFalse
=
0
,
kFiles
=
1
,
kGlobal
=
2
};
// Possible values for Border types
enum
class
BorderType
{
kConstant
=
0
,
kEdge
=
1
,
kReflect
=
2
,
kSymmetric
=
3
};
...
...
mindspore/ccsrc/minddata/dataset/include/datasets.h
浏览文件 @
7f39b5cf
...
...
@@ -23,6 +23,7 @@
#include <map>
#include <utility>
#include <string>
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/include/tensor.h"
#include "minddata/dataset/include/iterator.h"
#include "minddata/dataset/include/samplers.h"
...
...
@@ -47,6 +48,7 @@ class Cifar100Dataset;
class
CocoDataset
;
class
ImageFolderDataset
;
class
MnistDataset
;
class
TextFileDataset
;
class
VOCDataset
;
// Dataset Op classes (in alphabetical order)
class
BatchDataset
;
...
...
@@ -83,7 +85,7 @@ std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std:
std
::
shared_ptr
<
Cifar10Dataset
>
Cifar10
(
const
std
::
string
&
dataset_dir
,
std
::
shared_ptr
<
SamplerObj
>
sampler
=
nullptr
);
/// \brief Function to create a Cifar100 Dataset
/// \notes The generated dataset has t
wo
columns ['image', 'coarse_label', 'fine_label']
/// \notes The generated dataset has t
hree
columns ['image', 'coarse_label', 'fine_label']
/// \param[in] dataset_dir Path to the root directory that contains the dataset
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
/// will be used to randomly iterate the entire dataset
...
...
@@ -143,6 +145,25 @@ std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<Sam
std
::
shared_ptr
<
ConcatDataset
>
operator
+
(
const
std
::
shared_ptr
<
Dataset
>
&
datasets1
,
const
std
::
shared_ptr
<
Dataset
>
&
datasets2
);
/// \brief Function to create a TextFileDataset
/// \notes The generated dataset has one column ['text']
/// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
/// will be sorted in a lexicographical order.
/// \param[in] num_samples The number of samples to be included in the dataset.
/// (Default = 0 means all samples.)
/// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal)
/// Can be any of:
/// ShuffleMode.kFalse - No shuffling is performed.
/// ShuffleMode.kFiles - Shuffle files only.
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified. (Default = 0)
/// \return Shared pointer to the current TextFileDataset
std
::
shared_ptr
<
TextFileDataset
>
TextFile
(
std
::
vector
<
std
::
string
>
dataset_files
,
int32_t
num_samples
=
0
,
ShuffleMode
shuffle
=
ShuffleMode
::
kGlobal
,
int32_t
num_shards
=
1
,
int32_t
shard_id
=
0
);
/// \brief Function to create a VOCDataset
/// \notes The generated dataset has multi-columns :
/// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['label', dtype=uint32],
...
...
@@ -289,10 +310,14 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
int32_t
num_workers_
;
int32_t
rows_per_buffer_
;
int32_t
connector_que_size_
;
int32_t
worker_connector_size_
;
};
/* ####################################### Derived Dataset classes ################################# */
// DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS
// (In alphabetical order)
class
CelebADataset
:
public
Dataset
{
public:
/// \brief Constructor
...
...
@@ -318,6 +343,8 @@ class CelebADataset : public Dataset {
std
::
set
<
std
::
string
>
extensions_
;
std
::
shared_ptr
<
SamplerObj
>
sampler_
;
};
// DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS
// (In alphabetical order)
class
Cifar10Dataset
:
public
Dataset
{
public:
...
...
@@ -435,6 +462,33 @@ class MnistDataset : public Dataset {
std
::
shared_ptr
<
SamplerObj
>
sampler_
;
};
/// \class TextFileDataset
/// \brief A Dataset derived class to represent TextFile dataset
class
TextFileDataset
:
public
Dataset
{
public:
/// \brief Constructor
TextFileDataset
(
std
::
vector
<
std
::
string
>
dataset_files
,
int32_t
num_samples
,
ShuffleMode
shuffle
,
int32_t
num_shards
,
int32_t
shard_id
);
/// \brief Destructor
~
TextFileDataset
()
=
default
;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
Build
()
override
;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool
ValidateParams
()
override
;
private:
std
::
vector
<
std
::
string
>
dataset_files_
;
int32_t
num_samples_
;
int32_t
num_shards_
;
int32_t
shard_id_
;
ShuffleMode
shuffle_
;
};
class
VOCDataset
:
public
Dataset
{
public:
/// \brief Constructor
...
...
@@ -467,6 +521,9 @@ class VOCDataset : public Dataset {
std
::
shared_ptr
<
SamplerObj
>
sampler_
;
};
// DERIVED DATASET CLASSES FOR DATASET OPS
// (In alphabetical order)
class
BatchDataset
:
public
Dataset
{
public:
/// \brief Constructor
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
7f39b5cf
...
...
@@ -5012,7 +5012,7 @@ class CSVDataset(SourceDataset):
class
TextFileDataset
(
SourceDataset
):
"""
A source dataset that reads and parses datasets stored on disk in text format.
The generated dataset has one column
s
['text'].
The generated dataset has one column ['text'].
Args:
dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search for a
...
...
tests/ut/cpp/dataset/CMakeLists.txt
浏览文件 @
7f39b5cf
...
...
@@ -97,6 +97,7 @@ SET(DE_UT_SRCS
c_api_dataset_ops_test.cc
c_api_dataset_cifar_test.cc
c_api_dataset_coco_test.cc
c_api_dataset_filetext_test.cc
c_api_dataset_voc_test.cc
c_api_datasets_test.cc
c_api_dataset_iterator_test.cc
...
...
tests/ut/cpp/dataset/c_api_dataset_filetext_test.cc
0 → 100644
浏览文件 @
7f39b5cf
此差异已折叠。
点击以展开。
tests/ut/cpp/dataset/text_file_op_test.cc
浏览文件 @
7f39b5cf
...
...
@@ -89,6 +89,23 @@ TEST_F(MindDataTestTextFileOp, TestTextFileBasic) {
ASSERT_EQ
(
row_count
,
3
);
}
TEST_F
(
MindDataTestTextFileOp
,
TestTextFileFileNotExist
)
{
// Start with an empty execution tree
auto
tree
=
std
::
make_shared
<
ExecutionTree
>
();
std
::
string
dataset_path
=
datasets_root_path_
+
"/does/not/exist/0.txt"
;
std
::
shared_ptr
<
TextFileOp
>
op
;
TextFileOp
::
Builder
builder
;
builder
.
SetTextFilesList
({
dataset_path
})
.
SetRowsPerBuffer
(
16
)
.
SetNumWorkers
(
16
)
.
SetOpConnectorSize
(
2
);
Status
rc
=
builder
.
Build
(
&
op
);
ASSERT_TRUE
(
rc
.
IsOk
());
}
TEST_F
(
MindDataTestTextFileOp
,
TestTotalRows
)
{
std
::
string
tf_file1
=
datasets_root_path_
+
"/testTextFileDataset/1.txt"
;
std
::
string
tf_file2
=
datasets_root_path_
+
"/testTextFileDataset/2.txt"
;
...
...
@@ -110,3 +127,14 @@ TEST_F(MindDataTestTextFileOp, TestTotalRows) {
ASSERT_EQ
(
total_rows
,
5
);
files
.
clear
();
}
TEST_F
(
MindDataTestTextFileOp
,
TestTotalRowsFileNotExist
)
{
std
::
string
tf_file1
=
datasets_root_path_
+
"/does/not/exist/0.txt"
;
std
::
vector
<
std
::
string
>
files
;
files
.
push_back
(
tf_file1
);
int64_t
total_rows
=
0
;
TextFileOp
::
CountAllFileRows
(
files
,
&
total_rows
);
ASSERT_EQ
(
total_rows
,
0
);
}
tests/ut/python/dataset/test_datasets_textfileop.py
浏览文件 @
7f39b5cf
...
...
@@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
pytest
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
from
util
import
config_get_set_num_parallel_workers
from
util
import
config_get_set_num_parallel_workers
,
config_get_set_seed
DATA_FILE
=
"../data/dataset/testTextFileDataset/1.txt"
...
...
@@ -39,10 +40,54 @@ def test_textline_dataset_all_file():
assert
count
==
5
def
test_textline_dataset_totext
():
def
test_textline_dataset_num_samples_zero
():
data
=
ds
.
TextFileDataset
(
DATA_FILE
,
num_samples
=
0
)
count
=
0
for
i
in
data
.
create_dict_iterator
():
logger
.
info
(
"{}"
.
format
(
i
[
"text"
]))
count
+=
1
assert
count
==
3
def
test_textline_dataset_shuffle_false4
():
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
4
)
original_seed
=
config_get_set_seed
(
987
)
data
=
ds
.
TextFileDataset
(
DATA_ALL_FILE
,
shuffle
=
False
)
count
=
0
line
=
[
"This is a text file."
,
"Another file."
,
"Be happy every day."
,
"End of file."
,
"Good luck to everyone."
]
for
i
in
data
.
create_dict_iterator
():
strs
=
i
[
"text"
].
item
().
decode
(
"utf8"
)
assert
strs
==
line
[
count
]
count
+=
1
assert
count
==
5
# Restore configuration
ds
.
config
.
set_num_parallel_workers
(
original_num_parallel_workers
)
ds
.
config
.
set_seed
(
original_seed
)
def
test_textline_dataset_shuffle_false1
():
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
original_seed
=
config_get_set_seed
(
987
)
data
=
ds
.
TextFileDataset
(
DATA_ALL_FILE
,
shuffle
=
False
)
count
=
0
line
=
[
"This is a text file."
,
"Be happy every day."
,
"Good luck to everyone."
,
"Another file."
,
"End of file."
]
for
i
in
data
.
create_dict_iterator
():
strs
=
i
[
"text"
].
item
().
decode
(
"utf8"
)
assert
strs
==
line
[
count
]
count
+=
1
assert
count
==
5
# Restore configuration
ds
.
config
.
set_num_parallel_workers
(
original_num_parallel_workers
)
ds
.
config
.
set_seed
(
original_seed
)
def
test_textline_dataset_shuffle_files4
():
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
4
)
original_seed
=
config_get_set_seed
(
135
)
data
=
ds
.
TextFileDataset
(
DATA_ALL_FILE
,
shuffle
=
ds
.
Shuffle
.
FILES
)
count
=
0
line
=
[
"This is a text file."
,
"Another file."
,
"Be happy every day."
,
"End of file."
,
"Good luck to everyone."
]
for
i
in
data
.
create_dict_iterator
():
...
...
@@ -50,8 +95,60 @@ def test_textline_dataset_totext():
assert
strs
==
line
[
count
]
count
+=
1
assert
count
==
5
# Restore configuration
num_parallel_workers
# Restore configuration
ds
.
config
.
set_num_parallel_workers
(
original_num_parallel_workers
)
ds
.
config
.
set_seed
(
original_seed
)
def
test_textline_dataset_shuffle_files1
():
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
original_seed
=
config_get_set_seed
(
135
)
data
=
ds
.
TextFileDataset
(
DATA_ALL_FILE
,
shuffle
=
ds
.
Shuffle
.
FILES
)
count
=
0
line
=
[
"This is a text file."
,
"Be happy every day."
,
"Good luck to everyone."
,
"Another file."
,
"End of file."
]
for
i
in
data
.
create_dict_iterator
():
strs
=
i
[
"text"
].
item
().
decode
(
"utf8"
)
assert
strs
==
line
[
count
]
count
+=
1
assert
count
==
5
# Restore configuration
ds
.
config
.
set_num_parallel_workers
(
original_num_parallel_workers
)
ds
.
config
.
set_seed
(
original_seed
)
def
test_textline_dataset_shuffle_global4
():
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
4
)
original_seed
=
config_get_set_seed
(
246
)
data
=
ds
.
TextFileDataset
(
DATA_ALL_FILE
,
shuffle
=
ds
.
Shuffle
.
GLOBAL
)
count
=
0
line
=
[
"Another file."
,
"Good luck to everyone."
,
"End of file."
,
"This is a text file."
,
"Be happy every day."
]
for
i
in
data
.
create_dict_iterator
():
strs
=
i
[
"text"
].
item
().
decode
(
"utf8"
)
assert
strs
==
line
[
count
]
count
+=
1
assert
count
==
5
# Restore configuration
ds
.
config
.
set_num_parallel_workers
(
original_num_parallel_workers
)
ds
.
config
.
set_seed
(
original_seed
)
def
test_textline_dataset_shuffle_global1
():
original_num_parallel_workers
=
config_get_set_num_parallel_workers
(
1
)
original_seed
=
config_get_set_seed
(
246
)
data
=
ds
.
TextFileDataset
(
DATA_ALL_FILE
,
shuffle
=
ds
.
Shuffle
.
GLOBAL
)
count
=
0
line
=
[
"Another file."
,
"Good luck to everyone."
,
"This is a text file."
,
"End of file."
,
"Be happy every day."
]
for
i
in
data
.
create_dict_iterator
():
strs
=
i
[
"text"
].
item
().
decode
(
"utf8"
)
assert
strs
==
line
[
count
]
count
+=
1
assert
count
==
5
# Restore configuration
ds
.
config
.
set_num_parallel_workers
(
original_num_parallel_workers
)
ds
.
config
.
set_seed
(
original_seed
)
def
test_textline_dataset_num_samples
():
...
...
@@ -94,11 +191,33 @@ def test_textline_dataset_to_device():
data
=
data
.
to_device
()
data
.
send
()
def
test_textline_dataset_exceptions
():
with
pytest
.
raises
(
ValueError
)
as
error_info
:
_
=
ds
.
TextFileDataset
(
DATA_FILE
,
num_samples
=-
1
)
assert
"Input num_samples is not within the required interval"
in
str
(
error_info
.
value
)
with
pytest
.
raises
(
ValueError
)
as
error_info
:
_
=
ds
.
TextFileDataset
(
"does/not/exist/no.txt"
)
assert
"The following patterns did not match any files"
in
str
(
error_info
.
value
)
with
pytest
.
raises
(
ValueError
)
as
error_info
:
_
=
ds
.
TextFileDataset
(
""
)
assert
"The following patterns did not match any files"
in
str
(
error_info
.
value
)
if
__name__
==
"__main__"
:
test_textline_dataset_one_file
()
test_textline_dataset_all_file
()
test_textline_dataset_totext
()
test_textline_dataset_num_samples_zero
()
test_textline_dataset_shuffle_false4
()
test_textline_dataset_shuffle_false1
()
test_textline_dataset_shuffle_files4
()
test_textline_dataset_shuffle_files1
()
test_textline_dataset_shuffle_global4
()
test_textline_dataset_shuffle_global1
()
test_textline_dataset_num_samples
()
test_textline_dataset_distribution
()
test_textline_dataset_repeat
()
test_textline_dataset_get_datasetsize
()
test_textline_dataset_to_device
()
test_textline_dataset_exceptions
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录