Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8e4c0a9d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8e4c0a9d
编写于
7月 20, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 20, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3212 GetDatasize feature
Merge pull request !3212 from anzhengqi/epochs-ready
上级
bae2f964
008b91b2
变更
94
展开全部
隐藏空白更改
内联
并排
Showing
94 changed file
with
5260 addition
and
397 deletion
+5260
-397
mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc
mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc
+30
-5
mindspore/ccsrc/minddata/dataset/api/de_pipeline.h
mindspore/ccsrc/minddata/dataset/api/de_pipeline.h
+9
-2
mindspore/ccsrc/minddata/dataset/api/python_bindings.cc
mindspore/ccsrc/minddata/dataset/api/python_bindings.cc
+4
-2
mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc
mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc
+20
-27
mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.h
mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.h
+3
-0
mindspore/ccsrc/minddata/dataset/engine/datasetops/CMakeLists.txt
...e/ccsrc/minddata/dataset/engine/datasetops/CMakeLists.txt
+1
-0
mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.cc
...csrc/minddata/dataset/engine/datasetops/build_vocab_op.cc
+26
-0
mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.h
...ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.h
+21
-0
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc
...csrc/minddata/dataset/engine/datasetops/cache_merge_op.cc
+15
-1
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h
...ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h
+5
-0
mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc
...re/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc
+16
-0
mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h
...ore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h
+9
-0
mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc
...src/minddata/dataset/engine/datasetops/device_queue_op.cc
+48
-37
mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h
...csrc/minddata/dataset/engine/datasetops/device_queue_op.h
+13
-7
mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc
...ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc
+130
-0
mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.h
.../ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.h
+82
-0
mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc
...ore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc
+3
-1
mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h
...pore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h
+4
-4
mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc
mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc
+2
-1
mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc
mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc
+12
-3
mindspore/ccsrc/minddata/dataset/engine/execution_tree.h
mindspore/ccsrc/minddata/dataset/engine/execution_tree.h
+7
-1
mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt
mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt
+1
-0
mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc
mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc
+17
-0
mindspore/ccsrc/minddata/dataset/engine/opt/pass.h
mindspore/ccsrc/minddata/dataset/engine/opt/pass.h
+10
-0
mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc
...ore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc
+71
-17
mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h
...pore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h
+19
-5
mindspore/ccsrc/minddata/dataset/engine/opt/pre/injection_pass.cc
...e/ccsrc/minddata/dataset/engine/opt/pre/injection_pass.cc
+82
-0
mindspore/ccsrc/minddata/dataset/engine/opt/pre/injection_pass.h
...re/ccsrc/minddata/dataset/engine/opt/pre/injection_pass.h
+75
-0
mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc
mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc
+15
-8
mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.h
mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.h
+2
-1
mindspore/ccsrc/pipeline/jit/pipeline.cc
mindspore/ccsrc/pipeline/jit/pipeline.cc
+3
-0
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+24
-37
mindspore/dataset/engine/iterators.py
mindspore/dataset/engine/iterators.py
+29
-7
mindspore/train/_utils.py
mindspore/train/_utils.py
+5
-3
mindspore/train/dataset_helper.py
mindspore/train/dataset_helper.py
+87
-72
mindspore/train/model.py
mindspore/train/model.py
+27
-16
model_zoo/alexnet/train.py
model_zoo/alexnet/train.py
+1
-1
model_zoo/deepfm/train.py
model_zoo/deepfm/train.py
+2
-2
model_zoo/deeplabv3/train.py
model_zoo/deeplabv3/train.py
+1
-1
model_zoo/faster_rcnn/train.py
model_zoo/faster_rcnn/train.py
+1
-1
model_zoo/googlenet/train.py
model_zoo/googlenet/train.py
+1
-1
model_zoo/lenet/train.py
model_zoo/lenet/train.py
+1
-2
model_zoo/lenet_quant/train.py
model_zoo/lenet_quant/train.py
+1
-1
model_zoo/lstm/train.py
model_zoo/lstm/train.py
+1
-1
model_zoo/mass/train.py
model_zoo/mass/train.py
+6
-6
model_zoo/mobilenetv2/train.py
model_zoo/mobilenetv2/train.py
+2
-2
model_zoo/mobilenetv2_quant/train.py
model_zoo/mobilenetv2_quant/train.py
+1
-1
model_zoo/mobilenetv3/train.py
model_zoo/mobilenetv3/train.py
+2
-2
model_zoo/official/nlp/bert/run_classifier.py
model_zoo/official/nlp/bert/run_classifier.py
+4
-5
model_zoo/official/nlp/bert/run_ner.py
model_zoo/official/nlp/bert/run_ner.py
+4
-5
model_zoo/official/nlp/bert/run_pretrain.py
model_zoo/official/nlp/bert/run_pretrain.py
+5
-4
model_zoo/official/nlp/bert/run_squad.py
model_zoo/official/nlp/bert/run_squad.py
+4
-5
model_zoo/official/nlp/bert/src/dataset.py
model_zoo/official/nlp/bert/src/dataset.py
+0
-1
model_zoo/official/nlp/transformer/src/dataset.py
model_zoo/official/nlp/transformer/src/dataset.py
+1
-5
model_zoo/official/nlp/transformer/train.py
model_zoo/official/nlp/transformer/train.py
+5
-5
model_zoo/resnet/train.py
model_zoo/resnet/train.py
+2
-2
model_zoo/resnet_thor/train.py
model_zoo/resnet_thor/train.py
+1
-1
model_zoo/ssd/train.py
model_zoo/ssd/train.py
+1
-1
model_zoo/vgg16/train.py
model_zoo/vgg16/train.py
+1
-1
model_zoo/wide_and_deep/train.py
model_zoo/wide_and_deep/train.py
+1
-1
model_zoo/wide_and_deep/train_and_eval.py
model_zoo/wide_and_deep/train_and_eval.py
+2
-2
model_zoo/wide_and_deep/train_and_eval_auto_parallel.py
model_zoo/wide_and_deep/train_and_eval_auto_parallel.py
+4
-4
model_zoo/wide_and_deep/train_and_eval_distribute.py
model_zoo/wide_and_deep/train_and_eval_distribute.py
+2
-2
model_zoo/yolov3_resnet18/train.py
model_zoo/yolov3_resnet18/train.py
+1
-1
tests/dataset_mock.py
tests/dataset_mock.py
+8
-1
tests/st/model_zoo_tests/deeplabv3/train_one_epoch_with_loss.py
...st/model_zoo_tests/deeplabv3/train_one_epoch_with_loss.py
+1
-1
tests/st/model_zoo_tests/transformer/test_transformer.py
tests/st/model_zoo_tests/transformer/test_transformer.py
+5
-5
tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/train_and_test_multinpu_ci.py
...and_deep/python_file_for_ci/train_and_test_multinpu_ci.py
+2
-2
tests/st/model_zoo_tests/wide_and_deep/train_and_test_multinpu_ci_data_parallel.py
...wide_and_deep/train_and_test_multinpu_ci_data_parallel.py
+2
-2
tests/st/model_zoo_tests/yolov3/test_yolov3.py
tests/st/model_zoo_tests/yolov3/test_yolov3.py
+3
-3
tests/st/networks/models/bert/test_bert_tdt_lossscale.py
tests/st/networks/models/bert/test_bert_tdt_lossscale.py
+7
-7
tests/st/networks/models/deeplabv3/test_deeplabv3.py
tests/st/networks/models/deeplabv3/test_deeplabv3.py
+1
-1
tests/st/networks/models/resnet50/test_resnet50_imagenet.py
tests/st/networks/models/resnet50/test_resnet50_imagenet.py
+4
-4
tests/st/tbe_networks/resnet_cifar.py
tests/st/tbe_networks/resnet_cifar.py
+1
-1
tests/st/tbe_networks/test_resnet_cifar_1p.py
tests/st/tbe_networks/test_resnet_cifar_1p.py
+1
-1
tests/st/tbe_networks/test_resnet_cifar_8p.py
tests/st/tbe_networks/test_resnet_cifar_8p.py
+1
-1
tests/ut/cpp/dataset/CMakeLists.txt
tests/ut/cpp/dataset/CMakeLists.txt
+2
-1
tests/ut/cpp/dataset/cache_op_test.cc
tests/ut/cpp/dataset/cache_op_test.cc
+15
-19
tests/ut/cpp/dataset/epoch_ctrl_op_test.cc
tests/ut/cpp/dataset/epoch_ctrl_op_test.cc
+639
-0
tests/ut/cpp/dataset/repeat_op_test.cc
tests/ut/cpp/dataset/repeat_op_test.cc
+2
-1
tests/ut/python/dataset/test_cache_map.py
tests/ut/python/dataset/test_cache_map.py
+6
-0
tests/ut/python/dataset/test_datasets_tfrecord.py
tests/ut/python/dataset/test_datasets_tfrecord.py
+2
-2
tests/ut/python/dataset/test_deviceop_cpu.py
tests/ut/python/dataset/test_deviceop_cpu.py
+13
-1
tests/ut/python/dataset/test_epoch_ctrl.py
tests/ut/python/dataset/test_epoch_ctrl.py
+608
-0
tests/ut/python/dataset/test_five_crop.py
tests/ut/python/dataset/test_five_crop.py
+1
-1
tests/ut/python/dataset/test_get_size.py
tests/ut/python/dataset/test_get_size.py
+8
-8
tests/ut/python/dataset/test_iterator.py
tests/ut/python/dataset/test_iterator.py
+1
-1
tests/ut/python/dataset/test_repeat.py
tests/ut/python/dataset/test_repeat.py
+45
-0
tests/ut/python/dataset/test_zip.py
tests/ut/python/dataset/test_zip.py
+11
-11
tests/ut/python/log
tests/ut/python/log
+2770
-0
tests/ut/python/parallel/test_auto_parallel_resnet.py
tests/ut/python/parallel/test_auto_parallel_resnet.py
+3
-0
tests/ut/python/parallel/test_bias_add.py
tests/ut/python/parallel/test_bias_add.py
+3
-0
tests/ut/python/parallel/test_gather_v2_primitive.py
tests/ut/python/parallel/test_gather_v2_primitive.py
+3
-0
tests/ut/python/train/test_dataset_helper.py
tests/ut/python/train/test_dataset_helper.py
+107
-0
未找到文件。
mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc
浏览文件 @
8e4c0a9d
...
...
@@ -25,6 +25,8 @@
#include "minddata/dataset/engine/dataset_iterator.h"
#include "minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/datasetops/filter_op.h"
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
...
...
@@ -84,7 +86,8 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {
{
kRandomData
,
&
DEPipeline
::
ParseRandomDataOp
},
{
kTextFile
,
&
DEPipeline
::
ParseTextFileOp
},
{
kBuildVocab
,
&
DEPipeline
::
ParseBuildVocabOp
},
{
kClue
,
&
DEPipeline
::
ParseClueOp
}};
{
kClue
,
&
DEPipeline
::
ParseClueOp
},
{
kEpochCtrl
,
&
DEPipeline
::
ParseEpochCtrlOp
}};
DEPipeline
::
DEPipeline
()
:
iterator_
(
nullptr
)
{
try
{
...
...
@@ -166,8 +169,8 @@ Status DEPipeline::AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &
Status
DEPipeline
::
AssignRootNode
(
const
DsOpPtr
&
dataset_op
)
{
return
(
tree_
->
AssignRoot
(
dataset_op
));
}
// Function to launch the tree execution.
Status
DEPipeline
::
LaunchTreeExec
()
{
RETURN_IF_NOT_OK
(
tree_
->
Prepare
());
Status
DEPipeline
::
LaunchTreeExec
(
const
int32_t
num_epochs
)
{
RETURN_IF_NOT_OK
(
tree_
->
Prepare
(
num_epochs
));
RETURN_IF_NOT_OK
(
tree_
->
Launch
());
iterator_
=
std
::
make_unique
<
DatasetIterator
>
(
tree_
);
if
(
iterator_
==
nullptr
)
RETURN_STATUS_UNEXPECTED
(
"Cannot create an Iterator."
);
...
...
@@ -252,6 +255,16 @@ int DEPipeline::GetRepeatCount() const { return repeat_num_; }
float
ToFloat
(
const
py
::
handle
&
handle
)
{
return
py
::
reinterpret_borrow
<
py
::
float_
>
(
handle
);
}
Status
DEPipeline
::
StopSend
()
{
// tree_.root() must be DeviceQueueOp
DeviceQueueOp
*
op
=
dynamic_cast
<
DeviceQueueOp
*>
(
tree_
->
root
().
get
());
if
(
op
==
nullptr
)
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"StopSend only supported by DeviceQueueOp"
);
}
op
->
StopSend
();
return
Status
::
OK
();
}
int
ToInt
(
const
py
::
handle
&
handle
)
{
return
py
::
reinterpret_borrow
<
py
::
int_
>
(
handle
);
}
bool
ToBool
(
const
py
::
handle
&
handle
)
{
return
py
::
reinterpret_borrow
<
py
::
bool_
>
(
handle
);
}
...
...
@@ -804,6 +817,18 @@ Status DEPipeline::ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp>
return
Status
::
OK
();
}
Status
DEPipeline
::
ParseEpochCtrlOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
)
{
if
(
args
[
"count"
].
is_none
())
{
std
::
string
err_msg
=
"Error: count is invalid or not set."
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
std
::
shared_ptr
<
EpochCtrlOp
>
op
;
RETURN_IF_NOT_OK
(
EpochCtrlOp
::
Builder
(
ToInt
(
args
[
"count"
])).
Build
(
&
op
));
*
top
=
op
;
return
Status
::
OK
();
}
Status
DEPipeline
::
ParseGeneratorOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
)
{
std
::
shared_ptr
<
GeneratorOp
::
Builder
>
builder
=
std
::
make_shared
<
GeneratorOp
::
Builder
>
();
...
...
@@ -973,8 +998,8 @@ Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<Data
(
void
)
builder
->
SetDeviceType
(
ToString
(
value
));
}
else
if
(
key
==
"device_id"
)
{
(
void
)
builder
->
SetDeviceId
(
ToInt
(
value
));
}
else
if
(
key
==
"
num_batch
"
)
{
(
void
)
builder
->
Set
NumBatch
(
ToInt
(
value
));
}
else
if
(
key
==
"
send_epoch_end
"
)
{
(
void
)
builder
->
Set
SendEpochEnd
(
ToBool
(
value
));
}
}
}
...
...
mindspore/ccsrc/minddata/dataset/api/de_pipeline.h
浏览文件 @
8e4c0a9d
...
...
@@ -70,7 +70,8 @@ enum OpName {
kRandomData
,
kTextFile
,
kBuildVocab
,
kClue
kClue
,
kEpochCtrl
};
// The C++ binder class that we expose to the python script.
...
...
@@ -90,7 +91,7 @@ class DEPipeline {
Status
AssignRootNode
(
const
DsOpPtr
&
dataset_op
);
// Function to launch the tree execution.
Status
LaunchTreeExec
();
Status
LaunchTreeExec
(
int32_t
num_epochs
);
// Get a row of data as dictionary of column name to the value.
Status
GetNextAsMap
(
py
::
dict
*
output
);
...
...
@@ -143,6 +144,10 @@ class DEPipeline {
Status
ParseBucketBatchByLengthOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseEpochCtrlOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseBatchOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseBarrierOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
ParseGeneratorOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
...
...
@@ -189,6 +194,8 @@ class DEPipeline {
Status
ParseBuildVocabOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
Status
StopSend
();
Status
ParseClueOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
top
,
std
::
shared_ptr
<
DatasetOp
>
*
bottom
);
private:
...
...
mindspore/ccsrc/minddata/dataset/api/python_bindings.cc
浏览文件 @
8e4c0a9d
...
...
@@ -159,7 +159,7 @@ void bindDEPipeline(py::module *m) {
[](
DEPipeline
&
de
,
const
DsOpPtr
&
dataset_op
)
{
THROW_IF_ERROR
(
de
.
AssignRootNode
(
dataset_op
));
})
.
def
(
"SetBatchParameters"
,
[](
DEPipeline
&
de
,
const
py
::
dict
&
args
)
{
THROW_IF_ERROR
(
de
.
SetBatchParameters
(
args
));
})
.
def
(
"LaunchTreeExec"
,
[](
DEPipeline
&
de
)
{
THROW_IF_ERROR
(
de
.
LaunchTreeExec
(
));
})
.
def
(
"LaunchTreeExec"
,
[](
DEPipeline
&
de
,
int32_t
num_epochs
)
{
THROW_IF_ERROR
(
de
.
LaunchTreeExec
(
num_epochs
));
})
.
def
(
"GetNextAsMap"
,
[](
DEPipeline
&
de
)
{
py
::
dict
out
;
...
...
@@ -188,6 +188,7 @@ void bindDEPipeline(py::module *m) {
.
def
(
"GetBatchSize"
,
&
DEPipeline
::
GetBatchSize
)
.
def
(
"GetNumClasses"
,
&
DEPipeline
::
GetNumClasses
)
.
def
(
"GetRepeatCount"
,
&
DEPipeline
::
GetRepeatCount
)
.
def
(
"StopSend"
,
[](
DEPipeline
&
de
)
{
THROW_IF_ERROR
(
de
.
StopSend
());
})
.
def
(
"SaveDataset"
,
[](
DEPipeline
&
de
,
const
std
::
vector
<
std
::
string
>
&
file_names
,
const
std
::
string
&
file_type
)
{
THROW_IF_ERROR
(
de
.
SaveDataset
(
file_names
,
file_type
));
return
true
;
...
...
@@ -999,7 +1000,8 @@ PYBIND11_MODULE(_c_dataengine, m) {
.
value
(
"BUILDVOCAB"
,
OpName
::
kBuildVocab
)
.
value
(
"CELEBA"
,
OpName
::
kCelebA
)
.
value
(
"TEXTFILE"
,
OpName
::
kTextFile
)
.
value
(
"CLUE"
,
OpName
::
kClue
);
.
value
(
"CLUE"
,
OpName
::
kClue
)
.
value
(
"EPOCHCTRL"
,
OpName
::
kEpochCtrl
);
(
void
)
py
::
enum_
<
JiebaMode
>
(
m
,
"JiebaMode"
,
py
::
arithmetic
())
.
value
(
"DE_JIEBA_MIX"
,
JiebaMode
::
kMix
)
...
...
mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc
浏览文件 @
8e4c0a9d
...
...
@@ -40,7 +40,9 @@ Status IteratorBase::GetNextAsMap(TensorMap *out_map) {
out_map
->
clear
();
TensorRow
curr_row
;
MS_LOG
(
INFO
)
<<
"get next as map start."
;
RETURN_IF_NOT_OK
(
FetchNextTensorRow
(
&
curr_row
));
MS_LOG
(
INFO
)
<<
"fetchNextTensor success."
;
// Return empty map if there's no data
if
(
curr_row
.
empty
())
{
...
...
@@ -105,7 +107,8 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
// Once eof is handled, always return empty row. Class must be destroyed and recreated if you
// want to iterate again.
if
(
eof_handled_
)
{
return
Status
::
OK
();
std
::
string
err
=
"EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
;
RETURN_STATUS_UNEXPECTED
(
err
);
}
// Check if we need to get a new DataBuffer to iterate.
...
...
@@ -119,36 +122,22 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
// Since GetNextBuffer was used rather than GetNextInput(), it means we need to manually
// handle eoe and eof messages here.
//
// An eoe buffer means we have iterated fully to the end of the tree.
// An eoe buffer will be immediately followed by an eof buffer, which signals the shutdown of
// all operators.
// An eoe buffer means we have iterated an epoch.
// The next buffer in the pipeline might be an EOF or a databuffer for next epoch
if
(
curr_buffer_
->
eoe
())
{
MS_LOG
(
DEBUG
)
<<
"End of data iteration. Fetch eof and then return empty row."
;
// Before returning the last empty vector, fetch the eof buffer which should be the last
// buffer, and then free it.
RETURN_IF_NOT_OK
(
root_
->
GetNextBuffer
(
&
curr_buffer_
));
if
(
!
curr_buffer_
->
eof
())
{
RETURN_STATUS_UNEXPECTED
(
"Non-eof after getting eoe in iterator!"
);
}
eof_handled_
=
true
;
curr_buffer_
.
reset
();
// explicitly free the eof buffer
// Set tree to Finished state
root_
->
Tree
()
->
SetFinished
();
MS_LOG
(
INFO
)
<<
"End of data iteration."
;
curr_buffer_
.
reset
();
// explicitly free the eoe buffer
return
Status
::
OK
();
}
// An eof buffer means it is the end of execution and all operators are shutting down.
// Because there is no more data to return to the caller, this will change `eof_handled_` state and
// returns status unexpected error.
if
(
curr_buffer_
->
eof
())
{
// An eof by itself, without being preceded by an eoe, is possible if a repeat operator
// exists below us in the stack. Repeat operator eats eoe's but eventually allows the
// flow of an eof up the pipeline by itself.
eof_handled_
=
true
;
curr_buffer_
.
reset
();
// explicitly free the eof buffer
// Set tree to Finished state
root_
->
Tree
()
->
SetFinished
();
return
Status
::
OK
();
std
::
string
err
=
"EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
;
RETURN_STATUS_UNEXPECTED
(
err
);
}
}
...
...
@@ -208,20 +197,24 @@ Status ChildIterator::FetchNextTensorRow(TensorRow *out_row) {
// Once eof is handled, always return empty row. Class must be destroyed and recreated if you
// want to iterate again.
if
(
eof_handled_
)
{
return
Status
::
OK
();
std
::
string
err
=
"EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
;
RETURN_STATUS_UNEXPECTED
(
err
);
}
// Check if we need to get a new DataBuffer to iterate.
if
(
curr_buffer_
==
nullptr
||
curr_buffer_
->
NumRows
()
==
0
)
{
// GetNextInput() depends on current_op's EoeReceived. So, EOE buffer might be already be handled and
// this child iterator might not see EOE buffer.
RETURN_IF_NOT_OK
(
current_op_
->
GetNextInput
(
&
curr_buffer_
,
worker_id_
,
child_idx_
));
// Unlike the DatasetIterator, this child iterator does not quit after eoe.
// Instead, if an eoe is picked up here, we simply return an empty vector and it's up to the
// If an eoe is picked up here, we simply return an empty vector and it's up to the
// caller to decide what it wants to do next.
if
(
curr_buffer_
->
eoe
())
{
MS_LOG
(
DEBUG
)
<<
"Child iterator picked up EOE."
;
end_epoch_
=
true
;
return
Status
::
OK
();
}
else
{
end_epoch_
=
false
;
}
if
(
curr_buffer_
->
eof
())
{
...
...
mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.h
浏览文件 @
8e4c0a9d
...
...
@@ -144,6 +144,9 @@ class ChildIterator : public IteratorBase {
// @return The string to column id mapping.
std
::
unordered_map
<
std
::
string
,
int32_t
>
GetColumnNameMap
()
const
override
;
// Return T/F if end of epoch
bool
end_of_epoch
()
{
return
end_epoch_
;
}
private:
DatasetOp
*
current_op_
;
// The parent operator. We consume from it's children.
int32_t
child_idx_
;
// The specific child this iterator will fetch from.
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/CMakeLists.txt
浏览文件 @
8e4c0a9d
...
...
@@ -18,6 +18,7 @@ set(DATASET_ENGINE_DATASETOPS_SRC_FILES
shuffle_op.cc
zip_op.cc
concat_op.cc
epoch_ctrl_op.cc
cache_base_op.cc
cache_lookup_op.cc
cache_op.cc
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.cc
浏览文件 @
8e4c0a9d
...
...
@@ -17,11 +17,13 @@
#include "minddata/dataset/engine/datasetops/build_vocab_op.h"
#include <algorithm>
#include <iomanip>
#include <limits>
#include <string>
#include <unordered_map>
#include <utility>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/opt/pass.h"
namespace
mindspore
{
namespace
dataset
{
...
...
@@ -202,5 +204,29 @@ BuildVocabOp::Builder::Builder()
builder_num_workers_
=
cfg
->
num_parallel_workers
();
builder_connector_size_
=
cfg
->
op_connector_size
();
}
// A print method typically used for debugging
void
BuildVocabOp
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
// Always show the id and name as first line regardless if this summary or detailed print
out
<<
"("
<<
std
::
setw
(
2
)
<<
operator_id_
<<
") <BuildVocabOp>:"
;
if
(
!
show_all
)
{
// Call the super class for displaying any common 1-liner info
ParallelOp
::
Print
(
out
,
show_all
);
// Then show any custom derived-internal 1-liner info for this op
out
<<
"
\n
"
;
}
else
{
// Call the super class for displaying any common detailed info
ParallelOp
::
Print
(
out
,
show_all
);
// Then show any custom derived-internal stuff
out
<<
"
\n
Code is needed here to show more info about the op."
<<
"
\n\n
"
;
}
}
// Pre-Visitor accept method for NodePass
Status
BuildVocabOp
::
PreAccept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call the pre-visitation
return
p
->
PreRunOnNode
(
shared_from_base
<
BuildVocabOp
>
(),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.h
浏览文件 @
8e4c0a9d
...
...
@@ -131,6 +131,21 @@ class BuildVocabOp : public ParallelOp {
~
BuildVocabOp
()
=
default
;
/// \brief A print method typically used for debugging
/// \param[out] out The output stream to write output to
/// \param[in] show_all A bool to control if you want to show all info or just a summary
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
/// \briefStream output operator overload
/// \notes This allows you to write the debug print info using stream operators
/// \param[out] out Reference to the output stream being overloaded
/// \param[in] vop - reference to the BuildVocabOp to display
/// \return - the output stream must be returned
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
BuildVocabOp
&
vop
)
{
vop
.
Print
(
out
,
false
);
return
out
;
}
Status
WorkerEntry
(
int32_t
worker_id
)
override
;
// collect the work product from each worker
...
...
@@ -152,6 +167,12 @@ class BuildVocabOp : public ParallelOp {
Status
Reset
()
override
{
RETURN_STATUS_UNEXPECTED
(
"Reset shouldn't be called in BuildVocabOp"
);
}
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status
PreAccept
(
NodePass
*
p
,
bool
*
modified
)
override
;
private:
const
int32_t
interval_
;
bool
special_first_
;
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc
浏览文件 @
8e4c0a9d
...
...
@@ -96,7 +96,7 @@ Status CacheMergeOp::WorkerEntry(int32_t worker_id) {
RETURN_IF_NOT_OK
(
cache_hit_stream
->
GetNextBuffer
(
&
db_ptr
,
worker_id
));
}
}
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
worker_id
,
std
::
move
(
db_ptr
)
));
RETURN_IF_NOT_OK
(
EofReceived
(
worker_id
));
return
Status
::
OK
();
}
Status
CacheMergeOp
::
CacheMissWorkerEntry
(
int32_t
workerId
)
{
...
...
@@ -298,5 +298,19 @@ Status CacheMergeOp::EoeReceived(int32_t worker_id) {
}
return
Status
::
OK
();
}
// Base-class override for handling cases when an eof is received.
Status
CacheMergeOp
::
EofReceived
(
int32_t
worker_id
)
{
// If we are not in a repeated path, then the merge op gets a eof by itself, without first
// getting an eoe. However, the logic demands that all epochs close with an eoe first before eof.
// Thus, generate an eoe first, before flowing up the eof in the non-repeated case. Base class
// provides that for us.
if
(
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
))
{
MS_LOG
(
DEBUG
)
<<
"Cache merge sending eoe"
;
RETURN_IF_NOT_OK
(
DatasetOp
::
EoeReceived
(
worker_id
));
}
MS_LOG
(
DEBUG
)
<<
"Cache merge sending eof"
;
return
DatasetOp
::
EofReceived
(
worker_id
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h
浏览文件 @
8e4c0a9d
...
...
@@ -176,6 +176,11 @@ class CacheMergeOp : public ParallelOp {
/// \return Status object
Status
EoeReceived
(
int32_t
worker_id
)
override
;
/// \brief Base-class override for handling cases when an eof is received.
/// \param worker_id - The worker id
/// \return Status - The error code return
Status
EofReceived
(
int32_t
worker_id
)
override
;
protected:
Status
ComputeColMap
()
override
;
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc
浏览文件 @
8e4c0a9d
...
...
@@ -26,6 +26,7 @@
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/opt/pass.h"
...
...
@@ -102,6 +103,15 @@ Status DatasetOp::InsertAsParent(std::shared_ptr<DatasetOp> to_add) {
}
return
Status
::
OK
();
}
// Removes child operator in this operator.
Status
DatasetOp
::
RemoveChildren
()
{
for
(
const
auto
&
child
:
child_
)
{
child
->
RemoveParent
(
this
);
}
child_
.
clear
();
return
Status
::
OK
();
}
// Adds a parent operator to this operator
void
DatasetOp
::
AddParent
(
DatasetOp
*
parent
)
{
parent_
.
push_back
(
parent
);
}
...
...
@@ -185,6 +195,12 @@ void DatasetOp::Parent(DatasetOp **parent, int32_t parent_index) const {
}
}
// Getter function to get all of our children.
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
DatasetOp
::
children
()
const
{
return
child_
;
}
// Getter function to get all of our parents.
std
::
vector
<
DatasetOp
*>
DatasetOp
::
parents
()
const
{
return
parent_
;
}
// Creates the connector within this operator
void
DatasetOp
::
CreateConnector
(
int32_t
num_producers
,
int32_t
num_consumers
)
{
MS_LOG
(
DEBUG
)
<<
"Creating connector in tree operator: "
<<
operator_id_
<<
". Producer: "
<<
num_producers
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h
浏览文件 @
8e4c0a9d
...
...
@@ -76,6 +76,9 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return Status eerror code returned
Status
Remove
();
// Removes child operator in this operator.
Status
RemoveChildren
();
/// \brief Getter function to get a shared pointer to our child
/// \param[in] child_index An operator can have n children. Indicates which child to return.
/// \return The shared pointer to the child. If there are no children, it returns null regardless of the given index
...
...
@@ -86,6 +89,12 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \param[in] parent_index An operator can have n parents. Indicates which parent to return.
void
Parent
(
DatasetOp
**
parent
,
int32_t
parent_index
)
const
;
// Getter function to get all of our children.
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
children
()
const
;
// Getter function to get all of our parents.
std
::
vector
<
DatasetOp
*>
parents
()
const
;
// Inserts a operator as the parent current op.
// Inserted op will become the sole parent of the current op.
// The existing parent of the current op will be transferred to the inserted op.
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc
浏览文件 @
8e4c0a9d
...
...
@@ -25,19 +25,21 @@
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/perf/profiling.h"
#include "minddata/dataset/engine/perf/device_queue_tracing.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/task_manager.h"
namespace
mindspore
{
namespace
dataset
{
DeviceQueueOp
::
DeviceQueueOp
(
std
::
string
channel_name
,
DeviceType
device_type
,
int32_t
device_id
,
int32_t
prefetch_size
,
int32_t
op_connector_size
,
int64_t
num_batch
)
int32_t
op_connector_size
,
bool
send_epoch_end
)
:
PipelineOp
(
op_connector_size
),
channel_name_
(
channel_name
),
device_type_
(
device_type
),
device_id_
(
device_id
),
prefetch_size_
(
prefetch_size
),
num_batch_
(
num_batch
)
{}
send_epoch_end_
(
send_epoch_end
),
stop_send_
(
false
)
{}
DeviceQueueOp
::~
DeviceQueueOp
()
{}
...
...
@@ -53,8 +55,7 @@ DeviceQueueOp::Builder::Builder(int32_t prefetch_size)
:
builder_prefetch_size_
(
prefetch_size
),
builder_device_id_
(
0
),
builder_device_type_
(
DeviceType
::
CPU
),
builder_channel_name_
(
""
),
builder_num_batch_
(
0
)
{
builder_channel_name_
(
""
)
{
std
::
shared_ptr
<
ConfigManager
>
cfg
=
GlobalContext
::
config_manager
();
builder_op_connector_size_
=
cfg
->
op_connector_size
();
}
...
...
@@ -64,6 +65,18 @@ Status DeviceQueueOp::EoeReceived(int32_t worker_id) {
return
Status
::
OK
();
}
Status
DeviceQueueOp
::
CheckExceptions
(
const
std
::
unique_ptr
<
DataBuffer
>
&
buffer
)
const
{
// this method checks if the buffer meets the conditions to be sent to TDT
if
(
buffer
->
NumRows
()
!=
0
)
{
TensorRow
row
;
buffer
->
GetRow
(
0
,
&
row
);
for
(
const
auto
&
item
:
row
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
item
->
type
().
IsNumeric
(),
"Cannot send tensor of string type to device."
);
}
}
return
Status
::
OK
();
}
Status
DeviceQueueOp
::
operator
()()
{
TaskManager
::
FindMe
()
->
Post
();
...
...
@@ -82,23 +95,10 @@ Status DeviceQueueOp::operator()() {
return
Status
::
OK
();
}
Status
DeviceQueueOp
::
CheckExceptions
(
const
std
::
unique_ptr
<
DataBuffer
>
&
buffer
)
const
{
// this method checks if the buffer meets the conditions to be sent to TDT
if
(
buffer
->
NumRows
()
!=
0
)
{
TensorRow
row
;
buffer
->
GetRow
(
0
,
&
row
);
for
(
const
auto
&
item
:
row
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
item
->
type
().
IsNumeric
(),
"Cannot send tensor of string type to device."
);
}
}
return
Status
::
OK
();
}
#ifdef ENABLE_TDTQUE
Status
DeviceQueueOp
::
SendDataToAscend
()
{
MS_LOG
(
INFO
)
<<
"Device queue, sending data to Ascend."
;
int64_t
total_batch
=
0
;
bool
is_break_loop
=
false
;
double
batch_start_time
,
end_time
;
int32_t
batch_cost
,
tdt_cost
;
int32_t
connector_size
=
0
;
...
...
@@ -115,15 +115,20 @@ Status DeviceQueueOp::SendDataToAscend() {
std
::
unique_ptr
<
DataBuffer
>
current_buffer
;
RETURN_IF_NOT_OK
(
GetNextInput
(
&
current_buffer
));
while
(
!
current_buffer
->
eof
()
&&
!
is_break_loop
)
{
while
(
!
current_buffer
->
eoe
()
&&
!
is_break_loop
)
{
while
(
!
current_buffer
->
eof
())
{
while
(
!
current_buffer
->
eoe
())
{
RETURN_IF_NOT_OK
(
CheckExceptions
(
current_buffer
));
TensorRow
currRow
;
for
(
int
row_id
=
0
;
row_id
<
current_buffer
->
NumRows
()
&&
!
is_break_loop
;
row_id
++
)
{
for
(
int
row_id
=
0
;
row_id
<
current_buffer
->
NumRows
();
row_id
++
)
{
RETURN_IF_NOT_OK
(
current_buffer
->
GetRow
(
row_id
,
&
currRow
));
auto
status
=
tdtInstancePtr
->
hostPush
(
currRow
,
true
,
channel_name_
,
isProfilingEnable
,
tdt_cost
);
if
(
status
==
TdtStatus
::
FAILED
)
{
return
Status
(
StatusCode
::
kTDTPushFailure
,
"TDT Push Failed"
);
if
(
stop_send_
)
{
MS_LOG
(
INFO
)
<<
"stop_send received"
;
return
Status
::
OK
();
}
else
{
return
Status
(
StatusCode
::
kTDTPushFailure
,
"TDT Push Failed"
);
}
}
if
(
isProfilingEnable
)
{
...
...
@@ -140,9 +145,6 @@ Status DeviceQueueOp::SendDataToAscend() {
profiling_node
->
Record
(
CONNECTOR_DEPTH
,
connector_capacity
,
total_batch
+
1
,
connector_size
);
}
total_batch
++
;
if
(
num_batch_
>
0
&&
total_batch
==
num_batch_
)
{
is_break_loop
=
true
;
}
}
if
(
isProfilingEnable
)
{
connector_size
=
ChildOpConnectorSize
();
...
...
@@ -150,6 +152,19 @@ Status DeviceQueueOp::SendDataToAscend() {
}
RETURN_IF_NOT_OK
(
GetNextInput
(
&
current_buffer
));
}
if
(
current_buffer
->
eoe
()
&&
send_epoch_end_
)
{
TensorRow
currRow
;
auto
status
=
tdtInstancePtr
->
hostPush
(
currRow
,
true
,
channel_name_
,
isProfilingEnable
,
tdt_cost
,
tdt
::
TDT_END_OF_SEQUENCE
);
if
(
status
==
TdtStatus
::
FAILED
)
{
if
(
stop_send_
)
{
MS_LOG
(
INFO
)
<<
"stop_send received"
;
return
Status
::
OK
();
}
else
{
return
Status
(
StatusCode
::
kTDTPushFailure
,
"TDT Push Failed"
);
}
}
}
if
(
isProfilingEnable
)
{
connector_size
=
ChildOpConnectorSize
();
connector_capacity
=
ChildOpConnectorCapacity
();
...
...
@@ -158,7 +173,7 @@ Status DeviceQueueOp::SendDataToAscend() {
}
tree_
->
SetFinished
();
MS_LOG
(
INFO
)
<<
"Device queue total batch is "
<<
total_batch
<<
", number of batches is "
<<
num_batch_
<<
"."
;
MS_LOG
(
INFO
)
<<
"Device queue total batch is "
<<
total_batch
;
return
Status
::
OK
();
}
...
...
@@ -196,9 +211,6 @@ Status DeviceQueueOp::SendDataToGPU() {
}
RETURN_IF_NOT_OK
(
RetryPushGPUData
(
data_size
,
curr_row
,
handle
));
total_batch
++
;
if
(
num_batch_
>
0
&&
total_batch
==
num_batch_
)
{
is_break_loop
=
true
;
}
}
if
(
!
TaskManager
::
FindMe
()
->
Interrupted
())
RETURN_IF_NOT_OK
(
GetNextInput
(
&
current_buffer
));
...
...
@@ -211,12 +223,10 @@ Status DeviceQueueOp::SendDataToGPU() {
is_break_loop
=
true
;
}
MS_LOG
(
INFO
)
<<
"Device queue total batch is "
<<
total_batch
<<
"
, number of batches is "
<<
num_batch_
<<
"
."
;
MS_LOG
(
INFO
)
<<
"Device queue total batch is "
<<
total_batch
<<
"."
;
GpuBufferMgr
::
GetInstance
().
Close
(
handle
);
GpuBufferMgr
::
GetInstance
().
CloseConfirm
();
return
Status
::
OK
();
}
...
...
@@ -240,8 +250,11 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con
if
(
ret
==
BlockQueueStatus_T
::
ERROR_INPUT
)
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"invalid input Data, please check it."
);
}
else
{
MS_LOG
(
WARNING
)
<<
"Retry pushing data..."
;
continue
;
if
(
!
stop_send_
)
{
MS_LOG
(
WARNING
)
<<
"Retry pushing data..."
;
continue
;
}
break
;
}
}
else
{
break
;
...
...
@@ -283,13 +296,11 @@ Status DeviceQueueOp::SendDataToCPU() {
MS_LOG
(
DEBUG
)
<<
"Feature size is "
<<
curr_row
[
0
]
->
SizeInBytes
()
<<
"."
;
MS_LOG
(
DEBUG
)
<<
"Label size is "
<<
curr_row
[
1
]
->
SizeInBytes
()
<<
"."
;
total_batch
++
;
if
(
num_batch_
>
0
&&
total_batch
==
num_batch_
)
{
break
;
}
if
(
stop_send_
)
break
;
}
}
MS_LOG
(
INFO
)
<<
"Device queue total batch is "
<<
total_batch
<<
"
, number of batches is "
<<
num_batch_
<<
"
."
;
MS_LOG
(
INFO
)
<<
"Device queue total batch is "
<<
total_batch
<<
"."
;
return
Status
::
OK
();
}
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h
浏览文件 @
8e4c0a9d
...
...
@@ -21,6 +21,7 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/pipeline_op.h"
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/util/status.h"
#ifdef ENABLE_TDTQUE
...
...
@@ -84,8 +85,8 @@ class DeviceQueueOp : public PipelineOp {
return
*
this
;
}
Builder
&
Set
NumBatch
(
int64_t
num_batch
)
{
builder_
num_batch_
=
num_batch
;
Builder
&
Set
SendEpochEnd
(
bool
send_epoch_end
)
{
builder_
send_epoch_end_
=
send_epoch_end
;
return
*
this
;
}
...
...
@@ -94,8 +95,9 @@ class DeviceQueueOp : public PipelineOp {
// to call this Build() method. It will instantiate the DeviceQueueOp
// and return it to caller as a shared pointer.
Status
Build
(
std
::
shared_ptr
<
DeviceQueueOp
>
*
ptr
)
{
*
ptr
=
std
::
make_shared
<
DeviceQueueOp
>
(
builder_channel_name_
,
builder_device_type_
,
builder_device_id_
,
builder_prefetch_size_
,
builder_op_connector_size_
,
builder_num_batch_
);
*
ptr
=
std
::
make_shared
<
DeviceQueueOp
>
(
builder_channel_name_
,
builder_device_type_
,
builder_device_id_
,
builder_prefetch_size_
,
builder_op_connector_size_
,
builder_send_epoch_end_
);
return
Status
::
OK
();
}
...
...
@@ -104,14 +106,14 @@ class DeviceQueueOp : public PipelineOp {
int32_t
builder_device_id_
;
DeviceType
builder_device_type_
;
std
::
string
builder_channel_name_
;
int64_t
builder_num_batch_
;
int32_t
builder_op_connector_size_
;
bool
builder_send_epoch_end_
;
};
// Name: constructor
// Description
DeviceQueueOp
(
std
::
string
channel_name
,
DeviceType
device_type
,
int32_t
device_id
,
int32_t
prefetch_size
,
int32_t
op_connector_size
,
int64_t
num_batch
);
int32_t
op_connector_size
,
bool
send_epoch_end
);
// Name: destructor
// Description
...
...
@@ -121,6 +123,8 @@ class DeviceQueueOp : public PipelineOp {
const
int32_t
get_prefetch_size
()
{
return
prefetch_size_
;
}
void
StopSend
()
{
stop_send_
=
true
;
}
// Name: Print()
// Description: A function that prints info about the node
void
Print
(
std
::
ostream
&
out
,
// In: The output stream to print to
...
...
@@ -149,6 +153,7 @@ class DeviceQueueOp : public PipelineOp {
// Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp
Status
CheckExceptions
(
const
std
::
unique_ptr
<
DataBuffer
>
&
buffer
)
const
;
private:
#ifdef ENABLE_TDTQUE
Status
SendDataToAscend
();
#endif
...
...
@@ -164,7 +169,8 @@ class DeviceQueueOp : public PipelineOp {
DeviceType
device_type_
;
const
int32_t
device_id_
;
const
int32_t
prefetch_size_
;
const
int64_t
num_batch_
;
const
bool
send_epoch_end_
;
bool
stop_send_
;
#ifdef ENABLE_TDTQUE
std
::
shared_ptr
<
TdtPlugin
>
tdtInstancePtr
;
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.cc
0 → 100644
浏览文件 @
8e4c0a9d
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iomanip>
#include <iostream>
#include <utility>
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "utils/log_adapter.h"
namespace
mindspore
{
namespace
dataset
{
// The builder "build" method creates the final object.
Status
EpochCtrlOp
::
Builder
::
Build
(
std
::
shared_ptr
<
EpochCtrlOp
>
*
ptr
)
{
RETURN_IF_NOT_OK
(
SanityCheck
());
*
ptr
=
std
::
make_shared
<
EpochCtrlOp
>
(
build_max_repeats_
);
return
Status
::
OK
();
}
// Constructor
EpochCtrlOp
::
EpochCtrlOp
(
int32_t
num_epoch
)
:
RepeatOp
(
num_epoch
)
{
MS_LOG
(
INFO
)
<<
"Welcome to Epoch Ctrl Op."
;
}
// Destructor
EpochCtrlOp
::~
EpochCtrlOp
()
{}
// A print method typically used for debugging
void
EpochCtrlOp
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
// Always show the id and name as first line regardless if this summary or detailed print
out
<<
"("
<<
std
::
setw
(
2
)
<<
operator_id_
<<
") <EpochCtrlOp>:"
;
if
(
!
show_all
)
{
// Call the super class for displaying any common 1-liner info
PipelineOp
::
Print
(
out
,
show_all
);
// Then show any custom derived-internal 1-liner info for this op
out
<<
" [epochs: "
<<
max_repeats_
<<
"]
\n
"
;
}
else
{
// Call the super class for displaying any common detailed info
PipelineOp
::
Print
(
out
,
show_all
);
// Then show any custom derived-internal stuff
out
<<
"
\n
Current epoch count: "
<<
repeat_count_
<<
"
\n
Max epoch count: "
<<
max_repeats_
<<
"
\n
Leaf Nodes in execution path:"
;
if
(
!
eoe_ops_
.
empty
())
{
for
(
size_t
i
=
0
;
i
<
eoe_ops_
.
size
();
i
++
)
{
out
<<
"
\n
Operator: "
<<
eoe_ops_
[
i
]
->
id
();
}
}
else
{
out
<<
" None."
;
}
out
<<
"
\n\n
"
;
}
}
Status
EpochCtrlOp
::
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
p_buffer
,
int32_t
worker_id
,
bool
retry_if_eoe
)
{
if
(
child_
.
empty
())
{
RETURN_STATUS_UNEXPECTED
(
"EpochCtrlOp can't be the leaf node."
);
}
std
::
unique_ptr
<
DataBuffer
>
buf
;
// `retry_if_eoe` is false because EpochCtrlOp does not eat EOE.
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
buf
,
worker_id
,
false
));
// Only intercept EOE for EoeReceived processing, after that the EOE is forwarded to next op.
// Other databuffers containing data or EOF will simply be forwarded.
// EOF can simply be forwarded because this op does not spawn any thread, thus does not require clean up.
if
(
buf
->
eoe
())
{
RETURN_IF_NOT_OK
(
EoeReceived
(
worker_id
));
}
*
p_buffer
=
std
::
move
(
buf
);
return
Status
::
OK
();
}
Status
EpochCtrlOp
::
EoeReceived
(
int32_t
worker_id
)
{
repeat_count_
++
;
MS_LOG
(
DEBUG
)
<<
"Epoch Control operator received end of epoch. Epoch count is now: "
<<
repeat_count_
<<
". Repeated: "
<<
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
<<
". Max epochs: "
<<
max_repeats_
;
// If we've reached the requested epoch count, then flag the leaf nodes
// to tell them they've got one more epoch to perform. When they reach the end
// of the last epoch, they quit rather than loop again.
if
(
max_repeats_
!=
kInfiniteRepeat
&&
repeat_count_
==
(
max_repeats_
-
1
))
{
for
(
auto
&
eoe_op
:
eoe_ops_
)
{
MS_LOG
(
DEBUG
)
<<
"EpochCtrl setting last repeat for eoe_op: "
<<
eoe_op
->
id
();
eoe_op
->
set_control_flag
(
kDeOpLastRepeat
);
}
}
// This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it.
state_
=
OpState
::
kDeOpIdle
;
if
(
repeat_count_
!=
max_repeats_
)
{
for
(
auto
&
eoe_op
:
eoe_ops_
)
{
MS_LOG
(
DEBUG
)
<<
"Epoch Control driving reset to op: "
<<
eoe_op
->
id
();
RETURN_IF_NOT_OK
(
eoe_op
->
Reset
());
}
}
return
Status
::
OK
();
}
// Pre-Visitor accept method for NodePass
Status
EpochCtrlOp
::
PreAccept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call the pre-visitation
return
p
->
PreRunOnNode
(
shared_from_base
<
EpochCtrlOp
>
(),
modified
);
}
// Visitor accept method for NodePass
Status
EpochCtrlOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call the pre-visitation
return
p
->
RunOnNode
(
shared_from_base
<
EpochCtrlOp
>
(),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/datasetops/epoch_ctrl_op.h
0 → 100644
浏览文件 @
8e4c0a9d
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_EPOCH_CTRL_OP_H_
#define DATASET_ENGINE_DATASETOPS_EPOCH_CTRL_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/engine/datasetops/pipeline_op.h"
namespace
mindspore
{
namespace
dataset
{
class
EpochCtrlOp
:
public
RepeatOp
{
public:
class
Builder
:
public
RepeatOp
::
Builder
{
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @param count - The number of repeats to do
// @return This is a constructor.
explicit
Builder
(
int32_t
count
)
:
RepeatOp
::
Builder
(
count
)
{}
// Default destructor
~
Builder
()
=
default
;
// The builder "build" method creates the final object.
// @return shared_ptr to the new EpochCtrlOp object
Status
Build
(
std
::
shared_ptr
<
EpochCtrlOp
>
*
);
};
// Contructor
explicit
EpochCtrlOp
(
int32_t
num_epoch
);
// Destructor
~
EpochCtrlOp
();
// A print method typically used for debugging
// @param out - The output stream to write output to
// @param show_all - A bool to control if you want to show all info or just a summary
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
// This function returns the buffer that is at the top of our output connector. The caller is
// typically our parent node, when the parent is asking us to provide the next buffer of data.
// Since EpochCtrlOp is derived from RepeatOp which is an inlined op, getting a buffer from us
// will simply bounce you to get a buffer from our child.
// Epoch Control Op does not eat the EOE, it will pass the EOE to the next op.
Status
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
p_buffer
,
int32_t
worker_id
,
bool
retry_if_eoe
)
override
;
// Base-class override for handling cases when an eoe is received.
// @param worker_id - The worker id
Status
EoeReceived
(
int32_t
worker_id
)
override
;
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status
PreAccept
(
NodePass
*
p
,
bool
*
modified
)
override
;
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_EPOCH_CTRL_OP_H_
mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc
浏览文件 @
8e4c0a9d
...
...
@@ -132,6 +132,7 @@ Status RepeatOp::EoeReceived(int32_t worker_id) {
// Invoke a reset against the eoe nodes only.
for
(
auto
&
eoe_op
:
eoe_ops_
)
{
MS_LOG
(
DEBUG
)
<<
"Repeat operator sending reset to operator: "
<<
eoe_op
->
id
();
RETURN_IF_NOT_OK
(
eoe_op
->
Reset
());
}
...
...
@@ -167,8 +168,9 @@ int32_t RepeatOp::num_consumers() const {
Status
RepeatOp
::
Reset
()
{
// If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op.
// In that case, we now have to bounce the reset down to our own eoe ops.
MS_LOG
(
DEBUG
)
<<
"Repeat operator
("
<<
operator_id_
<<
")
reset."
;
MS_LOG
(
DEBUG
)
<<
"Repeat operator
"
<<
operator_id_
<<
" got
reset."
;
for
(
auto
&
eoe_op
:
eoe_ops_
)
{
MS_LOG
(
DEBUG
)
<<
"Nested repeat operator bouncing a reset to operator: "
<<
eoe_op
->
id
();
RETURN_IF_NOT_OK
(
eoe_op
->
Reset
());
}
state_
=
OpState
::
kDeOpRunning
;
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h
浏览文件 @
8e4c0a9d
...
...
@@ -46,7 +46,7 @@ class RepeatOp : public PipelineOp {
// @return shared_ptr to the new RepeatOp object
Status
Build
(
std
::
shared_ptr
<
RepeatOp
>
*
);
pr
ivate
:
pr
otected
:
int32_t
build_max_repeats_
;
Status
SanityCheck
()
const
;
...
...
@@ -131,11 +131,11 @@ class RepeatOp : public PipelineOp {
// @return Name of the current Op
std
::
string
Name
()
const
override
{
return
"RepeatOp"
;
}
//
/
\brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
//
/
\param[in] eoe_op The input leaf/eoe operator to add to the list
// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
// \param[in] eoe_op The input leaf/eoe operator to add to the list
void
AddToEoeList
(
std
::
shared_ptr
<
DatasetOp
>
eoe_op
)
{
eoe_ops_
.
push_back
(
std
::
move
(
eoe_op
));
}
pr
ivate
:
pr
otected
:
int32_t
max_repeats_
;
// The number of repeats that the user requested
int32_t
repeat_count_
;
// A counter for the current number of executed repeats
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
eoe_ops_
;
// List of operators that can generate EOE underneath this repeat.
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc
浏览文件 @
8e4c0a9d
...
...
@@ -132,8 +132,9 @@ Status ZipOp::prepare(TensorQTable *const table) {
if
(
eof_
)
{
return
Status
::
OK
();
}
// One of our child iterators encounter EOE. Returns and proceed with draining phase.
if
(
new_row
.
empty
())
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"ZipOp prepare phase got empty row!"
);
return
Status
::
OK
(
);
}
// Pack this first row into our tensor table
...
...
mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc
浏览文件 @
8e4c0a9d
...
...
@@ -23,6 +23,7 @@
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
#include "minddata/dataset/engine/opt/post/repeat_pass.h"
#include "minddata/dataset/engine/opt/pre/injection_pass.h"
#include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
#include "minddata/dataset/engine/perf/profiling.h"
#include "minddata/dataset/engine/perf/monitor.h"
...
...
@@ -50,11 +51,11 @@ Status ExecutionTree::AssociateNode(const std::shared_ptr<DatasetOp> &op) {
if
(
op
->
tree_
==
this
)
{
return
Status
::
OK
();
}
if
(
tree_state_
!=
kDeTStateInit
&&
tree_state_
!=
kDeTStateBuilding
)
{
if
(
tree_state_
!=
kDeTStateInit
&&
tree_state_
!=
kDeTStateBuilding
&&
tree_state_
!=
kDeTStatePrepare
)
{
std
::
string
err_msg
=
"Invalid tree state for adding a node. Current state: "
+
std
::
to_string
(
static_cast
<
int
>
(
tree_state_
))
+
" Expected states: "
+
std
::
to_string
(
static_cast
<
int
>
(
kDeTStateInit
))
+
" or "
+
std
::
to_string
(
static_cast
<
int
>
(
kDeTStateBuilding
));
std
::
to_string
(
static_cast
<
int
>
(
kDeTStateBuilding
))
+
" or "
+
std
::
to_string
(
static_cast
<
int
>
(
kDeTStatePrepare
))
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
...
...
@@ -200,7 +201,9 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
// For example, repeatOp inlining
//
// @return Status - The error code return
Status
ExecutionTree
::
Prepare
()
{
Status
ExecutionTree
::
Prepare
(
int32_t
num_epochs
)
{
num_epochs_
=
num_epochs
;
// Pre optimization compulsory transformation
RETURN_IF_NOT_OK
(
this
->
PrepareTreePreAction
());
...
...
@@ -222,6 +225,7 @@ Status ExecutionTree::PrepareTreePreAction() {
std
::
vector
<
std
::
unique_ptr
<
Pass
>>
pre_actions
;
// Construct pre actions
MS_LOG
(
INFO
)
<<
"Running pre pass loops."
;
pre_actions
.
push_back
(
std
::
make_unique
<
InjectionPass
>
());
pre_actions
.
push_back
(
std
::
make_unique
<
RemovalPass
>
());
pre_actions
.
push_back
(
std
::
make_unique
<
CacheTransformPass
>
());
// Apply pre action passes
...
...
@@ -278,6 +282,11 @@ Status ExecutionTree::PrepareDeprecated() {
" Expected state: "
+
std
::
to_string
(
static_cast
<
int
>
(
kDeTStatePrepare
));
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
if
(
root_
==
nullptr
)
{
RETURN_STATUS_UNEXPECTED
(
"Please assign one operator as the root of this tree."
);
}
// Start the recursive prepare
RETURN_IF_NOT_OK
(
this
->
PrepareNode
(
root_
));
tree_state_
=
kDeTStateReady
;
...
...
mindspore/ccsrc/minddata/dataset/engine/execution_tree.h
浏览文件 @
8e4c0a9d
...
...
@@ -176,7 +176,7 @@ class ExecutionTree {
// For example, repeatOp inlining
//
// @return Status - The error code return
Status
Prepare
();
Status
Prepare
(
int
num_epochs
=
-
1
);
// Compulsory transformation/action pre optimization.
// @return Status - The error code return
...
...
@@ -193,6 +193,7 @@ class ExecutionTree {
// The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
// it ready for execution.
// @param Total number of epochs that will be run on this tree
// @return Status - The error code return
Status
PrepareDeprecated
();
...
...
@@ -231,6 +232,10 @@ class ExecutionTree {
// Optional optimizations status
bool
OptimizationEnabled
()
const
{
return
optimize_
;
}
// Getter function to get the total number of epochs to be run on this tree.
// @return total number of epochs
int32_t
num_epochs
()
{
return
num_epochs_
;
}
private:
// A helper functions for doing the recursive printing
// @param dataset_op - The dataset op to print
...
...
@@ -245,6 +250,7 @@ class ExecutionTree {
int32_t
id_count_
;
// Counter for generating operator id's
uint32_t
prepare_flags_
;
// Flags used during tree prepare
TreeState
tree_state_
;
// Tracking the current tree state
int32_t
num_epochs_
;
// Total number of epochs to run for this tree
std
::
unique_ptr
<
Monitor
>
perf_monitor_
;
// Performance Monitor
std
::
unique_ptr
<
ProfilingManager
>
profiling_manager_
;
// Profiling manager
bool
optimize_
;
// Flag to enable optional optimizations
...
...
mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt
浏览文件 @
8e4c0a9d
...
...
@@ -5,6 +5,7 @@ add_library(engine-opt OBJECT
post/repeat_pass.cc
pre/cache_pass.cc
pre/cache_transform_pass.cc
pre/injection_pass.cc
pre/removal_nodes.cc
pre/removal_pass.cc
optional/tensor_op_fusion_pass.cc
...
...
mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc
浏览文件 @
8e4c0a9d
...
...
@@ -16,11 +16,13 @@
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/datasetops/batch_op.h"
#include "minddata/dataset/engine/datasetops/build_vocab_op.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/datasetops/dataset_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/datasetops/map_op.h"
#include "minddata/dataset/engine/datasetops/project_op.h"
#include "minddata/dataset/engine/datasetops/rename_op.h"
...
...
@@ -230,6 +232,11 @@ Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified)
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
EpochCtrlOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
PreRunOnNode
(
std
::
shared_ptr
<
RepeatOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
PreRunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
...
...
@@ -244,5 +251,15 @@ Status NodePass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified
// Fallback to base class visitor by default
return
PreRunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
PreRunOnNode
(
std
::
shared_ptr
<
EpochCtrlOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
PreRunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
PreRunOnNode
(
std
::
shared_ptr
<
BuildVocabOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
PreRunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/opt/pass.h
浏览文件 @
8e4c0a9d
...
...
@@ -77,6 +77,10 @@ class CacheMergeOp;
class
CacheLookupOp
;
class
EpochCtrlOp
;
class
BuildVocabOp
;
// The base class Pass is the basic unit of tree transformation.
// The actual implementation of the passes will be derived from here.
class
Pass
:
public
std
::
enable_shared_from_this
<
Pass
>
{
...
...
@@ -190,12 +194,18 @@ class NodePass : public Pass {
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
CacheLookupOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
EpochCtrlOp
>
node
,
bool
*
modified
);
virtual
Status
PreRunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
);
virtual
Status
PreRunOnNode
(
std
::
shared_ptr
<
RepeatOp
>
node
,
bool
*
modified
);
virtual
Status
PreRunOnNode
(
std
::
shared_ptr
<
CacheMergeOp
>
node
,
bool
*
modified
);
virtual
Status
PreRunOnNode
(
std
::
shared_ptr
<
EpochCtrlOp
>
node
,
bool
*
modified
);
virtual
Status
PreRunOnNode
(
std
::
shared_ptr
<
BuildVocabOp
>
node
,
bool
*
modified
);
private:
// Helper function to perform DFS visit
Status
DFSNodeVisit
(
std
::
shared_ptr
<
DatasetOp
>
node
,
bool
*
modified
);
...
...
mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc
浏览文件 @
8e4c0a9d
...
...
@@ -20,6 +20,7 @@
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
namespace
mindspore
{
namespace
dataset
{
...
...
@@ -28,6 +29,9 @@ RepeatPass::RepeatPass() : is_repeated_(false), nested_repeats_(0), is_merge_(fa
// Identifies the subtree below this node as being in a repeated path of the tree.
Status
RepeatPass
::
PreRunOnNode
(
std
::
shared_ptr
<
RepeatOp
>
node
,
bool
*
modified
)
{
// Create a new stack for eoe operators and push onto our stack of stacks.
std
::
unique_ptr
<
eoe_op_stack
>
new_stack
=
std
::
make_unique
<
eoe_op_stack
>
();
eoe_op_stacks_
.
push
(
std
::
move
(
new_stack
));
// If we are already repeated, then this is a nested repeat.
if
(
is_repeated_
)
{
nested_repeats_
++
;
...
...
@@ -36,6 +40,18 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified)
return
Status
::
OK
();
}
// Identifies the subtree below this node as being in a repeated path of the tree.
Status
RepeatPass
::
PreRunOnNode
(
std
::
shared_ptr
<
EpochCtrlOp
>
node
,
bool
*
modified
)
{
// EpochCtrl is derived from RepeatOp. Generally it should do the identical setup
// that RepeatOp does. However, epoch control is actually simpler because it can
// only exist as the root node so it doesn't need all the nested code.
// Create a new stack for eoe operators and push onto our stack of stacks.
std
::
unique_ptr
<
eoe_op_stack
>
new_stack
=
std
::
make_unique
<
eoe_op_stack
>
();
eoe_op_stacks_
.
push
(
std
::
move
(
new_stack
));
is_repeated_
=
true
;
return
Status
::
OK
();
}
// Identifies the subtree below this node as being in a cache merge path
Status
RepeatPass
::
PreRunOnNode
(
std
::
shared_ptr
<
CacheMergeOp
>
node
,
bool
*
modified
)
{
// Turn on the flag that we're under a merge op
...
...
@@ -47,13 +63,24 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modifi
Status
RepeatPass
::
RunOnNode
(
std
::
shared_ptr
<
RepeatOp
>
node
,
bool
*
modified
)
{
// Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking
std
::
shared_ptr
<
DatasetOp
>
leaf_op
=
PopFromEOEOpStack
();
while
(
leaf_op
!=
nullptr
)
{
node
->
AddToEoeList
(
leaf_op
);
leaf_op
=
PopFromEOEOpStack
();
}
// At this point, we are done with the save area stack. It's a unique pointer to an empty stack
// at this time, so we can pop it to get rid of it.
eoe_op_stack
*
current_stack
=
eoe_op_stacks_
.
top
().
get
();
if
(
!
current_stack
->
empty
())
{
RETURN_STATUS_UNEXPECTED
(
"The eoe op stack should be empty right now!"
);
}
eoe_op_stacks_
.
pop
();
// We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up
// and add it to the list of eoe/leaf ops for the repeat, removing it from the save area.
// and add it to the list of eoe/leaf ops for the repeat. It is important that the op is removed
// from the save area, because the merge op above us may also take action on it later for a different
// case when there is no repeat in the merge leg.
if
(
is_merge_
&&
cache_lookup_
)
{
cache_lookup_
->
set_control_flag
(
DatasetOp
::
kDeOpRepeated
);
node
->
AddToEoeList
(
std
::
move
(
cache_lookup_
));
...
...
@@ -65,16 +92,29 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
node
->
set_control_flag
(
DatasetOp
::
kDeOpRepeated
);
AddToEOEOpStack
(
node
);
nested_repeats_
--
;
}
// If we are not nested, or we were the top-most repeat, now we clear the flag
if
(
nested_repeats_
==
0
)
{
}
else
{
// If we are not nested, or we were the top-most repeat, now we clear the flag
if
(
nested_repeats_
!=
0
)
{
RETURN_STATUS_UNEXPECTED
(
"Nested repeat counter cannot be negative!"
);
}
is_repeated_
=
false
;
}
return
Status
::
OK
();
}
// Hooks up any identified eoe nodes under this repeat.
Status
RepeatPass
::
RunOnNode
(
std
::
shared_ptr
<
EpochCtrlOp
>
node
,
bool
*
modified
)
{
// Pop the leaf ops from the save-area stack and add them to the eoe node tracking
std
::
shared_ptr
<
DatasetOp
>
leaf_op
=
PopFromEOEOpStack
();
while
(
leaf_op
!=
nullptr
)
{
node
->
AddToEoeList
(
leaf_op
);
leaf_op
=
PopFromEOEOpStack
();
}
is_repeated_
=
false
;
return
Status
::
OK
();
}
// CacheOp removes previous leaf ops and replaces them with itself
Status
RepeatPass
::
RunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
{
if
(
is_repeated_
)
{
...
...
@@ -118,9 +158,16 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) {
// Turns off the tracking for operations under merge op
Status
RepeatPass
::
RunOnNode
(
std
::
shared_ptr
<
CacheMergeOp
>
node
,
bool
*
modified
)
{
// Setting the flag is needed since we didn't call the base class DatasetOp version
if
(
is_repeated_
)
node
->
set_control_flag
(
DatasetOp
::
kDeOpRepeated
);
if
(
is_repeated_
)
{
node
->
set_control_flag
(
DatasetOp
::
kDeOpRepeated
);
// If there was not any repeat in the merge cache miss leg, then the cache_lookup
// would not have been consumed yet. In that case, we need to assign it to the upper repeat eoe stack
if
(
cache_lookup_
)
{
AddToEOEOpStack
(
std
::
move
(
cache_lookup_
));
}
}
cache_lookup_
.
reset
();
// If we are not repeated then the saved lookup is no longer needed or used
is_merge_
=
false
;
cache_lookup_
.
reset
();
// If a repeat op did not consume this then it's no longer needed
return
Status
::
OK
();
}
...
...
@@ -135,25 +182,32 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified
// In this case, we naturally are a repeating leaf op so add the required setup for leafs under repeat here.
if
(
is_repeated_
)
{
node
->
set_control_flag
(
DatasetOp
::
kDeOpRepeated
);
AddToEOEOpStack
(
node
);
}
else
{
// save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we
// may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself
// into the pass so that the decision can be made during the processing of the cache miss leg of the merge.
cache_lookup_
=
std
::
static_pointer_cast
<
DatasetOp
>
(
node
);
// Delay the assigment of this leap to the eoe stack and allow the merge op processing to handle that.
}
// save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we
// may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself
// into the pass so that the decision can be made during the processing of the cache miss leg of the merge.
// Further, if there's a repeat above the merge but no repeat in the cache miss leg, then the merge op will
// add the lookup to the eoe stack
cache_lookup_
=
std
::
static_pointer_cast
<
DatasetOp
>
(
node
);
return
Status
::
OK
();
}
// Adds an operator to the eoe operator stack save area
void
RepeatPass
::
AddToEOEOpStack
(
std
::
shared_ptr
<
DatasetOp
>
dataset_op
)
{
eoe_stack_
.
push
(
dataset_op
);
}
void
RepeatPass
::
AddToEOEOpStack
(
std
::
shared_ptr
<
DatasetOp
>
dataset_op
)
{
eoe_op_stack
*
current_stack
=
eoe_op_stacks_
.
top
().
get
();
current_stack
->
push
(
dataset_op
);
}
// Pops an operator from the eoe operator stack save area
std
::
shared_ptr
<
DatasetOp
>
RepeatPass
::
PopFromEOEOpStack
()
{
std
::
shared_ptr
<
DatasetOp
>
top_op
=
nullptr
;
if
(
!
eoe_stack_
.
empty
())
{
top_op
=
eoe_stack_
.
top
();
eoe_stack_
.
pop
();
eoe_op_stack
*
current_stack
=
eoe_op_stacks_
.
top
().
get
();
if
(
current_stack
!=
nullptr
&&
!
current_stack
->
empty
())
{
top_op
=
current_stack
->
top
();
current_stack
->
pop
();
}
return
top_op
;
}
...
...
mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h
浏览文件 @
8e4c0a9d
...
...
@@ -30,6 +30,8 @@ namespace dataset {
/// to the eoe-producing (typically leaf) nodes underneath it.
class
RepeatPass
:
public
NodePass
{
public:
using
eoe_op_stack
=
std
::
stack
<
std
::
shared_ptr
<
DatasetOp
>>
;
/// \brief Constructor
RepeatPass
();
...
...
@@ -39,6 +41,12 @@ class RepeatPass : public NodePass {
/// \return Status The error code return
Status
PreRunOnNode
(
std
::
shared_ptr
<
RepeatOp
>
node
,
bool
*
modified
)
override
;
/// \brief Identifies the subtree below this node as being in a repeated path of the tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
PreRunOnNode
(
std
::
shared_ptr
<
EpochCtrlOp
>
node
,
bool
*
modified
)
override
;
/// \brief Identifies the subtree below this node as being in a cache merge path
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
...
...
@@ -51,6 +59,12 @@ class RepeatPass : public NodePass {
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
RepeatOp
>
node
,
bool
*
modified
)
override
;
/// \brief Hooks up any identified eoe nodes under this repeat.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
EpochCtrlOp
>
node
,
bool
*
modified
)
override
;
/// \brief CacheOp removes previous leaf ops and replaces them with itself
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
...
...
@@ -86,11 +100,11 @@ class RepeatPass : public NodePass {
/// \return shared_ptr to the popped operator
std
::
shared_ptr
<
DatasetOp
>
PopFromEOEOpStack
();
bool
is_repeated_
;
// T/F if we are processing under a repeat
bool
is_merge_
;
// T/F if we are processing under a cache merge op
int32_t
nested_repeats_
;
// A counter for nested repeats
std
::
stack
<
std
::
shared_ptr
<
DatasetOp
>>
eoe_stack_
;
// A save area for leaf/eoe ops
std
::
shared_ptr
<
DatasetOp
>
cache_lookup_
;
// A save area for a cache lookup op
bool
is_repeated_
;
// T/F if we are processing under a repeat
bool
is_merge_
;
// T/F if we are processing under a cache merge op
int32_t
nested_repeats_
;
// A counter for nested repeats
std
::
stack
<
std
::
unique_ptr
<
eoe_op_stack
>>
eoe_op_stacks_
;
// A save area for leaf/eoe ops (with nesting)
std
::
shared_ptr
<
DatasetOp
>
cache_lookup_
;
// A save area for a cache lookup op
};
}
// namespace dataset
}
// namespace mindspore
...
...
mindspore/ccsrc/minddata/dataset/engine/opt/pre/injection_pass.cc
0 → 100644
浏览文件 @
8e4c0a9d
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <algorithm>
#include "minddata/dataset/engine/opt/pre/injection_pass.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
namespace
mindspore
{
namespace
dataset
{
// constructor
InjectionPass
::
InjectionFinder
::
InjectionFinder
(
InjectionPass
*
injection_pass
)
:
injection_pass_
(
injection_pass
)
{}
// Performs finder work for BuildVocabOp that has special rules about epoch control injection
Status
InjectionPass
::
InjectionFinder
::
PreRunOnNode
(
std
::
shared_ptr
<
BuildVocabOp
>
node
,
bool
*
modified
)
{
if
(
injection_pass_
)
{
injection_pass_
->
epoch_ctrl_bypass_
=
true
;
return
Status
::
OK
();
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Missing outer injection pass object from inside InjectionFinder!"
);
}
}
// Temporary code to prevent the injection of epoch control when cache op is present
// Remove this code in cache op phase 2
Status
InjectionPass
::
InjectionFinder
::
PreRunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
{
if
(
injection_pass_
)
{
injection_pass_
->
epoch_ctrl_bypass_
=
true
;
return
Status
::
OK
();
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Missing outer injection pass object from inside InjectionFinder!"
);
}
}
// constructor
InjectionPass
::
InjectionPass
()
:
epoch_ctrl_bypass_
(
false
)
{}
// Runs an injection pass to inject in operators needed at the pre pass stage
Status
InjectionPass
::
RunOnTree
(
ExecutionTree
*
tree
,
bool
*
modified
)
{
MS_LOG
(
INFO
)
<<
"Pre pass: Injection pass started."
;
// First, run the finder to perform any injection info before we can go ahead to drive the op injection work.
// The finder can make updates to the InjectionPass object.
InjectionPass
::
InjectionFinder
finder
(
this
);
finder
.
Run
(
tree
,
modified
);
// The first injection logic is to check if we should inject the epoch control op as the root node.
// Do not inject the op if the number of epochs is 1.
int32_t
num_epochs
=
tree
->
num_epochs
();
if
(
num_epochs
!=
1
&&
!
epoch_ctrl_bypass_
)
{
std
::
shared_ptr
<
EpochCtrlOp
>
epoch_ctrl_op
;
RETURN_IF_NOT_OK
(
EpochCtrlOp
::
Builder
(
num_epochs
).
Build
(
&
epoch_ctrl_op
));
RETURN_IF_NOT_OK
(
tree
->
AssociateNode
(
epoch_ctrl_op
));
std
::
shared_ptr
<
DatasetOp
>
node
=
tree
->
root
();
if
(
std
::
dynamic_pointer_cast
<
DeviceQueueOp
>
(
node
)
==
nullptr
)
{
tree
->
root
()
->
InsertAsParent
(
epoch_ctrl_op
);
}
else
{
tree
->
root
()
->
child
(
0
)
->
InsertAsParent
(
epoch_ctrl_op
);
}
}
MS_LOG
(
INFO
)
<<
"Pre pass: Injection pass complete."
;
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/opt/pre/injection_pass.h
0 → 100644
浏览文件 @
8e4c0a9d
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
#define DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
#include <memory>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"
namespace
mindspore
{
namespace
dataset
{
class
DatasetOp
;
/// \class InjectionPass injection_pass.h
/// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api
/// parsing.
class
InjectionPass
:
public
TreePass
{
/// \class InjectionFinder
/// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for
/// operators that need to be injected. It is run first by the main injection pass to find out what operators
/// it may need to inject.
class
InjectionFinder
:
public
NodePass
{
public:
/// \brief Constructor
explicit
InjectionFinder
(
InjectionPass
*
injection_pass
);
/// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
PreRunOnNode
(
std
::
shared_ptr
<
BuildVocabOp
>
node
,
bool
*
modified
)
override
;
/// \brief Temporary code to prevent the injection of epoch control when cache op is present.
/// Remove this code in cache op phase 2
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
PreRunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
override
;
private:
InjectionPass
*
injection_pass_
;
};
public:
/// \brief Constructor
InjectionPass
();
/// \brief Runs an injection pass to inject in operators needed at the pre pass stage
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
Status
RunOnTree
(
ExecutionTree
*
tree
,
bool
*
modified
)
override
;
private:
bool
epoch_ctrl_bypass_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc
浏览文件 @
8e4c0a9d
...
...
@@ -29,20 +29,27 @@ std::shared_ptr<TdtPlugin> TdtPlugin::GetInstance() {
return
instance_ptr_
;
}
TdtStatus
TdtPlugin
::
hostPush
(
TensorRow
ts_row
,
bool
is_wait
,
std
::
string
channel_name
,
bool
profiling
,
int32_t
&
time
)
{
TdtStatus
TdtPlugin
::
hostPush
(
TensorRow
ts_row
,
bool
is_wait
,
std
::
string
channel_name
,
bool
profiling
,
int32_t
&
time
,
tdt
::
TdtDataType
tdt_type
)
{
MS_LOG
(
DEBUG
)
<<
"TDT channel name is "
<<
channel_name
<<
"."
;
std
::
vector
<
DataItem
>
items
;
double
start_time
;
auto
ret
=
translate
(
ts_row
,
items
);
if
(
ret
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"TDT converting tensor failed!"
;
return
FAILED
;
if
(
tdt_type
==
tdt
::
TDT_TENSOR
)
{
auto
ret
=
translate
(
ts_row
,
items
);
if
(
ret
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"TDT converting tensor failed!"
;
return
FAILED
;
}
}
else
if
(
tdt_type
==
tdt
::
TDT_END_OF_SEQUENCE
)
{
DataItem
data_item
;
data_item
.
dataType_
=
tdt
::
TDT_END_OF_SEQUENCE
;
items
.
emplace_back
(
data_item
);
MS_LOG
(
INFO
)
<<
"TDT data type is TDT_END_OF_SEQUENCE"
;
}
if
(
profiling
)
{
start_time
=
ProfilingTime
::
GetCurMilliSecond
();
}
if
(
tdt
::
TdtHostPushData
(
channel_name
,
items
)
!=
0
)
{
MS_LOG
(
ERROR
)
<<
"TDT pushing data failed!"
;
return
FAILED
;
}
if
(
profiling
)
{
...
...
@@ -122,8 +129,8 @@ TdtStatus TdtPlugin::translate(const TensorRow &ts_row, std::vector<DataItem> &i
data_item
.
dataPtr_
=
std
::
shared_ptr
<
void
>
(
reinterpret_cast
<
uchar
*>
(
&
(
*
ts
->
begin
<
uint8_t
>
())),
[](
const
void
*
elem
)
{});
items
.
emplace_back
(
data_item
);
MS_LOG
(
DEBUG
)
<<
"TDT data type is "
<<
datatype
<<
", data shape is "
<<
dataShapes
<<
", data length is "
<<
ts
->
Size
()
<<
"."
;
MS_LOG
(
INFO
)
<<
"TDT data type is TDT_TENSOR, tensor type is "
<<
datatype
<<
", tensor shape is "
<<
dataShapes
<<
", data length is "
<<
ts
->
Size
()
<<
"."
;
}
return
SUCCESS
;
}
...
...
mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.h
浏览文件 @
8e4c0a9d
...
...
@@ -38,7 +38,8 @@ class TdtPlugin {
public:
static
std
::
shared_ptr
<
TdtPlugin
>
GetInstance
();
TdtStatus
hostPush
(
TensorRow
ts_row
,
bool
is_wait
,
std
::
string
channel_name
,
bool
profilig
,
int32_t
&
time
);
TdtStatus
hostPush
(
TensorRow
ts_row
,
bool
is_wait
,
std
::
string
channel_name
,
bool
profilig
,
int32_t
&
time
,
tdt
::
TdtDataType
tdt_type
=
tdt
::
TDT_TENSOR
);
private:
TdtPlugin
()
{}
...
...
mindspore/ccsrc/pipeline/jit/pipeline.cc
浏览文件 @
8e4c0a9d
...
...
@@ -797,6 +797,9 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba
(
void
)
InitBackend
();
}
#endif
if
(
iter_num
==
-
1
)
{
iter_num
=
INT32_MAX
;
}
if
(
name
==
kMsConvert
||
name
==
kMsVm
)
{
return
InitExecDatasetVm
(
queue_name
,
iter_num
,
batch_size
,
types
,
shapes
,
input_indexes
,
need_run
);
}
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
8e4c0a9d
...
...
@@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_take
,
check_project
,
check_imagefolderdatasetv2
,
check_mnist_cifar_dataset
,
check_manifestdataset
,
\
check_tfrecorddataset
,
check_vocdataset
,
check_cocodataset
,
check_celebadataset
,
check_minddataset
,
\
check_generatordataset
,
check_sync_wait
,
check_zip_dataset
,
check_add_column
,
check_textfiledataset
,
check_concat
,
\
check_random_dataset
,
check_split
,
check_bucket_batch_by_length
,
check_cluedataset
,
check_
positive_int32
,
check_
save
check_random_dataset
,
check_split
,
check_bucket_batch_by_length
,
check_cluedataset
,
check_save
from
..core.datatypes
import
mstype_to_detype
,
mstypelist_to_detypelist
try
:
...
...
@@ -946,14 +946,14 @@ class Dataset:
raise
TypeError
(
"apply_func must return a dataset."
)
return
dataset
@
check_positive_int32
def
device_que
(
self
,
prefetch_size
=
None
):
def
device_que
(
self
,
prefetch_size
=
None
,
send_epoch_end
=
True
):
"""
Return a transferredDataset that transfer data through device.
Args:
prefetch_size (int, optional): prefetch number of records ahead of the
user's request (default=None).
send_epoch_end (bool, optional): whether send end of sequence to device or not.(default=True)
Note:
If device is Ascend, features of data will be transferred one by one. The limitation
...
...
@@ -962,15 +962,14 @@ class Dataset:
Return:
TransferDataset, dataset for transferring.
"""
return
self
.
to_device
()
return
self
.
to_device
(
send_epoch_end
=
send_epoch_end
)
@
check_positive_int32
def
to_device
(
self
,
num_batch
=
None
):
def
to_device
(
self
,
send_epoch_end
=
True
):
"""
Transfer data through CPU, GPU or Ascend devices.
Args:
num_batch (int, optional): limit the number of batch to be sent to device (default=None).
send_epoch_end (bool, optional): whether send end of sequence to device or not.(default=True)
Note:
If device is Ascend, features of data will be transferred one by one. The limitation
...
...
@@ -982,19 +981,9 @@ class Dataset:
Raises:
TypeError: If device_type is empty.
ValueError: If device_type is not 'Ascend', 'GPU' or 'CPU'.
ValueError: If num_batch is not positive or larger than int_max.
ValueError: If dataset size is None or 0.
RuntimeError: If dataset is unknown.
RuntimeError: If distribution file path is given but failed to read.
"""
if
self
.
get_dataset_size
()
is
None
or
0
:
raise
ValueError
(
"dataset size is None or 0."
)
if
num_batch
is
None
:
num_batch
=
self
.
get_dataset_size
()
repeat_count
=
self
.
get_repeat_count
()
num_batch
=
num_batch
*
repeat_count
queue_name
=
str
(
uuid
.
uuid1
())
if
context
:
...
...
@@ -1008,9 +997,6 @@ class Dataset:
if
device_type
not
in
(
'Ascend'
,
'GPU'
,
'CPU'
):
raise
ValueError
(
"Only support CPU, Ascend, GPU"
)
if
num_batch
==
0
:
raise
ValueError
(
"num_batch is 0."
)
def
get_distribution
(
output_dataset
):
dev_id
=
0
if
isinstance
(
output_dataset
,
(
Cifar10Dataset
,
Cifar100Dataset
,
GeneratorDataset
,
ImageFolderDatasetV2
,
...
...
@@ -1032,7 +1018,7 @@ class Dataset:
distribution_path
,
device_id
=
get_distribution
(
self
)
if
distribution_path
==
""
:
return
TransferDataset
(
self
,
queue_name
,
device_id
,
device_type
,
num_batch
)
return
TransferDataset
(
self
,
queue_name
,
device_id
,
device_type
,
send_epoch_end
)
try
:
with
open
(
distribution_path
,
'r'
)
as
distribution_f
:
dist
=
json
.
load
(
distribution_f
)
...
...
@@ -1042,7 +1028,7 @@ class Dataset:
except
Exception
:
raise
RuntimeError
(
"Distribution file failed to read"
)
return
TransferDataset
(
self
,
queue_name
,
device_id
,
device_type
,
num_batch
)
return
TransferDataset
(
self
,
queue_name
,
device_id
,
device_type
,
send_epoch_end
)
@
check_save
def
save
(
self
,
file_name
,
num_files
=
1
,
file_type
=
'mindrecord'
):
...
...
@@ -1072,7 +1058,7 @@ class Dataset:
return
SaveOp
(
self
).
save
(
file_names
,
file_type
)
def
create_tuple_iterator
(
self
,
columns
=
None
):
def
create_tuple_iterator
(
self
,
columns
=
None
,
num_epochs
=-
1
):
"""
Create an Iterator over the dataset. The data retrieved will be a list of ndarray of data.
...
...
@@ -1098,9 +1084,9 @@ class Dataset:
"""
if
self
.
_noop_mode
():
return
DummyIterator
(
self
,
'tuple'
)
return
TupleIterator
(
self
,
columns
)
return
TupleIterator
(
self
,
columns
,
num_epochs
)
def
create_dict_iterator
(
self
):
def
create_dict_iterator
(
self
,
num_epochs
=-
1
):
"""
Create an Iterator over the dataset.
...
...
@@ -1123,7 +1109,7 @@ class Dataset:
"""
if
self
.
_noop_mode
():
return
DummyIterator
(
self
,
'dict'
)
return
DictIterator
(
self
)
return
DictIterator
(
self
,
num_epochs
)
def
__iter__
(
self
):
"""Create an Iterator over the dataset."""
...
...
@@ -1149,7 +1135,7 @@ class Dataset:
self
.
_batch_size
=
device_iter
.
get_batch_size
()
self
.
_num_classes
=
device_iter
.
num_classes
()
self
.
_repeat_count
=
device_iter
.
get_repeat_count
()
device_iter
.
release
()
device_iter
.
stop
()
def
output_shapes
(
self
):
"""
...
...
@@ -2085,7 +2071,7 @@ class RepeatDataset(DatasetOp):
"""
child_size
=
self
.
children
[
0
].
get_dataset_size
()
if
child_size
is
not
None
:
return
child_size
return
child_size
*
self
.
count
return
None
def
get_repeat_count
(
self
):
...
...
@@ -2097,7 +2083,6 @@ class RepeatDataset(DatasetOp):
"""
return
self
.
count
class
SkipDataset
(
DatasetOp
):
"""
The result of applying Skip operator to the input Dataset.
...
...
@@ -2317,10 +2302,10 @@ class TransferDataset(DatasetOp):
queue_name (str): Name of device queue.
device_id (int): Id of device.
device_type (str): Type of device, including "CPU", "GPU", and "Ascend".
num_batch (int): limit the number of batch to be sent to device (default=None).
send_epoch_end (bool, optional): Whether send end of sequence to device or not.(default=True)
"""
def
__init__
(
self
,
input_dataset
,
queue_name
,
device_id
,
device_type
,
num_batch
=
Non
e
):
def
__init__
(
self
,
input_dataset
,
queue_name
,
device_id
,
device_type
,
send_epoch_end
=
Tru
e
):
super
().
__init__
()
self
.
children
.
append
(
input_dataset
)
input_dataset
.
parent
.
append
(
self
)
...
...
@@ -2328,7 +2313,7 @@ class TransferDataset(DatasetOp):
self
.
_input_indexs
=
input_dataset
.
input_indexs
self
.
_device_type
=
device_type
self
.
_device_id
=
device_id
self
.
_
_num_batch
=
num_batch
self
.
_
send_epoch_end
=
send_epoch_end
self
.
iterator
=
None
def
get_args
(
self
):
...
...
@@ -2336,13 +2321,13 @@ class TransferDataset(DatasetOp):
args
[
"queue_name"
]
=
self
.
queue_name
args
[
"device_type"
]
=
self
.
_device_type
args
[
"device_id"
]
=
self
.
_device_id
args
[
"
num_batch"
]
=
self
.
__num_batch
args
[
"
send_epoch_end"
]
=
self
.
_send_epoch_end
return
args
def
create_dict_iterator
(
self
):
def
create_dict_iterator
(
self
,
num_epochs
=-
1
):
raise
RuntimeError
(
"TransferDataset is not iterable"
)
def
create_tuple_iterator
(
self
,
columns
=
None
):
def
create_tuple_iterator
(
self
,
columns
=
None
,
num_epochs
=-
1
):
raise
RuntimeError
(
"TransferDataset is not iterable"
)
def
__iter__
(
self
):
...
...
@@ -2354,12 +2339,14 @@ class TransferDataset(DatasetOp):
def
output_types
(
self
):
raise
RuntimeError
(
"TransferDataset does not support output_types"
)
def
send
(
self
):
def
send
(
self
,
num_epochs
=-
1
):
# need to keep iterator alive so the executionTree is not destroyed
if
self
.
_noop_mode
():
return
self
.
iterator
=
TupleIterator
(
self
)
self
.
iterator
=
TupleIterator
(
self
,
num_epochs
=-
1
)
def
stop_send
(
self
):
self
.
iterator
.
depipeline
.
StopSend
()
class
RangeDataset
(
MappableDataset
):
"""
...
...
mindspore/dataset/engine/iterators.py
浏览文件 @
8e4c0a9d
...
...
@@ -29,7 +29,6 @@ from . import datasets as de
ITERATORS_LIST
=
list
()
def
_cleanup
():
"""Release all the Iterator."""
for
itr_ref
in
ITERATORS_LIST
:
...
...
@@ -60,7 +59,6 @@ def _alter_node(node):
node
.
iterator_bootstrap
()
return
node
class
Iterator
:
"""
General Iterator over a dataset.
...
...
@@ -69,10 +67,21 @@ class Iterator:
dataset: Dataset to be iterated over
"""
def
__init__
(
self
,
dataset
):
def
__init__
(
self
,
dataset
,
num_epochs
=-
1
):
self
.
num_epochs
=
num_epochs
ITERATORS_LIST
.
append
(
weakref
.
ref
(
self
))
# create a copy of tree and work on it.
self
.
dataset
=
copy
.
deepcopy
(
dataset
)
self
.
parent_subtree
=
[]
# The dataset passed into the iterator is not the root of the tree.
# Trim the tree by saving the parent subtree into self.parent_subtree and
# restore it after launching our c++ pipeline.
if
self
.
dataset
.
parent
:
logger
.
warning
(
"The dataset passed in is not the root of the pipeline. Ignoring parent subtree."
)
self
.
parent_subtree
=
self
.
dataset
.
parent
self
.
dataset
.
parent
=
[]
self
.
dataset
=
alter_tree
(
self
.
dataset
)
if
not
self
.
__is_tree
():
raise
ValueError
(
"The data pipeline is not a tree (i.e., one node has 2 consumers)"
)
...
...
@@ -83,9 +92,17 @@ class Iterator:
root
=
self
.
__convert_node_postorder
(
self
.
dataset
)
self
.
depipeline
.
AssignRootNode
(
root
)
self
.
depipeline
.
LaunchTreeExec
()
self
.
depipeline
.
LaunchTreeExec
(
self
.
num_epochs
)
self
.
_index
=
0
def
stop
(
self
):
"""
Manually terminate python iterator instead of relying on out of scope destruction.
"""
logger
.
info
(
"terminating python iterator. This will also terminate c++ pipeline."
)
if
hasattr
(
self
,
'depipeline'
)
and
self
.
depipeline
:
del
self
.
depipeline
def
__is_tree_node
(
self
,
node
):
"""Check if a node is tree node."""
if
not
node
.
children
:
...
...
@@ -214,9 +231,14 @@ class Iterator:
@
abstractmethod
def
get_next
(
self
):
pass
raise
RuntimeError
(
"Calling base class Iterator's get_next is invalid."
)
def
__next__
(
self
):
if
not
self
.
depipeline
:
logger
.
warning
(
"Iterator does not have a running c++ pipeline."
+
"It can be because Iterator stop() had been called, or c++ pipeline crashed silently."
)
raise
RuntimeError
(
"Iterator does not have a running c++ pipeline."
)
data
=
self
.
get_next
()
if
not
data
:
if
self
.
_index
==
0
:
...
...
@@ -293,12 +315,12 @@ class TupleIterator(Iterator):
def
check_node_type
(
self
,
node
):
pass
def
__init__
(
self
,
dataset
,
columns
=
None
):
def
__init__
(
self
,
dataset
,
columns
=
None
,
num_epochs
=-
1
):
if
columns
is
not
None
:
if
not
isinstance
(
columns
,
list
):
columns
=
[
columns
]
dataset
=
dataset
.
project
(
columns
)
super
().
__init__
(
dataset
)
super
().
__init__
(
dataset
,
num_epochs
)
def
__iter__
(
self
):
return
self
...
...
mindspore/train/_utils.py
浏览文件 @
8e4c0a9d
...
...
@@ -57,7 +57,8 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
# transform data format
dataset_types
,
dataset_shapes
=
_get_types_and_shapes
(
exec_dataset
)
exec_dataset
=
exec_dataset
.
device_que
()
send_epoch_end
=
bool
(
dataset_size
==
-
1
)
exec_dataset
=
exec_dataset
.
device_que
(
send_epoch_end
=
send_epoch_end
)
_executor
.
init_dataset
(
exec_dataset
.
queue_name
,
dataset_size
,
...
...
@@ -126,7 +127,7 @@ def _construct_tensor_list(types, shapes, batch_expand_num=1):
def
_to_tensor
(
elem
,
scaling_sens
=
None
):
"""Conver
numpy to tensor, adapt to minddata feed
solution."""
"""Conver
t numpy to tensor, adapt to feed the data from host
solution."""
lst
=
[]
if
not
isinstance
(
elem
,
(
tuple
,
list
)):
elem
=
[
elem
]
...
...
@@ -145,7 +146,8 @@ def _to_tensor(elem, scaling_sens=None):
def
_to_full_tensor
(
elem
,
device_num
,
global_rank
,
scaling_sens
=
None
):
"""Conver numpy to tensor, expanding batch dimension according to device_num, adapt to minddata feed solution."""
"""Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data
from host solution."""
lst
=
[]
if
not
isinstance
(
elem
,
(
tuple
,
list
)):
elem
=
[
elem
]
...
...
mindspore/train/dataset_helper.py
浏览文件 @
8e4c0a9d
...
...
@@ -16,7 +16,7 @@
import
math
import
os
from
mindspore._checkparam
import
check_bool
from
mindspore._checkparam
import
check_bool
,
check_int
from
..
import
context
from
._utils
import
_exec_datagraph
,
_get_types_and_shapes
,
_to_tensor
,
\
_construct_tensor_list
,
_to_full_shapes
,
_to_full_tensor
...
...
@@ -42,17 +42,23 @@ class DatasetHelper:
The iter of DatasetHelper will give one epoch data.
Args:
dataset (DataSet): The dataset.
dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host.
Default: True.
dataset (DataSet): The training dataset iterator.
dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. Default: True.
sink_size (int): Control the amount of data each sink.
If sink_size=-1, sink the complete dataset each epoch.
If sink_size>0, sink sink_size data each epoch. Default: -1.
Examples:
>>> dataset_helper = DatasetHelper(dataset)
>>> for inputs in dataset_helper:
>>> outputs = network(*inputs)
"""
def
__init__
(
self
,
dataset
,
dataset_sink_mode
=
True
):
def
__init__
(
self
,
dataset
,
dataset_sink_mode
=
True
,
sink_size
=-
1
):
check_bool
(
dataset_sink_mode
)
check_int
(
sink_size
)
if
sink_size
<
-
1
or
sink_size
==
0
:
raise
ValueError
(
"The sink_size must be -1 or positive, but got sink_size {}."
.
format
(
sink_size
))
if
dataset_sink_mode
:
if
context
.
get_context
(
"enable_ge"
):
...
...
@@ -68,9 +74,10 @@ class DatasetHelper:
iterclass
=
_DatasetIterMS
elif
context
.
get_context
(
"device_target"
)
==
"CPU"
:
raise
RuntimeError
(
"Currently dataset sink mode is not supported when the device target is CPU."
)
self
.
iter
=
iterclass
(
dataset
,
sink_size
)
else
:
iterclass
=
_DatasetIter
Feed
self
.
iter
=
iterclass
(
dataset
)
iterclass
=
_DatasetIter
Normal
self
.
iter
=
iterclass
(
dataset
)
def
__iter__
(
self
):
return
self
.
iter
.
__iter__
()
...
...
@@ -80,21 +87,26 @@ class DatasetHelper:
"""Get the types and shapes from dataset on current config."""
return
self
.
iter
.
types_shapes
()
def
loop_size
(
self
):
"""Get loop_size for every iteration."""
return
self
.
iter
.
loop_size
def
sink_size
(
self
):
"""Get sink_size for every iteration."""
return
self
.
iter
.
get_sink_size
()
def
stop_send
(
self
):
"""Free up resources about data sink."""
self
.
iter
.
stop_send
()
class
_DatasetIter
:
"""Base iter for dataset help"""
def
__init__
(
self
,
dataset
):
if
not
hasattr
(
dataset
,
'__loop_size__'
):
self
.
loop_size
=
dataset
.
get_dataset_size
()
else
:
self
.
loop_size
=
dataset
.
__loop_size__
"""Base iter for dataset helper"""
def
__init__
(
self
,
dataset
,
sink_size
):
self
.
dataset
=
dataset
self
.
sink_size
=
sink_size
self
.
sink_count
=
1
if
not
hasattr
(
dataset
,
'__ME_INITED__'
):
dataset
.
__TRANSFER_DATASET__
=
_exec_datagraph
(
dataset
,
self
.
loop_size
)
if
not
hasattr
(
dataset
,
'__TRANSFER_DATASET__'
):
if
hasattr
(
dataset
,
'__loop_size__'
):
self
.
sink_size
=
dataset
.
__loop_size__
dataset
.
__TRANSFER_DATASET__
=
_exec_datagraph
(
dataset
,
self
.
sink_size
)
dataset
.
__ME_INITED__
=
dataset
.
__TRANSFER_DATASET__
.
queue_name
if
not
hasattr
(
dataset
,
'__no_send__'
):
...
...
@@ -102,43 +114,70 @@ class _DatasetIter:
else
:
_send_data
(
dataset
)
self
.
ind
=
0
self
.
dataset
=
dataset
dataset_types
,
dataset_shapes
=
_get_types_and_shapes
(
dataset
)
self
.
dataset_types
,
self
.
dataset_shapes
=
dataset_types
,
dataset_shapes
self
.
stop_send
=
dataset
.
__TRANSFER_DATASET__
.
stop_send
self
.
dataset_types
,
self
.
dataset_shapes
=
_get_types_and_shapes
(
dataset
)
def
__iter__
(
self
):
self
.
ind
=
0
self
.
ind
ex
=
0
return
self
def
__next__
(
self
):
if
self
.
ind
>=
self
.
loop
_count
:
if
self
.
ind
ex
>=
self
.
sink
_count
:
raise
StopIteration
()
self
.
ind
+=
1
self
.
ind
ex
+=
1
return
self
.
op
()
def
types_shapes
(
self
):
return
self
.
dataset_types
,
self
.
dataset_shapes
def
get_
loop
_count
(
self
,
dataset
):
loop
_count
=
1
def
get_
sink
_count
(
self
,
dataset
):
sink
_count
=
1
if
hasattr
(
dataset
,
'__loop_size__'
):
loop_size
=
dataset
.
__loop_size__
if
loop_size
<=
dataset
.
get_dataset_size
()
and
dataset
.
get_dataset_size
()
%
loop_size
!=
0
:
raise
ValueError
(
f
'Dataset size
{
dataset
.
get_dataset_size
()
}
and '
f
'loop_size
{
loop_size
}
are not matched.'
)
loop_count
=
math
.
ceil
(
dataset
.
get_dataset_size
()
/
loop_size
)
return
loop_count
f
'sink_size
{
loop_size
}
are not matched.'
)
sink_count
=
math
.
ceil
(
dataset
.
get_dataset_size
()
/
loop_size
)
return
sink_count
def
get_sink_size
(
self
):
"""get sink_size to device"""
sink_size
=
1
if
hasattr
(
self
.
dataset
,
'__loop_size__'
):
sink_size
=
self
.
dataset
.
__loop_size__
else
:
if
context
.
get_context
(
"enable_ge"
)
or
context
.
get_context
(
"device_target"
)
==
"Ascend"
:
if
self
.
sink_size
>
0
:
sink_size
=
self
.
sink_size
else
:
sink_size
=
self
.
dataset
.
get_dataset_size
()
return
sink_size
class
_DatasetIterGE
(
_DatasetIter
):
"""Iter for GE."""
def
__init__
(
self
,
dataset
,
sink_size
):
super
().
__init__
(
dataset
,
sink_size
)
self
.
sink_count
=
self
.
get_sink_count
(
dataset
)
batch_expand_num
=
1
if
_need_to_full
():
batch_expand_num
=
_get_device_num
()
tensor_list_run
=
_construct_tensor_list
(
self
.
dataset_types
,
self
.
dataset_shapes
,
batch_expand_num
)
def
op
():
return
tensor_list_run
self
.
op
=
op
class
_DatasetIterMSLoopSink
(
_DatasetIter
):
"""Iter for context (device_target=Ascend)"""
def
__init__
(
self
,
dataset
):
super
(
_DatasetIterMSLoopSink
,
self
).
__init__
(
dataset
)
self
.
loop_count
=
self
.
get_loop
_count
(
dataset
)
def
__init__
(
self
,
dataset
,
sink_size
):
super
(
).
__init__
(
dataset
,
sink_size
)
self
.
sink_count
=
self
.
get_sink
_count
(
dataset
)
ms_role
=
os
.
getenv
(
"MS_ROLE"
)
if
ms_role
in
(
"MS_PSERVER"
,
"MS_SCHED"
):
self
.
loop
_count
=
1
self
.
sink
_count
=
1
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch,
# use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
# compile is device_number times the batch dimension of tensors for run. Now only support LoopSink.
...
...
@@ -153,66 +192,42 @@ class _DatasetIterMSLoopSink(_DatasetIter):
class
_DatasetIterMS
(
_DatasetIter
):
"""Iter for context (device_target=GPU)"""
def
__init__
(
self
,
dataset
):
super
(
_DatasetIterMS
,
self
).
__init__
(
dataset
)
self
.
loop_count
=
dataset
.
get_dataset_size
()
self
.
loop_size
=
1
"""Iter for MS(enable_loop_sink=False)."""
def
__init__
(
self
,
dataset
,
sink_size
):
super
().
__init__
(
dataset
,
sink_size
)
if
sink_size
>
0
:
self
.
sink_count
=
sink_size
else
:
self
.
sink_count
=
dataset
.
get_dataset_size
()
queue_name
=
dataset
.
__ME_INITED__
self
.
op
=
GetNextSingleOp
(
self
.
dataset_types
,
self
.
dataset_shapes
,
queue_name
)
class
_DatasetIterPSLite
(
_DatasetIter
):
"""Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED"""
def
__init__
(
self
,
dataset
):
super
(
_DatasetIterPSLite
,
self
).
__init__
(
dataset
)
self
.
loop
_count
=
1
self
.
loop
_size
=
1
def
__init__
(
self
,
dataset
,
sink_size
):
super
(
).
__init__
(
dataset
,
sink_size
)
self
.
sink
_count
=
1
self
.
sink
_size
=
1
self
.
op
=
None
def
op
():
return
_construct_tensor_list
(
self
.
dataset_types
,
self
.
dataset_shapes
,
batch_expand_num
=
1
)
self
.
op
=
op
class
_DatasetIterGE
(
_DatasetIter
):
"""Iter for ge"""
def
__init__
(
self
,
dataset
):
super
(
_DatasetIterGE
,
self
).
__init__
(
dataset
)
self
.
loop_count
=
self
.
get_loop_count
(
dataset
)
batch_expand_num
=
1
if
_need_to_full
():
batch_expand_num
=
_get_device_num
()
tensor_list_run
=
_construct_tensor_list
(
self
.
dataset_types
,
self
.
dataset_shapes
,
batch_expand_num
)
def
op
():
return
tensor_list_run
self
.
op
=
op
class
_DatasetIterFeed
:
class
_DatasetIterNormal
:
"""Iter for normal(non sink) mode, feed the data from host."""
def
__init__
(
self
,
dataset
):
self
.
dataset
=
dataset
self
.
device_num
=
_get_device_num
()
self
.
global_rank
=
_get_global_rank
()
self
.
repeat_count
=
dataset
.
get_repeat_count
()
self
.
repeat_ind
=
0
self
.
loop_count
=
dataset
.
get_dataset_size
()
self
.
ind
=
0
def
__iter__
(
self
):
if
self
.
repeat_ind
%
self
.
repeat_count
==
0
:
self
.
iter
=
self
.
dataset
.
__iter__
()
self
.
repeat_ind
+=
1
self
.
ind
=
0
self
.
iter
=
self
.
dataset
.
create_tuple_iterator
()
return
self
def
__next__
(
self
):
if
self
.
ind
>=
self
.
loop_count
:
raise
StopIteration
()
self
.
ind
+=
1
data
=
self
.
iter
.
__next__
()
if
_need_to_full
():
return
_to_full_tensor
(
data
,
self
.
device_num
,
self
.
global_rank
)
...
...
mindspore/train/model.py
浏览文件 @
8e4c0a9d
...
...
@@ -21,7 +21,7 @@ import numpy as np
from
mindspore
import
log
as
logger
from
..common.tensor
import
Tensor
from
..nn.metrics
import
get_metrics
from
.._checkparam
import
check_input_data
,
check_output_data
,
check_int_positive
,
check_bool
from
.._checkparam
import
check_input_data
,
check_output_data
,
check_int_positive
,
check_bool
,
check_int
from
.callback
import
_InternalCallbackParam
,
RunContext
,
_CallbackManager
from
..
import
context
from
..parallel._utils
import
_get_parallel_mode
,
_get_device_num
,
_get_global_rank
,
\
...
...
@@ -225,7 +225,7 @@ class Model:
scaling_sens
/=
self
.
_device_number
return
scaling_sens
def
_exec_preprocess
(
self
,
network
,
is_train
,
phase
,
dataset
,
dataset_sink_mode
):
def
_exec_preprocess
(
self
,
network
,
is_train
,
phase
,
dataset
,
dataset_sink_mode
,
sink_size
=-
1
):
"""Initializes dataset."""
need_wrap
=
False
if
dataset_sink_mode
:
...
...
@@ -237,7 +237,7 @@ class Model:
if
not
is_train
:
dataset
.
__loop_size__
=
1
dataset_helper
=
DatasetHelper
(
dataset
,
dataset_sink_mode
)
dataset_helper
=
DatasetHelper
(
dataset
,
dataset_sink_mode
,
sink_size
)
# remove later to deal with loop sink
if
need_wrap
:
...
...
@@ -317,7 +317,7 @@ class Model:
self
.
_eval_network
.
compile
(
*
inputs
)
break
def
_train
(
self
,
epoch
,
train_dataset
,
callbacks
=
None
,
dataset_sink_mode
=
True
):
def
_train
(
self
,
epoch
,
train_dataset
,
callbacks
=
None
,
dataset_sink_mode
=
True
,
sink_size
=-
1
):
"""
Training.
...
...
@@ -332,6 +332,7 @@ class Model:
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
Configure pynative mode, the training process will be performed with
dataset not sink.
sink_size (int): Control the amount of data each sink. Default: -1.
"""
epoch
=
check_int_positive
(
epoch
)
self
.
_train_network
.
set_train
()
...
...
@@ -342,7 +343,10 @@ class Model:
cb_params
=
_InternalCallbackParam
()
cb_params
.
train_network
=
self
.
_train_network
cb_params
.
epoch_num
=
epoch
cb_params
.
batch_num
=
train_dataset
.
get_dataset_size
()
if
dataset_sink_mode
and
sink_size
>
0
:
cb_params
.
batch_num
=
sink_size
else
:
cb_params
.
batch_num
=
train_dataset
.
get_dataset_size
()
cb_params
.
mode
=
"train"
cb_params
.
loss_fn
=
self
.
_loss_fn
cb_params
.
optimizer
=
self
.
_optimizer
...
...
@@ -364,7 +368,7 @@ class Model:
"So the training process will be performed with dataset not sink."
)
self
.
_train_process
(
epoch
,
train_dataset
,
list_callback
,
cb_params
)
else
:
self
.
_train_dataset_sink_process
(
epoch
,
train_dataset
,
list_callback
,
cb_params
)
self
.
_train_dataset_sink_process
(
epoch
,
train_dataset
,
list_callback
,
cb_params
,
sink_size
)
@
staticmethod
def
_transform_callbacks
(
callbacks
):
...
...
@@ -377,7 +381,7 @@ class Model:
return
[
callbacks
]
def
_train_dataset_sink_process
(
self
,
epoch
,
train_dataset
,
list_callback
=
None
,
cb_params
=
None
):
def
_train_dataset_sink_process
(
self
,
epoch
,
train_dataset
,
list_callback
=
None
,
cb_params
=
None
,
sink_size
=-
1
):
"""
Training process. The data would be passed to network through dataset channel.
...
...
@@ -390,17 +394,18 @@ class Model:
function respectively.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
sink_size (int): Control the amount of data each sink. Default: -1.
"""
dataset_helper
,
train_network
=
self
.
_exec_preprocess
(
self
.
_train_network
,
is_train
=
True
,
phase
=
'train'
,
dataset
=
train_dataset
,
dataset_sink_mode
=
True
)
dataset_sink_mode
=
True
,
sink_size
=
sink_size
)
self
.
_train_network
=
train_network
cb_params
.
train_network
=
self
.
_train_network
cb_params
.
cur_step_num
=
0
loop_size
=
dataset_helper
.
loop_size
()
run_context
=
RunContext
(
cb_params
)
list_callback
.
begin
(
run_context
)
...
...
@@ -412,9 +417,9 @@ class Model:
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
for
inputs
in
dataset_helper
:
cb_params
.
cur_step_num
+=
loop_size
list_callback
.
step_begin
(
run_context
)
outputs
=
self
.
_train_network
(
*
inputs
)
cb_params
.
cur_step_num
+=
dataset_helper
.
sink_size
()
cb_params
.
net_outputs
=
outputs
list_callback
.
step_end
(
run_context
)
...
...
@@ -422,6 +427,7 @@ class Model:
should_stop
=
should_stop
or
run_context
.
get_stop_requested
()
if
should_stop
:
break
dataset_helper
.
stop_send
()
list_callback
.
end
(
run_context
)
...
...
@@ -490,7 +496,7 @@ class Model:
list_callback
.
end
(
run_context
)
def
train
(
self
,
epoch
,
train_dataset
,
callbacks
=
None
,
dataset_sink_mode
=
True
):
def
train
(
self
,
epoch
,
train_dataset
,
callbacks
=
None
,
dataset_sink_mode
=
True
,
sink_size
=-
1
):
"""
Training API where the iteration is controlled by python front-end.
...
...
@@ -515,7 +521,10 @@ class Model:
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
Configure pynative mode, the training process will be performed with
dataset not sink.
sink_size (int): Control the amount of data each sink.
If sink_size=-1, sink the complete dataset each epoch.
If sink_size>0, sink sink_size data each epoch.
If dataset_sink_mode is False, set sink_size invalid. Default: -1.
Examples:
>>> dataset = get_dataset()
...
...
@@ -526,17 +535,19 @@ class Model:
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
>>> model.train(2, dataset)
"""
repeat_count
=
train_dataset
.
get_repeat_count
()
if
epoch
!=
repeat_count
and
dataset_sink_mode
is
True
:
logger
.
warning
(
f
"The epoch_size
{
epoch
}
is not the same with dataset repeat_count
{
repeat_count
}
"
)
check_bool
(
dataset_sink_mode
)
check_int
(
sink_size
)
if
sink_size
<
-
1
or
sink_size
==
0
:
raise
ValueError
(
"The sink_size must be -1 or positive, but got sink_size {}."
.
format
(
sink_size
))
_device_number_check
(
self
.
_parallel_mode
,
self
.
_device_number
)
_parameter_broadcast_check
(
self
.
_parallel_mode
,
self
.
_parameter_broadcast
)
self
.
_train
(
epoch
,
train_dataset
,
callbacks
=
callbacks
,
dataset_sink_mode
=
dataset_sink_mode
)
dataset_sink_mode
=
dataset_sink_mode
,
sink_size
=
sink_size
)
def
_eval_dataset_sink_process
(
self
,
valid_dataset
,
list_callback
=
None
,
cb_params
=
None
):
"""
...
...
model_zoo/alexnet/train.py
浏览文件 @
8e4c0a9d
...
...
@@ -43,7 +43,7 @@ if __name__ == "__main__":
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
args
.
device_target
)
ds_train
=
create_dataset_cifar10
(
args
.
data_path
,
cfg
.
batch_size
,
cfg
.
epoch_size
)
ds_train
=
create_dataset_cifar10
(
args
.
data_path
,
cfg
.
batch_size
,
1
)
network
=
AlexNet
(
cfg
.
num_classes
)
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
lr
=
Tensor
(
get_lr
(
0
,
cfg
.
learning_rate
,
cfg
.
epoch_size
,
ds_train
.
get_dataset_size
()))
...
...
model_zoo/deepfm/train.py
浏览文件 @
8e4c0a9d
...
...
@@ -57,7 +57,7 @@ if __name__ == '__main__':
ds_train
=
create_dataset
(
args_opt
.
dataset_path
,
train_mode
=
True
,
epochs
=
train_config
.
train_epochs
,
epochs
=
1
,
batch_size
=
train_config
.
batch_size
,
data_type
=
DataType
(
data_config
.
data_format
),
rank_size
=
rank_size
,
...
...
@@ -82,7 +82,7 @@ if __name__ == '__main__':
if
args_opt
.
do_eval
:
ds_eval
=
create_dataset
(
args_opt
.
dataset_path
,
train_mode
=
False
,
epochs
=
train_config
.
train_epochs
,
epochs
=
1
,
batch_size
=
train_config
.
batch_size
,
data_type
=
DataType
(
data_config
.
data_format
))
eval_callback
=
EvalCallBack
(
model
,
ds_eval
,
auc_metric
,
...
...
model_zoo/deeplabv3/train.py
浏览文件 @
8e4c0a9d
...
...
@@ -66,7 +66,7 @@ if __name__ == "__main__":
init
()
args_opt
.
base_size
=
config
.
crop_size
args_opt
.
crop_size
=
config
.
crop_size
train_dataset
=
create_dataset
(
args_opt
,
args_opt
.
data_url
,
config
.
epoch_size
,
config
.
batch_size
,
usage
=
"train"
)
train_dataset
=
create_dataset
(
args_opt
,
args_opt
.
data_url
,
1
,
config
.
batch_size
,
usage
=
"train"
)
dataset_size
=
train_dataset
.
get_dataset_size
()
time_cb
=
TimeMonitor
(
data_size
=
dataset_size
)
callback
=
[
time_cb
,
LossCallBack
()]
...
...
model_zoo/faster_rcnn/train.py
浏览文件 @
8e4c0a9d
...
...
@@ -94,7 +94,7 @@ if __name__ == '__main__':
loss_scale
=
float
(
config
.
loss_scale
)
# When create MindDataset, using the fitst mindrecord file, such as FasterRcnn.mindrecord0.
dataset
=
create_fasterrcnn_dataset
(
mindrecord_file
,
repeat_num
=
config
.
epoch_size
,
dataset
=
create_fasterrcnn_dataset
(
mindrecord_file
,
repeat_num
=
1
,
batch_size
=
config
.
batch_size
,
device_num
=
device_num
,
rank_id
=
rank
)
dataset_size
=
dataset
.
get_dataset_size
()
...
...
model_zoo/googlenet/train.py
浏览文件 @
8e4c0a9d
...
...
@@ -78,7 +78,7 @@ if __name__ == '__main__':
mirror_mean
=
True
)
init
()
dataset
=
create_dataset
(
cfg
.
data_path
,
cfg
.
epoch_size
)
dataset
=
create_dataset
(
cfg
.
data_path
,
1
)
batch_num
=
dataset
.
get_dataset_size
()
net
=
GoogleNet
(
num_classes
=
cfg
.
num_classes
)
...
...
model_zoo/lenet/train.py
浏览文件 @
8e4c0a9d
...
...
@@ -45,8 +45,7 @@ if __name__ == "__main__":
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
args
.
device_target
)
ds_train
=
create_dataset
(
os
.
path
.
join
(
args
.
data_path
,
"train"
),
cfg
.
batch_size
,
cfg
.
epoch_size
)
cfg
.
batch_size
)
network
=
LeNet5
(
cfg
.
num_classes
)
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
...
...
model_zoo/lenet_quant/train.py
浏览文件 @
8e4c0a9d
...
...
@@ -44,7 +44,7 @@ args = parser.parse_args()
if
__name__
==
"__main__"
:
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
args
.
device_target
)
ds_train
=
create_dataset
(
os
.
path
.
join
(
args
.
data_path
,
"train"
),
cfg
.
batch_size
,
cfg
.
epoch_size
)
ds_train
=
create_dataset
(
os
.
path
.
join
(
args
.
data_path
,
"train"
),
cfg
.
batch_size
,
1
)
step_size
=
ds_train
.
get_dataset_size
()
# define fusion network
...
...
model_zoo/lstm/train.py
浏览文件 @
8e4c0a9d
...
...
@@ -77,7 +77,7 @@ if __name__ == '__main__':
model
=
Model
(
network
,
loss
,
opt
,
{
'acc'
:
Accuracy
()})
print
(
"============== Starting Training =============="
)
ds_train
=
lstm_create_dataset
(
args
.
preprocess_path
,
cfg
.
batch_size
,
cfg
.
num_epochs
)
ds_train
=
lstm_create_dataset
(
args
.
preprocess_path
,
cfg
.
batch_size
,
1
)
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
save_checkpoint_steps
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"lstm"
,
directory
=
args
.
ckpt_path
,
config
=
config_ck
)
...
...
model_zoo/mass/train.py
浏览文件 @
8e4c0a9d
...
...
@@ -249,7 +249,7 @@ def train_parallel(config: TransformerConfig):
pre_train_dataset
=
load_dataset
(
data_files
=
config
.
pre_train_dataset
,
batch_size
=
config
.
batch_size
,
epoch_count
=
config
.
epochs
,
batch_size
=
config
.
batch_size
,
epoch_count
=
1
,
sink_mode
=
config
.
dataset_sink_mode
,
sink_step
=
config
.
dataset_sink_step
,
rank_size
=
MultiAscend
.
get_group_size
(),
...
...
@@ -257,7 +257,7 @@ def train_parallel(config: TransformerConfig):
)
if
config
.
pre_train_dataset
else
None
fine_tune_dataset
=
load_dataset
(
data_files
=
config
.
fine_tune_dataset
,
batch_size
=
config
.
batch_size
,
epoch_count
=
config
.
epochs
,
batch_size
=
config
.
batch_size
,
epoch_count
=
1
,
sink_mode
=
config
.
dataset_sink_mode
,
sink_step
=
config
.
dataset_sink_step
,
rank_size
=
MultiAscend
.
get_group_size
(),
...
...
@@ -265,7 +265,7 @@ def train_parallel(config: TransformerConfig):
)
if
config
.
fine_tune_dataset
else
None
test_dataset
=
load_dataset
(
data_files
=
config
.
test_dataset
,
batch_size
=
config
.
batch_size
,
epoch_count
=
config
.
epochs
,
batch_size
=
config
.
batch_size
,
epoch_count
=
1
,
sink_mode
=
config
.
dataset_sink_mode
,
sink_step
=
config
.
dataset_sink_step
,
rank_size
=
MultiAscend
.
get_group_size
(),
...
...
@@ -288,17 +288,17 @@ def train_single(config: TransformerConfig):
print
(
" | Starting training on single device."
)
pre_train_dataset
=
load_dataset
(
data_files
=
config
.
pre_train_dataset
,
batch_size
=
config
.
batch_size
,
epoch_count
=
config
.
epochs
,
epoch_count
=
1
,
sink_mode
=
config
.
dataset_sink_mode
,
sink_step
=
config
.
dataset_sink_step
)
if
config
.
pre_train_dataset
else
None
fine_tune_dataset
=
load_dataset
(
data_files
=
config
.
fine_tune_dataset
,
batch_size
=
config
.
batch_size
,
epoch_count
=
config
.
epochs
,
epoch_count
=
1
,
sink_mode
=
config
.
dataset_sink_mode
,
sink_step
=
config
.
dataset_sink_step
)
if
config
.
fine_tune_dataset
else
None
test_dataset
=
load_dataset
(
data_files
=
config
.
test_dataset
,
batch_size
=
config
.
batch_size
,
epoch_count
=
config
.
epochs
,
epoch_count
=
1
,
sink_mode
=
config
.
dataset_sink_mode
,
sink_step
=
config
.
dataset_sink_step
)
if
config
.
test_dataset
else
None
...
...
model_zoo/mobilenetv2/train.py
浏览文件 @
8e4c0a9d
...
...
@@ -180,7 +180,7 @@ if __name__ == '__main__':
do_train
=
True
,
config
=
config_gpu
,
platform
=
args_opt
.
platform
,
repeat_num
=
epoch_size
,
repeat_num
=
1
,
batch_size
=
config_gpu
.
batch_size
)
step_size
=
dataset
.
get_dataset_size
()
# resume
...
...
@@ -239,7 +239,7 @@ if __name__ == '__main__':
do_train
=
True
,
config
=
config_ascend
,
platform
=
args_opt
.
platform
,
repeat_num
=
epoch_size
,
repeat_num
=
1
,
batch_size
=
config_ascend
.
batch_size
)
step_size
=
dataset
.
get_dataset_size
()
if
args_opt
.
pre_trained
:
...
...
model_zoo/mobilenetv2_quant/train.py
浏览文件 @
8e4c0a9d
...
...
@@ -86,7 +86,7 @@ if __name__ == '__main__':
do_train
=
True
,
config
=
config
,
device_target
=
args_opt
.
device_target
,
repeat_num
=
epoch_size
,
repeat_num
=
1
,
batch_size
=
config
.
batch_size
)
step_size
=
dataset
.
get_dataset_size
()
# load pre trained ckpt
...
...
model_zoo/mobilenetv3/train.py
浏览文件 @
8e4c0a9d
...
...
@@ -181,7 +181,7 @@ if __name__ == '__main__':
do_train
=
True
,
config
=
config_gpu
,
platform
=
args_opt
.
platform
,
repeat_num
=
epoch_size
,
repeat_num
=
1
,
batch_size
=
config_gpu
.
batch_size
)
step_size
=
dataset
.
get_dataset_size
()
# resume
...
...
@@ -240,7 +240,7 @@ if __name__ == '__main__':
do_train
=
True
,
config
=
config_ascend
,
platform
=
args_opt
.
platform
,
repeat_num
=
epoch_size
,
repeat_num
=
1
,
batch_size
=
config_ascend
.
batch_size
)
step_size
=
dataset
.
get_dataset_size
()
if
args_opt
.
pre_trained
:
...
...
model_zoo/official/nlp/bert/run_classifier.py
浏览文件 @
8e4c0a9d
...
...
@@ -36,12 +36,11 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
_cur_dir
=
os
.
getcwd
()
def
do_train
(
dataset
=
None
,
network
=
None
,
load_checkpoint_path
=
""
,
save_checkpoint_path
=
""
):
def
do_train
(
dataset
=
None
,
network
=
None
,
load_checkpoint_path
=
""
,
save_checkpoint_path
=
""
,
epoch_num
=
1
):
""" do train """
if
load_checkpoint_path
==
""
:
raise
ValueError
(
"Pretrain model missed, finetune task must load pretrain model!"
)
steps_per_epoch
=
dataset
.
get_dataset_size
()
epoch_num
=
dataset
.
get_repeat_count
()
# optimizer
if
optimizer_cfg
.
optimizer
==
'AdamWeightDecayDynamicLR'
:
optimizer
=
AdamWeightDecayDynamicLR
(
network
.
trainable_params
(),
...
...
@@ -176,11 +175,11 @@ def run_classifier():
assessment_method
=
assessment_method
)
if
args_opt
.
do_train
.
lower
()
==
"true"
:
ds
=
create_classification_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
epoch_num
,
ds
=
create_classification_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
1
,
assessment_method
=
assessment_method
,
data_file_path
=
args_opt
.
train_data_file_path
,
schema_file_path
=
args_opt
.
schema_file_path
)
do_train
(
ds
,
netwithloss
,
load_pretrain_checkpoint_path
,
save_finetune_checkpoint_path
)
do_train
(
ds
,
netwithloss
,
load_pretrain_checkpoint_path
,
save_finetune_checkpoint_path
,
epoch_num
)
if
args_opt
.
do_eval
.
lower
()
==
"true"
:
if
save_finetune_checkpoint_path
==
""
:
...
...
@@ -191,7 +190,7 @@ def run_classifier():
ds
.
get_dataset_size
(),
epoch_num
,
"classifier"
)
if
args_opt
.
do_eval
.
lower
()
==
"true"
:
ds
=
create_classification_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
epoch_num
,
ds
=
create_classification_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
1
,
assessment_method
=
assessment_method
,
data_file_path
=
args_opt
.
eval_data_file_path
,
schema_file_path
=
args_opt
.
schema_file_path
)
...
...
model_zoo/official/nlp/bert/run_ner.py
浏览文件 @
8e4c0a9d
...
...
@@ -38,12 +38,11 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
_cur_dir
=
os
.
getcwd
()
def
do_train
(
dataset
=
None
,
network
=
None
,
load_checkpoint_path
=
""
,
save_checkpoint_path
=
""
):
def
do_train
(
dataset
=
None
,
network
=
None
,
load_checkpoint_path
=
""
,
save_checkpoint_path
=
""
,
epoch_num
=
1
):
""" do train """
if
load_checkpoint_path
==
""
:
raise
ValueError
(
"Pretrain model missed, finetune task must load pretrain model!"
)
steps_per_epoch
=
dataset
.
get_dataset_size
()
epoch_num
=
dataset
.
get_repeat_count
()
# optimizer
if
optimizer_cfg
.
optimizer
==
'AdamWeightDecayDynamicLR'
:
optimizer
=
AdamWeightDecayDynamicLR
(
network
.
trainable_params
(),
...
...
@@ -204,10 +203,10 @@ def run_ner():
use_crf
=
(
args_opt
.
use_crf
.
lower
()
==
"true"
),
tag_to_index
=
tag_to_index
,
dropout_prob
=
0.1
)
if
args_opt
.
do_train
.
lower
()
==
"true"
:
ds
=
create_ner_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
epoch_num
,
ds
=
create_ner_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
1
,
assessment_method
=
assessment_method
,
data_file_path
=
args_opt
.
train_data_file_path
,
schema_file_path
=
args_opt
.
schema_file_path
)
do_train
(
ds
,
netwithloss
,
load_pretrain_checkpoint_path
,
save_finetune_checkpoint_path
)
do_train
(
ds
,
netwithloss
,
load_pretrain_checkpoint_path
,
save_finetune_checkpoint_path
,
epoch_num
)
if
args_opt
.
do_eval
.
lower
()
==
"true"
:
if
save_finetune_checkpoint_path
==
""
:
...
...
@@ -218,7 +217,7 @@ def run_ner():
ds
.
get_dataset_size
(),
epoch_num
,
"ner"
)
if
args_opt
.
do_eval
.
lower
()
==
"true"
:
ds
=
create_ner_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
epoch_num
,
ds
=
create_ner_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
1
,
assessment_method
=
assessment_method
,
data_file_path
=
args_opt
.
eval_data_file_path
,
schema_file_path
=
args_opt
.
schema_file_path
)
do_eval
(
ds
,
BertNER
,
args_opt
.
use_crf
,
number_labels
,
assessment_method
,
args_opt
.
eval_data_file_path
,
...
...
model_zoo/official/nlp/bert/run_pretrain.py
浏览文件 @
8e4c0a9d
...
...
@@ -100,11 +100,12 @@ def run_pretrain():
bert_net_cfg
.
compute_type
=
mstype
.
float32
ds
,
new_repeat_count
=
create_bert_dataset
(
args_opt
.
epoch_size
,
device_num
,
rank
,
args_opt
.
do_shuffle
,
args_opt
.
enable_data_sink
,
args_opt
.
data_sink_steps
,
args_opt
.
data_dir
,
args_opt
.
schema_dir
)
ds
=
create_bert_dataset
(
1
,
device_num
,
rank
,
args_opt
.
do_shuffle
,
args_opt
.
enable_data_sink
,
args_opt
.
data_sink_steps
,
args_opt
.
data_dir
,
args_opt
.
schema_dir
)
new_repeat_count
=
args_opt
.
epoch_size
if
args_opt
.
train_steps
>
0
:
new_repeat_count
=
min
(
new_repeat_count
,
args_opt
.
train_steps
//
args_opt
.
data_sink_steps
)
new_repeat_count
=
min
(
args_opt
.
epoch_size
,
args_opt
.
train_steps
//
args_opt
.
data_sink_steps
)
netwithloss
=
BertNetworkWithLoss
(
bert_net_cfg
,
True
)
if
cfg
.
optimizer
==
'Lamb'
:
...
...
model_zoo/official/nlp/bert/run_squad.py
浏览文件 @
8e4c0a9d
...
...
@@ -38,12 +38,11 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
_cur_dir
=
os
.
getcwd
()
def
do_train
(
dataset
=
None
,
network
=
None
,
load_checkpoint_path
=
""
,
save_checkpoint_path
=
""
):
def
do_train
(
dataset
=
None
,
network
=
None
,
load_checkpoint_path
=
""
,
save_checkpoint_path
=
""
,
epoch_num
=
1
):
""" do train """
if
load_checkpoint_path
==
""
:
raise
ValueError
(
"Pretrain model missed, finetune task must load pretrain model!"
)
steps_per_epoch
=
dataset
.
get_dataset_size
()
epoch_num
=
dataset
.
get_repeat_count
()
# optimizer
if
optimizer_cfg
.
optimizer
==
'AdamWeightDecayDynamicLR'
:
optimizer
=
AdamWeightDecayDynamicLR
(
network
.
trainable_params
(),
...
...
@@ -181,10 +180,10 @@ def run_squad():
netwithloss
=
BertSquad
(
bert_net_cfg
,
True
,
2
,
dropout_prob
=
0.1
)
if
args_opt
.
do_train
.
lower
()
==
"true"
:
ds
=
create_squad_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
epoch_num
,
ds
=
create_squad_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
1
,
data_file_path
=
args_opt
.
train_data_file_path
,
schema_file_path
=
args_opt
.
schema_file_path
)
do_train
(
ds
,
netwithloss
,
load_pretrain_checkpoint_path
,
save_finetune_checkpoint_path
)
do_train
(
ds
,
netwithloss
,
load_pretrain_checkpoint_path
,
save_finetune_checkpoint_path
,
epoch_num
)
if
args_opt
.
do_eval
.
lower
()
==
"true"
:
if
save_finetune_checkpoint_path
==
""
:
load_finetune_checkpoint_dir
=
_cur_dir
...
...
@@ -194,7 +193,7 @@ def run_squad():
ds
.
get_dataset_size
(),
epoch_num
,
"squad"
)
if
args_opt
.
do_eval
.
lower
()
==
"true"
:
ds
=
create_squad_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
epoch_num
,
ds
=
create_squad_dataset
(
batch_size
=
bert_net_cfg
.
batch_size
,
repeat_count
=
1
,
data_file_path
=
args_opt
.
eval_data_file_path
,
schema_file_path
=
args_opt
.
schema_file_path
,
is_training
=
False
)
do_eval
(
ds
,
args_opt
.
vocab_file_path
,
args_opt
.
eval_json_path
,
...
...
model_zoo/official/nlp/bert/src/dataset.py
浏览文件 @
8e4c0a9d
...
...
@@ -54,7 +54,6 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
ds
=
ds
.
map
(
input_columns
=
"input_ids"
,
operations
=
type_cast_op
)
# apply batch operations
ds
=
ds
.
batch
(
bert_net_cfg
.
batch_size
,
drop_remainder
=
True
)
ds
=
ds
.
repeat
(
max
(
new_repeat_count
,
repeat_count
))
logger
.
info
(
"data size: {}"
.
format
(
ds
.
get_dataset_size
()))
logger
.
info
(
"repeatcount: {}"
.
format
(
ds
.
get_repeat_count
()))
return
ds
,
new_repeat_count
...
...
model_zoo/official/nlp/transformer/src/dataset.py
浏览文件 @
8e4c0a9d
...
...
@@ -17,7 +17,6 @@
import
mindspore.common.dtype
as
mstype
import
mindspore.dataset.engine.datasets
as
de
import
mindspore.dataset.transforms.c_transforms
as
deC
from
mindspore
import
log
as
logger
from
.config
import
transformer_net_cfg
def
create_transformer_dataset
(
epoch_count
=
1
,
rank_size
=
1
,
rank_id
=
0
,
do_shuffle
=
"true"
,
enable_data_sink
=
"true"
,
...
...
@@ -42,7 +41,4 @@ def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle
ds
=
ds
.
batch
(
transformer_net_cfg
.
batch_size
,
drop_remainder
=
True
)
ds
=
ds
.
repeat
(
repeat_count
)
ds
.
channel_name
=
'transformer'
logger
.
info
(
"data size: {}"
.
format
(
ds
.
get_dataset_size
()))
logger
.
info
(
"repeatcount: {}"
.
format
(
ds
.
get_repeat_count
()))
return
ds
,
repeat_count
return
ds
model_zoo/official/nlp/transformer/train.py
浏览文件 @
8e4c0a9d
...
...
@@ -125,10 +125,10 @@ def run_transformer_train():
else
:
device_num
=
1
rank_id
=
0
dataset
,
repeat_count
=
create_transformer_dataset
(
epoch_count
=
args
.
epoch_size
,
rank_size
=
device_num
,
rank_id
=
rank_id
,
do_shuffle
=
args
.
do_shuffle
,
enable_data_sink
=
args
.
enable_data_sink
,
dataset_path
=
args
.
data_path
)
dataset
=
create_transformer_dataset
(
epoch_count
=
1
,
rank_size
=
device_num
,
rank_id
=
rank_id
,
do_shuffle
=
args
.
do_shuffle
,
enable_data_sink
=
args
.
enable_data_sink
,
dataset_path
=
args
.
data_path
)
netwithloss
=
TransformerNetworkWithLoss
(
transformer_net_cfg
,
True
)
...
...
@@ -165,7 +165,7 @@ def run_transformer_train():
netwithgrads
.
set_train
(
True
)
model
=
Model
(
netwithgrads
)
model
.
train
(
repeat_count
,
dataset
,
callbacks
=
callbacks
,
dataset_sink_mode
=
(
args
.
enable_data_sink
==
"true"
))
model
.
train
(
args
.
epoch_size
,
dataset
,
callbacks
=
callbacks
,
dataset_sink_mode
=
(
args
.
enable_data_sink
==
"true"
))
if
__name__
==
'__main__'
:
run_transformer_train
()
model_zoo/resnet/train.py
浏览文件 @
8e4c0a9d
...
...
@@ -88,10 +88,10 @@ if __name__ == '__main__':
# create dataset
if
args_opt
.
net
==
"resnet50"
:
dataset
=
create_dataset
(
dataset_path
=
args_opt
.
dataset_path
,
do_train
=
True
,
repeat_num
=
config
.
epoch_size
,
dataset
=
create_dataset
(
dataset_path
=
args_opt
.
dataset_path
,
do_train
=
True
,
repeat_num
=
1
,
batch_size
=
config
.
batch_size
,
target
=
target
)
else
:
dataset
=
create_dataset
(
dataset_path
=
args_opt
.
dataset_path
,
do_train
=
True
,
repeat_num
=
config
.
epoch_size
,
dataset
=
create_dataset
(
dataset_path
=
args_opt
.
dataset_path
,
do_train
=
True
,
repeat_num
=
1
,
batch_size
=
config
.
batch_size
)
step_size
=
dataset
.
get_dataset_size
()
...
...
model_zoo/resnet_thor/train.py
浏览文件 @
8e4c0a9d
...
...
@@ -105,7 +105,7 @@ if __name__ == '__main__':
loss
=
CrossEntropy
(
smooth_factor
=
config
.
label_smooth_factor
,
num_classes
=
config
.
class_num
)
if
args_opt
.
do_train
:
dataset
=
create_dataset
(
dataset_path
=
args_opt
.
dataset_path
,
do_train
=
True
,
repeat_num
=
epoch_size
,
batch_size
=
config
.
batch_size
)
batch_size
=
config
.
batch_size
)
step_size
=
dataset
.
get_dataset_size
()
loss_scale
=
FixedLossScaleManager
(
config
.
loss_scale
,
drop_overflow_update
=
False
)
...
...
model_zoo/ssd/train.py
浏览文件 @
8e4c0a9d
...
...
@@ -91,7 +91,7 @@ def main():
loss_scale
=
float
(
args_opt
.
loss_scale
)
# When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0.
dataset
=
create_ssd_dataset
(
mindrecord_file
,
repeat_num
=
args_opt
.
epoch_size
,
dataset
=
create_ssd_dataset
(
mindrecord_file
,
repeat_num
=
1
,
batch_size
=
args_opt
.
batch_size
,
device_num
=
device_num
,
rank
=
rank
)
dataset_size
=
dataset
.
get_dataset_size
()
...
...
model_zoo/vgg16/train.py
浏览文件 @
8e4c0a9d
...
...
@@ -83,7 +83,7 @@ if __name__ == '__main__':
mirror_mean
=
True
)
init
()
dataset
=
vgg_create_dataset
(
args_opt
.
data_path
,
cfg
.
epoch_size
)
dataset
=
vgg_create_dataset
(
args_opt
.
data_path
,
1
)
batch_num
=
dataset
.
get_dataset_size
()
net
=
vgg16
(
num_classes
=
cfg
.
num_classes
)
...
...
model_zoo/wide_and_deep/train.py
浏览文件 @
8e4c0a9d
...
...
@@ -63,7 +63,7 @@ def test_train(configure):
data_path
=
configure
.
data_path
batch_size
=
configure
.
batch_size
epochs
=
configure
.
epochs
ds_train
=
create_dataset
(
data_path
,
train_mode
=
True
,
epochs
=
epochs
,
batch_size
=
batch_size
)
ds_train
=
create_dataset
(
data_path
,
train_mode
=
True
,
epochs
=
1
,
batch_size
=
batch_size
)
print
(
"ds_train.size: {}"
.
format
(
ds_train
.
get_dataset_size
()))
net_builder
=
ModelBuilder
()
...
...
model_zoo/wide_and_deep/train_and_eval.py
浏览文件 @
8e4c0a9d
...
...
@@ -67,8 +67,8 @@ def test_train_eval(config):
data_path
=
config
.
data_path
batch_size
=
config
.
batch_size
epochs
=
config
.
epochs
ds_train
=
create_dataset
(
data_path
,
train_mode
=
True
,
epochs
=
epochs
,
batch_size
=
batch_size
)
ds_eval
=
create_dataset
(
data_path
,
train_mode
=
False
,
epochs
=
epochs
+
1
,
batch_size
=
batch_size
)
ds_train
=
create_dataset
(
data_path
,
train_mode
=
True
,
epochs
=
1
,
batch_size
=
batch_size
)
ds_eval
=
create_dataset
(
data_path
,
train_mode
=
False
,
epochs
=
1
,
batch_size
=
batch_size
)
print
(
"ds_train.size: {}"
.
format
(
ds_train
.
get_dataset_size
()))
print
(
"ds_eval.size: {}"
.
format
(
ds_eval
.
get_dataset_size
()))
...
...
model_zoo/wide_and_deep/train_and_eval_auto_parallel.py
浏览文件 @
8e4c0a9d
...
...
@@ -85,14 +85,14 @@ def train_and_eval(config):
if
config
.
full_batch
:
context
.
set_auto_parallel_context
(
full_batch
=
True
)
de
.
config
.
set_seed
(
1
)
ds_train
=
create_dataset
(
data_path
,
train_mode
=
True
,
epochs
=
epochs
,
ds_train
=
create_dataset
(
data_path
,
train_mode
=
True
,
epochs
=
1
,
batch_size
=
batch_size
*
get_group_size
())
ds_eval
=
create_dataset
(
data_path
,
train_mode
=
False
,
epochs
=
epochs
+
1
,
ds_eval
=
create_dataset
(
data_path
,
train_mode
=
False
,
epochs
=
1
,
batch_size
=
batch_size
*
get_group_size
())
else
:
ds_train
=
create_dataset
(
data_path
,
train_mode
=
True
,
epochs
=
epochs
,
ds_train
=
create_dataset
(
data_path
,
train_mode
=
True
,
epochs
=
1
,
batch_size
=
batch_size
,
rank_id
=
get_rank
(),
rank_size
=
get_group_size
())
ds_eval
=
create_dataset
(
data_path
,
train_mode
=
False
,
epochs
=
epochs
+
1
,
ds_eval
=
create_dataset
(
data_path
,
train_mode
=
False
,
epochs
=
1
,
batch_size
=
batch_size
,
rank_id
=
get_rank
(),
rank_size
=
get_group_size
())
print
(
"ds_train.size: {}"
.
format
(
ds_train
.
get_dataset_size
()))
print
(
"ds_eval.size: {}"
.
format
(
ds_eval
.
get_dataset_size
()))
...
...
model_zoo/wide_and_deep/train_and_eval_distribute.py
浏览文件 @
8e4c0a9d
...
...
@@ -74,9 +74,9 @@ def train_and_eval(config):
batch_size
=
config
.
batch_size
epochs
=
config
.
epochs
print
(
"epochs is {}"
.
format
(
epochs
))
ds_train
=
create_dataset
(
data_path
,
train_mode
=
True
,
epochs
=
epochs
,
ds_train
=
create_dataset
(
data_path
,
train_mode
=
True
,
epochs
=
1
,
batch_size
=
batch_size
,
rank_id
=
get_rank
(),
rank_size
=
get_group_size
())
ds_eval
=
create_dataset
(
data_path
,
train_mode
=
False
,
epochs
=
epochs
+
1
,
ds_eval
=
create_dataset
(
data_path
,
train_mode
=
False
,
epochs
=
1
,
batch_size
=
batch_size
,
rank_id
=
get_rank
(),
rank_size
=
get_group_size
())
print
(
"ds_train.size: {}"
.
format
(
ds_train
.
get_dataset_size
()))
print
(
"ds_eval.size: {}"
.
format
(
ds_eval
.
get_dataset_size
()))
...
...
model_zoo/yolov3_resnet18/train.py
浏览文件 @
8e4c0a9d
...
...
@@ -121,7 +121,7 @@ def main():
loss_scale
=
float
(
args_opt
.
loss_scale
)
# When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0.
dataset
=
create_yolo_dataset
(
mindrecord_file
,
repeat_num
=
args_opt
.
epoch_size
,
dataset
=
create_yolo_dataset
(
mindrecord_file
,
batch_size
=
args_opt
.
batch_size
,
device_num
=
device_num
,
rank
=
rank
)
dataset_size
=
dataset
.
get_dataset_size
()
print
(
"Create dataset done!"
)
...
...
tests/dataset_mock.py
浏览文件 @
8e4c0a9d
...
...
@@ -50,13 +50,20 @@ class MindData:
def
input_indexs
(
self
):
return
self
.
_input_indexs
def
device_que
(
self
):
def
device_que
(
self
,
send_epoch_end
=
True
):
self
.
queue_name
=
'6ba41974-209e-11ea-88b0-a24efeb2c736'
self
.
send_epoch_end
=
send_epoch_end
return
self
def
create_tuple_iterator
(
self
):
return
self
.
__iter__
()
def
send
(
self
):
pass
def
stop_send
(
self
):
pass
def
__len__
(
self
):
return
self
.
_size
...
...
tests/st/model_zoo_tests/deeplabv3/train_one_epoch_with_loss.py
浏览文件 @
8e4c0a9d
...
...
@@ -73,7 +73,7 @@ if __name__ == "__main__":
epoch_size
=
3
args_opt
.
base_size
=
config
.
crop_size
args_opt
.
crop_size
=
config
.
crop_size
train_dataset
=
create_dataset
(
args_opt
,
args_opt
.
data_url
,
epoch_size
,
config
.
batch_size
,
train_dataset
=
create_dataset
(
args_opt
,
args_opt
.
data_url
,
1
,
config
.
batch_size
,
usage
=
"train"
,
shuffle
=
False
)
dataset_size
=
train_dataset
.
get_dataset_size
()
callback
=
LossCallBack
(
dataset_size
)
...
...
tests/st/model_zoo_tests/transformer/test_transformer.py
浏览文件 @
8e4c0a9d
...
...
@@ -120,10 +120,10 @@ def test_transformer():
batch_size
=
96
epoch_size
=
3
config
=
get_config
(
version
=
version
,
batch_size
=
batch_size
)
dataset
,
repeat_count
=
create_transformer_dataset
(
epoch_count
=
epoch_size
,
do_shuffle
=
"false"
,
enable_data_sink
=
"false"
,
dataset_path
=
DATA_DIR
)
dataset
=
create_transformer_dataset
(
epoch_count
=
1
,
do_shuffle
=
"false"
,
enable_data_sink
=
"false"
,
dataset_path
=
DATA_DIR
)
netwithloss
=
TransformerNetworkWithLoss
(
config
,
True
)
...
...
@@ -146,7 +146,7 @@ def test_transformer():
netwithgrads
.
set_train
(
True
)
time_monitor_callback
=
TimeMonitor
(
dataset
.
get_dataset_size
())
model
=
Model
(
netwithgrads
)
model
.
train
(
repeat_count
,
dataset
,
callbacks
=
[
time_monitor_callback
,
callback
],
dataset_sink_mode
=
False
)
model
.
train
(
epoch_size
,
dataset
,
callbacks
=
[
time_monitor_callback
,
callback
],
dataset_sink_mode
=
False
)
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value
=
np
.
array
(
callback
.
loss_list
)
...
...
tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/train_and_test_multinpu_ci.py
浏览文件 @
8e4c0a9d
...
...
@@ -79,9 +79,9 @@ def test_train_eval():
batch_size
=
config
.
batch_size
epochs
=
config
.
epochs
print
(
"epochs is {}"
.
format
(
epochs
))
ds_train
=
create_dataset
(
data_path
,
train_mode
=
True
,
epochs
=
epochs
,
batch_size
=
batch_size
,
ds_train
=
create_dataset
(
data_path
,
train_mode
=
True
,
epochs
=
1
,
batch_size
=
batch_size
,
data_type
=
DataType
.
MINDRECORD
,
rank_id
=
get_rank
(),
rank_size
=
get_group_size
())
ds_eval
=
create_dataset
(
data_path
,
train_mode
=
False
,
epochs
=
epochs
+
1
,
batch_size
=
batch_size
,
ds_eval
=
create_dataset
(
data_path
,
train_mode
=
False
,
epochs
=
1
,
batch_size
=
batch_size
,
data_type
=
DataType
.
MINDRECORD
,
rank_id
=
get_rank
(),
rank_size
=
get_group_size
())
print
(
"ds_train.size: {}"
.
format
(
ds_train
.
get_dataset_size
()))
print
(
"ds_eval.size: {}"
.
format
(
ds_eval
.
get_dataset_size
()))
...
...
tests/st/model_zoo_tests/wide_and_deep/train_and_test_multinpu_ci_data_parallel.py
浏览文件 @
8e4c0a9d
...
...
@@ -76,9 +76,9 @@ def test_train_eval():
batch_size
=
config
.
batch_size
epochs
=
config
.
epochs
print
(
"epochs is {}"
.
format
(
epochs
))
ds_train
=
create_dataset
(
data_path
,
train_mode
=
True
,
epochs
=
epochs
,
ds_train
=
create_dataset
(
data_path
,
train_mode
=
True
,
epochs
=
1
,
batch_size
=
batch_size
,
rank_id
=
get_rank
(),
rank_size
=
get_group_size
())
ds_eval
=
create_dataset
(
data_path
,
train_mode
=
False
,
epochs
=
epochs
+
1
,
ds_eval
=
create_dataset
(
data_path
,
train_mode
=
False
,
epochs
=
1
,
batch_size
=
batch_size
,
rank_id
=
get_rank
(),
rank_size
=
get_group_size
())
print
(
"ds_train.size: {}"
.
format
(
ds_train
.
get_dataset_size
()))
print
(
"ds_eval.size: {}"
.
format
(
ds_eval
.
get_dataset_size
()))
...
...
tests/st/model_zoo_tests/yolov3/test_yolov3.py
浏览文件 @
8e4c0a9d
...
...
@@ -113,7 +113,7 @@ def test_yolov3():
loss_scale
=
float
(
loss_scale
)
# When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0.
dataset
=
create_yolo_dataset
(
mindrecord_file
,
repeat_num
=
epoch_size
,
dataset
=
create_yolo_dataset
(
mindrecord_file
,
repeat_num
=
1
,
batch_size
=
batch_size
,
device_num
=
device_num
,
rank
=
rank
)
dataset_size
=
dataset
.
get_dataset_size
()
print
(
"Create dataset done!"
)
...
...
@@ -146,12 +146,12 @@ def test_yolov3():
assert
loss_value
[
2
]
<
expect_loss_value
[
2
]
epoch_mseconds
=
np
.
array
(
time_monitor_callback
.
epoch_mseconds_list
)[
2
]
expect_epoch_mseconds
=
95
0
expect_epoch_mseconds
=
200
0
print
(
"epoch mseconds: {}"
.
format
(
epoch_mseconds
))
assert
epoch_mseconds
<=
expect_epoch_mseconds
per_step_mseconds
=
np
.
array
(
time_monitor_callback
.
per_step_mseconds_list
)[
2
]
expect_per_step_mseconds
=
11
0
expect_per_step_mseconds
=
22
0
print
(
"per step mseconds: {}"
.
format
(
per_step_mseconds
))
assert
per_step_mseconds
<=
expect_per_step_mseconds
print
(
"yolov3 test case passed."
)
tests/st/networks/models/bert/test_bert_tdt_lossscale.py
浏览文件 @
8e4c0a9d
...
...
@@ -91,6 +91,7 @@ def me_de_train_dataset(sink_mode=False):
"""test me de train dataset"""
# apply repeat operations
repeat_count
=
1
sink_size
=
-
1
batch_size
=
16
ds
=
de
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"input_ids"
,
"input_mask"
,
"segment_ids"
,
"next_sentence_labels"
,
"masked_lm_positions"
,
...
...
@@ -99,9 +100,9 @@ def me_de_train_dataset(sink_mode=False):
new_repeat_count
=
repeat_count
if
sink_mode
:
repeat_count
=
30
sink_s
teps
=
100
sink_s
ize
=
100
ori_dataaet_size
=
ds
.
get_dataset_size
()
new_size
=
sink_s
teps
*
batch_size
new_size
=
sink_s
ize
*
batch_size
ds
.
set_dataset_size
(
new_size
)
new_repeat_count
=
int
(
repeat_count
*
ori_dataaet_size
//
ds
.
get_dataset_size
())
ds
=
ds
.
map
(
input_columns
=
"masked_lm_ids"
,
operations
=
type_cast_op
)
...
...
@@ -112,10 +113,9 @@ def me_de_train_dataset(sink_mode=False):
ds
=
ds
.
map
(
input_columns
=
"input_ids"
,
operations
=
type_cast_op
)
# apply batch operations
ds
=
ds
.
batch
(
batch_size
,
drop_remainder
=
True
)
ds
=
ds
.
repeat
(
repeat_count
)
logger
.
info
(
"data size: {}"
.
format
(
ds
.
get_dataset_size
()))
logger
.
info
(
"repeat_count: {}"
.
format
(
ds
.
get_repeat_count
()))
return
ds
,
new_repeat_count
return
ds
,
new_repeat_count
,
sink_size
def
weight_variable
(
shape
):
...
...
@@ -157,7 +157,7 @@ class TimeMonitor(Callback):
def
test_bert_percision
():
"""test bert percision"""
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
reserve_class_name_in_scope
=
False
)
ds
,
new_repeat_count
=
me_de_train_dataset
()
ds
,
new_repeat_count
,
_
=
me_de_train_dataset
()
version
=
os
.
getenv
(
'VERSION'
,
'large'
)
batch_size
=
16
config
=
get_config
(
version
=
version
,
batch_size
=
batch_size
)
...
...
@@ -215,7 +215,7 @@ def test_bert_percision():
def
test_bert_performance
():
"""test bert performance"""
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
reserve_class_name_in_scope
=
False
)
ds
,
new_repeat_count
=
me_de_train_dataset
(
sink_mode
=
True
)
ds
,
new_repeat_count
,
sink_size
=
me_de_train_dataset
(
sink_mode
=
True
)
version
=
os
.
getenv
(
'VERSION'
,
'large'
)
batch_size
=
16
config
=
get_config
(
version
=
version
,
batch_size
=
batch_size
)
...
...
@@ -251,7 +251,7 @@ def test_bert_performance():
param
.
default_input
=
weight_variable
(
value
.
asnumpy
().
shape
)
time_monitor_callback
=
TimeMonitor
(
ds
.
get_dataset_size
())
model
.
train
(
new_repeat_count
,
ds
,
callbacks
=
[
time_monitor_callback
,
callback
],
dataset_sink_mode
=
True
)
dataset_sink_mode
=
True
,
sink_size
=
sink_size
)
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value
=
np
.
array
(
callback
.
loss_list
)
...
...
tests/st/networks/models/deeplabv3/test_deeplabv3.py
浏览文件 @
8e4c0a9d
...
...
@@ -79,7 +79,7 @@ def test_deeplabv3_1p():
args_opt
.
base_size
=
config
.
crop_size
args_opt
.
crop_size
=
config
.
crop_size
args_opt
.
batch_size
=
config
.
batch_size
train_dataset
=
create_dataset
(
args_opt
,
data_url
,
epoch_size
,
config
.
batch_size
,
train_dataset
=
create_dataset
(
args_opt
,
data_url
,
1
,
config
.
batch_size
,
usage
=
"eval"
)
dataset_size
=
train_dataset
.
get_dataset_size
()
callback
=
LossCallBack
(
dataset_size
)
...
...
tests/st/networks/models/resnet50/test_resnet50_imagenet.py
浏览文件 @
8e4c0a9d
...
...
@@ -155,7 +155,7 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl):
# train dataset
dataset
=
create_dataset
(
dataset_path
=
dataset_path
,
do_train
=
True
,
repeat_num
=
epoch_size
,
batch_size
=
config
.
batch_size
)
repeat_num
=
1
,
batch_size
=
config
.
batch_size
)
step_size
=
dataset
.
get_dataset_size
()
eval_interval
=
config
.
eval_interval
...
...
@@ -163,7 +163,7 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl):
# evalutation dataset
eval_dataset
=
create_dataset
(
dataset_path
=
eval_path
,
do_train
=
False
,
repeat_num
=
epoch_size
,
batch_size
=
config
.
eval_batch_size
)
repeat_num
=
1
,
batch_size
=
config
.
eval_batch_size
)
# loss scale
loss_scale
=
FixedLossScaleManager
(
config
.
loss_scale
,
drop_overflow_update
=
False
)
...
...
@@ -260,14 +260,14 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
# train dataset
dataset
=
create_dataset
(
dataset_path
=
dataset_path
,
do_train
=
True
,
repeat_num
=
epoch_size
,
batch_size
=
thor_config
.
batch_size
)
repeat_num
=
1
,
batch_size
=
thor_config
.
batch_size
)
step_size
=
dataset
.
get_dataset_size
()
eval_interval
=
thor_config
.
eval_interval
# evalutation dataset
eval_dataset
=
create_dataset
(
dataset_path
=
eval_path
,
do_train
=
False
,
repeat_num
=
epoch_size
,
batch_size
=
thor_config
.
eval_batch_size
)
repeat_num
=
1
,
batch_size
=
thor_config
.
eval_batch_size
)
# loss scale
loss_scale
=
FixedLossScaleManager
(
thor_config
.
loss_scale
,
drop_overflow_update
=
False
)
...
...
tests/st/tbe_networks/resnet_cifar.py
浏览文件 @
8e4c0a9d
...
...
@@ -136,7 +136,7 @@ if __name__ == '__main__':
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
opt
,
metrics
=
{
'acc'
})
if
args_opt
.
do_train
:
dataset
=
create_dataset
(
epoch_size
)
dataset
=
create_dataset
(
1
)
batch_num
=
dataset
.
get_dataset_size
()
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
batch_num
*
5
,
keep_checkpoint_max
=
10
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"train_resnet_cifar10"
,
directory
=
"./"
,
config
=
config_ck
)
...
...
tests/st/tbe_networks/test_resnet_cifar_1p.py
浏览文件 @
8e4c0a9d
...
...
@@ -140,7 +140,7 @@ def train_process(epoch_size, num_classes, batch_size):
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
opt
,
metrics
=
{
'acc'
})
dataset
=
create_dataset
(
epoch_size
,
training
=
True
,
batch_size
=
batch_size
)
dataset
=
create_dataset
(
1
,
training
=
True
,
batch_size
=
batch_size
)
loss_cb
=
LossGet
()
model
.
train
(
epoch_size
,
dataset
,
callbacks
=
[
loss_cb
])
...
...
tests/st/tbe_networks/test_resnet_cifar_8p.py
浏览文件 @
8e4c0a9d
...
...
@@ -164,7 +164,7 @@ def train_process(q, device_id, epoch_size, num_classes, device_num, batch_size,
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
opt
,
metrics
=
{
'acc'
})
dataset
=
create_dataset
(
epoch_size
,
training
=
True
,
dataset
=
create_dataset
(
1
,
training
=
True
,
batch_size
=
batch_size
,
rank_id
=
device_id
,
rank_size
=
device_num
,
enable_hccl
=
enable_hccl
)
...
...
tests/ut/cpp/dataset/CMakeLists.txt
浏览文件 @
8e4c0a9d
...
...
@@ -91,8 +91,9 @@ SET(DE_UT_SRCS
cyclic_array_test.cc
perf_data_test.cc
c_api_test.cc
tensor_op_fusion_pass_test.cc
tensor_op_fusion_pass_test.cc
sliding_window_op_test.cc
epoch_ctrl_op_test.cc
)
add_executable
(
de_ut_tests
${
DE_UT_SRCS
}
)
...
...
tests/ut/cpp/dataset/cache_op_test.cc
浏览文件 @
8e4c0a9d
...
...
@@ -397,23 +397,21 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) {
std
::
shared_ptr
<
CacheClient
>
myClient
=
std
::
make_shared
<
CacheClient
>
(
1
,
0
,
true
);
std
::
shared_ptr
<
CacheMergeOp
>
myMergeOp
;
rc
=
CacheMergeOp
::
Builder
().
SetNumWorkers
(
3
).
SetOpConnectorSize
(
3
).
SetNumCleaner
(
2
).
SetClient
(
myClient
).
Build
(
&
myMergeOp
);
EXPECT_TRUE
(
rc
.
IsOk
());
// In a mappable dataset, it uses a complex interactions of cache lookup op and cache merge op.
// Rather than manually build this, the way to do it is to choose the position of the cache in the tree by
// adding a CacheOp. Then, the tree prepare code will drive a transform that will remove the CacheOp and
// replace it with the required tree structures for cache lookup op and cache merge op.
std
::
shared_ptr
<
CacheLookupOp
>
myLookupOp
;
rc
=
CacheLookupOp
::
Builder
()
.
SetNumWorkers
(
3
)
.
SetOpConnectorSize
(
3
)
std
::
shared_ptr
<
CacheOp
>
myCacheOp
;
rc
=
CacheOp
::
Builder
()
.
SetNumWorkers
(
4
)
.
SetClient
(
myClient
)
.
SetSampler
(
seq_sampler
)
.
Build
(
&
myLookupOp
);
EXPECT_TRUE
(
rc
.
IsOk
());
.
SetRowsPerBuffer
(
3
)
.
Build
(
&
myCacheOp
);
std
::
shared_ptr
<
ImageFolderOp
>
so
;
ImageFolderOp
::
Builder
builder
;
builder
.
SetSampler
(
myLookupOp
)
builder
.
SetSampler
(
std
::
move
(
seq_sampler
)
)
.
SetOpConnectorSize
(
3
)
.
SetNumWorkers
(
3
)
.
SetRowsPerBuffer
(
2
)
...
...
@@ -432,20 +430,18 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) {
auto
myTree
=
std
::
make_shared
<
ExecutionTree
>
();
rc
=
myTree
->
AssociateNode
(
so
);
EXPECT_TRUE
(
rc
.
IsOk
());
rc
=
myTree
->
AssociateNode
(
myLookupOp
);
EXPECT_TRUE
(
rc
.
IsOk
());
rc
=
myTree
->
AssociateNode
(
myMergeOp
);
rc
=
myTree
->
AssociateNode
(
myCacheOp
);
EXPECT_TRUE
(
rc
.
IsOk
());
rc
=
myTree
->
AssociateNode
(
myRepeatOp
);
EXPECT_TRUE
(
rc
.
IsOk
());
rc
=
myTree
->
AssignRoot
(
myRepeatOp
);
EXPECT_TRUE
(
rc
.
IsOk
());
rc
=
myRepeatOp
->
AddChild
(
myMergeOp
);
EXPECT_TRUE
(
rc
.
IsOk
());
rc
=
myMergeOp
->
AddChild
(
myLookupOp
);
rc
=
myRepeatOp
->
AddChild
(
myCacheOp
);
EXPECT_TRUE
(
rc
.
IsOk
());
rc
=
my
Merg
eOp
->
AddChild
(
so
);
rc
=
my
Cach
eOp
->
AddChild
(
so
);
EXPECT_TRUE
(
rc
.
IsOk
());
rc
=
myTree
->
Prepare
();
...
...
tests/ut/cpp/dataset/epoch_ctrl_op_test.cc
0 → 100644
浏览文件 @
8e4c0a9d
此差异已折叠。
点击以展开。
tests/ut/cpp/dataset/repeat_op_test.cc
浏览文件 @
8e4c0a9d
...
...
@@ -46,7 +46,8 @@ TEST_F(MindDataTestrepeat_op, Testrepeat_opFuntions) {
ASSERT_TRUE
(
rc
.
IsOk
());
rc
=
my_tree
->
AssociateNode
(
my_tfreader_op
);
ASSERT_TRUE
(
rc
.
IsOk
());
my_tree
->
AssociateNode
(
parent_op
);
rc
=
my_tree
->
AssociateNode
(
parent_op
);
ASSERT_TRUE
(
rc
.
IsOk
());
ASSERT_NE
(
parent_op
,
nullptr
);
ASSERT_NE
(
my_tfreader_op
,
nullptr
);
parent_op
->
AddChild
(
std
::
move
(
my_tfreader_op
));
...
...
tests/ut/python/dataset/test_cache_map.py
浏览文件 @
8e4c0a9d
...
...
@@ -104,9 +104,11 @@ def test_cache_map_basic3():
decode_op
=
c_vision
.
Decode
()
ds1
=
ds1
.
repeat
(
4
)
ds1
=
ds1
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
,
cache
=
some_cache
)
print
(
"ds1.dataset_size is "
,
ds1
.
get_dataset_size
())
num_iter
=
0
for
_
in
ds1
.
create_dict_iterator
():
print
(
"get data from dataset"
)
num_iter
+=
1
logger
.
info
(
"Number of data in ds1: {} "
.
format
(
num_iter
))
...
...
@@ -152,6 +154,10 @@ def test_cache_map_failure1():
if
__name__
==
'__main__'
:
test_cache_map_basic1
()
print
(
"test_cache_map_basic1 success."
)
test_cache_map_basic2
()
print
(
"test_cache_map_basic2 success."
)
test_cache_map_basic3
()
print
(
"test_cache_map_basic3 success."
)
test_cache_map_failure1
()
print
(
"test_cache_map_failure1 success."
)
tests/ut/python/dataset/test_datasets_tfrecord.py
浏览文件 @
8e4c0a9d
...
...
@@ -238,7 +238,7 @@ def test_tfrecord_shard_equal_rows():
def
test_tfrecord_no_schema_columns_list
():
logger
.
info
(
"test_tfrecord_no_schema_columns_list"
)
data
=
ds
.
TFRecordDataset
(
FILES
,
shuffle
=
False
,
columns_list
=
[
"col_sint16"
])
row
=
data
.
create_dict_iterator
().
get_next
()
row
=
data
.
create_dict_iterator
().
__next__
()
assert
row
[
"col_sint16"
]
==
[
-
32768
]
with
pytest
.
raises
(
KeyError
)
as
info
:
...
...
@@ -258,7 +258,7 @@ def test_tfrecord_schema_columns_list():
schema
.
add_column
(
'col_sint32'
,
de_type
=
mstype
.
int64
,
shape
=
[
1
])
schema
.
add_column
(
'col_sint64'
,
de_type
=
mstype
.
int64
,
shape
=
[
1
])
data
=
ds
.
TFRecordDataset
(
FILES
,
schema
=
schema
,
shuffle
=
False
,
columns_list
=
[
"col_sint16"
])
row
=
data
.
create_dict_iterator
().
get_next
()
row
=
data
.
create_dict_iterator
().
__next__
()
assert
row
[
"col_sint16"
]
==
[
-
32768
]
with
pytest
.
raises
(
KeyError
)
as
info
:
...
...
tests/ut/python/dataset/test_deviceop_cpu.py
浏览文件 @
8e4c0a9d
...
...
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
time
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
from
mindspore
import
log
as
logger
...
...
@@ -35,6 +37,8 @@ def test_case_0():
data
=
data
.
device_que
()
data
.
send
()
time
.
sleep
(
0.1
)
data
.
stop_send
()
def
test_case_1
():
...
...
@@ -58,6 +62,8 @@ def test_case_1():
data
=
data
.
device_que
()
data
.
send
()
time
.
sleep
(
0.1
)
data
.
stop_send
()
def
test_case_2
():
...
...
@@ -84,6 +90,8 @@ def test_case_2():
data
=
data
.
device_que
()
assert
data
.
get_repeat_count
()
==
2
data
.
send
()
time
.
sleep
(
0.1
)
data
.
stop_send
()
def
test_case_3
():
...
...
@@ -109,13 +117,17 @@ def test_case_3():
data
=
data
.
device_que
()
data
.
send
()
time
.
sleep
(
0.1
)
data
.
stop_send
()
def
test_case_tf_file
():
data
=
ds
.
TFRecordDataset
(
TF_FILES
,
TF_SCHEMA_FILE
,
shuffle
=
ds
.
Shuffle
.
FILES
)
data
=
data
.
to_device
(
num_batch
=
10
)
data
=
data
.
to_device
()
data
.
send
()
time
.
sleep
(
0.1
)
data
.
stop_send
()
if
__name__
==
'__main__'
:
...
...
tests/ut/python/dataset/test_epoch_ctrl.py
0 → 100644
浏览文件 @
8e4c0a9d
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing Epoch Control op in DE
"""
import
itertools
import
cv2
import
numpy
as
np
import
pytest
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
from
mindspore
import
log
as
logger
DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
SCHEMA_DIR
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
def
diff_mse
(
in1
,
in2
):
"""
diff_mse
"""
mse
=
(
np
.
square
(
in1
.
astype
(
float
)
/
255
-
in2
.
astype
(
float
)
/
255
)).
mean
()
return
mse
*
100
def
test_cifar10
():
"""
dataset parameter
"""
logger
.
info
(
"Test dataset parameter"
)
data_dir_10
=
"../data/dataset/testCifar10Data"
num_repeat
=
2
batch_size
=
32
limit_dataset
=
100
# apply dataset operations
data1
=
ds
.
Cifar10Dataset
(
data_dir_10
,
limit_dataset
)
data1
=
data1
.
repeat
(
num_repeat
)
data1
=
data1
.
batch
(
batch_size
,
True
)
num_epoch
=
5
# iter1 will always assume there is a next epoch and never shutdown.
iter1
=
data1
.
create_tuple_iterator
()
epoch_count
=
0
sample_count
=
0
for
_
in
range
(
num_epoch
):
row_count
=
0
for
_
in
iter1
:
# in this example, each dictionary has keys "image" and "label"
row_count
+=
1
assert
row_count
==
int
(
limit_dataset
*
num_repeat
/
batch_size
)
logger
.
debug
(
"row_count: "
,
row_count
)
epoch_count
+=
1
sample_count
+=
row_count
assert
epoch_count
==
num_epoch
logger
.
debug
(
"total epochs: "
,
epoch_count
)
assert
sample_count
==
int
(
limit_dataset
*
num_repeat
/
batch_size
)
*
num_epoch
logger
.
debug
(
"total sample: "
,
sample_count
)
def
test_decode_op
():
"""
Test Decode op
"""
logger
.
info
(
"test_decode_op"
)
# Decode with rgb format set to True
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
# Serialize and Load dataset requires using vision.Decode instead of vision.Decode().
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
[
vision
.
Decode
(
True
)])
# Second dataset
data2
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
],
shuffle
=
False
)
num_epoch
=
5
# iter1 will always assume there is a next epoch and never shutdown.
iter1
=
data1
.
create_dict_iterator
()
# iter 2 will stop and shutdown pipeline after num_epoch
iter2
=
data2
.
create_dict_iterator
(
num_epoch
)
for
_
in
range
(
num_epoch
):
i
=
0
for
item1
,
item2
in
itertools
.
zip_longest
(
iter1
,
iter2
):
actual
=
item1
[
"image"
]
expected
=
cv2
.
imdecode
(
item2
[
"image"
],
cv2
.
IMREAD_COLOR
)
expected
=
cv2
.
cvtColor
(
expected
,
cv2
.
COLOR_BGR2RGB
)
assert
actual
.
shape
==
expected
.
shape
diff
=
actual
-
expected
mse
=
np
.
sum
(
np
.
power
(
diff
,
2
))
assert
mse
==
0
i
=
i
+
1
assert
i
==
3
# Users have the option to manually stop the iterator, or rely on garbage collector.
iter1
.
stop
()
# Expect a AttributeError since iter1 has been stopped.
with
pytest
.
raises
(
AttributeError
)
as
info
:
iter1
.
__next__
()
assert
"object has no attribute 'depipeline'"
in
str
(
info
.
value
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
iter2
.
__next__
()
err_msg
=
"EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert
err_msg
in
str
(
info
.
value
)
# Generate 1d int numpy array from 0 - 63
def
generator_1d
():
"""
generator
"""
for
i
in
range
(
64
):
yield
(
np
.
array
([
i
]),)
def
test_generator_dict_0
():
"""
test generator dict 0
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
i
=
0
# create the iterator inside the loop declaration
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
item
[
"data"
],
golden
)
i
=
i
+
1
def
test_generator_dict_1
():
"""
test generator dict 1
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
for
_
in
range
(
10
):
i
=
0
# BAD. Do not create iterator every time inside.
# Create iterator outside the epoch for loop.
for
item
in
data1
.
create_dict_iterator
():
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
item
[
"data"
],
golden
)
i
=
i
+
1
assert
i
==
64
def
test_generator_dict_2
():
"""
test generator dict 2
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
iter1
=
data1
.
create_dict_iterator
()
for
_
in
range
(
10
):
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
item
[
"data"
],
golden
)
i
=
i
+
1
assert
i
==
64
# iter1 is still alive and running.
item1
=
iter1
.
__next__
()
assert
item1
# rely on garbage collector to destroy iter1
def
test_generator_dict_3
():
"""
test generator dict 3
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
iter1
=
data1
.
create_dict_iterator
()
for
_
in
range
(
10
):
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
item
[
"data"
],
golden
)
i
=
i
+
1
assert
i
==
64
# optional
iter1
.
stop
()
# Expect a AttributeError since iter1 has been stopped.
with
pytest
.
raises
(
AttributeError
)
as
info
:
iter1
.
__next__
()
assert
"object has no attribute 'depipeline'"
in
str
(
info
.
value
)
def
test_generator_dict_4
():
"""
test generator dict 4
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
iter1
=
data1
.
create_dict_iterator
(
num_epochs
=
10
)
for
_
in
range
(
10
):
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
item
[
"data"
],
golden
)
i
=
i
+
1
assert
i
==
64
with
pytest
.
raises
(
RuntimeError
)
as
info
:
iter1
.
__next__
()
err_msg
=
"EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert
err_msg
in
str
(
info
.
value
)
def
test_generator_dict_4_1
():
"""
test generator dict 4_1
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
# epoch ctrl op will not be injected if num_epochs is 1.
iter1
=
data1
.
create_dict_iterator
(
num_epochs
=
1
)
for
_
in
range
(
1
):
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
item
[
"data"
],
golden
)
i
=
i
+
1
assert
i
==
64
with
pytest
.
raises
(
RuntimeError
)
as
info
:
iter1
.
__next__
()
err_msg
=
"EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert
err_msg
in
str
(
info
.
value
)
def
test_generator_dict_4_2
():
"""
test generator dict 4_2
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
# repeat will not be injected when num repeat is 1.
data1
=
data1
.
repeat
(
1
)
# epoch ctrl op will not be injected if num_epochs is 1.
iter1
=
data1
.
create_dict_iterator
(
num_epochs
=
1
)
for
_
in
range
(
1
):
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
item
[
"data"
],
golden
)
i
=
i
+
1
assert
i
==
64
with
pytest
.
raises
(
RuntimeError
)
as
info
:
iter1
.
__next__
()
err_msg
=
"EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert
err_msg
in
str
(
info
.
value
)
def
test_generator_dict_5
():
"""
test generator dict 5
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
iter1
=
data1
.
create_dict_iterator
(
num_epochs
=
11
)
for
_
in
range
(
10
):
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
item
[
"data"
],
golden
)
i
=
i
+
1
assert
i
==
64
# still one more epoch left in the iter1.
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
item
[
"data"
],
golden
)
i
=
i
+
1
assert
i
==
64
# now iter1 has been exhausted, c++ pipeline has been shut down.
with
pytest
.
raises
(
RuntimeError
)
as
info
:
iter1
.
__next__
()
err_msg
=
"EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert
err_msg
in
str
(
info
.
value
)
# Test tuple iterator
def
test_generator_tuple_0
():
"""
test generator tuple 0
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
i
=
0
# create the iterator inside the loop declaration
for
item
in
data1
.
create_tuple_iterator
():
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
item
[
0
],
golden
)
i
=
i
+
1
def
test_generator_tuple_1
():
"""
test generator tuple 1
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
for
_
in
range
(
10
):
i
=
0
# BAD. Do not create iterator every time inside.
# Create iterator outside the epoch for loop.
for
item
in
data1
.
create_tuple_iterator
():
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
item
[
0
],
golden
)
i
=
i
+
1
assert
i
==
64
def
test_generator_tuple_2
():
"""
test generator tuple 2
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
iter1
=
data1
.
create_tuple_iterator
()
for
_
in
range
(
10
):
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
item
[
0
],
golden
)
i
=
i
+
1
assert
i
==
64
# iter1 is still alive and running.
item1
=
iter1
.
__next__
()
assert
item1
# rely on garbage collector to destroy iter1
def
test_generator_tuple_3
():
"""
test generator tuple 3
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
iter1
=
data1
.
create_tuple_iterator
()
for
_
in
range
(
10
):
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
item
[
0
],
golden
)
i
=
i
+
1
assert
i
==
64
# optional
iter1
.
stop
()
# Expect a AttributeError since iter1 has been stopped.
with
pytest
.
raises
(
AttributeError
)
as
info
:
iter1
.
__next__
()
assert
"object has no attribute 'depipeline'"
in
str
(
info
.
value
)
def
test_generator_tuple_4
():
"""
test generator tuple 4
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
iter1
=
data1
.
create_tuple_iterator
(
num_epochs
=
10
)
for
_
in
range
(
10
):
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
item
[
0
],
golden
)
i
=
i
+
1
assert
i
==
64
with
pytest
.
raises
(
RuntimeError
)
as
info
:
iter1
.
__next__
()
err_msg
=
"EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert
err_msg
in
str
(
info
.
value
)
def
test_generator_tuple_5
():
"""
test generator tuple 5
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
iter1
=
data1
.
create_tuple_iterator
(
num_epochs
=
11
)
for
_
in
range
(
10
):
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
item
[
0
],
golden
)
i
=
i
+
1
assert
i
==
64
# still one more epoch left in the iter1.
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
])
assert
np
.
array_equal
(
item
[
0
],
golden
)
i
=
i
+
1
assert
i
==
64
# now iter1 has been exhausted, c++ pipeline has been shut down.
with
pytest
.
raises
(
RuntimeError
)
as
info
:
iter1
.
__next__
()
err_msg
=
"EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert
err_msg
in
str
(
info
.
value
)
# Test with repeat
def
test_generator_tuple_repeat_1
():
"""
test generator tuple repeat 1
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
data1
=
data1
.
repeat
(
2
)
iter1
=
data1
.
create_tuple_iterator
(
num_epochs
=
11
)
for
_
in
range
(
10
):
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
%
64
])
assert
np
.
array_equal
(
item
[
0
],
golden
)
i
=
i
+
1
assert
i
==
64
*
2
# still one more epoch left in the iter1.
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
%
64
])
assert
np
.
array_equal
(
item
[
0
],
golden
)
i
=
i
+
1
assert
i
==
64
*
2
# now iter1 has been exhausted, c++ pipeline has been shut down.
with
pytest
.
raises
(
RuntimeError
)
as
info
:
iter1
.
__next__
()
err_msg
=
"EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert
err_msg
in
str
(
info
.
value
)
# Test with repeat
def
test_generator_tuple_repeat_repeat_1
():
"""
test generator tuple repeat repeat 1
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
data1
=
data1
.
repeat
(
2
)
data1
=
data1
.
repeat
(
3
)
iter1
=
data1
.
create_tuple_iterator
(
num_epochs
=
11
)
for
_
in
range
(
10
):
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
%
64
])
assert
np
.
array_equal
(
item
[
0
],
golden
)
i
=
i
+
1
assert
i
==
64
*
2
*
3
# still one more epoch left in the iter1.
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
%
64
])
assert
np
.
array_equal
(
item
[
0
],
golden
)
i
=
i
+
1
assert
i
==
64
*
2
*
3
# now iter1 has been exhausted, c++ pipeline has been shut down.
with
pytest
.
raises
(
RuntimeError
)
as
info
:
iter1
.
__next__
()
err_msg
=
"EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert
err_msg
in
str
(
info
.
value
)
def
test_generator_tuple_repeat_repeat_2
():
"""
test generator tuple repeat repeat 2
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
data1
=
data1
.
repeat
(
2
)
data1
=
data1
.
repeat
(
3
)
iter1
=
data1
.
create_tuple_iterator
()
for
_
in
range
(
10
):
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
%
64
])
assert
np
.
array_equal
(
item
[
0
],
golden
)
i
=
i
+
1
assert
i
==
64
*
2
*
3
# optional
iter1
.
stop
()
# Expect a AttributeError since iter1 has been stopped.
with
pytest
.
raises
(
AttributeError
)
as
info
:
iter1
.
__next__
()
assert
"object has no attribute 'depipeline'"
in
str
(
info
.
value
)
def
test_generator_tuple_repeat_repeat_3
():
"""
test generator tuple repeat repeat 3
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
data1
=
data1
.
repeat
(
2
)
data1
=
data1
.
repeat
(
3
)
iter1
=
data1
.
create_tuple_iterator
()
for
_
in
range
(
10
):
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
%
64
])
assert
np
.
array_equal
(
item
[
0
],
golden
)
i
=
i
+
1
assert
i
==
64
*
2
*
3
for
_
in
range
(
5
):
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
%
64
])
assert
np
.
array_equal
(
item
[
0
],
golden
)
i
=
i
+
1
assert
i
==
64
*
2
*
3
# rely on garbage collector to destroy iter1
def
test_generator_reusedataset
():
"""
test generator reusedataset
"""
logger
.
info
(
"Test 1D Generator : 0 - 63"
)
# apply dataset operations
data1
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
data1
=
data1
.
repeat
(
2
)
iter1
=
data1
.
create_tuple_iterator
()
for
_
in
range
(
10
):
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
%
64
])
assert
np
.
array_equal
(
item
[
0
],
golden
)
i
=
i
+
1
assert
i
==
64
*
2
data1
=
data1
.
repeat
(
3
)
iter1
=
data1
.
create_tuple_iterator
()
for
_
in
range
(
5
):
i
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([
i
%
64
])
assert
np
.
array_equal
(
item
[
0
],
golden
)
i
=
i
+
1
assert
i
==
64
*
2
*
3
data1
=
data1
.
batch
(
2
)
iter1
=
data1
.
create_dict_iterator
()
for
_
in
range
(
5
):
i
=
0
sample
=
0
for
item
in
iter1
:
# each data is a dictionary
golden
=
np
.
array
([[
i
%
64
],
[(
i
+
1
)
%
64
]])
assert
np
.
array_equal
(
item
[
"data"
],
golden
)
i
=
i
+
2
sample
=
sample
+
1
assert
sample
==
64
*
3
# rely on garbage collector to destroy iter1
tests/ut/python/dataset/test_five_crop.py
浏览文件 @
8e4c0a9d
...
...
@@ -87,7 +87,7 @@ def test_five_crop_error_msg():
data
=
data
.
map
(
input_columns
=
[
"image"
],
operations
=
transform
())
with
pytest
.
raises
(
RuntimeError
)
as
info
:
data
.
create_tuple_iterator
().
get_next
()
data
.
create_tuple_iterator
().
__next__
()
error_msg
=
"TypeError: img should be PIL Image or Numpy array. Got <class 'tuple'>"
# error msg comes from ToTensor()
...
...
tests/ut/python/dataset/test_get_size.py
浏览文件 @
8e4c0a9d
...
...
@@ -41,18 +41,18 @@ def test_case1():
assert
data
.
get_batch_size
()
==
2
assert
data
.
get_repeat_count
()
==
1
data
=
data
.
repeat
(
10
)
assert
data
.
get_dataset_size
()
==
6
assert
data
.
get_dataset_size
()
==
6
0
assert
data
.
get_batch_size
()
==
2
assert
data
.
get_repeat_count
()
==
10
data
=
data
.
project
([
"new_column"
])
assert
data
.
get_dataset_size
()
==
6
assert
data
.
get_dataset_size
()
==
6
0
assert
data
.
get_batch_size
()
==
2
assert
data
.
get_repeat_count
()
==
10
data2
=
ds
.
TFRecordDataset
(
FILES
,
SCHEMA_FILE
).
batch
(
2
).
repeat
(
10
)
data1
=
data
.
zip
(
data2
)
assert
data1
.
get_dataset_size
()
==
6
assert
data1
.
get_dataset_size
()
==
6
0
def
test_case2
():
...
...
@@ -65,14 +65,14 @@ def test_case2():
data
=
data
.
rename
(
"col_sint64"
,
"new_column"
)
assert
data
.
get_dataset_size
()
==
3
data
=
data
.
repeat
(
10
)
assert
data
.
get_dataset_size
()
==
3
assert
data
.
get_dataset_size
()
==
3
0
data
=
data
.
project
([
"new_column"
])
assert
data
.
get_dataset_size
()
==
3
assert
data
.
get_dataset_size
()
==
3
0
data2
=
ds
.
TFRecordDataset
(
FILES
,
num_samples
=
6
).
batch
(
2
).
repeat
(
10
)
data1
=
data
.
zip
(
data2
)
assert
data1
.
get_dataset_size
()
==
3
assert
data1
.
get_dataset_size
()
==
3
0
def
test_case3
():
...
...
@@ -94,11 +94,11 @@ def test_case4():
data2
=
data2
.
shuffle
(
100
)
assert
data2
.
get_dataset_size
()
==
6
data2
=
data2
.
repeat
(
3
)
assert
data2
.
get_dataset_size
()
==
6
assert
data2
.
get_dataset_size
()
==
18
data3
=
ds
.
zip
((
data1
,
data2
))
assert
data3
.
get_dataset_size
()
==
6
assert
data3
.
get_dataset_size
()
==
18
def
test_case5
():
...
...
tests/ut/python/dataset/test_iterator.py
浏览文件 @
8e4c0a9d
...
...
@@ -73,7 +73,7 @@ def test_iterator_weak_ref():
_cleanup
()
with
pytest
.
raises
(
AttributeError
)
as
info
:
itr2
.
get_next
()
itr2
.
__next__
()
assert
"object has no attribute 'depipeline'"
in
str
(
info
.
value
)
del
itr1
...
...
tests/ut/python/dataset/test_repeat.py
浏览文件 @
8e4c0a9d
...
...
@@ -251,6 +251,49 @@ def test_nested_repeat11():
assert
sum
([
1
for
_
in
data
])
==
2
*
3
*
4
*
5
*
3
def
test_repeat_count1
():
data1
=
ds
.
TFRecordDataset
(
DATA_DIR_TF2
,
SCHEMA_DIR_TF2
,
shuffle
=
False
)
data1_size
=
data1
.
get_dataset_size
()
logger
.
info
(
"dataset size is {}"
.
format
(
data1_size
))
batch_size
=
2
repeat_count
=
4
resize_height
,
resize_width
=
32
,
32
decode_op
=
vision
.
Decode
()
resize_op
=
vision
.
Resize
((
resize_height
,
resize_width
),
interpolation
=
ds
.
transforms
.
vision
.
Inter
.
LINEAR
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
resize_op
)
data1
=
data1
.
repeat
(
repeat_count
)
data1
=
data1
.
batch
(
batch_size
,
drop_remainder
=
False
)
dataset_size
=
data1
.
get_dataset_size
()
logger
.
info
(
"dataset repeat then batch's size is {}"
.
format
(
dataset_size
))
num1_iter
=
0
for
_
in
data1
.
create_dict_iterator
():
num1_iter
+=
1
assert
data1_size
==
3
assert
dataset_size
==
num1_iter
==
6
def
test_repeat_count2
():
data1
=
ds
.
TFRecordDataset
(
DATA_DIR_TF2
,
SCHEMA_DIR_TF2
,
shuffle
=
False
)
data1_size
=
data1
.
get_dataset_size
()
logger
.
info
(
"dataset size is {}"
.
format
(
data1_size
))
batch_size
=
2
repeat_count
=
4
resize_height
,
resize_width
=
32
,
32
decode_op
=
vision
.
Decode
()
resize_op
=
vision
.
Resize
((
resize_height
,
resize_width
),
interpolation
=
ds
.
transforms
.
vision
.
Inter
.
LINEAR
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
resize_op
)
data1
=
data1
.
batch
(
batch_size
,
drop_remainder
=
False
)
data1
=
data1
.
repeat
(
repeat_count
)
dataset_size
=
data1
.
get_dataset_size
()
logger
.
info
(
"dataset batch then repeat's size is {}"
.
format
(
dataset_size
))
num1_iter
=
0
for
_
in
data1
.
create_dict_iterator
():
num1_iter
+=
1
assert
data1_size
==
3
assert
dataset_size
==
num1_iter
==
8
if
__name__
==
"__main__"
:
test_tf_repeat_01
()
...
...
@@ -268,3 +311,5 @@ if __name__ == "__main__":
test_nested_repeat9
()
test_nested_repeat10
()
test_nested_repeat11
()
test_repeat_count1
()
test_repeat_count2
()
tests/ut/python/dataset/test_zip.py
浏览文件 @
8e4c0a9d
...
...
@@ -252,14 +252,14 @@ def test_zip_exception_06():
if
__name__
==
'__main__'
:
test_zip_01
()
test_zip_02
()
test_zip_03
()
test_zip_04
()
test_zip_05
()
test_zip_06
()
test_zip_exception_01
()
test_zip_exception_02
()
test_zip_exception_03
()
test_zip_exception_04
()
test_zip_exception_05
()
test_zip_exception_06
()
#
test_zip_02()
#
test_zip_03()
#
test_zip_04()
#
test_zip_05()
#
test_zip_06()
#
test_zip_exception_01()
#
test_zip_exception_02()
#
test_zip_exception_03()
#
test_zip_exception_04()
#
test_zip_exception_05()
#
test_zip_exception_06()
tests/ut/python/log
0 → 100644
浏览文件 @
8e4c0a9d
此差异已折叠。
点击以展开。
tests/ut/python/parallel/test_auto_parallel_resnet.py
浏览文件 @
8e4c0a9d
...
...
@@ -274,6 +274,9 @@ class DatasetLenet():
def
get_repeat_count
(
self
):
return
1
def
create_tuple_iterator
(
self
):
return
self
def
test_train_32k_8p
(
batch_size
=
32
,
num_classes
=
32768
):
dev_num
=
8
...
...
tests/ut/python/parallel/test_bias_add.py
浏览文件 @
8e4c0a9d
...
...
@@ -61,6 +61,9 @@ class DatasetLenet():
def
get_repeat_count
(
self
):
return
1
def
create_tuple_iterator
(
self
):
return
self
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
...
...
tests/ut/python/parallel/test_gather_v2_primitive.py
浏览文件 @
8e4c0a9d
...
...
@@ -58,6 +58,9 @@ class Dataset():
def
get_repeat_count
(
self
):
return
1
def
create_tuple_iterator
(
self
):
return
self
class
GatherV2
(
_Loss
):
def
__init__
(
self
,
index_dim
,
strategy
,
index_size
=
16
):
...
...
tests/ut/python/train/test_dataset_helper.py
0 → 100644
浏览文件 @
8e4c0a9d
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test dataset helper."""
import
pytest
import
numpy
as
np
import
mindspore.context
as
context
from
mindspore.communication.management
import
init
from
mindspore.train.dataset_helper
import
DatasetHelper
from
....dataset_mock
import
MindData
def
get_dataset
(
batch_size
=
1
):
dataset_types
=
(
np
.
int32
,
np
.
int32
,
np
.
int32
,
np
.
int32
,
np
.
int32
,
np
.
int32
,
np
.
int32
)
dataset_shapes
=
((
batch_size
,
128
),
(
batch_size
,
128
),
(
batch_size
,
128
),
(
batch_size
,
1
),
(
batch_size
,
20
),
(
batch_size
,
20
),
(
batch_size
,
20
))
dataset
=
MindData
(
size
=
2
,
batch_size
=
batch_size
,
np_types
=
dataset_types
,
output_shapes
=
dataset_shapes
,
input_indexs
=
(
0
,
1
))
return
dataset
def
test_dataset_helper_dataset_sink_mode_str
():
dataset
=
get_dataset
(
32
)
with
pytest
.
raises
(
TypeError
):
DatasetHelper
(
dataset
,
dataset_sink_mode
=
"True"
)
def
test_dataset_helper_dataset_sink_mode_int
():
dataset
=
get_dataset
(
32
)
with
pytest
.
raises
(
TypeError
):
DatasetHelper
(
dataset
,
dataset_sink_mode
=
1
)
def
test_dataset_helper_sink_size_bool
():
dataset
=
get_dataset
(
32
)
with
pytest
.
raises
(
TypeError
):
DatasetHelper
(
dataset
,
dataset_sink_mode
=
True
,
sink_size
=
True
)
def
test_dataset_helper_sink_size_float
():
dataset
=
get_dataset
(
32
)
with
pytest
.
raises
(
TypeError
):
DatasetHelper
(
dataset
,
dataset_sink_mode
=
True
,
sink_size
=
1.0
)
def
test_dataset_helper_sink_size_negative
():
dataset
=
get_dataset
(
32
)
with
pytest
.
raises
(
ValueError
):
DatasetHelper
(
dataset
,
dataset_sink_mode
=
True
,
sink_size
=-
2
)
def
test_dataset_iter_normal
():
dataset
=
get_dataset
(
32
)
dataset_helper
=
DatasetHelper
(
dataset
,
dataset_sink_mode
=
False
)
count
=
0
for
_
in
range
(
2
):
for
_
in
dataset_helper
:
count
+=
1
dataset
.
reset
()
assert
count
==
6
@
pytest
.
mark
.
skipif
(
'not context.get_context("enable_ge")'
)
def
test_dataset_iter_ge
():
init
()
dataset
=
get_dataset
(
32
)
dataset_helper
=
DatasetHelper
(
dataset
,
dataset_sink_mode
=
True
,
sink_size
=
10
)
count
=
0
for
_
in
range
(
2
):
for
_
in
dataset_helper
:
count
+=
1
assert
count
==
2
@
pytest
.
mark
.
skipif
(
'context.get_context("enable_ge")'
)
def
test_dataset_iter_ms_loop_sink
():
init
()
context
.
set_context
(
enable_loop_sink
=
True
)
dataset
=
get_dataset
(
32
)
dataset_helper
=
DatasetHelper
(
dataset
,
dataset_sink_mode
=
True
,
sink_size
=
10
)
count
=
0
for
_
in
range
(
2
):
for
inputs
in
dataset_helper
:
count
+=
1
assert
inputs
==
tuple
()
assert
count
==
2
@
pytest
.
mark
.
skipif
(
'context.get_context("enable_ge")'
)
def
test_dataset_iter_ms
():
init
()
context
.
set_context
(
enable_loop_sink
=
False
)
dataset
=
get_dataset
(
32
)
DatasetHelper
(
dataset
,
dataset_sink_mode
=
True
,
sink_size
=
10
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录