Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2795e492
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2795e492
编写于
4年前
作者:
Y
yanghaitao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
TextFileDataset
上级
18580a78
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
1175 addition
and
38 deletion
+1175
-38
mindspore/ccsrc/dataset/api/de_pipeline.cc
mindspore/ccsrc/dataset/api/de_pipeline.cc
+35
-2
mindspore/ccsrc/dataset/api/de_pipeline.h
mindspore/ccsrc/dataset/api/de_pipeline.h
+4
-1
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+14
-1
mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt
...ore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt
+1
-0
mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc
...re/ccsrc/dataset/engine/datasetops/source/text_file_op.cc
+459
-0
mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h
...ore/ccsrc/dataset/engine/datasetops/source/text_file_op.h
+263
-0
mindspore/dataset/__init__.py
mindspore/dataset/__init__.py
+3
-3
mindspore/dataset/engine/__init__.py
mindspore/dataset/engine/__init__.py
+2
-2
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+104
-26
mindspore/dataset/engine/iterators.py
mindspore/dataset/engine/iterators.py
+8
-2
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+22
-0
mindspore/dataset/transforms/nlp/__init__.py
mindspore/dataset/transforms/nlp/__init__.py
+20
-0
mindspore/dataset/transforms/nlp/utils.py
mindspore/dataset/transforms/nlp/utils.py
+35
-0
tests/ut/cpp/dataset/CMakeLists.txt
tests/ut/cpp/dataset/CMakeLists.txt
+1
-1
tests/ut/cpp/dataset/text_file_op_test.cc
tests/ut/cpp/dataset/text_file_op_test.cc
+112
-0
tests/ut/data/dataset/testTextFileDataset/1.txt
tests/ut/data/dataset/testTextFileDataset/1.txt
+3
-0
tests/ut/data/dataset/testTextFileDataset/2.txt
tests/ut/data/dataset/testTextFileDataset/2.txt
+2
-0
tests/ut/python/dataset/test_datasets_textfileop.py
tests/ut/python/dataset/test_datasets_textfileop.py
+87
-0
未找到文件。
mindspore/ccsrc/dataset/api/de_pipeline.cc
浏览文件 @
2795e492
...
@@ -28,10 +28,10 @@
...
@@ -28,10 +28,10 @@
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/celeba_op.h"
#include "dataset/engine/datasetops/source/celeba_op.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "mindrecord/include/shard_category.h"
#include "mindrecord/include/shard_category.h"
#include "mindrecord/include/shard_sample.h"
#include "mindrecord/include/shard_sample.h"
#include "mindrecord/include/shard_shuffle.h"
#include "mindrecord/include/shard_shuffle.h"
#include "dataset/util/random.h"
#include "dataset/util/random.h"
#include "dataset/util/status.h"
#include "dataset/util/status.h"
#include "utils/log_adapter.h"
#include "utils/log_adapter.h"
...
@@ -61,7 +61,8 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
...
@@ -61,7 +61,8 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
{
kVoc
,
&
DEPipeline
::
ParseVOCOp
},
{
kVoc
,
&
DEPipeline
::
ParseVOCOp
},
{
kCifar10
,
&
DEPipeline
::
ParseCifar10Op
},
{
kCifar10
,
&
DEPipeline
::
ParseCifar10Op
},
{
kCifar100
,
&
DEPipeline
::
ParseCifar100Op
},
{
kCifar100
,
&
DEPipeline
::
ParseCifar100Op
},
{
kCelebA
,
&
DEPipeline
::
ParseCelebAOp
}};
{
kCelebA
,
&
DEPipeline
::
ParseCelebAOp
},
{
kTextFile
,
&
DEPipeline
::
ParseTextFileOp
}};
DEPipeline
::
DEPipeline
()
:
iterator_
(
nullptr
)
{
DEPipeline
::
DEPipeline
()
:
iterator_
(
nullptr
)
{
try
{
try
{
...
@@ -985,5 +986,37 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
...
@@ -985,5 +986,37 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
*
ptr
=
op
;
*
ptr
=
op
;
return
Status
::
OK
();
return
Status
::
OK
();
}
}
Status
DEPipeline
::
ParseTextFileOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
)
{
// Required arguments
std
::
shared_ptr
<
TextFileOp
::
Builder
>
builder
=
std
::
make_shared
<
TextFileOp
::
Builder
>
();
if
(
!
args
[
"dataset_files"
].
is_none
())
{
(
void
)
builder
->
SetTextFilesList
(
ToStringVector
(
args
[
"dataset_files"
]));
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Error: dataset_files is missing"
);
}
// Optional arguments
for
(
auto
arg
:
args
)
{
std
::
string
key
=
py
::
str
(
arg
.
first
);
py
::
handle
value
=
arg
.
second
;
if
(
!
value
.
is_none
())
{
if
(
key
==
"num_parallel_workers"
)
{
(
void
)
builder
->
SetNumWorkers
(
ToInt
(
value
));
}
else
if
(
key
==
"shuffle_files"
)
{
(
void
)
builder
->
SetShuffleFiles
(
ToBool
(
value
));
}
else
if
(
key
==
"num_samples"
)
{
(
void
)
builder
->
SetNumSamples
(
ToInt
(
value
));
}
else
if
(
key
==
"num_shards"
)
{
(
void
)
builder
->
SetNumDevices
(
ToInt
(
value
));
}
else
if
(
key
==
"shard_id"
)
{
(
void
)
builder
->
SetDeviceId
(
ToInt
(
value
));
}
}
}
std
::
shared_ptr
<
TextFileOp
>
op
;
RETURN_IF_NOT_OK
(
builder
->
Build
(
&
op
));
*
ptr
=
op
;
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace dataset
}
// namespace mindspore
}
// namespace mindspore
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/dataset/api/de_pipeline.h
浏览文件 @
2795e492
...
@@ -58,7 +58,8 @@ enum OpName {
...
@@ -58,7 +58,8 @@ enum OpName {
kVoc
,
kVoc
,
kCifar10
,
kCifar10
,
kCifar100
,
kCifar100
,
kCelebA
kCelebA
,
kTextFile
};
};
// The C++ binder class that we expose to the python script.
// The C++ binder class that we expose to the python script.
...
@@ -148,6 +149,8 @@ class DEPipeline {
...
@@ -148,6 +149,8 @@ class DEPipeline {
Status
ParseCelebAOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseCelebAOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseTextFileOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
private:
private:
// Execution tree that links the dataset operators.
// Execution tree that links the dataset operators.
std
::
shared_ptr
<
ExecutionTree
>
tree_
;
std
::
shared_ptr
<
ExecutionTree
>
tree_
;
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
2795e492
...
@@ -55,6 +55,7 @@
...
@@ -55,6 +55,7 @@
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/kernels/data/to_float16_op.h"
#include "dataset/kernels/data/to_float16_op.h"
#include "dataset/util/random.h"
#include "dataset/util/random.h"
#include "mindrecord/include/shard_operator.h"
#include "mindrecord/include/shard_operator.h"
...
@@ -176,6 +177,17 @@ void bindDatasetOps(py::module *m) {
...
@@ -176,6 +177,17 @@ void bindDatasetOps(py::module *m) {
THROW_IF_ERROR
(
MnistOp
::
CountTotalRows
(
dir
,
numSamples
,
&
count
));
THROW_IF_ERROR
(
MnistOp
::
CountTotalRows
(
dir
,
numSamples
,
&
count
));
return
count
;
return
count
;
});
});
(
void
)
py
::
class_
<
TextFileOp
,
DatasetOp
,
std
::
shared_ptr
<
TextFileOp
>>
(
*
m
,
"TextFileOp"
)
.
def_static
(
"get_num_rows"
,
[](
const
py
::
list
&
files
)
{
int64_t
count
=
0
;
std
::
vector
<
std
::
string
>
filenames
;
for
(
auto
file
:
files
)
{
!
file
.
is_none
()
?
filenames
.
push_back
(
py
::
str
(
file
))
:
(
void
)
filenames
.
emplace_back
(
""
);
}
THROW_IF_ERROR
(
TextFileOp
::
CountAllFileRows
(
filenames
,
&
count
));
return
count
;
});
}
}
void
bindTensor
(
py
::
module
*
m
)
{
void
bindTensor
(
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
GlobalContext
>
(
*
m
,
"GlobalContext"
)
(
void
)
py
::
class_
<
GlobalContext
>
(
*
m
,
"GlobalContext"
)
...
@@ -463,7 +475,8 @@ PYBIND11_MODULE(_c_dataengine, m) {
...
@@ -463,7 +475,8 @@ PYBIND11_MODULE(_c_dataengine, m) {
.
value
(
"VOC"
,
OpName
::
kVoc
)
.
value
(
"VOC"
,
OpName
::
kVoc
)
.
value
(
"CIFAR10"
,
OpName
::
kCifar10
)
.
value
(
"CIFAR10"
,
OpName
::
kCifar10
)
.
value
(
"CIFAR100"
,
OpName
::
kCifar100
)
.
value
(
"CIFAR100"
,
OpName
::
kCifar100
)
.
value
(
"CELEBA"
,
OpName
::
kCelebA
);
.
value
(
"CELEBA"
,
OpName
::
kCelebA
)
.
value
(
"TEXTFILE"
,
OpName
::
kTextFile
);
(
void
)
py
::
enum_
<
InterpolationMode
>
(
m
,
"InterpolationMode"
,
py
::
arithmetic
())
(
void
)
py
::
enum_
<
InterpolationMode
>
(
m
,
"InterpolationMode"
,
py
::
arithmetic
())
.
value
(
"DE_INTER_LINEAR"
,
InterpolationMode
::
kLinear
)
.
value
(
"DE_INTER_LINEAR"
,
InterpolationMode
::
kLinear
)
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt
浏览文件 @
2795e492
...
@@ -18,6 +18,7 @@ add_library(engine-datasetops-source OBJECT
...
@@ -18,6 +18,7 @@ add_library(engine-datasetops-source OBJECT
manifest_op.cc
manifest_op.cc
cifar_op.cc
cifar_op.cc
celeba_op.cc
celeba_op.cc
text_file_op.cc
)
)
add_dependencies
(
engine-datasetops-source mindspore::protobuf
)
add_dependencies
(
engine-datasetops-source mindspore::protobuf
)
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc
0 → 100644
浏览文件 @
2795e492
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <fstream>
#include <memory>
#include <string>
#include <utility>
#include "common/utils.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/core/config_manager.h"
#include "dataset/util/task_manager.h"
#include "dataset/util/wait_post.h"
#include "dataset/util/random.h"
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/engine/execution_tree.h"
namespace
mindspore
{
namespace
dataset
{
TextFileOp
::
Builder
::
Builder
()
:
builder_device_id_
(
0
),
builder_num_devices_
(
1
),
builder_num_samples_
(
0
),
builder_shuffle_files_
(
false
)
{
std
::
shared_ptr
<
ConfigManager
>
config_manager
=
GlobalContext
::
config_manager
();
builder_num_workers_
=
config_manager
->
num_parallel_workers
();
builder_op_connector_size_
=
config_manager
->
op_connector_size
();
builder_rows_per_buffer_
=
config_manager
->
rows_per_buffer
();
builder_worker_connector_size_
=
config_manager
->
worker_connector_size
();
}
Status
TextFileOp
::
Builder
::
ValidateInputs
()
const
{
std
::
string
err_msg
;
err_msg
+=
builder_num_workers_
<=
0
?
"Number of parallel workers should be greate than 0
\n
"
:
""
;
err_msg
+=
builder_device_id_
>=
builder_num_devices_
||
builder_num_devices_
<
1
?
"Wrong sharding configs
\n
"
:
""
;
return
err_msg
.
empty
()
?
Status
::
OK
()
:
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
err_msg
);
}
Status
TextFileOp
::
Builder
::
Build
(
std
::
shared_ptr
<
TextFileOp
>
*
op
)
{
RETURN_IF_NOT_OK
(
ValidateInputs
());
// Throttle the number of workers if we have more workers than files!
if
(
static_cast
<
size_t
>
(
builder_num_workers_
)
>
builder_text_files_list_
.
size
())
{
builder_num_workers_
=
builder_text_files_list_
.
size
();
MS_LOG
(
WARNING
)
<<
"TextFileOp operator parallelism reduced to "
<<
builder_num_workers_
<<
" workers."
;
}
builder_schema_
=
std
::
make_unique
<
DataSchema
>
();
RETURN_IF_NOT_OK
(
builder_schema_
->
AddColumn
(
ColDescriptor
(
"text"
,
DataType
(
DataType
::
DE_UINT8
),
TensorImpl
::
kFlexible
,
1
)));
std
::
shared_ptr
<
TextFileOp
>
text_file_op
=
std
::
make_shared
<
TextFileOp
>
(
builder_num_workers_
,
builder_rows_per_buffer_
,
builder_num_samples_
,
builder_worker_connector_size_
,
std
::
move
(
builder_schema_
),
builder_text_files_list_
,
builder_op_connector_size_
,
builder_shuffle_files_
,
builder_num_devices_
,
builder_device_id_
);
RETURN_IF_NOT_OK
(
text_file_op
->
Init
());
*
op
=
std
::
move
(
text_file_op
);
return
Status
::
OK
();
}
TextFileOp
::
TextFileOp
(
int32_t
num_workers
,
int64_t
rows_per_buffer
,
int64_t
num_samples
,
int32_t
worker_connector_size
,
std
::
unique_ptr
<
DataSchema
>
schema
,
std
::
vector
<
std
::
string
>
text_files_list
,
int32_t
op_connector_size
,
bool
shuffle_files
,
int32_t
num_device
,
int32_t
device_id
)
:
ParallelOp
(
num_workers
,
op_connector_size
),
device_id_
(
device_id
),
num_devices_
(
num_device
),
rows_per_buffer_
(
rows_per_buffer
),
num_samples_
(
num_samples
),
text_files_list_
(
std
::
move
(
text_files_list
)),
shuffle_files_
(
shuffle_files
),
data_schema_
(
std
::
move
(
schema
)),
all_num_rows_
(
0
),
num_rows_per_shard_
(
0
),
filename_index_
(
std
::
make_unique
<
StringIndex
>
()),
finished_reading_dataset_
(
false
),
load_io_block_queue_
(
true
),
load_jagged_connector_
(
true
)
{
worker_connector_size_
=
worker_connector_size
;
}
Status
TextFileOp
::
Init
()
{
RETURN_IF_NOT_OK
(
filename_index_
->
insert
(
text_files_list_
));
int32_t
safe_queue_size
=
static_cast
<
int32_t
>
(
std
::
ceil
(
text_files_list_
.
size
()
/
num_workers_
)
+
1
);
io_block_queues_
.
Init
(
num_workers_
,
safe_queue_size
);
for
(
int32_t
i
=
0
;
i
<
data_schema_
->
NumColumns
();
++
i
)
{
col_name_map_
[
data_schema_
->
column
(
i
).
name
()]
=
i
;
}
RETURN_IF_NOT_OK
(
ParallelOp
::
CreateWorkerConnector
(
worker_connector_size_
));
jagged_buffer_connector_
=
std
::
make_unique
<
JaggedConnector
>
(
num_workers_
,
1
,
worker_connector_size_
);
return
Status
::
OK
();
}
Status
TextFileOp
::
Reset
()
{
load_jagged_connector_
=
true
;
load_io_block_queue_
=
true
;
RETURN_IF_NOT_OK
(
ParallelOp
::
Reset
());
NotifyToFillIOBlockQueue
();
return
Status
::
OK
();
}
Status
TextFileOp
::
LoadTensor
(
const
std
::
string
&
line
,
std
::
unique_ptr
<
TensorQTable
>
*
tensor_table
,
int64_t
row
)
{
TensorRow
tRow
(
1
,
nullptr
);
(
*
tensor_table
)
->
push_back
(
std
::
move
(
tRow
));
std
::
shared_ptr
<
Tensor
>
tensor
;
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
&
tensor
,
data_schema_
->
column
(
0
).
tensorImpl
(),
TensorShape
(
std
::
vector
<
dsize_t
>
(
1
,
line
.
size
())),
data_schema_
->
column
(
0
).
type
(),
const_cast
<
unsigned
char
*>
(
reinterpret_cast
<
const
unsigned
char
*>
(
common
::
SafeCStr
(
line
)))));
(
**
tensor_table
)[
row
][
0
]
=
std
::
move
(
tensor
);
return
Status
::
OK
();
}
Status
TextFileOp
::
LoadFile
(
const
std
::
string
&
file
,
const
int64_t
start_offset
,
const
int64_t
end_offset
,
const
int32_t
worker_id
)
{
std
::
ifstream
handle
(
file
);
if
(
!
handle
.
is_open
())
{
RETURN_STATUS_UNEXPECTED
(
"Failed to open file "
+
file
);
}
int64_t
rows_each_buffer
=
0
;
int64_t
rows_total
=
0
;
std
::
string
line
;
std
::
unique_ptr
<
DataBuffer
>
cur_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
BufferFlags
::
kDeBFlagNone
);
cur_buffer
->
set_column_name_map
(
col_name_map_
);
std
::
unique_ptr
<
TensorQTable
>
tensor_table
=
std
::
make_unique
<
TensorQTable
>
();
while
(
getline
(
handle
,
line
))
{
// If read to the end offset of this file, break.
if
(
rows_total
>=
end_offset
)
{
break
;
}
// Skip line before start offset.
if
(
rows_total
<
start_offset
)
{
rows_total
++
;
continue
;
}
RETURN_IF_NOT_OK
(
LoadTensor
(
line
,
&
tensor_table
,
rows_each_buffer
));
rows_each_buffer
++
;
rows_total
++
;
if
(
rows_each_buffer
==
rows_per_buffer_
)
{
cur_buffer
->
set_tensor_table
(
std
::
move
(
tensor_table
));
RETURN_IF_NOT_OK
(
jagged_buffer_connector_
->
Add
(
worker_id
,
std
::
move
(
cur_buffer
)));
cur_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
BufferFlags
::
kDeBFlagNone
);
cur_buffer
->
set_column_name_map
(
col_name_map_
);
tensor_table
=
std
::
make_unique
<
TensorQTable
>
();
rows_each_buffer
=
0
;
}
}
if
(
rows_each_buffer
>
0
)
{
cur_buffer
->
set_tensor_table
(
std
::
move
(
tensor_table
));
RETURN_IF_NOT_OK
(
jagged_buffer_connector_
->
Add
(
worker_id
,
std
::
move
(
cur_buffer
)));
}
return
Status
::
OK
();
}
Status
TextFileOp
::
WorkerEntry
(
int32_t
worker_id
)
{
TaskManager
::
FindMe
()
->
Post
();
std
::
unique_ptr
<
FilenameBlock
>
io_block
;
RETURN_IF_NOT_OK
(
PopIoBlockQueue
(
worker_id
,
&
io_block
));
while
(
!
io_block
->
eof
())
{
if
(
!
io_block
->
eoe
())
{
if
(
load_jagged_connector_
)
{
std
::
string
filename
;
RETURN_IF_NOT_OK
(
io_block
->
GetFilename
(
&
filename
,
*
filename_index_
));
int64_t
start_offset
=
io_block
->
GetStartOffset
();
int64_t
end_offset
=
io_block
->
GetEndOffset
();
RETURN_IF_NOT_OK
(
LoadFile
(
filename
,
start_offset
,
end_offset
,
worker_id
));
}
}
else
{
std
::
unique_ptr
<
DataBuffer
>
eoe_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
RETURN_IF_NOT_OK
(
jagged_buffer_connector_
->
Add
(
worker_id
,
std
::
move
(
eoe_buffer
)));
}
RETURN_IF_NOT_OK
(
PopIoBlockQueue
(
worker_id
,
&
io_block
));
}
return
Status
::
OK
();
}
// Pops an element from a queue in io_block_queues
Status
TextFileOp
::
PopIoBlockQueue
(
int32_t
index
,
std
::
unique_ptr
<
FilenameBlock
>
*
out_block
)
{
RETURN_IF_NOT_OK
(
io_block_queues_
[
index
]
->
PopFront
(
out_block
));
return
Status
::
OK
();
}
// Pushes an element to a queue in io_block_queues
Status
TextFileOp
::
PushIoBlockQueue
(
int32_t
index
,
std
::
unique_ptr
<
FilenameBlock
>
&&
io_block
)
{
RETURN_IF_NOT_OK
(
io_block_queues_
[
index
]
->
Add
(
std
::
move
(
io_block
)));
return
Status
::
OK
();
}
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
Status
TextFileOp
::
PostEndOfData
()
{
for
(
int
i
=
0
;
i
<
num_workers_
;
++
i
)
{
std
::
unique_ptr
<
FilenameBlock
>
eof
=
std
::
make_unique
<
FilenameBlock
>
(
IOBlock
::
kDeIoBlockFlagEof
);
RETURN_IF_NOT_OK
(
PushIoBlockQueue
(
i
,
std
::
move
(
eof
)));
}
return
Status
::
OK
();
}
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
Status
TextFileOp
::
PostEndOfEpoch
(
int32_t
queue_index
)
{
for
(
int
i
=
0
;
i
<
num_workers_
;
++
i
)
{
std
::
unique_ptr
<
FilenameBlock
>
eoe
=
std
::
make_unique
<
FilenameBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
);
RETURN_IF_NOT_OK
(
PushIoBlockQueue
((
queue_index
+
i
)
%
num_workers_
,
std
::
move
(
eoe
)));
}
return
Status
::
OK
();
}
static
void
ShuffleKeys
(
std
::
vector
<
int64_t
>
*
i_keys
,
uint32_t
seed
)
{
std
::
mt19937
rng
(
seed
);
std
::
shuffle
(
i_keys
->
begin
(),
i_keys
->
end
(),
rng
);
}
bool
TextFileOp
::
NeedPushFileToBlockQueue
(
const
std
::
string
&
file_name
,
int64_t
*
start_offset
,
int64_t
*
end_offset
,
const
int64_t
&
pre_count
)
{
*
start_offset
=
0
;
*
end_offset
=
0
;
bool
push
=
false
;
int64_t
start_index
=
device_id_
*
num_rows_per_shard_
;
if
(
device_id_
+
1
<
0
)
{
MS_LOG
(
ERROR
)
<<
"Device id is invalid"
;
return
false
;
}
int64_t
end_index
=
(
static_cast
<
int64_t
>
(
device_id_
)
+
1
)
*
num_rows_per_shard_
;
if
(
pre_count
<=
start_index
&&
pre_count
+
filename_numrows_
[
file_name
]
>
start_index
)
{
*
start_offset
=
start_index
-
pre_count
;
push
=
true
;
if
(
pre_count
<
end_index
&&
pre_count
+
filename_numrows_
[
file_name
]
>=
end_index
)
{
*
end_offset
=
end_index
-
pre_count
;
}
else
{
*
end_offset
=
filename_numrows_
[
file_name
];
}
}
if
(
pre_count
>=
start_index
&&
pre_count
<
end_index
)
{
*
start_offset
=
0
;
push
=
true
;
if
(
pre_count
+
filename_numrows_
[
file_name
]
>=
end_index
)
{
*
end_offset
=
end_index
-
pre_count
;
}
else
{
*
end_offset
=
filename_numrows_
[
file_name
];
}
}
return
push
;
}
Status
TextFileOp
::
FillIOBlockQueue
(
const
std
::
vector
<
int64_t
>
&
i_keys
)
{
int32_t
queue_index
=
0
;
int64_t
pre_count
=
0
;
int64_t
start_offset
=
0
;
int64_t
end_offset
=
0
;
bool
finish
=
false
;
while
(
!
finish
)
{
std
::
vector
<
std
::
pair
<
std
::
string
,
int64_t
>>
file_index
;
if
(
!
i_keys
.
empty
())
{
for
(
auto
it
=
i_keys
.
begin
();
it
!=
i_keys
.
end
();
++
it
)
{
{
if
(
!
load_io_block_queue_
)
{
break
;
}
}
auto
file_it
=
filename_index_
->
Search
(
*
it
);
file_index
.
emplace_back
(
std
::
pair
<
std
::
string
,
int64_t
>
(
file_it
.
value
(),
*
it
));
}
}
else
{
for
(
auto
it
=
filename_index_
->
begin
();
it
!=
filename_index_
->
end
();
++
it
)
{
{
if
(
!
load_io_block_queue_
)
{
break
;
}
}
file_index
.
emplace_back
(
std
::
pair
<
std
::
string
,
int64_t
>
(
it
.
value
(),
it
.
key
()));
}
}
for
(
auto
file_info
:
file_index
)
{
if
(
NeedPushFileToBlockQueue
(
file_info
.
first
,
&
start_offset
,
&
end_offset
,
pre_count
))
{
auto
ioBlock
=
std
::
make_unique
<
FilenameBlock
>
(
file_info
.
second
,
start_offset
,
end_offset
,
IOBlock
::
kDeIoBlockNone
);
RETURN_IF_NOT_OK
(
PushIoBlockQueue
(
queue_index
,
std
::
move
(
ioBlock
)));
queue_index
=
(
queue_index
+
1
)
%
num_workers_
;
}
pre_count
+=
filename_numrows_
[
file_info
.
first
];
}
if
(
pre_count
<
(
static_cast
<
int64_t
>
(
device_id_
)
+
1
)
*
num_rows_per_shard_
)
{
finish
=
false
;
}
else
{
finish
=
true
;
}
}
RETURN_IF_NOT_OK
(
PostEndOfEpoch
(
queue_index
));
return
Status
::
OK
();
}
Status
TextFileOp
::
WaitToFillIOBlockQueue
()
{
// must be called first if called by worker spanwed by taskgroup
TaskManager
::
FindMe
()
->
Post
();
std
::
vector
<
int64_t
>
i_keys
;
if
(
shuffle_files_
)
{
for
(
auto
it
=
filename_index_
->
begin
();
it
!=
filename_index_
->
end
();
++
it
)
{
i_keys
.
push_back
(
it
.
key
());
}
}
uint32_t
seed
=
0
;
while
(
true
)
{
RETURN_IF_NOT_OK
(
io_block_queue_wait_post_
.
Wait
());
io_block_queue_wait_post_
.
Clear
();
if
(
finished_reading_dataset_
)
{
break
;
}
if
(
shuffle_files_
)
{
ShuffleKeys
(
&
i_keys
,
num_devices_
==
1
?
GetSeed
()
:
++
seed
);
}
RETURN_IF_NOT_OK
(
FillIOBlockQueue
(
i_keys
));
}
return
Status
::
OK
();
}
void
TextFileOp
::
NotifyToFillIOBlockQueue
()
{
io_block_queue_wait_post_
.
Set
();
}
Status
TextFileOp
::
operator
()()
{
RETURN_IF_NOT_OK
(
CalculateNumRowsPerShard
());
// launch one thread, responsible for filling IoBlockQueue
RETURN_IF_NOT_OK
(
tree_
->
LaunchWorkers
(
1
,
std
::
bind
(
&
TextFileOp
::
WaitToFillIOBlockQueue
,
this
)));
// Read data from disk into buffers
RETURN_IF_NOT_OK
(
tree_
->
LaunchWorkers
(
num_workers_
,
std
::
bind
(
&
TextFileOp
::
WorkerEntry
,
this
,
std
::
placeholders
::
_1
)));
// must be called after launching workers.
TaskManager
::
FindMe
()
->
Post
();
io_block_queue_wait_post_
.
Register
(
tree_
->
AllTasks
());
NotifyToFillIOBlockQueue
();
while
(
!
finished_reading_dataset_
)
{
int64_t
buffer_id
=
0
;
int32_t
workers_done
=
0
;
int64_t
rows_read
=
0
;
load_io_block_queue_
=
true
;
while
(
workers_done
<
num_workers_
)
{
std
::
unique_ptr
<
DataBuffer
>
buffer
;
RETURN_IF_NOT_OK
(
jagged_buffer_connector_
->
Pop
(
0
,
&
buffer
));
if
(
buffer
->
eoe
())
{
workers_done
++
;
}
else
if
(
num_samples_
==
0
||
rows_read
<
num_samples_
)
{
if
((
num_samples_
>
0
)
&&
(
rows_read
+
buffer
->
NumRows
()
>
num_samples_
))
{
int64_t
rowsToRemove
=
buffer
->
NumRows
()
-
(
num_samples_
-
rows_read
);
RETURN_IF_NOT_OK
(
buffer
->
SliceOff
(
rowsToRemove
));
}
rows_read
+=
buffer
->
NumRows
();
buffer
->
set_id
(
buffer_id
++
);
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
buffer
)));
}
else
{
// end of epoch
load_jagged_connector_
=
false
;
load_io_block_queue_
=
false
;
}
}
std
::
unique_ptr
<
DataBuffer
>
eoe_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
eoe_buffer
)));
if
(
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
||
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
))
{
finished_reading_dataset_
=
true
;
NotifyToFillIOBlockQueue
();
}
else
{
jagged_buffer_connector_
->
DoReset
();
buffer_id
=
0
;
}
}
std
::
unique_ptr
<
DataBuffer
>
eof_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOF
);
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
eof_buffer
)));
RETURN_IF_NOT_OK
(
PostEndOfData
());
return
Status
::
OK
();
}
int64_t
TextFileOp
::
CountTotalRows
(
const
std
::
string
&
file
)
{
std
::
ifstream
handle
(
file
);
if
(
!
handle
.
is_open
())
{
MS_LOG
(
ERROR
)
<<
"Failed to open file: "
<<
file
;
return
0
;
}
std
::
string
line
;
int64_t
count
=
0
;
while
(
getline
(
handle
,
line
))
{
count
++
;
}
return
count
;
}
Status
TextFileOp
::
CalculateNumRowsPerShard
()
{
for
(
auto
it
=
filename_index_
->
begin
();
it
!=
filename_index_
->
end
();
++
it
)
{
int64_t
count
=
CountTotalRows
(
it
.
value
());
filename_numrows_
[
it
.
value
()]
=
count
;
all_num_rows_
+=
count
;
}
if
(
all_num_rows_
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"Number of rows can not be zero"
);
}
num_rows_per_shard_
=
static_cast
<
int64_t
>
(
std
::
ceil
(
all_num_rows_
*
1.0
/
num_devices_
));
MS_LOG
(
DEBUG
)
<<
"Number rows per shard is "
<<
num_rows_per_shard_
;
return
Status
::
OK
();
}
Status
TextFileOp
::
CountAllFileRows
(
const
std
::
vector
<
std
::
string
>
&
files
,
int64_t
*
count
)
{
std
::
shared_ptr
<
TextFileOp
>
op
;
*
count
=
0
;
RETURN_IF_NOT_OK
(
Builder
().
SetTextFilesList
(
files
).
Build
(
&
op
));
for
(
auto
file
:
files
)
{
*
count
+=
op
->
CountTotalRows
(
file
);
}
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h
0 → 100644
浏览文件 @
2795e492
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_
#include <memory>
#include <map>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
#include "dataset/util/status.h"
#include "dataset/util/auto_index.h"
#include "dataset/engine/data_schema.h"
#include "dataset/engine/datasetops/parallel_op.h"
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/util/queue.h"
#include "dataset/util/wait_post.h"
#include "dataset/engine/jagged_connector.h"
namespace
mindspore
{
namespace
dataset
{
using
StringIndex
=
AutoIndexObj
<
std
::
string
>
;
class
TextFileOp
:
public
ParallelOp
{
public:
class
Builder
{
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @return This is a constructor.
Builder
();
// Default destructor
~
Builder
()
=
default
;
// Checks if the inputs of the builder is valid.
// @return Status - the error code returned.
Status
ValidateInputs
()
const
;
// Create the final object.
// @param op - dataset op.
// @return - the error code return.
Status
Build
(
std
::
shared_ptr
<
TextFileOp
>
*
op
);
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetNumWorkers
(
int32_t
num_workers
)
{
builder_num_workers_
=
num_workers
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetOpConnectorSize
(
int32_t
op_connector_size
)
{
builder_op_connector_size_
=
op_connector_size
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetRowsPerBuffer
(
int64_t
rows_per_buffer
)
{
builder_rows_per_buffer_
=
rows_per_buffer
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetNumDevices
(
int64_t
num_dev
)
{
builder_num_devices_
=
num_dev
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetDeviceId
(
int64_t
dev_id
)
{
builder_device_id_
=
dev_id
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetTextFilesList
(
const
std
::
vector
<
std
::
string
>
&
files_list
)
{
builder_text_files_list_
=
files_list
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetShuffleFiles
(
bool
shuffle_files
)
{
builder_shuffle_files_
=
shuffle_files
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetNumSamples
(
int64_t
num_samples
)
{
builder_num_samples_
=
num_samples
;
return
*
this
;
}
private:
int32_t
builder_device_id_
;
int32_t
builder_num_devices_
;
int32_t
builder_num_workers_
;
int32_t
builder_op_connector_size_
;
int64_t
builder_rows_per_buffer_
;
int64_t
builder_num_samples_
;
int32_t
builder_worker_connector_size_
;
std
::
vector
<
std
::
string
>
builder_text_files_list_
;
bool
builder_shuffle_files_
;
std
::
unique_ptr
<
DataSchema
>
builder_schema_
;
};
// Constructor of TextFileOp
// @note The builder class should be used to call this constructor.
// @param num_workers - number of worker threads reading data from tf_file files.
// @param rows_per_buffer - number of rows that a full buffer will contain.
// @param total_num_rows - number of rows to read
// @param dataset_files_list - list of filepaths for the dataset files.
// @param data_schema - the data schema object.
// @param op_connector_size - size of each queue in the connector that the child operator pulls from.
// @param columns_to_load - the names of the columns to load data from.
// @param shuffle_files - whether or not to shuffle the files before reading data.
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
TextFileOp
(
int32_t
num_workers
,
int64_t
rows_per_buffer
,
int64_t
num_samples
,
int32_t
worker_connector_size
,
std
::
unique_ptr
<
DataSchema
>
,
std
::
vector
<
std
::
string
>
text_files_list
,
int32_t
op_connector_size
,
bool
shuffle_files
,
int32_t
num_devices
,
int32_t
device_id
);
// Default destructor
~
TextFileOp
()
=
default
;
// Instantiates the internal queues and connectors
// @return Status - the error code returned
Status
Init
();
// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status
operator
()()
override
;
// Overrides base class reset method. Cleans up any state info from it's previous execution
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status
Reset
()
override
;
// Get total rows in files.
// @param files - all text files.
// @param count - number of rows.
// @return Status - the error coed returned.
static
Status
CountAllFileRows
(
const
std
::
vector
<
std
::
string
>
&
files
,
int64_t
*
count
);
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status
WorkerEntry
(
int32_t
worker_id
)
override
;
// Parses a single row and puts the data into a tensor table.
// @param line - the content of the row.
// @param tensor_table - the tensor table to put the parsed data in.
// @param row - the id of the row filled in the tensor table.
// @return Status - the error code returned.
Status
LoadTensor
(
const
std
::
string
&
line
,
std
::
unique_ptr
<
TensorQTable
>
*
tensor_table
,
int64_t
row
);
// Reads a text file and loads the data into multiple buffers.
// @param file - the file to read.
// @param start_offset - the start offset of file.
// @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status
LoadFile
(
const
std
::
string
&
file
,
const
int64_t
start_offset
,
const
int64_t
end_offset
,
const
int32_t
worker_id
);
// Calculate number of rows in each shard.
// @return Status - the error code returned.
Status
CalculateNumRowsPerShard
();
// Count number of rows in each file.
// @param filename - text file name.
// @return int64_t - the total number of rows in file.
int64_t
CountTotalRows
(
const
std
::
string
&
file
);
// Notifies the thread which called FillIoBlockQueue to resume execution
void
NotifyToFillIOBlockQueue
();
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status
WaitToFillIOBlockQueue
();
// Fill the IOBlockQueue.
// @para i_keys - keys of file to fill to the IOBlockQueue
// @return Status - the error code returned.
Status
FillIOBlockQueue
(
const
std
::
vector
<
int64_t
>
&
i_keys
);
// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_file - If file contains the first sample of data.
// @param end_file - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool
NeedPushFileToBlockQueue
(
const
std
::
string
&
file_name
,
int64_t
*
start_offset
,
int64_t
*
end_offset
,
const
int64_t
&
pre_count
);
// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status
PopIoBlockQueue
(
int32_t
index
,
std
::
unique_ptr
<
FilenameBlock
>
*
out_block
);
// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status
PushIoBlockQueue
(
int32_t
index
,
std
::
unique_ptr
<
FilenameBlock
>
&&
io_block
);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status
PostEndOfData
();
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status
PostEndOfEpoch
(
int32_t
queue_index
);
int32_t
device_id_
;
int32_t
num_devices_
;
int64_t
rows_per_buffer_
;
int64_t
num_samples_
;
std
::
vector
<
std
::
string
>
text_files_list_
;
bool
shuffle_files_
;
std
::
unique_ptr
<
DataSchema
>
data_schema_
;
int64_t
all_num_rows_
;
int64_t
num_rows_per_shard_
;
std
::
map
<
std
::
string
,
int64_t
>
filename_numrows_
;
std
::
unique_ptr
<
StringIndex
>
filename_index_
;
QueueList
<
std
::
unique_ptr
<
FilenameBlock
>>
io_block_queues_
;
WaitPost
io_block_queue_wait_post_
;
bool
finished_reading_dataset_
;
bool
load_io_block_queue_
;
bool
load_jagged_connector_
;
std
::
unordered_map
<
std
::
string
,
int32_t
>
col_name_map_
;
std
::
unique_ptr
<
JaggedConnector
>
jagged_buffer_connector_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_
This diff is collapsed.
Click to expand it.
mindspore/dataset/__init__.py
浏览文件 @
2795e492
...
@@ -20,8 +20,8 @@ can also create samplers with this module to sample data.
...
@@ -20,8 +20,8 @@ can also create samplers with this module to sample data.
from
.core.configuration
import
config
from
.core.configuration
import
config
from
.engine.datasets
import
StorageDataset
,
TFRecordDataset
,
ImageFolderDatasetV2
,
MnistDataset
,
MindDataset
,
\
from
.engine.datasets
import
StorageDataset
,
TFRecordDataset
,
ImageFolderDatasetV2
,
MnistDataset
,
MindDataset
,
\
GeneratorDataset
,
ManifestDataset
,
Cifar10Dataset
,
Cifar100Dataset
,
VOCDataset
,
CelebADataset
,
Schema
,
\
GeneratorDataset
,
ManifestDataset
,
Cifar10Dataset
,
Cifar100Dataset
,
VOCDataset
,
CelebADataset
,
TextFileDataset
,
\
Shuffle
,
zip
S
chema
,
S
huffle
,
zip
from
.engine.samplers
import
DistributedSampler
,
PKSampler
,
RandomSampler
,
SequentialSampler
,
SubsetRandomSampler
,
\
from
.engine.samplers
import
DistributedSampler
,
PKSampler
,
RandomSampler
,
SequentialSampler
,
SubsetRandomSampler
,
\
WeightedRandomSampler
WeightedRandomSampler
from
.engine.serializer_deserializer
import
serialize
,
deserialize
,
show
from
.engine.serializer_deserializer
import
serialize
,
deserialize
,
show
...
@@ -29,5 +29,5 @@ from .engine.serializer_deserializer import serialize, deserialize, show
...
@@ -29,5 +29,5 @@ from .engine.serializer_deserializer import serialize, deserialize, show
__all__
=
[
"config"
,
"ImageFolderDatasetV2"
,
"MnistDataset"
,
"StorageDataset"
,
__all__
=
[
"config"
,
"ImageFolderDatasetV2"
,
"MnistDataset"
,
"StorageDataset"
,
"MindDataset"
,
"GeneratorDataset"
,
"TFRecordDataset"
,
"MindDataset"
,
"GeneratorDataset"
,
"TFRecordDataset"
,
"ManifestDataset"
,
"Cifar10Dataset"
,
"Cifar100Dataset"
,
"CelebADataset"
,
"ManifestDataset"
,
"Cifar10Dataset"
,
"Cifar100Dataset"
,
"CelebADataset"
,
"VOCDataset"
,
"Schema"
,
"DistributedSampler"
,
"PKSampler"
,
"RandomSampler"
,
"VOCDataset"
,
"
TextFileDataset"
,
"
Schema"
,
"DistributedSampler"
,
"PKSampler"
,
"RandomSampler"
,
"SequentialSampler"
,
"SubsetRandomSampler"
,
"WeightedRandomSampler"
,
"zip"
]
"SequentialSampler"
,
"SubsetRandomSampler"
,
"WeightedRandomSampler"
,
"zip"
]
This diff is collapsed.
Click to expand it.
mindspore/dataset/engine/__init__.py
浏览文件 @
2795e492
...
@@ -33,5 +33,5 @@ __all__ = ["config", "ConfigurationManager", "zip", "StorageDataset",
...
@@ -33,5 +33,5 @@ __all__ = ["config", "ConfigurationManager", "zip", "StorageDataset",
"ImageFolderDatasetV2"
,
"MnistDataset"
,
"ImageFolderDatasetV2"
,
"MnistDataset"
,
"MindDataset"
,
"GeneratorDataset"
,
"TFRecordDataset"
,
"MindDataset"
,
"GeneratorDataset"
,
"TFRecordDataset"
,
"ManifestDataset"
,
"Cifar10Dataset"
,
"Cifar100Dataset"
,
"CelebADataset"
,
"ManifestDataset"
,
"Cifar10Dataset"
,
"Cifar100Dataset"
,
"CelebADataset"
,
"VOCDataset"
,
"
Schema"
,
"DistributedSampler"
,
"PKSampler"
,
"Random
Sampler"
,
"VOCDataset"
,
"
TextFileDataset"
,
"Schema"
,
"DistributedSampler"
,
"PK
Sampler"
,
"SequentialSampler"
,
"SubsetRandomSampler"
,
"WeightedRandomSampler"
]
"
RandomSampler"
,
"
SequentialSampler"
,
"SubsetRandomSampler"
,
"WeightedRandomSampler"
]
This diff is collapsed.
Click to expand it.
mindspore/dataset/engine/datasets.py
浏览文件 @
2795e492
...
@@ -29,7 +29,7 @@ from importlib import import_module
...
@@ -29,7 +29,7 @@ from importlib import import_module
import
numpy
as
np
import
numpy
as
np
from
mindspore._c_dataengine
import
DataType
,
TFReaderOp
,
ImageFolderOp
,
CifarOp
,
MnistOp
,
ManifestOp
,
\
from
mindspore._c_dataengine
import
DataType
,
TFReaderOp
,
ImageFolderOp
,
CifarOp
,
MnistOp
,
ManifestOp
,
\
MindRecordOp
,
CBatchInfo
MindRecordOp
,
TextFileOp
,
CBatchInfo
from
mindspore._c_expression
import
typing
from
mindspore._c_expression
import
typing
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
...
@@ -38,7 +38,7 @@ from .iterators import DictIterator, TupleIterator
...
@@ -38,7 +38,7 @@ from .iterators import DictIterator, TupleIterator
from
.validators
import
check
,
check_batch
,
check_shuffle
,
check_map
,
check_repeat
,
check_skip
,
check_zip
,
check_rename
,
\
from
.validators
import
check
,
check_batch
,
check_shuffle
,
check_map
,
check_repeat
,
check_skip
,
check_zip
,
check_rename
,
\
check_take
,
check_project
,
check_imagefolderdatasetv2
,
check_mnist_cifar_dataset
,
check_manifestdataset
,
\
check_take
,
check_project
,
check_imagefolderdatasetv2
,
check_mnist_cifar_dataset
,
check_manifestdataset
,
\
check_tfrecorddataset
,
check_vocdataset
,
check_celebadataset
,
check_minddataset
,
check_generatordataset
,
\
check_tfrecorddataset
,
check_vocdataset
,
check_celebadataset
,
check_minddataset
,
check_generatordataset
,
\
check_zip_dataset
,
check_add_column
check_zip_dataset
,
check_add_column
,
check_textfiledataset
from
..core.datatypes
import
mstype_to_detype
,
mstypelist_to_detypelist
from
..core.datatypes
import
mstype_to_detype
,
mstypelist_to_detypelist
try
:
try
:
...
@@ -888,6 +888,29 @@ class SourceDataset(Dataset):
...
@@ -888,6 +888,29 @@ class SourceDataset(Dataset):
# No need for __init__ since it is the same as the super's init
# No need for __init__ since it is the same as the super's init
@
staticmethod
def
_find_files
(
patterns
):
"""
Utility function to search for files with the given glob patterns.
Args:
patterns (str or list[str]): string or list of patterns to be searched.
Returns:
List, files.
"""
def
flat
(
lists
):
return
list
(
np
.
array
(
lists
).
flatten
())
if
not
isinstance
(
patterns
,
list
):
patterns
=
[
patterns
]
file_list
=
flat
([
glob
.
glob
(
file
,
recursive
=
True
)
for
file
in
patterns
])
if
file_list
:
# not empty
return
file_list
raise
ValueError
(
"The list of path names matching the patterns is empty."
)
class
DatasetOp
(
Dataset
):
class
DatasetOp
(
Dataset
):
"""
"""
...
@@ -2126,30 +2149,6 @@ class TFRecordDataset(SourceDataset):
...
@@ -2126,30 +2149,6 @@ class TFRecordDataset(SourceDataset):
>>> # 3) get all rows from dataset_files with schema file "./schema.json":
>>> # 3) get all rows from dataset_files with schema file "./schema.json":
>>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files, schema="./schema.json")
>>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files, schema="./schema.json")
"""
"""
@
staticmethod
def
_find_files
(
patterns
):
"""
Utility function to search for files with the given glob patterns.
Args:
patterns (str or list[str]): string or list of patterns to be searched.
Returns:
List, files.
"""
def
flat
(
lists
):
return
list
(
np
.
array
(
lists
).
flatten
())
if
not
isinstance
(
patterns
,
list
):
patterns
=
[
patterns
]
file_list
=
flat
([
glob
.
glob
(
file
,
recursive
=
True
)
for
file
in
patterns
])
if
file_list
:
# not empty
return
file_list
raise
ValueError
(
"The list of path names matching the patterns is empty."
)
@
check_tfrecorddataset
@
check_tfrecorddataset
def
__init__
(
self
,
dataset_files
,
schema
=
None
,
columns_list
=
None
,
num_samples
=
None
,
num_parallel_workers
=
None
,
def
__init__
(
self
,
dataset_files
,
schema
=
None
,
columns_list
=
None
,
num_samples
=
None
,
num_parallel_workers
=
None
,
shuffle
=
Shuffle
.
GLOBAL
,
num_shards
=
None
,
shard_id
=
None
,
shard_equal_rows
=
False
):
shuffle
=
Shuffle
.
GLOBAL
,
num_shards
=
None
,
shard_id
=
None
,
shard_equal_rows
=
False
):
...
@@ -2952,3 +2951,82 @@ class CelebADataset(SourceDataset):
...
@@ -2952,3 +2951,82 @@ class CelebADataset(SourceDataset):
args
[
"num_shards"
]
=
self
.
num_shards
args
[
"num_shards"
]
=
self
.
num_shards
args
[
"shard_id"
]
=
self
.
shard_id
args
[
"shard_id"
]
=
self
.
shard_id
return
args
return
args
class
TextFileDataset
(
SourceDataset
):
"""
A source dataset that reads and parses datasets stored on disk in text format.
The generated dataset has one columns ['text'].
Args:
dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of
files. The list will be sorted in a lexicographical order.
num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
num_parallel_workers (int, optional): number of workers to read the data
(default=None, number set in the config).
shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
If shuffle is False, no shuffling will be performed;
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
Otherwise, there are two levels of shuffling:
- Shuffle.GLOBAL: Shuffle both the files and samples.
- Shuffle.FILES: Shuffle files only.
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument should be specified only when num_shards is also specified.
Examples:
>>> import mindspore.dataset as ds
>>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files
>>> dataset = ds.TextFileDataset(dataset_files=dataset_files)
"""
@
check_textfiledataset
def
__init__
(
self
,
dataset_files
,
num_samples
=
None
,
num_parallel_workers
=
None
,
shuffle
=
Shuffle
.
GLOBAL
,
num_shards
=
None
,
shard_id
=
None
):
super
().
__init__
(
num_parallel_workers
)
self
.
dataset_files
=
self
.
_find_files
(
dataset_files
)
self
.
dataset_files
.
sort
()
self
.
num_samples
=
num_samples
if
not
isinstance
(
shuffle
,
(
bool
,
Shuffle
)):
raise
TypeError
(
"shuffle should be of boolean or enum 'Shuffle'."
)
if
not
isinstance
(
shuffle
,
Shuffle
):
if
shuffle
:
self
.
shuffle_level
=
Shuffle
.
GLOBAL
self
.
shuffle_files
=
True
else
:
self
.
shuffle_level
=
None
self
.
shuffle_files
=
False
else
:
self
.
shuffle_level
=
shuffle
self
.
shuffle_files
=
True
self
.
num_shards
=
num_shards
self
.
shard_id
=
shard_id
def
get_args
(
self
):
args
=
super
().
get_args
()
args
[
"dataset_files"
]
=
self
.
dataset_files
args
[
"num_samples"
]
=
self
.
num_samples
if
self
.
shuffle_files
is
not
None
:
args
[
"shuffle_files"
]
=
self
.
shuffle_files
args
[
"shuffle"
]
=
self
.
shuffle_level
args
[
"num_shards"
]
=
self
.
num_shards
args
[
"shard_id"
]
=
self
.
shard_id
return
args
def
get_dataset_size
(
self
):
"""
Get the number of batches in an epoch.
Return:
Number, number of batches.
"""
if
self
.
_dataset_size
is
None
:
num_rows
=
TextFileOp
.
get_num_rows
(
self
.
dataset_files
)
num_rows
=
get_num_rows
(
num_rows
,
self
.
num_shards
)
if
self
.
num_samples
is
None
:
return
num_rows
return
min
(
self
.
num_samples
,
num_rows
)
return
self
.
_dataset_size
This diff is collapsed.
Click to expand it.
mindspore/dataset/engine/iterators.py
浏览文件 @
2795e492
...
@@ -48,12 +48,16 @@ def alter_tree(node):
...
@@ -48,12 +48,16 @@ def alter_tree(node):
def
_alter_node
(
node
):
def
_alter_node
(
node
):
"""Performing some alteration to a dataset node. A common alteration is to insert a node."""
"""Performing some alteration to a dataset node. A common alteration is to insert a node."""
if
isinstance
(
node
,
de
.
TFRecordDataset
)
and
node
.
shuffle_level
==
de
.
Shuffle
.
GLOBAL
:
if
isinstance
(
node
,
(
de
.
TFRecordDataset
,
de
.
TextFileDataset
)
)
and
node
.
shuffle_level
==
de
.
Shuffle
.
GLOBAL
:
# Remove the connection between the parent's node to the current node because we are inserting a node.
# Remove the connection between the parent's node to the current node because we are inserting a node.
if
node
.
output
:
if
node
.
output
:
node
.
output
.
pop
()
node
.
output
.
pop
()
# Perform a fast scan for average rows per file
# Perform a fast scan for average rows per file
avg_rows_per_file
=
node
.
get_dataset_size
(
True
)
//
len
(
node
.
dataset_files
)
if
isinstance
(
node
,
de
.
TFRecordDataset
):
avg_rows_per_file
=
node
.
get_dataset_size
(
True
)
//
len
(
node
.
dataset_files
)
else
:
avg_rows_per_file
=
node
.
get_dataset_size
()
//
len
(
node
.
dataset_files
)
# Shuffle between 4 files with a minimum size of 10000 rows
# Shuffle between 4 files with a minimum size of 10000 rows
new_shuffle
=
node
.
shuffle
(
max
(
avg_rows_per_file
*
4
,
10000
))
new_shuffle
=
node
.
shuffle
(
max
(
avg_rows_per_file
*
4
,
10000
))
return
new_shuffle
return
new_shuffle
...
@@ -157,6 +161,8 @@ class Iterator:
...
@@ -157,6 +161,8 @@ class Iterator:
op_type
=
OpName
.
CIFAR100
op_type
=
OpName
.
CIFAR100
elif
isinstance
(
dataset
,
de
.
CelebADataset
):
elif
isinstance
(
dataset
,
de
.
CelebADataset
):
op_type
=
OpName
.
CELEBA
op_type
=
OpName
.
CELEBA
elif
isinstance
(
dataset
,
de
.
TextFileDataset
):
op_type
=
OpName
.
TEXTFILE
else
:
else
:
raise
ValueError
(
"Unsupported DatasetOp"
)
raise
ValueError
(
"Unsupported DatasetOp"
)
...
...
This diff is collapsed.
Click to expand it.
mindspore/dataset/engine/validators.py
浏览文件 @
2795e492
...
@@ -849,3 +849,25 @@ def check_add_column(method):
...
@@ -849,3 +849,25 @@ def check_add_column(method):
return
method
(
*
args
,
**
kwargs
)
return
method
(
*
args
,
**
kwargs
)
return
new_method
return
new_method
def
check_textfiledataset
(
method
):
"""A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset)."""
@
wraps
(
method
)
def
new_method
(
*
args
,
**
kwargs
):
param_dict
=
make_param_dict
(
method
,
args
,
kwargs
)
nreq_param_int
=
[
'num_samples'
,
'num_parallel_workers'
,
'num_shards'
,
'shard_id'
]
# check dataset_files; required argument
dataset_files
=
param_dict
.
get
(
'dataset_files'
)
if
dataset_files
is
None
:
raise
ValueError
(
"dataset_files is not provided."
)
if
not
isinstance
(
dataset_files
,
(
str
,
list
)):
raise
TypeError
(
"dataset_files should be of type str or a list of strings."
)
check_param_type
(
nreq_param_int
,
param_dict
,
int
)
return
method
(
*
args
,
**
kwargs
)
return
new_method
This diff is collapsed.
Click to expand it.
mindspore/dataset/transforms/nlp/__init__.py
0 → 100644
浏览文件 @
2795e492
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module is to support nlp augmentations. It includes two parts:
c_transforms and py_transforms. C_transforms is a high performance
image augmentation module which is developed with c++ opencv. Py_transforms
provide more kinds of image augmentations which is developed with python PIL.
"""
from
.utils
import
as_text
This diff is collapsed.
Click to expand it.
mindspore/dataset/transforms/nlp/utils.py
0 → 100644
浏览文件 @
2795e492
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Some basic function for nlp
"""
import
numpy
as
np
def
as_text
(
array
,
encoding
=
'utf8'
):
"""
Convert data of array to unicode.
Args:
array (numpy array): Data of array should be ASCII values of each character after converted.
encoding (string): Indicating the charset for decoding.
Returns:
A 'str' object.
"""
if
not
isinstance
(
array
,
np
.
ndarray
):
raise
ValueError
(
'input should be a numpy array'
)
byte_array
=
bytearray
(
list
(
array
))
return
byte_array
.
decode
(
encoding
)
This diff is collapsed.
Click to expand it.
tests/ut/cpp/dataset/CMakeLists.txt
浏览文件 @
2795e492
...
@@ -65,7 +65,7 @@ SET(DE_UT_SRCS
...
@@ -65,7 +65,7 @@ SET(DE_UT_SRCS
cifar_op_test.cc
cifar_op_test.cc
celeba_op_test.cc
celeba_op_test.cc
take_op_test.cc
take_op_test.cc
)
text_file_op_test.cc
)
add_executable
(
de_ut_tests
${
DE_UT_SRCS
}
)
add_executable
(
de_ut_tests
${
DE_UT_SRCS
}
)
...
...
This diff is collapsed.
Click to expand it.
tests/ut/cpp/dataset/text_file_op_test.cc
0 → 100644
浏览文件 @
2795e492
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include <vector>
#include "dataset/core/client.h"
#include "common/common.h"
#include "common/utils.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/util/status.h"
namespace
common
=
mindspore
::
common
;
using
namespace
mindspore
::
dataset
;
using
mindspore
::
MsLogLevel
::
INFO
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
using
mindspore
::
LogStream
;
class
MindDataTestTextFileOp
:
public
UT
::
DatasetOpTesting
{
};
TEST_F
(
MindDataTestTextFileOp
,
TestTextFileBasic
)
{
// Start with an empty execution tree
auto
tree
=
std
::
make_shared
<
ExecutionTree
>
();
std
::
string
dataset_path
;
dataset_path
=
datasets_root_path_
+
"/testTextFileDataset/1.txt"
;
std
::
shared_ptr
<
TextFileOp
>
op
;
TextFileOp
::
Builder
builder
;
builder
.
SetTextFilesList
({
dataset_path
})
.
SetRowsPerBuffer
(
16
)
.
SetNumWorkers
(
16
)
.
SetOpConnectorSize
(
2
);
Status
rc
=
builder
.
Build
(
&
op
);
ASSERT_TRUE
(
rc
.
IsOk
());
rc
=
tree
->
AssociateNode
(
op
);
ASSERT_TRUE
(
rc
.
IsOk
());
rc
=
tree
->
AssignRoot
(
op
);
ASSERT_TRUE
(
rc
.
IsOk
());
MS_LOG
(
INFO
)
<<
"Launching tree and begin iteration."
;
rc
=
tree
->
Prepare
();
ASSERT_TRUE
(
rc
.
IsOk
());
rc
=
tree
->
Launch
();
ASSERT_TRUE
(
rc
.
IsOk
());
// Start the loop of reading tensors from our pipeline
DatasetIterator
di
(
tree
);
TensorRow
tensor_list
;
rc
=
di
.
FetchNextTensorRow
(
&
tensor_list
);
ASSERT_TRUE
(
rc
.
IsOk
());
int
row_count
=
0
;
while
(
!
tensor_list
.
empty
())
{
// Display the tensor by calling the printer on it
for
(
int
i
=
0
;
i
<
tensor_list
.
size
();
i
++
)
{
std
::
ostringstream
ss
;
ss
<<
"("
<<
tensor_list
[
i
]
<<
"): "
<<
*
tensor_list
[
i
]
<<
std
::
endl
;
MS_LOG
(
INFO
)
<<
"Tensor print: "
<<
ss
.
str
()
<<
"."
;
}
rc
=
di
.
FetchNextTensorRow
(
&
tensor_list
);
ASSERT_TRUE
(
rc
.
IsOk
());
row_count
++
;
}
ASSERT_EQ
(
row_count
,
3
);
}
TEST_F
(
MindDataTestTextFileOp
,
TestTotalRows
)
{
std
::
string
tf_file1
=
datasets_root_path_
+
"/testTextFileDataset/1.txt"
;
std
::
string
tf_file2
=
datasets_root_path_
+
"/testTextFileDataset/2.txt"
;
std
::
vector
<
std
::
string
>
files
;
files
.
push_back
(
tf_file1
);
int64_t
total_rows
=
0
;
TextFileOp
::
CountAllFileRows
(
files
,
&
total_rows
);
ASSERT_EQ
(
total_rows
,
3
);
files
.
clear
();
files
.
push_back
(
tf_file2
);
TextFileOp
::
CountAllFileRows
(
files
,
&
total_rows
);
ASSERT_EQ
(
total_rows
,
2
);
files
.
clear
();
files
.
push_back
(
tf_file1
);
files
.
push_back
(
tf_file2
);
TextFileOp
::
CountAllFileRows
(
files
,
&
total_rows
);
ASSERT_EQ
(
total_rows
,
5
);
files
.
clear
();
}
This diff is collapsed.
Click to expand it.
tests/ut/data/dataset/testTextFileDataset/1.txt
0 → 100644
浏览文件 @
2795e492
This is a text file.
Be happy every day.
Good luck to everyone.
This diff is collapsed.
Click to expand it.
tests/ut/data/dataset/testTextFileDataset/2.txt
0 → 100644
浏览文件 @
2795e492
Another file.
End of file.
This diff is collapsed.
Click to expand it.
tests/ut/python/dataset/test_datasets_textfileop.py
0 → 100644
浏览文件 @
2795e492
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
import
mindspore.dataset.transforms.nlp.utils
as
nlp
DATA_FILE
=
"../data/dataset/testTextFileDataset/1.txt"
DATA_ALL_FILE
=
"../data/dataset/testTextFileDataset/*"
def
test_textline_dataset_one_file
():
data
=
ds
.
TextFileDataset
(
DATA_FILE
)
count
=
0
for
i
in
data
.
create_dict_iterator
():
logger
.
info
(
"{}"
.
format
(
i
[
"text"
]))
count
+=
1
assert
(
count
==
3
)
def
test_textline_dataset_all_file
():
data
=
ds
.
TextFileDataset
(
DATA_ALL_FILE
)
count
=
0
for
i
in
data
.
create_dict_iterator
():
logger
.
info
(
"{}"
.
format
(
i
[
"text"
]))
count
+=
1
assert
(
count
==
5
)
def
test_textline_dataset_totext
():
data
=
ds
.
TextFileDataset
(
DATA_ALL_FILE
,
shuffle
=
False
)
count
=
0
line
=
[
"This is a text file."
,
"Another file."
,
"Be happy every day."
,
"End of file."
,
"Good luck to everyone."
]
for
i
in
data
.
create_dict_iterator
():
str
=
nlp
.
as_text
(
i
[
"text"
])
assert
(
str
==
line
[
count
])
count
+=
1
assert
(
count
==
5
)
def
test_textline_dataset_num_samples
():
data
=
ds
.
TextFileDataset
(
DATA_FILE
,
num_samples
=
2
)
count
=
0
for
i
in
data
.
create_dict_iterator
():
count
+=
1
assert
(
count
==
2
)
def
test_textline_dataset_distribution
():
data
=
ds
.
TextFileDataset
(
DATA_ALL_FILE
,
num_shards
=
2
,
shard_id
=
1
)
count
=
0
for
i
in
data
.
create_dict_iterator
():
count
+=
1
assert
(
count
==
3
)
def
test_textline_dataset_repeat
():
data
=
ds
.
TextFileDataset
(
DATA_FILE
,
shuffle
=
False
)
data
=
data
.
repeat
(
3
)
count
=
0
line
=
[
"This is a text file."
,
"Be happy every day."
,
"Good luck to everyone."
,
"This is a text file."
,
"Be happy every day."
,
"Good luck to everyone."
,
"This is a text file."
,
"Be happy every day."
,
"Good luck to everyone."
]
for
i
in
data
.
create_dict_iterator
():
str
=
nlp
.
as_text
(
i
[
"text"
])
assert
(
str
==
line
[
count
])
count
+=
1
assert
(
count
==
9
)
def
test_textline_dataset_get_datasetsize
():
data
=
ds
.
TextFileDataset
(
DATA_FILE
)
size
=
data
.
get_dataset_size
()
assert
(
size
==
3
)
if
__name__
==
"__main__"
:
test_textline_dataset_one_file
()
test_textline_dataset_all_file
()
test_textline_dataset_totext
()
test_textline_dataset_num_samples
()
test_textline_dataset_distribution
()
test_textline_dataset_repeat
()
test_textline_dataset_get_datasetsize
()
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
新手
引导
客服
返回
顶部