Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
a8cf83ac
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a8cf83ac
编写于
6月 15, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 15, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1932 Add CLUE dataset
Merge pull request !1932 from jiangzhiwen/dataset/clue
上级
ad035c4c
e0e167a0
变更
33
隐藏空白更改
内联
并排
Showing
33 changed file
with
1676 addition
and
12 deletion
+1676
-12
mindspore/ccsrc/dataset/api/de_pipeline.cc
mindspore/ccsrc/dataset/api/de_pipeline.cc
+45
-1
mindspore/ccsrc/dataset/api/de_pipeline.h
mindspore/ccsrc/dataset/api/de_pipeline.h
+4
-1
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+15
-1
mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt
...ore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt
+1
-0
mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc
+551
-0
mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h
mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h
+270
-0
mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc
...re/ccsrc/dataset/engine/datasetops/source/text_file_op.cc
+1
-1
mindspore/dataset/__init__.py
mindspore/dataset/__init__.py
+4
-4
mindspore/dataset/engine/__init__.py
mindspore/dataset/engine/__init__.py
+1
-1
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+218
-2
mindspore/dataset/engine/iterators.py
mindspore/dataset/engine/iterators.py
+4
-1
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+35
-0
tests/ut/cpp/dataset/CMakeLists.txt
tests/ut/cpp/dataset/CMakeLists.txt
+1
-0
tests/ut/cpp/dataset/clue_op_test.cc
tests/ut/cpp/dataset/clue_op_test.cc
+117
-0
tests/ut/data/dataset/testCLUE/afqmc/dev.json
tests/ut/data/dataset/testCLUE/afqmc/dev.json
+3
-0
tests/ut/data/dataset/testCLUE/afqmc/test.json
tests/ut/data/dataset/testCLUE/afqmc/test.json
+3
-0
tests/ut/data/dataset/testCLUE/afqmc/train.json
tests/ut/data/dataset/testCLUE/afqmc/train.json
+3
-0
tests/ut/data/dataset/testCLUE/cmnli/dev.json
tests/ut/data/dataset/testCLUE/cmnli/dev.json
+3
-0
tests/ut/data/dataset/testCLUE/cmnli/test.json
tests/ut/data/dataset/testCLUE/cmnli/test.json
+3
-0
tests/ut/data/dataset/testCLUE/cmnli/train.json
tests/ut/data/dataset/testCLUE/cmnli/train.json
+3
-0
tests/ut/data/dataset/testCLUE/csl/dev.json
tests/ut/data/dataset/testCLUE/csl/dev.json
+3
-0
tests/ut/data/dataset/testCLUE/csl/test.json
tests/ut/data/dataset/testCLUE/csl/test.json
+3
-0
tests/ut/data/dataset/testCLUE/csl/train.json
tests/ut/data/dataset/testCLUE/csl/train.json
+3
-0
tests/ut/data/dataset/testCLUE/iflytek/dev.json
tests/ut/data/dataset/testCLUE/iflytek/dev.json
+3
-0
tests/ut/data/dataset/testCLUE/iflytek/test.json
tests/ut/data/dataset/testCLUE/iflytek/test.json
+3
-0
tests/ut/data/dataset/testCLUE/iflytek/train.json
tests/ut/data/dataset/testCLUE/iflytek/train.json
+3
-0
tests/ut/data/dataset/testCLUE/tnews/dev.json
tests/ut/data/dataset/testCLUE/tnews/dev.json
+3
-0
tests/ut/data/dataset/testCLUE/tnews/test.json
tests/ut/data/dataset/testCLUE/tnews/test.json
+3
-0
tests/ut/data/dataset/testCLUE/tnews/train.json
tests/ut/data/dataset/testCLUE/tnews/train.json
+3
-0
tests/ut/data/dataset/testCLUE/wsc/dev.json
tests/ut/data/dataset/testCLUE/wsc/dev.json
+3
-0
tests/ut/data/dataset/testCLUE/wsc/test.json
tests/ut/data/dataset/testCLUE/wsc/test.json
+3
-0
tests/ut/data/dataset/testCLUE/wsc/train.json
tests/ut/data/dataset/testCLUE/wsc/train.json
+3
-0
tests/ut/python/dataset/test_datasets_clue.py
tests/ut/python/dataset/test_datasets_clue.py
+355
-0
未找到文件。
mindspore/ccsrc/dataset/api/de_pipeline.cc
浏览文件 @
a8cf83ac
...
...
@@ -31,6 +31,7 @@
#include "dataset/engine/datasetops/source/celeba_op.h"
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/source/clue_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "mindrecord/include/shard_category.h"
#include "mindrecord/include/shard_distributed_sample.h"
...
...
@@ -72,7 +73,8 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
{
kCelebA
,
&
DEPipeline
::
ParseCelebAOp
},
{
kRandomData
,
&
DEPipeline
::
ParseRandomDataOp
},
{
kTextFile
,
&
DEPipeline
::
ParseTextFileOp
},
{
kBuildVocab
,
&
DEPipeline
::
ParseBuildVocabOp
}};
{
kBuildVocab
,
&
DEPipeline
::
ParseBuildVocabOp
},
{
kClue
,
&
DEPipeline
::
ParseClueOp
}};
DEPipeline
::
DEPipeline
()
:
iterator_
(
nullptr
)
{
try
{
...
...
@@ -1210,6 +1212,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
*
ptr
=
op
;
return
Status
::
OK
();
}
Status
DEPipeline
::
ParsePadInfo
(
py
::
handle
value
,
PadInfo
*
pad_info
)
{
for
(
auto
p
:
py
::
reinterpret_borrow
<
py
::
dict
>
(
value
))
{
if
(
!
p
.
second
.
is_none
())
{
...
...
@@ -1236,6 +1239,7 @@ Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) {
}
return
Status
::
OK
();
}
Status
DEPipeline
::
ParseBuildVocabOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
)
{
std
::
shared_ptr
<
BuildVocabOp
::
Builder
>
builder
=
std
::
make_shared
<
BuildVocabOp
::
Builder
>
();
for
(
auto
arg
:
args
)
{
...
...
@@ -1267,5 +1271,45 @@ Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr<Datas
return
Status
::
OK
();
}
Status
DEPipeline
::
ParseClueOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
)
{
std
::
shared_ptr
<
ClueOp
::
Builder
>
builder
=
std
::
make_shared
<
ClueOp
::
Builder
>
();
if
(
!
args
[
"dataset_files"
].
is_none
())
{
(
void
)
builder
->
SetClueFilesList
(
ToStringVector
(
args
[
"dataset_files"
]));
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Error: dataset_files is missing"
);
}
// Optional arguments
for
(
auto
arg
:
args
)
{
std
::
string
key
=
py
::
str
(
arg
.
first
);
py
::
handle
value
=
arg
.
second
;
if
(
!
value
.
is_none
())
{
if
(
key
==
"num_parallel_workers"
)
{
(
void
)
builder
->
SetNumWorkers
(
ToInt
(
value
));
}
else
if
(
key
==
"shuffle_files"
)
{
(
void
)
builder
->
SetShuffleFiles
(
ToBool
(
value
));
}
else
if
(
key
==
"num_samples"
)
{
(
void
)
builder
->
SetNumSamples
(
ToInt
(
value
));
}
else
if
(
key
==
"num_shards"
)
{
(
void
)
builder
->
SetNumDevices
(
ToInt
(
value
));
}
else
if
(
key
==
"shard_id"
)
{
(
void
)
builder
->
SetDeviceId
(
ToInt
(
value
));
}
else
if
(
key
==
"cols_to_keyword"
)
{
std
::
map
<
std
::
string
,
std
::
string
>
map_dict
;
for
(
auto
p
:
py
::
reinterpret_borrow
<
py
::
dict
>
(
value
))
{
if
(
!
p
.
second
.
is_none
())
{
map_dict
.
insert
({
ToString
(
p
.
first
),
ToString
(
p
.
second
)});
}
else
{
map_dict
.
insert
({
ToString
(
p
.
first
),
ToString
(
p
.
first
)});
}
}
(
void
)
builder
->
SetColsKeyMap
(
map_dict
);
}
}
}
std
::
shared_ptr
<
ClueOp
>
op
;
RETURN_IF_NOT_OK
(
builder
->
Build
(
&
op
));
*
ptr
=
op
;
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/api/de_pipeline.h
浏览文件 @
a8cf83ac
...
...
@@ -64,7 +64,8 @@ enum OpName {
kCelebA
,
kRandomData
,
kTextFile
,
kBuildVocab
kBuildVocab
,
kClue
};
// The C++ binder class that we expose to the python script.
...
...
@@ -166,6 +167,8 @@ class DEPipeline {
Status
ParseBuildVocabOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseClueOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
private:
// Execution tree that links the dataset operators.
std
::
shared_ptr
<
ExecutionTree
>
tree_
;
...
...
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
a8cf83ac
...
...
@@ -55,6 +55,7 @@
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/source/clue_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/engine/datasetops/source/coco_op.h"
#include "dataset/engine/gnn/graph.h"
...
...
@@ -201,6 +202,18 @@ void bindDatasetOps(py::module *m) {
THROW_IF_ERROR
(
TextFileOp
::
CountAllFileRows
(
filenames
,
&
count
));
return
count
;
});
(
void
)
py
::
class_
<
ClueOp
,
DatasetOp
,
std
::
shared_ptr
<
ClueOp
>>
(
*
m
,
"ClueOp"
)
.
def_static
(
"get_num_rows"
,
[](
const
py
::
list
&
files
)
{
int64_t
count
=
0
;
std
::
vector
<
std
::
string
>
filenames
;
for
(
auto
file
:
files
)
{
file
.
is_none
()
?
(
void
)
filenames
.
emplace_back
(
""
)
:
filenames
.
push_back
(
py
::
str
(
file
));
}
THROW_IF_ERROR
(
ClueOp
::
CountAllFileRows
(
filenames
,
&
count
));
return
count
;
});
(
void
)
py
::
class_
<
VOCOp
,
DatasetOp
,
std
::
shared_ptr
<
VOCOp
>>
(
*
m
,
"VOCOp"
)
.
def_static
(
"get_num_rows"
,
[](
const
std
::
string
&
dir
,
const
std
::
string
&
task_type
,
const
std
::
string
&
task_mode
,
...
...
@@ -629,7 +642,8 @@ PYBIND11_MODULE(_c_dataengine, m) {
.
value
(
"RANDOMDATA"
,
OpName
::
kRandomData
)
.
value
(
"BUILDVOCAB"
,
OpName
::
kBuildVocab
)
.
value
(
"CELEBA"
,
OpName
::
kCelebA
)
.
value
(
"TEXTFILE"
,
OpName
::
kTextFile
);
.
value
(
"TEXTFILE"
,
OpName
::
kTextFile
)
.
value
(
"CLUE"
,
OpName
::
kClue
);
(
void
)
py
::
enum_
<
JiebaMode
>
(
m
,
"JiebaMode"
,
py
::
arithmetic
())
.
value
(
"DE_JIEBA_MIX"
,
JiebaMode
::
kMix
)
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt
浏览文件 @
a8cf83ac
...
...
@@ -19,4 +19,5 @@ add_library(engine-datasetops-source OBJECT
random_data_op.cc
celeba_op.cc
text_file_op.cc
clue_op.cc
)
\ No newline at end of file
mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc
0 → 100644
浏览文件 @
a8cf83ac
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/source/clue_op.h"
#include <string>
#include <vector>
#include <fstream>
#include <iomanip>
#include <utility>
#include "dataset/core/config_manager.h"
#include "dataset/util/task_manager.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/util/random.h"
namespace
mindspore
{
namespace
dataset
{
ClueOp
::
Builder
::
Builder
()
:
builder_device_id_
(
0
),
builder_num_devices_
(
1
),
builder_num_samples_
(
0
),
builder_shuffle_files_
(
false
)
{
std
::
shared_ptr
<
ConfigManager
>
config_manager
=
GlobalContext
::
config_manager
();
builder_num_workers_
=
config_manager
->
num_parallel_workers
();
builder_op_connector_size_
=
config_manager
->
op_connector_size
();
builder_rows_per_buffer_
=
config_manager
->
rows_per_buffer
();
builder_worker_connector_size_
=
config_manager
->
worker_connector_size
();
}
Status
ClueOp
::
Builder
::
ValidateInputs
()
const
{
std
::
string
err
;
err
+=
builder_num_workers_
<=
0
?
"Number of parallel workers should be greater than 0
\n
"
:
""
;
err
+=
(
builder_device_id_
>=
builder_num_devices_
||
builder_num_devices_
<
1
)
?
"Wrong sharding configs
\n
"
:
""
;
return
err
.
empty
()
?
Status
::
OK
()
:
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
err
);
}
Status
ClueOp
::
Builder
::
Build
(
std
::
shared_ptr
<
ClueOp
>
*
op
)
{
RETURN_IF_NOT_OK
(
ValidateInputs
());
// Throttle the number of workers if we have more workers than files!
if
(
static_cast
<
size_t
>
(
builder_num_workers_
)
>
builder_clue_files_list_
.
size
())
{
builder_num_workers_
=
builder_clue_files_list_
.
size
();
MS_LOG
(
WARNING
)
<<
"ClueOp operator parallelism reduced to "
<<
builder_num_workers_
<<
" workers."
;
}
ColKeyMap
ck_map
;
for
(
auto
&
p
:
builder_cols_to_keyword_
)
{
ck_map
.
insert
({
p
.
first
,
split
(
p
.
second
,
'/'
)});
}
std
::
shared_ptr
<
ClueOp
>
clue_op
=
std
::
make_shared
<
ClueOp
>
(
builder_num_workers_
,
builder_rows_per_buffer_
,
builder_num_samples_
,
builder_worker_connector_size_
,
ck_map
,
builder_clue_files_list_
,
builder_op_connector_size_
,
builder_shuffle_files_
,
builder_num_devices_
,
builder_device_id_
);
RETURN_IF_NOT_OK
(
clue_op
->
Init
());
*
op
=
std
::
move
(
clue_op
);
return
Status
::
OK
();
}
std
::
vector
<
std
::
string
>
ClueOp
::
Builder
::
split
(
const
std
::
string
&
s
,
char
delim
)
{
std
::
vector
<
std
::
string
>
res
;
std
::
stringstream
ss
(
s
);
std
::
string
item
;
while
(
getline
(
ss
,
item
,
delim
))
{
res
.
push_back
(
item
);
}
return
res
;
}
ClueOp
::
ClueOp
(
int32_t
num_workers
,
int64_t
rows_per_buffer
,
int64_t
num_samples
,
int32_t
worker_connector_size
,
ColKeyMap
cols_to_keyword
,
std
::
vector
<
std
::
string
>
clue_files_list
,
int32_t
op_connector_size
,
bool
shuffle_files
,
int32_t
num_device
,
int32_t
device_id
)
:
ParallelOp
(
num_workers
,
op_connector_size
),
rows_per_buffer_
(
rows_per_buffer
),
num_rows_per_shard_
(
0
),
all_num_rows_
(
0
),
num_samples_
(
num_samples
),
filename_index_
(
std
::
make_unique
<
StringIndex
>
()),
clue_files_list_
(
std
::
move
(
clue_files_list
)),
load_jagged_connector_
(
true
),
cols_to_keyword_
(
cols_to_keyword
),
shuffle_files_
(
shuffle_files
),
finished_reading_dataset_
(
false
),
num_devices_
(
num_device
),
device_id_
(
device_id
),
load_io_block_queue_
(
true
)
{
worker_connector_size_
=
worker_connector_size
;
}
Status
ClueOp
::
Init
()
{
RETURN_IF_NOT_OK
(
filename_index_
->
insert
(
clue_files_list_
));
int32_t
safe_queue_size
=
static_cast
<
int32_t
>
(
std
::
ceil
(
clue_files_list_
.
size
()
/
num_workers_
)
+
1
);
io_block_queues_
.
Init
(
num_workers_
,
safe_queue_size
);
// Set the column name mapping (base class field)
int
count
=
0
;
for
(
auto
&
p
:
cols_to_keyword_
)
{
column_name_id_map_
[
p
.
first
]
=
count
;
count
++
;
}
RETURN_IF_NOT_OK
(
ParallelOp
::
CreateWorkerConnector
(
worker_connector_size_
));
jagged_buffer_connector_
=
std
::
make_unique
<
JaggedConnector
>
(
num_workers_
,
1
,
worker_connector_size_
);
return
Status
::
OK
();
}
Status
ClueOp
::
Reset
()
{
load_jagged_connector_
=
true
;
load_io_block_queue_
=
true
;
RETURN_IF_NOT_OK
(
ParallelOp
::
Reset
());
NotifyToFillIOBlockQueue
();
return
Status
::
OK
();
}
Status
ClueOp
::
LoadTensor
(
const
std
::
string
&
line
,
std
::
unique_ptr
<
TensorQTable
>
*
tensor_table
,
int64_t
row
)
{
TensorRow
tRow
(
1
,
nullptr
);
(
*
tensor_table
)
->
push_back
(
std
::
move
(
tRow
));
std
::
shared_ptr
<
Tensor
>
tensor
;
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
&
tensor
,
{
line
},
TensorShape
::
CreateScalar
()));
(
**
tensor_table
)[
row
][
0
]
=
std
::
move
(
tensor
);
return
Status
::
OK
();
}
Status
ClueOp
::
GetValue
(
const
nlohmann
::
json
&
js
,
std
::
vector
<
std
::
string
>
key_chain
,
std
::
shared_ptr
<
Tensor
>
*
t
)
{
nlohmann
::
json
cursor
=
js
;
for
(
int
i
=
0
;
i
<
key_chain
.
size
();
i
++
)
{
if
(
cursor
.
find
(
key_chain
[
i
])
!=
cursor
.
end
())
{
cursor
=
cursor
[
key_chain
[
i
]];
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Failed to find key: "
+
key_chain
[
i
]);
}
}
std
::
string
final_str
=
key_chain
.
back
();
switch
(
cursor
.
type
())
{
case
nlohmann
::
detail
::
value_t
::
string
:
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
t
,
{
cursor
.
get
<
std
::
string
>
()},
TensorShape
::
CreateScalar
()));
break
;
case
nlohmann
::
detail
::
value_t
::
number_integer
:
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
t
,
TensorImpl
::
kFlexible
,
TensorShape
::
CreateScalar
(),
DataType
(
DataType
::
DE_INT32
)));
(
*
t
)
->
SetItemAt
<
int32_t
>
({
0
},
cursor
.
get
<
int32_t
>
());
break
;
case
nlohmann
::
detail
::
value_t
::
number_unsigned
:
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
t
,
TensorImpl
::
kFlexible
,
TensorShape
::
CreateScalar
(),
DataType
(
DataType
::
DE_INT32
)));
(
*
t
)
->
SetItemAt
<
int32_t
>
({
0
},
cursor
.
get
<
uint32_t
>
());
break
;
case
nlohmann
::
detail
::
value_t
::
number_float
:
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
t
,
TensorImpl
::
kFlexible
,
TensorShape
::
CreateScalar
(),
DataType
(
DataType
::
DE_FLOAT32
)));
(
*
t
)
->
SetItemAt
<
int32_t
>
({
0
},
cursor
.
get
<
float
>
());
break
;
case
nlohmann
::
detail
::
value_t
::
array
:
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
t
,
{
cursor
.
get
<
std
::
vector
<
std
::
string
>>
()},
TensorShape
::
CreateScalar
()));
break
;
default:
break
;
}
return
Status
::
OK
();
}
Status
ClueOp
::
LoadFile
(
const
std
::
string
&
file
,
const
int64_t
start_offset
,
const
int64_t
end_offset
,
const
int32_t
worker_id
)
{
std
::
ifstream
handle
(
file
);
if
(
!
handle
.
is_open
())
{
RETURN_STATUS_UNEXPECTED
(
"Failed to open file "
+
file
);
}
int64_t
rows_each_buffer
=
0
;
int64_t
rows_total
=
0
;
std
::
string
line
;
std
::
unique_ptr
<
DataBuffer
>
cur_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
BufferFlags
::
kDeBFlagNone
);
std
::
unique_ptr
<
TensorQTable
>
tensor_table
=
std
::
make_unique
<
TensorQTable
>
();
while
(
getline
(
handle
,
line
))
{
if
(
line
.
empty
())
{
continue
;
}
// If read to the end offset of this file, break.
if
(
rows_total
>=
end_offset
)
{
break
;
}
// Skip line before start offset.
if
(
rows_total
<
start_offset
)
{
rows_total
++
;
continue
;
}
try
{
nlohmann
::
json
js
=
nlohmann
::
json
::
parse
(
line
);
int
cols_count
=
cols_to_keyword_
.
size
();
TensorRow
tRow
(
cols_count
,
nullptr
);
tensor_table
->
push_back
(
std
::
move
(
tRow
));
int
cout
=
0
;
for
(
auto
&
p
:
cols_to_keyword_
)
{
std
::
shared_ptr
<
Tensor
>
tensor
;
RETURN_IF_NOT_OK
(
GetValue
(
js
,
p
.
second
,
&
tensor
));
(
*
tensor_table
)[
rows_each_buffer
][
cout
]
=
std
::
move
(
tensor
);
cout
++
;
}
}
catch
(
const
std
::
exception
&
err
)
{
// Catch any exception and convert to Status return code
RETURN_STATUS_UNEXPECTED
(
"Failed to load json file"
);
}
// RETURN_IF_NOT_OK(LoadTensor(line, &tensor_table, rows_each_buffer));
rows_each_buffer
++
;
rows_total
++
;
if
(
rows_each_buffer
==
rows_per_buffer_
)
{
cur_buffer
->
set_tensor_table
(
std
::
move
(
tensor_table
));
RETURN_IF_NOT_OK
(
jagged_buffer_connector_
->
Add
(
worker_id
,
std
::
move
(
cur_buffer
)));
cur_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
BufferFlags
::
kDeBFlagNone
);
tensor_table
=
std
::
make_unique
<
TensorQTable
>
();
rows_each_buffer
=
0
;
}
}
if
(
rows_each_buffer
>
0
)
{
cur_buffer
->
set_tensor_table
(
std
::
move
(
tensor_table
));
RETURN_IF_NOT_OK
(
jagged_buffer_connector_
->
Add
(
worker_id
,
std
::
move
(
cur_buffer
)));
}
return
Status
::
OK
();
}
Status
ClueOp
::
operator
()()
{
RETURN_IF_NOT_OK
(
CalculateNumRowsPerShard
());
// launch one thread, responsible for filling IoBlockQueue
RETURN_IF_NOT_OK
(
tree_
->
LaunchWorkers
(
1
,
std
::
bind
(
&
ClueOp
::
WaitToFillIOBlockQueue
,
this
)));
RETURN_IF_NOT_OK
(
tree_
->
LaunchWorkers
(
num_workers_
,
std
::
bind
(
&
ClueOp
::
WorkerEntry
,
this
,
std
::
placeholders
::
_1
)));
// must be called after launching workers.
TaskManager
::
FindMe
()
->
Post
();
RETURN_IF_NOT_OK
(
io_block_queue_wait_post_
.
Register
(
tree_
->
AllTasks
()));
NotifyToFillIOBlockQueue
();
while
(
!
finished_reading_dataset_
)
{
int64_t
buffer_id
=
0
;
int32_t
workers_done
=
0
;
int64_t
rows_read
=
0
;
load_io_block_queue_
=
true
;
while
(
workers_done
<
num_workers_
)
{
std
::
unique_ptr
<
DataBuffer
>
buffer
;
RETURN_IF_NOT_OK
(
jagged_buffer_connector_
->
Pop
(
0
,
&
buffer
));
if
(
buffer
->
eoe
())
{
workers_done
++
;
}
else
if
(
num_samples_
==
0
||
rows_read
<
num_samples_
)
{
if
((
num_samples_
>
0
)
&&
(
rows_read
+
buffer
->
NumRows
()
>
num_samples_
))
{
int64_t
rowsToRemove
=
buffer
->
NumRows
()
-
(
num_samples_
-
rows_read
);
RETURN_IF_NOT_OK
(
buffer
->
SliceOff
(
rowsToRemove
));
}
rows_read
+=
buffer
->
NumRows
();
buffer
->
set_id
(
buffer_id
++
);
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
buffer
)));
}
else
{
// end of epoch
load_jagged_connector_
=
false
;
load_io_block_queue_
=
false
;
}
}
std
::
unique_ptr
<
DataBuffer
>
eoe_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
eoe_buffer
)));
if
(
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
||
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
))
{
finished_reading_dataset_
=
true
;
NotifyToFillIOBlockQueue
();
}
else
{
jagged_buffer_connector_
->
DoReset
();
buffer_id
=
0
;
}
}
std
::
unique_ptr
<
DataBuffer
>
eof_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOF
);
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
eof_buffer
)));
RETURN_IF_NOT_OK
(
PostEndOfData
());
return
Status
::
OK
();
}
Status
ClueOp
::
WorkerEntry
(
int32_t
worker_id
)
{
TaskManager
::
FindMe
()
->
Post
();
std
::
unique_ptr
<
FilenameBlock
>
io_block
;
RETURN_IF_NOT_OK
(
PopIoBlockQueue
(
worker_id
,
&
io_block
));
while
(
!
io_block
->
eof
())
{
if
(
!
io_block
->
eoe
())
{
if
(
load_jagged_connector_
)
{
std
::
string
filename
;
RETURN_IF_NOT_OK
(
io_block
->
GetFilename
(
&
filename
,
*
filename_index_
));
int64_t
start_offset
=
io_block
->
GetStartOffset
();
int64_t
end_offset
=
io_block
->
GetEndOffset
();
RETURN_IF_NOT_OK
(
LoadFile
(
filename
,
start_offset
,
end_offset
,
worker_id
));
}
}
else
{
std
::
unique_ptr
<
DataBuffer
>
eoe_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
RETURN_IF_NOT_OK
(
jagged_buffer_connector_
->
Add
(
worker_id
,
std
::
move
(
eoe_buffer
)));
}
RETURN_IF_NOT_OK
(
PopIoBlockQueue
(
worker_id
,
&
io_block
));
}
return
Status
::
OK
();
}
// A print method typically used for debugging
void
ClueOp
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
// Always show the id and name as first line regardless if this summary or detailed print
out
<<
"("
<<
std
::
setw
(
2
)
<<
operator_id_
<<
") <ClueOp>:"
;
if
(
!
show_all
)
{
// Call the super class for displaying any common 1-liner info
ParallelOp
::
Print
(
out
,
show_all
);
// Then show any custom derived-internal 1-liner info for this op
out
<<
"
\n
"
;
}
else
{
// Call the super class for displaying any common detailed info
ParallelOp
::
Print
(
out
,
show_all
);
// Then show any custom derived-internal stuff
out
<<
"
\n
Rows per buffer: "
<<
rows_per_buffer_
<<
"
\n
Sample count: "
<<
num_samples_
<<
"
\n
Device id: "
<<
device_id_
<<
"
\n
Number of devices: "
<<
num_devices_
<<
"
\n
Shuffle files: "
<<
((
shuffle_files_
)
?
"yes"
:
"no"
)
<<
"
\n
Clue files list:
\n
"
;
for
(
int
i
=
0
;
i
<
clue_files_list_
.
size
();
++
i
)
{
out
<<
" "
<<
clue_files_list_
[
i
];
}
out
<<
"
\n\n
"
;
}
}
// Pops an element from a queue in io_block_queues
Status
ClueOp
::
PopIoBlockQueue
(
int32_t
index
,
std
::
unique_ptr
<
FilenameBlock
>
*
out_block
)
{
RETURN_IF_NOT_OK
(
io_block_queues_
[
index
]
->
PopFront
(
out_block
));
return
Status
::
OK
();
}
// Pushes an element to a queue in io_block_queues
Status
ClueOp
::
PushIoBlockQueue
(
int32_t
index
,
std
::
unique_ptr
<
FilenameBlock
>
&&
io_block
)
{
RETURN_IF_NOT_OK
(
io_block_queues_
[
index
]
->
Add
(
std
::
move
(
io_block
)));
return
Status
::
OK
();
}
static
void
ShuffleKeys
(
std
::
vector
<
int64_t
>
*
i_keys
,
uint32_t
seed
)
{
std
::
mt19937
rng
(
seed
);
std
::
shuffle
(
i_keys
->
begin
(),
i_keys
->
end
(),
rng
);
}
Status
ClueOp
::
WaitToFillIOBlockQueue
()
{
// must be called first if called by worker spanwed by taskgroup
TaskManager
::
FindMe
()
->
Post
();
std
::
vector
<
int64_t
>
i_keys
;
if
(
shuffle_files_
)
{
for
(
auto
it
=
filename_index_
->
begin
();
it
!=
filename_index_
->
end
();
++
it
)
{
i_keys
.
push_back
(
it
.
key
());
}
}
uint32_t
seed
=
0
;
while
(
true
)
{
RETURN_IF_NOT_OK
(
io_block_queue_wait_post_
.
Wait
());
io_block_queue_wait_post_
.
Clear
();
if
(
finished_reading_dataset_
)
{
break
;
}
if
(
shuffle_files_
)
{
ShuffleKeys
(
&
i_keys
,
num_devices_
==
1
?
GetSeed
()
:
++
seed
);
}
RETURN_IF_NOT_OK
(
FillIOBlockQueue
(
i_keys
));
}
return
Status
::
OK
();
}
Status
ClueOp
::
FillIOBlockQueue
(
const
std
::
vector
<
int64_t
>
&
i_keys
)
{
int32_t
queue_index
=
0
;
int64_t
pre_count
=
0
;
int64_t
start_offset
=
0
;
int64_t
end_offset
=
0
;
bool
finish
=
false
;
while
(
!
finish
)
{
std
::
vector
<
std
::
pair
<
std
::
string
,
int64_t
>>
file_index
;
if
(
!
i_keys
.
empty
())
{
for
(
auto
it
=
i_keys
.
begin
();
it
!=
i_keys
.
end
();
++
it
)
{
{
if
(
!
load_io_block_queue_
)
{
break
;
}
}
auto
file_it
=
filename_index_
->
Search
(
*
it
);
file_index
.
emplace_back
(
std
::
pair
<
std
::
string
,
int64_t
>
(
file_it
.
value
(),
*
it
));
}
}
else
{
for
(
auto
it
=
filename_index_
->
begin
();
it
!=
filename_index_
->
end
();
++
it
)
{
{
if
(
!
load_io_block_queue_
)
{
break
;
}
}
file_index
.
emplace_back
(
std
::
pair
<
std
::
string
,
int64_t
>
(
it
.
value
(),
it
.
key
()));
}
}
for
(
auto
file_info
:
file_index
)
{
if
(
NeedPushFileToBlockQueue
(
file_info
.
first
,
&
start_offset
,
&
end_offset
,
pre_count
))
{
auto
ioBlock
=
std
::
make_unique
<
FilenameBlock
>
(
file_info
.
second
,
start_offset
,
end_offset
,
IOBlock
::
kDeIoBlockNone
);
RETURN_IF_NOT_OK
(
PushIoBlockQueue
(
queue_index
,
std
::
move
(
ioBlock
)));
queue_index
=
(
queue_index
+
1
)
%
num_workers_
;
}
pre_count
+=
filename_numrows_
[
file_info
.
first
];
}
if
(
pre_count
<
(
static_cast
<
int64_t
>
(
device_id_
)
+
1
)
*
num_rows_per_shard_
)
{
finish
=
false
;
}
else
{
finish
=
true
;
}
}
RETURN_IF_NOT_OK
(
PostEndOfEpoch
(
queue_index
));
return
Status
::
OK
();
}
void
ClueOp
::
NotifyToFillIOBlockQueue
()
{
io_block_queue_wait_post_
.
Set
();
}
bool
ClueOp
::
NeedPushFileToBlockQueue
(
const
std
::
string
&
file_name
,
int64_t
*
start_offset
,
int64_t
*
end_offset
,
const
int64_t
&
pre_count
)
{
*
start_offset
=
0
;
*
end_offset
=
0
;
bool
push
=
false
;
int64_t
start_index
=
device_id_
*
num_rows_per_shard_
;
if
(
device_id_
+
1
<
0
)
{
MS_LOG
(
ERROR
)
<<
"Device id is invalid"
;
return
false
;
}
int64_t
end_index
=
(
static_cast
<
int64_t
>
(
device_id_
)
+
1
)
*
num_rows_per_shard_
;
if
(
pre_count
<=
start_index
&&
pre_count
+
filename_numrows_
[
file_name
]
>
start_index
)
{
*
start_offset
=
start_index
-
pre_count
;
push
=
true
;
if
(
pre_count
<
end_index
&&
pre_count
+
filename_numrows_
[
file_name
]
>=
end_index
)
{
*
end_offset
=
end_index
-
pre_count
;
}
else
{
*
end_offset
=
filename_numrows_
[
file_name
];
}
}
if
(
pre_count
>=
start_index
&&
pre_count
<
end_index
)
{
*
start_offset
=
0
;
push
=
true
;
if
(
pre_count
+
filename_numrows_
[
file_name
]
>=
end_index
)
{
*
end_offset
=
end_index
-
pre_count
;
}
else
{
*
end_offset
=
filename_numrows_
[
file_name
];
}
}
return
push
;
}
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
Status
ClueOp
::
PostEndOfEpoch
(
int32_t
queue_index
)
{
for
(
int
i
=
0
;
i
<
num_workers_
;
++
i
)
{
std
::
unique_ptr
<
FilenameBlock
>
eoe
=
std
::
make_unique
<
FilenameBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
);
RETURN_IF_NOT_OK
(
PushIoBlockQueue
((
queue_index
+
i
)
%
num_workers_
,
std
::
move
(
eoe
)));
}
return
Status
::
OK
();
}
Status
ClueOp
::
CalculateNumRowsPerShard
()
{
for
(
auto
it
=
filename_index_
->
begin
();
it
!=
filename_index_
->
end
();
++
it
)
{
int64_t
count
=
CountTotalRows
(
it
.
value
());
filename_numrows_
[
it
.
value
()]
=
count
;
all_num_rows_
+=
count
;
}
if
(
all_num_rows_
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"There is no valid data matching the dataset API CLUEDataset. Please check file path or dataset API "
"validation first."
);
}
num_rows_per_shard_
=
static_cast
<
int64_t
>
(
std
::
ceil
(
all_num_rows_
*
1.0
/
num_devices_
));
MS_LOG
(
DEBUG
)
<<
"Number rows per shard is "
<<
num_rows_per_shard_
;
return
Status
::
OK
();
}
int64_t
ClueOp
::
CountTotalRows
(
const
std
::
string
&
file
)
{
std
::
ifstream
handle
(
file
);
if
(
!
handle
.
is_open
())
{
MS_LOG
(
ERROR
)
<<
"Failed to open file: "
<<
file
;
return
0
;
}
std
::
string
line
;
int64_t
count
=
0
;
while
(
getline
(
handle
,
line
))
{
if
(
!
line
.
empty
())
{
count
++
;
}
}
return
count
;
}
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
Status
ClueOp
::
PostEndOfData
()
{
for
(
int
i
=
0
;
i
<
num_workers_
;
++
i
)
{
std
::
unique_ptr
<
FilenameBlock
>
eof
=
std
::
make_unique
<
FilenameBlock
>
(
IOBlock
::
kDeIoBlockFlagEof
);
RETURN_IF_NOT_OK
(
PushIoBlockQueue
(
i
,
std
::
move
(
eof
)));
}
return
Status
::
OK
();
}
Status
ClueOp
::
CountAllFileRows
(
const
std
::
vector
<
std
::
string
>
&
files
,
int64_t
*
count
)
{
std
::
shared_ptr
<
ClueOp
>
op
;
*
count
=
0
;
RETURN_IF_NOT_OK
(
Builder
().
SetClueFilesList
(
files
).
Build
(
&
op
));
for
(
auto
file
:
files
)
{
*
count
+=
op
->
CountTotalRows
(
file
);
}
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h
0 → 100644
浏览文件 @
a8cf83ac
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_
#include <memory>
#include <map>
#include <mutex>
#include <string>
#include <vector>
#include <nlohmann/json.hpp>
#include "dataset/util/auto_index.h"
#include "dataset/engine/datasetops/parallel_op.h"
#include "dataset/engine/datasetops/source/io_block.h"
namespace
mindspore
{
namespace
dataset
{
using
StringIndex
=
AutoIndexObj
<
std
::
string
>
;
using
ColKeyMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
;
class
JaggedConnector
;
class
ClueOp
:
public
ParallelOp
{
public:
class
Builder
{
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @return This is a constructor.
Builder
();
// Default destructor
~
Builder
()
=
default
;
// Checks if the inputs of the builder is valid.
// @return Status - the error code returned.
Status
ValidateInputs
()
const
;
// Create the final object.
// @param op - dataset op.
// @return - the error code return.
Status
Build
(
std
::
shared_ptr
<
ClueOp
>
*
op
);
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetNumWorkers
(
int32_t
num_workers
)
{
builder_num_workers_
=
num_workers
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetOpConnectorSize
(
int32_t
op_connector_size
)
{
builder_op_connector_size_
=
op_connector_size
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetRowsPerBuffer
(
int64_t
rows_per_buffer
)
{
builder_rows_per_buffer_
=
rows_per_buffer
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetNumDevices
(
int64_t
num_dev
)
{
builder_num_devices_
=
num_dev
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetDeviceId
(
int64_t
dev_id
)
{
builder_device_id_
=
dev_id
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetClueFilesList
(
const
std
::
vector
<
std
::
string
>
&
files_list
)
{
builder_clue_files_list_
=
files_list
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetShuffleFiles
(
bool
shuffle_files
)
{
builder_shuffle_files_
=
shuffle_files
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetNumSamples
(
int64_t
num_samples
)
{
builder_num_samples_
=
num_samples
;
return
*
this
;
}
// Setter method.
// @return Builder - setter method returns reference to the builder.
Builder
&
SetColsKeyMap
(
const
std
::
map
<
std
::
string
,
std
::
string
>
&
cols_to_key
)
{
builder_cols_to_keyword_
=
cols_to_key
;
return
*
this
;
}
// Split string based on a character delimiter
// @return - the a string vector
std
::
vector
<
std
::
string
>
split
(
const
std
::
string
&
s
,
char
delim
);
private:
int32_t
builder_device_id_
;
int32_t
builder_num_devices_
;
int32_t
builder_num_workers_
;
int32_t
builder_op_connector_size_
;
int64_t
builder_rows_per_buffer_
;
int64_t
builder_num_samples_
;
int32_t
builder_worker_connector_size_
;
std
::
vector
<
std
::
string
>
builder_clue_files_list_
;
bool
builder_shuffle_files_
;
std
::
map
<
std
::
string
,
std
::
string
>
builder_cols_to_keyword_
;
};
// Constructor of ClueOp
ClueOp
(
int32_t
num_workers
,
int64_t
rows_per_buffer
,
int64_t
num_samples
,
int32_t
worker_connector_size
,
ColKeyMap
cols_to_keyword
,
std
::
vector
<
std
::
string
>
clue_files_list
,
int32_t
op_connector_size
,
bool
shuffle_files
,
int32_t
num_devices
,
int32_t
device_id
);
// Default destructor
~
ClueOp
()
=
default
;
// A print method typically used for debugging
// @param out - The output stream to write output to
// @param show_all - A bool to control if you want to show all info or just a summary
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
// Instantiates the internal queues and connectors
// @return Status - the error code returned
Status
Init
();
// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status
operator
()()
override
;
// Overrides base class reset method. Cleans up any state info from it's previous execution
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status
Reset
()
override
;
// Get total rows in files.
// @param files - all clue files.
// @param count - number of rows.
// @return Status - the error coed returned.
static
Status
CountAllFileRows
(
const
std
::
vector
<
std
::
string
>
&
files
,
int64_t
*
count
);
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status
WorkerEntry
(
int32_t
worker_id
)
override
;
// Parses a single row and puts the data into a tensor table.
// @param line - the content of the row.
// @param tensor_table - the tensor table to put the parsed data in.
// @param row - the id of the row filled in the tensor table.
// @return Status - the error code returned.
Status
LoadTensor
(
const
std
::
string
&
line
,
std
::
unique_ptr
<
TensorQTable
>
*
tensor_table
,
int64_t
row
);
// Reads a clue file and loads the data into multiple buffers.
// @param file - the file to read.
// @param start_offset - the start offset of file.
// @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status
LoadFile
(
const
std
::
string
&
file
,
const
int64_t
start_offset
,
const
int64_t
end_offset
,
const
int32_t
worker_id
);
// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status
PopIoBlockQueue
(
int32_t
index
,
std
::
unique_ptr
<
FilenameBlock
>
*
out_block
);
// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status
PushIoBlockQueue
(
int32_t
index
,
std
::
unique_ptr
<
FilenameBlock
>
&&
io_block
);
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status
WaitToFillIOBlockQueue
();
// Fill the IOBlockQueue.
// @para i_keys - keys of file to fill to the IOBlockQueue
// @return Status - the error code returned.
Status
FillIOBlockQueue
(
const
std
::
vector
<
int64_t
>
&
i_keys
);
// Notifies the thread which called FillIoBlockQueue to resume execution
void
NotifyToFillIOBlockQueue
();
// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_file - If file contains the first sample of data.
// @param end_file - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool
NeedPushFileToBlockQueue
(
const
std
::
string
&
file_name
,
int64_t
*
start_offset
,
int64_t
*
end_offset
,
const
int64_t
&
pre_count
);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status
PostEndOfEpoch
(
int32_t
queue_index
);
// Calculate number of rows in each shard.
// @return Status - the error code returned.
Status
CalculateNumRowsPerShard
();
// Count number of rows in each file.
// @param filename - clue file name.
// @return int64_t - the total number of rows in file.
int64_t
CountTotalRows
(
const
std
::
string
&
file
);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status
PostEndOfData
();
// @return Status - the error code returned.
Status
GetValue
(
const
nlohmann
::
json
&
js
,
std
::
vector
<
std
::
string
>
key_chain
,
std
::
shared_ptr
<
Tensor
>
*
t
);
int32_t
device_id_
;
bool
shuffle_files_
;
bool
finished_reading_dataset_
;
int32_t
num_devices_
;
int64_t
rows_per_buffer_
;
bool
load_io_block_queue_
;
int64_t
num_rows_per_shard_
;
int64_t
all_num_rows_
;
int64_t
num_samples_
;
std
::
map
<
std
::
string
,
int64_t
>
filename_numrows_
;
std
::
unique_ptr
<
StringIndex
>
filename_index_
;
std
::
vector
<
std
::
string
>
clue_files_list_
;
WaitPost
io_block_queue_wait_post_
;
std
::
unique_ptr
<
JaggedConnector
>
jagged_buffer_connector_
;
QueueList
<
std
::
unique_ptr
<
FilenameBlock
>>
io_block_queues_
;
bool
load_jagged_connector_
;
ColKeyMap
cols_to_keyword_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_
mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc
浏览文件 @
a8cf83ac
...
...
@@ -43,7 +43,7 @@ TextFileOp::Builder::Builder()
Status
TextFileOp
::
Builder
::
ValidateInputs
()
const
{
std
::
string
err_msg
;
err_msg
+=
builder_num_workers_
<=
0
?
"Number of parallel workers should be greate than 0
\n
"
:
""
;
err_msg
+=
builder_num_workers_
<=
0
?
"Number of parallel workers should be greate
r
than 0
\n
"
:
""
;
err_msg
+=
builder_device_id_
>=
builder_num_devices_
||
builder_num_devices_
<
1
?
"Wrong sharding configs
\n
"
:
""
;
return
err_msg
.
empty
()
?
Status
::
OK
()
:
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
err_msg
);
}
...
...
mindspore/dataset/__init__.py
浏览文件 @
a8cf83ac
...
...
@@ -21,7 +21,7 @@ can also create samplers with this module to sample data.
from
.core.configuration
import
config
from
.engine.datasets
import
TFRecordDataset
,
ImageFolderDatasetV2
,
MnistDataset
,
MindDataset
,
NumpySlicesDataset
,
\
GeneratorDataset
,
ManifestDataset
,
Cifar10Dataset
,
Cifar100Dataset
,
VOCDataset
,
CocoDataset
,
CelebADataset
,
\
TextFileDataset
,
Schema
,
Shuffle
,
zip
,
RandomDataset
TextFileDataset
,
CLUEDataset
,
Schema
,
Shuffle
,
zip
,
RandomDataset
from
.engine.samplers
import
DistributedSampler
,
PKSampler
,
RandomSampler
,
SequentialSampler
,
SubsetRandomSampler
,
\
WeightedRandomSampler
,
Sampler
from
.engine.serializer_deserializer
import
serialize
,
deserialize
,
show
...
...
@@ -29,6 +29,6 @@ from .engine.graphdata import GraphData
__all__
=
[
"config"
,
"ImageFolderDatasetV2"
,
"MnistDataset"
,
"MindDataset"
,
"GeneratorDataset"
,
"TFRecordDataset"
,
"ManifestDataset"
,
"Cifar10Dataset"
,
"Cifar100Dataset"
,
"CelebADataset"
,
"NumpySlicesDataset"
,
"
VOCDataset"
,
"CocoDataset"
,
"TextFileDataset"
,
"Schema"
,
"DistributedSampler"
,
"PKSampler"
,
"Random
Sampler"
,
"SequentialSampler"
,
"SubsetRandomSampler"
,
"WeightedRandomSampler"
,
"zip"
,
"GraphData"
]
"ManifestDataset"
,
"Cifar10Dataset"
,
"Cifar100Dataset"
,
"CelebADataset"
,
"NumpySlicesDataset"
,
"VOCDataset"
,
"
CocoDataset"
,
"TextFileDataset"
,
"CLUEDataset"
,
"Schema"
,
"DistributedSampler"
,
"PK
Sampler"
,
"
RandomSampler"
,
"
SequentialSampler"
,
"SubsetRandomSampler"
,
"WeightedRandomSampler"
,
"zip"
,
"GraphData"
]
mindspore/dataset/engine/__init__.py
浏览文件 @
a8cf83ac
...
...
@@ -30,7 +30,7 @@ from ..core.configuration import config, ConfigurationManager
__all__
=
[
"config"
,
"ConfigurationManager"
,
"zip"
,
"ImageFolderDatasetV2"
,
"MnistDataset"
,
"MindDataset"
,
"GeneratorDataset"
,
"TFRecordDataset"
,
"MindDataset"
,
"GeneratorDataset"
,
"TFRecordDataset"
,
"CLUEDataset"
,
"ManifestDataset"
,
"Cifar10Dataset"
,
"Cifar100Dataset"
,
"CelebADataset"
,
"VOCDataset"
,
"CocoDataset"
,
"TextFileDataset"
,
"BuildVocabDataset"
,
"Schema"
,
"Schema"
,
"DistributedSampler"
,
"PKSampler"
,
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
a8cf83ac
...
...
@@ -33,7 +33,7 @@ import copy
import
numpy
as
np
from
mindspore._c_dataengine
import
DataType
,
TFReaderOp
,
ImageFolderOp
,
CifarOp
,
MnistOp
,
ManifestOp
,
\
MindRecordOp
,
TextFileOp
,
VOCOp
,
CocoOp
,
CBatchInfo
MindRecordOp
,
TextFileOp
,
ClueOp
,
VOCOp
,
CocoOp
,
CBatchInfo
from
mindspore._c_expression
import
typing
from
mindspore
import
log
as
logger
...
...
@@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_take
,
check_project
,
check_imagefolderdatasetv2
,
check_mnist_cifar_dataset
,
check_manifestdataset
,
\
check_tfrecorddataset
,
check_vocdataset
,
check_cocodataset
,
check_celebadataset
,
check_minddataset
,
\
check_generatordataset
,
check_sync_wait
,
check_zip_dataset
,
check_add_column
,
check_textfiledataset
,
check_concat
,
\
check_split
check_split
,
check_cluedataset
from
..core.datatypes
import
mstype_to_detype
,
mstypelist_to_detypelist
try
:
...
...
@@ -4317,6 +4317,222 @@ class CelebADataset(MappableDataset):
return
self
.
sampler
.
is_sharded
()
class
CLUEDataset
(
SourceDataset
):
"""
A source dataset that reads and parses CLUE datasets.
CLUE, the Chinese Language Understanding Evaluation Benchmark, a collection of datasets, baselines, pre-trained
models, corpus and leaderboard. Here we bring in classification task of CLUE, which are AFQMC, TNEWS, IFLYTEK,
CMNLI, WSC and CSL.
Args:
dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of
files. The list will be sorted in a lexicographical order.
task (str, optional): The kind of task, one of 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' and 'CSL'.
(default=AFQMC).
usage (str, optional): Need train, test or eval data (default="train").
num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
num_parallel_workers (int, optional): number of workers to read the data
(default=None, number set in the config).
shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL).
If shuffle is False, no shuffling will be performed;
If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL
Otherwise, there are two levels of shuffling:
- Shuffle.GLOBAL: Shuffle both the files and samples.
- Shuffle.FILES: Shuffle files only.
num_shards (int, optional): Number of shards that the dataset should be divided into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument should be specified only when num_shards is also specified.
Examples:
>>> import mindspore.dataset as ds
>>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files
>>> dataset = ds.CLUEDataset(dataset_files=dataset_files, task='AFQMC', usage='train')
"""
@
check_cluedataset
def
__init__
(
self
,
dataset_files
,
task
=
'AFQMC'
,
usage
=
'train'
,
num_samples
=
None
,
num_parallel_workers
=
None
,
shuffle
=
Shuffle
.
GLOBAL
,
num_shards
=
None
,
shard_id
=
None
):
super
().
__init__
(
num_parallel_workers
)
self
.
dataset_files
=
self
.
_find_files
(
dataset_files
)
self
.
dataset_files
.
sort
()
self
.
num_samples
=
num_samples
self
.
task_dict
=
{
'AFQMC'
:
{
'train'
:
{
'sentence1'
:
'sentence1'
,
'sentence2'
:
'sentence2'
,
'label'
:
'label'
},
'test'
:
{
'id'
:
'id'
,
'sentence1'
:
'sentence1'
,
'sentence2'
:
'sentence2'
},
'eval'
:
{
'sentence1'
:
'sentence1'
,
'sentence2'
:
'sentence2'
,
'label'
:
'label'
}
},
'CMNLI'
:
{
'train'
:
{
'sentence1'
:
'sentence1'
,
'sentence2'
:
'sentence2'
,
'label'
:
'label'
},
'test'
:
{
'id'
:
'id'
,
'sentence1'
:
'sentence1'
,
'sentence2'
:
'sentence2'
},
'eval'
:
{
'sentence1'
:
'sentence1'
,
'sentence2'
:
'sentence2'
,
'label'
:
'label'
}
},
'CSL'
:
{
'train'
:
{
'id'
:
'id'
,
'abst'
:
'abst'
,
'keyword'
:
'keyword'
,
'label'
:
'label'
},
'test'
:
{
'id'
:
'id'
,
'abst'
:
'abst'
,
'keyword'
:
'keyword'
},
'eval'
:
{
'id'
:
'id'
,
'abst'
:
'abst'
,
'keyword'
:
'keyword'
,
'label'
:
'label'
}
},
'IFLYTEK'
:
{
'train'
:
{
'label'
:
'label'
,
'label_des'
:
'label_des'
,
'sentence'
:
'sentence'
},
'test'
:
{
'id'
:
'id'
,
'sentence'
:
'sentence'
,
},
'eval'
:
{
'label'
:
'label'
,
'label_des'
:
'label_des'
,
'sentence'
:
'sentence'
}
},
'TNEWS'
:
{
'train'
:
{
'label'
:
'label'
,
'label_desc'
:
'label_desc'
,
'sentence'
:
'sentence'
,
'keywords'
:
'keywords'
},
'test'
:
{
'id'
:
'id'
,
'sentence'
:
'sentence'
,
'keywords'
:
'keywords'
},
'eval'
:
{
'label'
:
'label'
,
'label_desc'
:
'label_desc'
,
'sentence'
:
'sentence'
,
'keywords'
:
'keywords'
}
},
'WSC'
:
{
'train'
:
{
'span1_index'
:
'target/span1_index'
,
'span2_index'
:
'target/span2_index'
,
'span1_text'
:
'target/span1_text'
,
'span2_text'
:
'target/span2_text'
,
'idx'
:
'idx'
,
'label'
:
'label'
,
'text'
:
'text'
},
'test'
:
{
'span1_index'
:
'target/span1_index'
,
'span2_index'
:
'target/span2_index'
,
'span1_text'
:
'target/span1_text'
,
'span2_text'
:
'target/span2_text'
,
'idx'
:
'idx'
,
'text'
:
'text'
},
'eval'
:
{
'span1_index'
:
'target/span1_index'
,
'span2_index'
:
'target/span2_index'
,
'span1_text'
:
'target/span1_text'
,
'span2_text'
:
'target/span2_text'
,
'idx'
:
'idx'
,
'label'
:
'label'
,
'text'
:
'text'
}
}
}
self
.
cols_to_keyword
=
self
.
task_dict
[
task
][
usage
]
if
not
isinstance
(
shuffle
,
(
bool
,
Shuffle
)):
raise
TypeError
(
"shuffle should be of boolean or enum 'Shuffle'."
)
if
not
isinstance
(
shuffle
,
Shuffle
):
if
shuffle
:
self
.
shuffle_level
=
Shuffle
.
GLOBAL
self
.
shuffle_files
=
True
else
:
self
.
shuffle_level
=
None
self
.
shuffle_files
=
False
else
:
self
.
shuffle_level
=
shuffle
self
.
shuffle_files
=
True
self
.
num_shards
=
num_shards
self
.
shard_id
=
shard_id
def
get_args
(
self
):
args
=
super
().
get_args
()
args
[
"dataset_files"
]
=
self
.
dataset_files
args
[
"num_samples"
]
=
self
.
num_samples
if
self
.
shuffle_files
is
not
None
:
args
[
"shuffle_files"
]
=
self
.
shuffle_files
args
[
"shuffle"
]
=
self
.
shuffle_level
args
[
"num_shards"
]
=
self
.
num_shards
args
[
"shard_id"
]
=
self
.
shard_id
args
[
"cols_to_keyword"
]
=
self
.
cols_to_keyword
return
args
def
get_dataset_size
(
self
):
"""
Get the number of batches in an epoch.
Return:
Number, number of batches.
"""
if
self
.
_dataset_size
is
None
:
num_rows
=
ClueOp
.
get_num_rows
(
self
.
dataset_files
)
num_rows
=
get_num_rows
(
num_rows
,
self
.
num_shards
)
if
self
.
num_samples
is
None
:
return
num_rows
return
min
(
self
.
num_samples
,
num_rows
)
return
self
.
_dataset_size
def
is_shuffled
(
self
):
return
self
.
shuffle_files
def
is_sharded
(
self
):
if
self
.
num_shards
is
not
None
:
return
self
.
num_shards
>
1
return
False
class
TextFileDataset
(
SourceDataset
):
"""
A source dataset that reads and parses datasets stored on disk in text format.
...
...
mindspore/dataset/engine/iterators.py
浏览文件 @
a8cf83ac
...
...
@@ -50,7 +50,8 @@ def alter_tree(node):
def
_alter_node
(
node
):
"""Performing some alteration to a dataset node. A common alteration is to insert a node."""
if
isinstance
(
node
,
(
de
.
TFRecordDataset
,
de
.
TextFileDataset
))
and
node
.
shuffle_level
==
de
.
Shuffle
.
GLOBAL
:
if
isinstance
(
node
,
(
de
.
TFRecordDataset
,
de
.
TextFileDataset
,
de
.
CLUEDataset
))
\
and
node
.
shuffle_level
==
de
.
Shuffle
.
GLOBAL
:
# Remove the connection between the parent's node to the current node because we are inserting a node.
if
node
.
output
:
node
.
output
.
pop
()
...
...
@@ -179,6 +180,8 @@ class Iterator:
op_type
=
OpName
.
TEXTFILE
elif
isinstance
(
dataset
,
de
.
BuildVocabDataset
):
op_type
=
OpName
.
BUILDVOCAB
elif
isinstance
(
dataset
,
de
.
CLUEDataset
):
op_type
=
OpName
.
CLUE
else
:
raise
ValueError
(
"Unsupported DatasetOp"
)
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
a8cf83ac
...
...
@@ -1075,6 +1075,41 @@ def check_add_column(method):
return
new_method
def
check_cluedataset
(
method
):
"""A wrapper that wrap a parameter checker to the original Dataset(CLUEDataset)."""
@
wraps
(
method
)
def
new_method
(
*
args
,
**
kwargs
):
param_dict
=
make_param_dict
(
method
,
args
,
kwargs
)
nreq_param_int
=
[
'num_samples'
,
'num_parallel_workers'
,
'num_shards'
,
'shard_id'
]
# check dataset_files; required argument
dataset_files
=
param_dict
.
get
(
'dataset_files'
)
if
dataset_files
is
None
:
raise
ValueError
(
"dataset_files is not provided."
)
if
not
isinstance
(
dataset_files
,
(
str
,
list
)):
raise
TypeError
(
"dataset_files should be of type str or a list of strings."
)
# check task
task_param
=
param_dict
.
get
(
'task'
)
if
task_param
not
in
[
'AFQMC'
,
'TNEWS'
,
'IFLYTEK'
,
'CMNLI'
,
'WSC'
,
'CSL'
]:
raise
ValueError
(
"task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL"
)
# check usage
usage_param
=
param_dict
.
get
(
'usage'
)
if
usage_param
not
in
[
'train'
,
'test'
,
'eval'
]:
raise
ValueError
(
"usage should be train, test or eval"
)
check_param_type
(
nreq_param_int
,
param_dict
,
int
)
check_sampler_shuffle_shard_options
(
param_dict
)
return
method
(
*
args
,
**
kwargs
)
return
new_method
def
check_textfiledataset
(
method
):
"""A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset)."""
...
...
tests/ut/cpp/dataset/CMakeLists.txt
浏览文件 @
a8cf83ac
...
...
@@ -65,6 +65,7 @@ SET(DE_UT_SRCS
cifar_op_test.cc
celeba_op_test.cc
take_op_test.cc
clue_op_test.cc
text_file_op_test.cc
filter_op_test.cc
concat_op_test.cc
...
...
tests/ut/cpp/dataset/clue_op_test.cc
0 → 100644
浏览文件 @
a8cf83ac
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include <vector>
#include "dataset/core/client.h"
#include "common/common.h"
#include "common/utils.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
#include "dataset/engine/datasetops/source/clue_op.h"
#include "dataset/util/status.h"
namespace
common
=
mindspore
::
common
;
using
namespace
mindspore
::
dataset
;
using
mindspore
::
MsLogLevel
::
INFO
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
using
mindspore
::
LogStream
;
class
MindDataTestCLUEOp
:
public
UT
::
DatasetOpTesting
{
};
TEST_F
(
MindDataTestCLUEOp
,
TestCLUEBasic
)
{
// Start with an empty execution tree
auto
tree
=
std
::
make_shared
<
ExecutionTree
>
();
std
::
string
dataset_path
;
dataset_path
=
datasets_root_path_
+
"/testCLUE/afqmc/train.json"
;
std
::
map
<
std
::
string
,
std
::
string
>
key_map
;
key_map
[
"sentence1"
]
=
"sentence1"
;
key_map
[
"sentence2"
]
=
"sentence2"
;
key_map
[
"label"
]
=
"label"
;
std
::
shared_ptr
<
ClueOp
>
op
;
ClueOp
::
Builder
builder
;
builder
.
SetClueFilesList
({
dataset_path
})
.
SetRowsPerBuffer
(
16
)
.
SetNumWorkers
(
16
)
.
SetOpConnectorSize
(
2
)
.
SetColsKeyMap
(
key_map
);
Status
rc
=
builder
.
Build
(
&
op
);
ASSERT_TRUE
(
rc
.
IsOk
());
rc
=
tree
->
AssociateNode
(
op
);
ASSERT_TRUE
(
rc
.
IsOk
());
rc
=
tree
->
AssignRoot
(
op
);
ASSERT_TRUE
(
rc
.
IsOk
());
MS_LOG
(
INFO
)
<<
"Launching tree and begin iteration."
;
rc
=
tree
->
Prepare
();
ASSERT_TRUE
(
rc
.
IsOk
());
rc
=
tree
->
Launch
();
ASSERT_TRUE
(
rc
.
IsOk
());
// Start the loop of reading tensors from our pipeline
DatasetIterator
di
(
tree
);
TensorRow
tensor_list
;
rc
=
di
.
FetchNextTensorRow
(
&
tensor_list
);
ASSERT_TRUE
(
rc
.
IsOk
());
int
row_count
=
0
;
while
(
!
tensor_list
.
empty
())
{
// Display the tensor by calling the printer on it
for
(
int
i
=
0
;
i
<
tensor_list
.
size
();
i
++
)
{
std
::
ostringstream
ss
;
ss
<<
"("
<<
tensor_list
[
i
]
<<
"): "
<<
*
tensor_list
[
i
]
<<
std
::
endl
;
MS_LOG
(
INFO
)
<<
"Tensor print: "
<<
ss
.
str
()
<<
"."
;
}
rc
=
di
.
FetchNextTensorRow
(
&
tensor_list
);
ASSERT_TRUE
(
rc
.
IsOk
());
row_count
++
;
}
ASSERT_EQ
(
row_count
,
3
);
}
TEST_F
(
MindDataTestCLUEOp
,
TestTotalRows
)
{
std
::
string
tf_file1
=
datasets_root_path_
+
"/testCLUE/afqmc/train.json"
;
std
::
string
tf_file2
=
datasets_root_path_
+
"/testCLUE/afqmc/dev.json"
;
std
::
vector
<
std
::
string
>
files
;
files
.
push_back
(
tf_file1
);
int64_t
total_rows
=
0
;
ClueOp
::
CountAllFileRows
(
files
,
&
total_rows
);
ASSERT_EQ
(
total_rows
,
3
);
files
.
clear
();
files
.
push_back
(
tf_file2
);
ClueOp
::
CountAllFileRows
(
files
,
&
total_rows
);
ASSERT_EQ
(
total_rows
,
3
);
files
.
clear
();
files
.
push_back
(
tf_file1
);
files
.
push_back
(
tf_file2
);
ClueOp
::
CountAllFileRows
(
files
,
&
total_rows
);
ASSERT_EQ
(
total_rows
,
6
);
files
.
clear
();
}
tests/ut/data/dataset/testCLUE/afqmc/dev.json
0 → 100644
浏览文件 @
a8cf83ac
{
"sentence1"
:
"你有花呗吗"
,
"sentence2"
:
"我的花呗没额度了"
,
"label"
:
"0"
}
{
"sentence1"
:
"吃饭能用花呗吗"
,
"sentence2"
:
"花呗太方便了"
,
"label"
:
"0"
}
{
"sentence1"
:
"蚂蚁花呗支付金额有什么限制"
,
"sentence2"
:
"我到实体店消费用花呗支付受金额限制"
,
"label"
:
"1"
}
tests/ut/data/dataset/testCLUE/afqmc/test.json
0 → 100644
浏览文件 @
a8cf83ac
{
"id"
:
0
,
"sentence1"
:
"借呗取消的时间"
,
"sentence2"
:
"蚂蚁借呗恢复的月数"
}
{
"id"
:
1
,
"sentence1"
:
"网商贷用什么方法转变成借呗"
,
"sentence2"
:
"什么手段能将网商贷切换为借呗"
}
{
"id"
:
2
,
"sentence1"
:
"我的借呗为什么开通不了"
,
"sentence2"
:
"我为啥没法开通借呗"
}
tests/ut/data/dataset/testCLUE/afqmc/train.json
0 → 100644
浏览文件 @
a8cf83ac
{
"sentence1"
:
"蚂蚁借呗等额还款能否换成先息后本"
,
"sentence2"
:
"借呗可以先息到期还本吗"
,
"label"
:
"0"
}
{
"sentence1"
:
"蚂蚁花呗说我违约了"
,
"sentence2"
:
"蚂蚁花呗违约行为是啥"
,
"label"
:
"0"
}
{
"sentence1"
:
"帮我看看本月花呗账单结清了没"
,
"sentence2"
:
"上月的花呗账单"
,
"label"
:
"0"
}
tests/ut/data/dataset/testCLUE/cmnli/dev.json
0 → 100644
浏览文件 @
a8cf83ac
{
"sentence1"
:
"每个人都有权利"
,
"sentence2"
:
"每个人都有福利"
,
"label"
:
"neutral"
}
{
"sentence1"
:
"有时候我喜欢他,但我也喜欢看到有人打他"
,
"sentence2"
:
"说实话,我有点喜欢他,但还是喜欢看到有人打他。"
,
"label"
:
"entailment"
}
{
"sentence1"
:
"我最喜欢的餐馆是离你最近的一家"
,
"sentence2"
:
"我最喜欢的餐馆离你家至少一百英里远。"
,
"label"
:
"contradiction"
}
tests/ut/data/dataset/testCLUE/cmnli/test.json
0 → 100644
浏览文件 @
a8cf83ac
{
"id"
:
0
,
"sentence1"
:
"今天,全球都在看着最新航天飞机的处女航。"
,
"sentence2"
:
"全世界都在看最新的航天飞机发射。"
}
{
"id"
:
1
,
"sentence1"
:
"而我们把竹篮放在一个地方,把玻璃瓶放在另一处,把书放在另一处,满了要把它放到车里"
,
"sentence2"
:
"我们没有分开任何东西,都把它全扔进一个箱子里。"
}
{
"id"
:
2
,
"sentence1"
:
"她占用了我的很多时间,她给我读了很多关于灵异的故事,我觉得很无聊。"
,
"sentence2"
:
"我喜欢和她一起读鬼故事。"
}
tests/ut/data/dataset/testCLUE/cmnli/train.json
0 → 100644
浏览文件 @
a8cf83ac
{
"sentence1"
:
"你应该给这件衣服定一个价格。"
,
"sentence2"
:
"不同的衣服有不同的价格。"
,
"label"
:
"neutral"
}
{
"sentence1"
:
"我怎么知道他要说什么"
,
"sentence2"
:
"他说什么我并不知道。"
,
"label"
:
"entailment"
}
{
"sentence1"
:
"向左。"
,
"sentence2"
:
"向右。"
,
"label"
:
"contradiction"
}
tests/ut/data/dataset/testCLUE/csl/dev.json
0 → 100644
浏览文件 @
a8cf83ac
{
"id"
:
1
,
"abst"
:
"这是第一段很长的文本"
,
"keyword"
:
[
"关键词1"
,
"关键词2"
,
"关键词3"
,
"关键词4"
],
"label"
:
"1"
}
{
"id"
:
2
,
"abst"
:
"这是第二段很长的文本"
,
"keyword"
:
[
"关键词1"
,
"关键词2"
,
"关键词3"
,
"关键词4"
],
"label"
:
"1"
}
{
"id"
:
3
,
"abst"
:
"这是第三段很长的文本"
,
"keyword"
:
[
"1"
,
"2"
,
"3"
],
"label"
:
"0"
}
tests/ut/data/dataset/testCLUE/csl/test.json
0 → 100644
浏览文件 @
a8cf83ac
{
"id"
:
2415
,
"abst"
:
"长文本1"
,
"keyword"
:
[
"关键词1"
,
"关键词2"
]}
{
"id"
:
2565
,
"abst"
:
"长文本2"
,
"keyword"
:
[
"关键词1"
,
"关键词2"
,
"关键词3"
]}
{
"id"
:
2625
,
"abst"
:
"长文本3"
,
"keyword"
:
[
"关键词1"
,
"关键词2"
,
"关键词3"
,
"关键词4"
]}
tests/ut/data/dataset/testCLUE/csl/train.json
0 → 100644
浏览文件 @
a8cf83ac
{
"id"
:
1
,
"abst"
:
"这是一段长文本"
,
"keyword"
:
[
"关键词1"
,
"关键词2"
,
"关键词3"
,
"关键词4"
],
"label"
:
"0"
}
{
"id"
:
2
,
"abst"
:
"这是一段长文本"
,
"keyword"
:
[
"关键词5"
,
"关键词6"
,
"关键词7"
,
"关键词8"
],
"label"
:
"0"
}
{
"id"
:
3
,
"abst"
:
"这是一段长文本"
,
"keyword"
:
[
"关键词9"
,
"关键词10"
,
"关键词11"
,
"关键词12"
],
"label"
:
"0"
}
tests/ut/data/dataset/testCLUE/iflytek/dev.json
0 → 100644
浏览文件 @
a8cf83ac
{
"label"
:
"110"
,
"label_des"
:
"社区超市"
,
"sentence"
:
"这是第一段文本"
}
{
"label"
:
"70"
,
"label_des"
:
"工具"
,
"sentence"
:
"这是第二段文本"
}
{
"label"
:
"10"
,
"label_des"
:
"社区服务"
,
"sentence"
:
"这是第三段文本"
}
tests/ut/data/dataset/testCLUE/iflytek/test.json
0 → 100644
浏览文件 @
a8cf83ac
{
"id"
:
0
,
"sentence"
:
"文本1"
}
{
"id"
:
1
,
"sentence"
:
"文本2"
}
{
"id"
:
2
,
"sentence"
:
"文本3"
}
tests/ut/data/dataset/testCLUE/iflytek/train.json
0 → 100644
浏览文件 @
a8cf83ac
{
"label"
:
"11"
,
"label_des"
:
"薅羊毛"
,
"sentence"
:
"第一个文本"
}
{
"label"
:
"95"
,
"label_des"
:
"借贷"
,
"sentence"
:
"第二个文本"
}
{
"label"
:
"74"
,
"label_des"
:
"违章"
,
"sentence"
:
"第三个文本"
}
tests/ut/data/dataset/testCLUE/tnews/dev.json
0 → 100644
浏览文件 @
a8cf83ac
{
"label"
:
"102"
,
"label_desc"
:
"news_entertainment"
,
"sentence"
:
"新闻1"
,
"keywords"
:
"关键词一,关键词二,关键词三,关键词四"
}
{
"label"
:
"110"
,
"label_desc"
:
"news_military"
,
"sentence"
:
"新闻2"
,
"keywords"
:
"关键词一,关键词二,关键词三,关键词四,关键词五"
}
{
"label"
:
"104"
,
"label_desc"
:
"news_finance"
,
"sentence"
:
"新闻3"
,
"keywords"
:
"关键词一,关键词二,关键词三,关键词四,关键词五"
}
tests/ut/data/dataset/testCLUE/tnews/test.json
0 → 100644
浏览文件 @
a8cf83ac
{
"id"
:
0
,
"sentence"
:
"新闻1"
,
"keywords"
:
"关键词1,关键词2,关键词3,关键词4,关键词5"
}
{
"id"
:
1
,
"sentence"
:
"新闻2"
,
"keywords"
:
"关键词1,关键词2,关键词3,关键词4"
}
{
"id"
:
2
,
"sentence"
:
"新闻3"
,
"keywords"
:
""
}
tests/ut/data/dataset/testCLUE/tnews/train.json
0 → 100644
浏览文件 @
a8cf83ac
{
"label"
:
"108"
,
"label_desc"
:
"news_edu"
,
"sentence"
:
"新闻1"
,
"keywords"
:
""
}
{
"label"
:
"104"
,
"label_desc"
:
"news_finance"
,
"sentence"
:
"新闻2"
,
"keywords"
:
"关键词1,关键词2,关键词3,关键词4,关键词5,关键词6"
}
{
"label"
:
"106"
,
"label_desc"
:
"news_house"
,
"sentence"
:
"新闻3"
,
"keywords"
:
""
}
tests/ut/data/dataset/testCLUE/wsc/dev.json
0 → 100755
浏览文件 @
a8cf83ac
{
"target"
:
{
"span1_index"
:
0
,
"span1_text"
:
"小明"
,
"span2_index"
:
4
,
"span2_text"
:
"他"
},
"idx"
:
0
,
"text"
:
"小明呢,他在哪?"
,
"label"
:
"true"
}
{
"target"
:
{
"span1_index"
:
0
,
"span1_text"
:
"小红"
,
"span2_index"
:
9
,
"span2_text"
:
"他"
},
"idx"
:
1
,
"text"
:
"小红刚刚看到小明,他在操场"
,
"label"
:
"false"
}
{
"target"
:
{
"span1_index"
:
6
,
"span1_text"
:
"小张"
,
"span2_index"
:
8
,
"span2_text"
:
"你"
},
"idx"
:
2
,
"text"
:
"等小明回来,小张你叫他交作业"
,
"label"
:
"true"
}
tests/ut/data/dataset/testCLUE/wsc/test.json
0 → 100755
浏览文件 @
a8cf83ac
{
"target"
:
{
"span1_index"
:
0
,
"span1_text"
:
"小明"
,
"span2_index"
:
4
,
"span2_text"
:
"他"
},
"idx"
:
0
,
"text"
:
"小明呢,他在哪?"
}
{
"target"
:
{
"span1_index"
:
0
,
"span1_text"
:
"小红"
,
"span2_index"
:
9
,
"span2_text"
:
"他"
},
"idx"
:
1
,
"text"
:
"小红刚刚看到小明,他在操场"
}
{
"target"
:
{
"span1_index"
:
6
,
"span1_text"
:
"小张"
,
"span2_index"
:
8
,
"span2_text"
:
"你"
},
"idx"
:
2
,
"text"
:
"等小明回来,小张你叫他交作业"
}
tests/ut/data/dataset/testCLUE/wsc/train.json
0 → 100755
浏览文件 @
a8cf83ac
{
"target"
:
{
"span1_index"
:
0
,
"span1_text"
:
"小明"
,
"span2_index"
:
4
,
"span2_text"
:
"他"
},
"idx"
:
0
,
"text"
:
"小明呢,他在哪?"
,
"label"
:
"true"
}
{
"target"
:
{
"span1_index"
:
0
,
"span1_text"
:
"小红"
,
"span2_index"
:
9
,
"span2_text"
:
"他"
},
"idx"
:
1
,
"text"
:
"小红刚刚看到小明,他在操场"
,
"label"
:
"false"
}
{
"target"
:
{
"span1_index"
:
6
,
"span1_text"
:
"小张"
,
"span2_index"
:
8
,
"span2_text"
:
"你"
},
"idx"
:
2
,
"text"
:
"等小明回来,小张你叫他交作业"
,
"label"
:
"true"
}
tests/ut/python/dataset/test_datasets_clue.py
0 → 100644
浏览文件 @
a8cf83ac
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
mindspore.dataset
as
ds
def
test_clue
():
"""
Test CLUE with repeat, skip and so on
"""
TRAIN_FILE
=
'../data/dataset/testCLUE/afqmc/train.json'
buffer
=
[]
data
=
ds
.
CLUEDataset
(
TRAIN_FILE
,
task
=
'AFQMC'
,
usage
=
'train'
,
shuffle
=
False
)
data
=
data
.
repeat
(
2
)
data
=
data
.
skip
(
3
)
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
({
'label'
:
d
[
'label'
].
item
().
decode
(
"utf8"
),
'sentence1'
:
d
[
'sentence1'
].
item
().
decode
(
"utf8"
),
'sentence2'
:
d
[
'sentence2'
].
item
().
decode
(
"utf8"
)
})
assert
len
(
buffer
)
==
3
def
test_clue_num_shards
():
"""
Test num_shards param of CLUE dataset
"""
TRAIN_FILE
=
'../data/dataset/testCLUE/afqmc/train.json'
buffer
=
[]
data
=
ds
.
CLUEDataset
(
TRAIN_FILE
,
task
=
'AFQMC'
,
usage
=
'train'
,
num_shards
=
3
,
shard_id
=
1
)
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
({
'label'
:
d
[
'label'
].
item
().
decode
(
"utf8"
),
'sentence1'
:
d
[
'sentence1'
].
item
().
decode
(
"utf8"
),
'sentence2'
:
d
[
'sentence2'
].
item
().
decode
(
"utf8"
)
})
assert
len
(
buffer
)
==
1
def
test_clue_num_samples
():
"""
Test num_samples param of CLUE dataset
"""
TRAIN_FILE
=
'../data/dataset/testCLUE/afqmc/train.json'
data
=
ds
.
CLUEDataset
(
TRAIN_FILE
,
task
=
'AFQMC'
,
usage
=
'train'
,
num_samples
=
2
)
count
=
0
for
_
in
data
.
create_dict_iterator
():
count
+=
1
assert
count
==
2
def
test_textline_dataset_get_datasetsize
():
"""
Test get_dataset_size of CLUE dataset
"""
TRAIN_FILE
=
'../data/dataset/testCLUE/afqmc/train.json'
data
=
ds
.
TextFileDataset
(
TRAIN_FILE
)
size
=
data
.
get_dataset_size
()
assert
size
==
3
def
test_clue_afqmc
():
"""
Test AFQMC for train, test and evaluation
"""
TRAIN_FILE
=
'../data/dataset/testCLUE/afqmc/train.json'
TEST_FILE
=
'../data/dataset/testCLUE/afqmc/test.json'
EVAL_FILE
=
'../data/dataset/testCLUE/afqmc/dev.json'
# train
buffer
=
[]
data
=
ds
.
CLUEDataset
(
TRAIN_FILE
,
task
=
'AFQMC'
,
usage
=
'train'
,
shuffle
=
False
)
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
({
'label'
:
d
[
'label'
].
item
().
decode
(
"utf8"
),
'sentence1'
:
d
[
'sentence1'
].
item
().
decode
(
"utf8"
),
'sentence2'
:
d
[
'sentence2'
].
item
().
decode
(
"utf8"
)
})
assert
len
(
buffer
)
==
3
# test
buffer
=
[]
data
=
ds
.
CLUEDataset
(
TEST_FILE
,
task
=
'AFQMC'
,
usage
=
'test'
,
shuffle
=
False
)
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
({
'id'
:
d
[
'id'
],
'sentence1'
:
d
[
'sentence1'
].
item
().
decode
(
"utf8"
),
'sentence2'
:
d
[
'sentence2'
].
item
().
decode
(
"utf8"
)
})
assert
len
(
buffer
)
==
3
# evaluation
buffer
=
[]
data
=
ds
.
CLUEDataset
(
EVAL_FILE
,
task
=
'AFQMC'
,
usage
=
'eval'
,
shuffle
=
False
)
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
({
'label'
:
d
[
'label'
].
item
().
decode
(
"utf8"
),
'sentence1'
:
d
[
'sentence1'
].
item
().
decode
(
"utf8"
),
'sentence2'
:
d
[
'sentence2'
].
item
().
decode
(
"utf8"
)
})
assert
len
(
buffer
)
==
3
def
test_clue_cmnli
():
"""
Test CMNLI for train, test and evaluation
"""
TRAIN_FILE
=
'../data/dataset/testCLUE/cmnli/train.json'
TEST_FILE
=
'../data/dataset/testCLUE/cmnli/test.json'
EVAL_FILE
=
'../data/dataset/testCLUE/cmnli/dev.json'
# train
buffer
=
[]
data
=
ds
.
CLUEDataset
(
TRAIN_FILE
,
task
=
'CMNLI'
,
usage
=
'train'
,
shuffle
=
False
)
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
({
'label'
:
d
[
'label'
].
item
().
decode
(
"utf8"
),
'sentence1'
:
d
[
'sentence1'
].
item
().
decode
(
"utf8"
),
'sentence2'
:
d
[
'sentence2'
].
item
().
decode
(
"utf8"
)
})
assert
len
(
buffer
)
==
3
# test
buffer
=
[]
data
=
ds
.
CLUEDataset
(
TEST_FILE
,
task
=
'CMNLI'
,
usage
=
'test'
,
shuffle
=
False
)
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
({
'id'
:
d
[
'id'
],
'sentence1'
:
d
[
'sentence1'
],
'sentence2'
:
d
[
'sentence2'
]
})
assert
len
(
buffer
)
==
3
# eval
buffer
=
[]
data
=
ds
.
CLUEDataset
(
EVAL_FILE
,
task
=
'CMNLI'
,
usage
=
'eval'
,
shuffle
=
False
)
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
({
'label'
:
d
[
'label'
],
'sentence1'
:
d
[
'sentence1'
],
'sentence2'
:
d
[
'sentence2'
]
})
assert
len
(
buffer
)
==
3
def
test_clue_csl
():
"""
Test CSL for train, test and evaluation
"""
TRAIN_FILE
=
'../data/dataset/testCLUE/csl/train.json'
TEST_FILE
=
'../data/dataset/testCLUE/csl/test.json'
EVAL_FILE
=
'../data/dataset/testCLUE/csl/dev.json'
# train
buffer
=
[]
data
=
ds
.
CLUEDataset
(
TRAIN_FILE
,
task
=
'CSL'
,
usage
=
'train'
,
shuffle
=
False
)
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
({
'id'
:
d
[
'id'
],
'abst'
:
d
[
'abst'
].
item
().
decode
(
"utf8"
),
'keyword'
:
[
i
.
item
().
decode
(
"utf8"
)
for
i
in
d
[
'keyword'
]],
'label'
:
d
[
'label'
].
item
().
decode
(
"utf8"
)
})
assert
len
(
buffer
)
==
3
# test
buffer
=
[]
data
=
ds
.
CLUEDataset
(
TEST_FILE
,
task
=
'CSL'
,
usage
=
'test'
,
shuffle
=
False
)
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
({
'id'
:
d
[
'id'
],
'abst'
:
d
[
'abst'
].
item
().
decode
(
"utf8"
),
'keyword'
:
[
i
.
item
().
decode
(
"utf8"
)
for
i
in
d
[
'keyword'
]],
})
assert
len
(
buffer
)
==
3
# eval
buffer
=
[]
data
=
ds
.
CLUEDataset
(
EVAL_FILE
,
task
=
'CSL'
,
usage
=
'eval'
,
shuffle
=
False
)
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
({
'id'
:
d
[
'id'
],
'abst'
:
d
[
'abst'
].
item
().
decode
(
"utf8"
),
'keyword'
:
[
i
.
item
().
decode
(
"utf8"
)
for
i
in
d
[
'keyword'
]],
'label'
:
d
[
'label'
].
item
().
decode
(
"utf8"
)
})
assert
len
(
buffer
)
==
3
def
test_clue_iflytek
():
"""
Test IFLYTEK for train, test and evaluation
"""
TRAIN_FILE
=
'../data/dataset/testCLUE/iflytek/train.json'
TEST_FILE
=
'../data/dataset/testCLUE/iflytek/test.json'
EVAL_FILE
=
'../data/dataset/testCLUE/iflytek/dev.json'
# train
buffer
=
[]
data
=
ds
.
CLUEDataset
(
TRAIN_FILE
,
task
=
'IFLYTEK'
,
usage
=
'train'
,
shuffle
=
False
)
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
({
'label'
:
d
[
'label'
].
item
().
decode
(
"utf8"
),
'label_des'
:
d
[
'label_des'
].
item
().
decode
(
"utf8"
),
'sentence'
:
d
[
'sentence'
].
item
().
decode
(
"utf8"
),
})
assert
len
(
buffer
)
==
3
# test
buffer
=
[]
data
=
ds
.
CLUEDataset
(
TEST_FILE
,
task
=
'IFLYTEK'
,
usage
=
'test'
,
shuffle
=
False
)
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
({
'id'
:
d
[
'id'
],
'sentence'
:
d
[
'sentence'
].
item
().
decode
(
"utf8"
)
})
assert
len
(
buffer
)
==
3
# eval
buffer
=
[]
data
=
ds
.
CLUEDataset
(
EVAL_FILE
,
task
=
'IFLYTEK'
,
usage
=
'eval'
,
shuffle
=
False
)
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
({
'label'
:
d
[
'label'
].
item
().
decode
(
"utf8"
),
'label_des'
:
d
[
'label_des'
].
item
().
decode
(
"utf8"
),
'sentence'
:
d
[
'sentence'
].
item
().
decode
(
"utf8"
)
})
assert
len
(
buffer
)
==
3
def
test_clue_tnews
():
"""
Test TNEWS for train, test and evaluation
"""
TRAIN_FILE
=
'../data/dataset/testCLUE/tnews/train.json'
TEST_FILE
=
'../data/dataset/testCLUE/tnews/test.json'
EVAL_FILE
=
'../data/dataset/testCLUE/tnews/dev.json'
# train
buffer
=
[]
data
=
ds
.
CLUEDataset
(
TRAIN_FILE
,
task
=
'TNEWS'
,
usage
=
'train'
,
shuffle
=
False
)
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
({
'label'
:
d
[
'label'
].
item
().
decode
(
"utf8"
),
'label_desc'
:
d
[
'label_desc'
].
item
().
decode
(
"utf8"
),
'sentence'
:
d
[
'sentence'
].
item
().
decode
(
"utf8"
),
'keywords'
:
d
[
'keywords'
].
item
().
decode
(
"utf8"
)
if
d
[
'keywords'
].
size
>
0
else
d
[
'keywords'
]
})
assert
len
(
buffer
)
==
3
# test
buffer
=
[]
data
=
ds
.
CLUEDataset
(
TEST_FILE
,
task
=
'TNEWS'
,
usage
=
'test'
,
shuffle
=
False
)
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
({
'id'
:
d
[
'id'
],
'sentence'
:
d
[
'sentence'
].
item
().
decode
(
"utf8"
),
'keywords'
:
d
[
'keywords'
].
item
().
decode
(
"utf8"
)
if
d
[
'keywords'
].
size
>
0
else
d
[
'keywords'
]
})
assert
len
(
buffer
)
==
3
# eval
buffer
=
[]
data
=
ds
.
CLUEDataset
(
EVAL_FILE
,
task
=
'TNEWS'
,
usage
=
'eval'
,
shuffle
=
False
)
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
({
'label'
:
d
[
'label'
].
item
().
decode
(
"utf8"
),
'label_desc'
:
d
[
'label_desc'
].
item
().
decode
(
"utf8"
),
'sentence'
:
d
[
'sentence'
].
item
().
decode
(
"utf8"
),
'keywords'
:
d
[
'keywords'
].
item
().
decode
(
"utf8"
)
if
d
[
'keywords'
].
size
>
0
else
d
[
'keywords'
]
})
assert
len
(
buffer
)
==
3
def
test_clue_wsc
():
"""
Test WSC for train, test and evaluation
"""
TRAIN_FILE
=
'../data/dataset/testCLUE/wsc/train.json'
TEST_FILE
=
'../data/dataset/testCLUE/wsc/test.json'
EVAL_FILE
=
'../data/dataset/testCLUE/wsc/dev.json'
# train
buffer
=
[]
data
=
ds
.
CLUEDataset
(
TRAIN_FILE
,
task
=
'WSC'
,
usage
=
'train'
)
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
({
'span1_index'
:
d
[
'span1_index'
],
'span2_index'
:
d
[
'span2_index'
],
'span1_text'
:
d
[
'span1_text'
].
item
().
decode
(
"utf8"
),
'span2_text'
:
d
[
'span2_text'
].
item
().
decode
(
"utf8"
),
'idx'
:
d
[
'idx'
],
'label'
:
d
[
'label'
].
item
().
decode
(
"utf8"
),
'text'
:
d
[
'text'
].
item
().
decode
(
"utf8"
)
})
assert
len
(
buffer
)
==
3
# test
buffer
=
[]
data
=
ds
.
CLUEDataset
(
TEST_FILE
,
task
=
'WSC'
,
usage
=
'test'
)
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
({
'span1_index'
:
d
[
'span1_index'
],
'span2_index'
:
d
[
'span2_index'
],
'span1_text'
:
d
[
'span1_text'
].
item
().
decode
(
"utf8"
),
'span2_text'
:
d
[
'span2_text'
].
item
().
decode
(
"utf8"
),
'idx'
:
d
[
'idx'
],
'text'
:
d
[
'text'
].
item
().
decode
(
"utf8"
)
})
assert
len
(
buffer
)
==
3
# eval
buffer
=
[]
data
=
ds
.
CLUEDataset
(
EVAL_FILE
,
task
=
'WSC'
,
usage
=
'eval'
)
for
d
in
data
.
create_dict_iterator
():
buffer
.
append
({
'span1_index'
:
d
[
'span1_index'
],
'span2_index'
:
d
[
'span2_index'
],
'span1_text'
:
d
[
'span1_text'
].
item
().
decode
(
"utf8"
),
'span2_text'
:
d
[
'span2_text'
].
item
().
decode
(
"utf8"
),
'idx'
:
d
[
'idx'
],
'label'
:
d
[
'label'
].
item
().
decode
(
"utf8"
),
'text'
:
d
[
'text'
].
item
().
decode
(
"utf8"
)
})
assert
len
(
buffer
)
==
3
if
__name__
==
"__main__"
:
test_clue
()
test_clue_afqmc
()
test_clue_cmnli
()
test_clue_csl
()
test_clue_iflytek
()
test_clue_tnews
()
test_clue_wsc
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录