Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b57d4ea2
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b57d4ea2
编写于
6月 30, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 30, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2602 Stage 2 of CacheOp delivery
Merge pull request !2602 from JesseKLee/cache_op_stage2
上级
9377e432
a0a863f2
变更
43
隐藏空白更改
内联
并排
Showing
43 changed file
with
351 addition
and
99 deletion
+351
-99
mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc
mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc
+61
-5
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h
+22
-1
mindspore/ccsrc/dataset/engine/datasetops/map_op.cc
mindspore/ccsrc/dataset/engine/datasetops/map_op.cc
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/parallel_op.cc
mindspore/ccsrc/dataset/engine/datasetops/parallel_op.cc
+2
-2
mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h
mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h
+2
-1
mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.cc
mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.cc
+2
-1
mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h
mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h
+2
-1
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc
+3
-3
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
...spore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
+1
-2
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h
+0
-1
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
+1
-2
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h
+0
-1
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
...ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
+1
-2
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h
.../ccsrc/dataset/engine/datasetops/source/image_folder_op.h
+0
-1
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
...ore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
+1
-2
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h
...pore/ccsrc/dataset/engine/datasetops/source/manifest_op.h
+0
-1
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
+1
-2
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h
+0
-1
mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.cc
.../ccsrc/dataset/engine/datasetops/source/random_data_op.cc
+19
-6
mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.h
...e/ccsrc/dataset/engine/datasetops/source/random_data_op.h
+18
-2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
...t/engine/datasetops/source/sampler/distributed_sampler.cc
+4
-5
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc
...rc/dataset/engine/datasetops/source/sampler/pk_sampler.cc
+8
-0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h
...src/dataset/engine/datasetops/source/sampler/pk_sampler.h
+5
-0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc
...ataset/engine/datasetops/source/sampler/python_sampler.cc
+9
-0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h
...dataset/engine/datasetops/source/sampler/python_sampler.h
+5
-0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc
...ataset/engine/datasetops/source/sampler/random_sampler.cc
+4
-5
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
...ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
+5
-4
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc
...et/engine/datasetops/source/sampler/sequential_sampler.cc
+9
-2
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h
...set/engine/datasetops/source/sampler/sequential_sampler.h
+3
-0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc
...engine/datasetops/source/sampler/subset_random_sampler.cc
+9
-0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h
.../engine/datasetops/source/sampler/subset_random_sampler.h
+5
-0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc
...gine/datasetops/source/sampler/weighted_random_sampler.cc
+9
-0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h
...ngine/datasetops/source/sampler/weighted_random_sampler.h
+5
-0
mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc
...re/ccsrc/dataset/engine/datasetops/source/text_file_op.cc
+9
-4
mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h
...ore/ccsrc/dataset/engine/datasetops/source/text_file_op.h
+12
-1
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc
...re/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc
+44
-9
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h
...ore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h
+17
-1
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
+1
-2
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h
+0
-1
mindspore/ccsrc/dataset/engine/datasetops/take_op.cc
mindspore/ccsrc/dataset/engine/datasetops/take_op.cc
+1
-1
mindspore/ccsrc/dataset/engine/execution_tree.cc
mindspore/ccsrc/dataset/engine/execution_tree.cc
+26
-13
mindspore/ccsrc/dataset/engine/execution_tree.h
mindspore/ccsrc/dataset/engine/execution_tree.h
+23
-12
未找到文件。
mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc
浏览文件 @
b57d4ea2
...
...
@@ -128,7 +128,7 @@ Status ConcatOp::Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf) {
Status
ConcatOp
::
PrepareNodePostAction
()
{
RETURN_IF_NOT_OK
(
PipelineOp
::
PrepareNodePostAction
());
tree_
->
AddTo
Repeat
Stack
(
shared_from_this
());
tree_
->
AddTo
EOEOp
Stack
(
shared_from_this
());
return
Status
::
OK
();
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc
浏览文件 @
b57d4ea2
...
...
@@ -18,23 +18,26 @@
#include <iomanip>
#include <iostream>
#include <memory>
#include <regex>
#include <utility>
#include <string>
#include <algorithm>
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/datasetops/device_queue_op.h"
#include "dataset/engine/datasetops/source/sampler/sampler.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/opt/pass.h"
#include "utils/system/crc32c.h"
#include "utils/log_adapter.h"
namespace
mindspore
{
namespace
dataset
{
// Constructor
DatasetOp
::
DatasetOp
(
int32_t
op_connector_size
)
DatasetOp
::
DatasetOp
(
int32_t
op_connector_size
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
oc_queue_size_
(
op_connector_size
),
sampler_
(
sampler
),
operator_id_
(
kInvalidOperatorId
),
tree_
(
nullptr
),
state_
(
OpState
::
kDeOpIdle
),
...
...
@@ -150,6 +153,9 @@ void DatasetOp::Print(std::ostream &out, bool show_all) const {
}
out
<<
"
\n
Connector queue size : "
<<
oc_queue_size_
<<
"
\n
Operator control flags : 0x"
<<
std
::
hex
<<
std
::
setw
(
8
)
<<
std
::
setfill
(
'0'
)
<<
op_ctrl_flags_
<<
std
::
dec
<<
std
::
setfill
(
' '
);
if
(
sampler_
)
{
sampler_
->
Print
(
out
,
show_all
);
}
}
}
...
...
@@ -222,11 +228,10 @@ Status DatasetOp::PrepareNodePreAction() {
Status
DatasetOp
::
PrepareNodePostAction
()
{
// If this op does not have any children and it is in a repeat path of the tree...
if
(
child_
.
empty
()
&&
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
))
{
// push ourselves onto the
tree repeat stack. Later, the repeat
operator
// push ourselves onto the
eoe operator stack. Later, a repeat/epoch ctrl
operator
// above us will consume them.
tree_
->
AddTo
Repeat
Stack
(
shared_from_this
());
tree_
->
AddTo
EOEOp
Stack
(
shared_from_this
());
}
// Creating Connector object for each op.
// The consumer of the root node is assumed to be one thread.
// If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion.
...
...
@@ -289,5 +294,56 @@ Status DatasetOp::Accept(NodePass *p, bool *modified) {
// This method will only be called if its derived class does not implement one.
return
p
->
RunOnNode
(
shared_from_this
(),
modified
);
}
// A helper function with some common code that leaf nodes can use during
// prepare phase for checking if they need to assign a sampler to the cache.
// @return - Status
Status
DatasetOp
::
SaveSamplerForCache
(
bool
random_access_op
)
{
// If we are a descendant under a cache op and we have a sampler, then save this sampler
// to a stack so that the cache can pick it up during it's processing above us.
if
(
sampler_
)
{
if
(
BitTest
(
tree_
->
PrepareFlags
(),
ExecutionTree
::
kDePrepCache
))
{
// use move semantic to set our sampler_ to null after the move. This is okay because a sampler is
// useless to a random data op. It was only being used as a temporary holding until the cache can
// be created
tree_
->
AddToSamplerStack
(
sampler_
);
MS_LOG
(
INFO
)
<<
"Preparing a leaf op: passing sampler up the tree for Cache handling."
;
}
else
if
(
!
random_access_op
)
{
// A sampler exists, but we are not in a caching tree and we are not a random access mappable leaf.
// This is an error because that type of leaf does not use sampling unless there's a cache to hook it into.
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"Non-mappable leaf op has a sampler, but it only supports sampling if there is a cache after it in the tree"
);
}
}
if
(
!
random_access_op
)
{
// Since we don't truly need the sampler for this non-mappable dataset and it's been saved for the cache
// we can remove it now from the base.
sampler_
.
reset
();
}
return
Status
::
OK
();
}
uint32_t
DatasetOp
::
GenerateCRC
(
const
std
::
shared_ptr
<
DatasetOp
>
&
op
)
{
std
::
stringstream
ss
;
op
->
tree_
->
Print
(
ss
,
op
);
std
::
string
ss_str
=
ss
.
str
();
// Filter out the Operator control flags field when generating the check sum
ss_str
=
std
::
regex_replace
(
ss_str
,
std
::
regex
(
"Operator control flags.*
\n
"
),
""
);
// Filter out the Device id field to allow cache sharing for a distributed run of the same pipeline
ss_str
=
std
::
regex_replace
(
ss_str
,
std
::
regex
(
"Device id.*
\n
"
),
""
);
ss_str
=
std
::
regex_replace
(
ss_str
,
std
::
regex
(
"device_id.*
\n
"
),
""
);
// The Cache crc and Server cache id field is different when creating new cache_client and re-using the same
// cache_client later. So we filter out these two fields to allow cache sharing.
ss_str
=
std
::
regex_replace
(
ss_str
,
std
::
regex
(
"Cache crc.*
\n
"
),
""
);
ss_str
=
std
::
regex_replace
(
ss_str
,
std
::
regex
(
"Server cache id.*
\n
"
),
""
);
uint32_t
cache_crc
=
system
::
Crc32c
::
GetMaskCrc32cValue
(
ss_str
.
c_str
(),
ss_str
.
length
());
return
cache_crc
;
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h
浏览文件 @
b57d4ea2
...
...
@@ -34,6 +34,8 @@ class DataBuffer;
class
NodePass
;
class
Sampler
;
// The base class DatasetOp is the main tree node. It is an abstract class, so
// the actual implementation of the operators will be derived from here.
class
DatasetOp
:
public
std
::
enable_shared_from_this
<
DatasetOp
>
{
...
...
@@ -55,7 +57,8 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
// Constructor
// @param op_connector_size - The size for the output connector of this operator.
explicit
DatasetOp
(
int32_t
op_connector_size
);
// @param sampler - The sampler for the op
explicit
DatasetOp
(
int32_t
op_connector_size
,
std
::
shared_ptr
<
Sampler
>
sampler
);
// Destructor
virtual
~
DatasetOp
()
{
tree_
=
nullptr
;
}
...
...
@@ -204,6 +207,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
// @return Sets the control flags
void
set_control_flag
(
uint64_t
flag
)
{
BitSet
(
&
op_ctrl_flags_
,
flag
);
}
// Setter function
// @return Sets the control flags
void
ClearControlFlag
(
uint64_t
flag
)
{
BitClear
(
&
op_ctrl_flags_
,
flag
);
}
// Register the internal worker connectors. No op unless it is a parallel op
// @return Status
virtual
Status
RegisterWorkerConnectors
()
{
return
Status
::
OK
();
}
...
...
@@ -271,6 +278,13 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
// @return Pointer to the ExecutionTree the current op belongs to, no ownership
ExecutionTree
*
Tree
()
{
return
tree_
;
}
// Getter for the sampler
// @return Shared pointer to the sampler (may return nullptr)
std
::
shared_ptr
<
Sampler
>
sampler
()
{
return
sampler_
;
}
// Computes a CRC value for the operator
static
uint32_t
GenerateCRC
(
const
std
::
shared_ptr
<
DatasetOp
>
&
op
);
protected:
// Adds a parent operator to this operator
// @notes External callers do not have access to this function.
...
...
@@ -289,8 +303,15 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
// @return - Status
virtual
Status
ComputeColMap
();
// A helper function with some common code that leaf nodes can use during
// prepare phase for checking if they need to assign a sampler to the cache.
// @param random_access_op - indicate if this is a mappable random access leaf or not
// @return - Status
Status
SaveSamplerForCache
(
bool
random_access_op
);
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
child_
;
// Child nodes
std
::
vector
<
DatasetOp
*>
parent_
;
// Parent nodes. No ownership
std
::
shared_ptr
<
Sampler
>
sampler_
;
// Some leaf ops might have a sampler
int32_t
oc_queue_size_
;
// Capacity for each out_connector_
int32_t
operator_id_
;
// Generated id for the node
ExecutionTree
*
tree_
;
// Back pointer to our tree.
...
...
mindspore/ccsrc/dataset/engine/datasetops/map_op.cc
浏览文件 @
b57d4ea2
...
...
@@ -100,7 +100,7 @@ void MapOp::Print(std::ostream &out, bool show_all) const {
}
out
<<
"
\n
TensorOps:"
;
for
(
size_t
i
=
0
;
i
<
tfuncs_
.
size
();
i
++
)
{
out
<<
" "
<<
tfuncs_
[
i
]
;
out
<<
" "
<<
*
(
tfuncs_
[
i
].
get
())
;
}
out
<<
"
\n\n
"
;
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/parallel_op.cc
浏览文件 @
b57d4ea2
...
...
@@ -26,8 +26,8 @@
namespace
mindspore
{
namespace
dataset
{
// Constructor
ParallelOp
::
ParallelOp
(
int32_t
num_workers
,
int32_t
op_connector_size
)
:
DatasetOp
(
op_connector_size
),
ParallelOp
::
ParallelOp
(
int32_t
num_workers
,
int32_t
op_connector_size
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
DatasetOp
(
op_connector_size
,
sampler
),
num_workers_
(
num_workers
),
num_producers_
(
num_workers
),
worker_connector_size_
(
1
),
...
...
mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h
浏览文件 @
b57d4ea2
...
...
@@ -38,7 +38,8 @@ class ParallelOp : public DatasetOp {
// Constructor
// @param num_workers
// @param op_connector_size - size of the output connector for this operator
ParallelOp
(
int32_t
num_workers
,
int32_t
op_connector_size
);
// @param sampler - The sampler for the op
ParallelOp
(
int32_t
num_workers
,
int32_t
op_connector_size
,
std
::
shared_ptr
<
Sampler
>
sampler
=
nullptr
);
// Destructor
~
ParallelOp
()
=
default
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.cc
浏览文件 @
b57d4ea2
...
...
@@ -20,7 +20,8 @@
namespace
mindspore
{
namespace
dataset
{
// Constructor
PipelineOp
::
PipelineOp
(
int32_t
op_connector_size
)
:
DatasetOp
(
op_connector_size
)
{}
PipelineOp
::
PipelineOp
(
int32_t
op_connector_size
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
DatasetOp
(
op_connector_size
,
sampler
)
{}
// A print method typically used for debugging
void
PipelineOp
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
...
...
mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h
浏览文件 @
b57d4ea2
...
...
@@ -32,7 +32,8 @@ class PipelineOp : public DatasetOp {
// Constructor
// @param op_connector_size - size of the output connector
// @return Builder setter method returns reference to the builder.
explicit
PipelineOp
(
int32_t
op_connector_size
);
// @param sampler - The sampler for the op
explicit
PipelineOp
(
int32_t
op_connector_size
,
std
::
shared_ptr
<
Sampler
>
sampler
=
nullptr
);
// Destructor
~
PipelineOp
()
=
default
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc
浏览文件 @
b57d4ea2
...
...
@@ -82,14 +82,14 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const {
Status
RepeatOp
::
PrepareNodePostAction
()
{
// Run any common code from super class first before adding our own specific logic
RETURN_IF_NOT_OK
(
PipelineOp
::
PrepareNodePostAction
());
std
::
shared_ptr
<
DatasetOp
>
leaf_op
=
tree_
->
PopFrom
Repeat
Stack
();
std
::
shared_ptr
<
DatasetOp
>
leaf_op
=
tree_
->
PopFrom
EOEOp
Stack
();
while
(
leaf_op
!=
nullptr
)
{
// Track the leaf operators that are under this repeat op.
eoe_ops_
.
push_back
(
leaf_op
);
leaf_op
=
tree_
->
PopFrom
Repeat
Stack
();
leaf_op
=
tree_
->
PopFrom
EOEOp
Stack
();
}
// Push ourselves to the stack in case one of our ascendants is repeat too.
tree_
->
AddTo
Repeat
Stack
(
shared_from_this
());
tree_
->
AddTo
EOEOp
Stack
(
shared_from_this
());
return
Status
::
OK
();
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
浏览文件 @
b57d4ea2
...
...
@@ -70,13 +70,12 @@ Status CelebAOp::Builder::SanityCheck() {
CelebAOp
::
CelebAOp
(
int32_t
num_workers
,
int32_t
rows_per_buffer
,
const
std
::
string
&
dir
,
int32_t
queue_size
,
bool
decode
,
const
std
::
string
&
dataset_type
,
const
std
::
set
<
std
::
string
>
&
exts
,
std
::
unique_ptr
<
DataSchema
>
schema
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
ParallelOp
(
num_workers
,
queue_size
),
:
ParallelOp
(
num_workers
,
queue_size
,
std
::
move
(
sampler
)
),
rows_per_buffer_
(
rows_per_buffer
),
folder_path_
(
dir
),
decode_
(
decode
),
extensions_
(
exts
),
data_schema_
(
std
::
move
(
schema
)),
sampler_
(
std
::
move
(
sampler
)),
num_rows_in_attr_file_
(
0
),
dataset_type_
(
dataset_type
)
{
attr_info_queue_
=
std
::
make_unique
<
Queue
<
std
::
vector
<
std
::
string
>>>
(
queue_size
);
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h
浏览文件 @
b57d4ea2
...
...
@@ -221,7 +221,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
bool
decode_
;
std
::
set
<
std
::
string
>
extensions_
;
// extensions allowed
std
::
unique_ptr
<
DataSchema
>
data_schema_
;
std
::
shared_ptr
<
Sampler
>
sampler_
;
std
::
unique_ptr
<
Queue
<
std
::
vector
<
std
::
string
>>>
attr_info_queue_
;
int64_t
num_rows_in_attr_file_
;
// rows number specified in attr file
QueueList
<
std
::
unique_ptr
<
IOBlock
>>
io_block_queues_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
浏览文件 @
b57d4ea2
...
...
@@ -79,12 +79,11 @@ Status CifarOp::Builder::SanityCheck() {
CifarOp
::
CifarOp
(
CifarType
type
,
int32_t
num_works
,
int32_t
rows_per_buf
,
const
std
::
string
&
file_dir
,
int32_t
queue_size
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
ParallelOp
(
num_works
,
queue_size
),
:
ParallelOp
(
num_works
,
queue_size
,
std
::
move
(
sampler
)
),
cifar_type_
(
type
),
rows_per_buffer_
(
rows_per_buf
),
folder_path_
(
file_dir
),
data_schema_
(
std
::
move
(
data_schema
)),
sampler_
(
std
::
move
(
sampler
)),
row_cnt_
(
0
),
buf_cnt_
(
0
)
{
constexpr
uint64_t
kUtilQueueSize
=
512
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h
浏览文件 @
b57d4ea2
...
...
@@ -216,7 +216,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
int32_t
rows_per_buffer_
;
std
::
string
folder_path_
;
std
::
unique_ptr
<
DataSchema
>
data_schema_
;
std
::
shared_ptr
<
Sampler
>
sampler_
;
int64_t
row_cnt_
;
int64_t
buf_cnt_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
浏览文件 @
b57d4ea2
...
...
@@ -65,7 +65,7 @@ ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::str
bool
recursive
,
bool
do_decode
,
const
std
::
set
<
std
::
string
>
&
exts
,
const
std
::
map
<
std
::
string
,
int32_t
>
&
map
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
ParallelOp
(
num_wkrs
,
queue_size
),
:
ParallelOp
(
num_wkrs
,
queue_size
,
std
::
move
(
sampler
)
),
rows_per_buffer_
(
rows_per_buffer
),
folder_path_
(
file_dir
),
recursive_
(
recursive
),
...
...
@@ -73,7 +73,6 @@ ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::str
extensions_
(
exts
),
class_index_
(
map
),
data_schema_
(
std
::
move
(
data_schema
)),
sampler_
(
std
::
move
(
sampler
)),
row_cnt_
(
0
),
buf_cnt_
(
0
),
sampler_ind_
(
0
),
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h
浏览文件 @
b57d4ea2
...
...
@@ -259,7 +259,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
std
::
set
<
std
::
string
>
extensions_
;
// extensions allowed
std
::
map
<
std
::
string
,
int32_t
>
class_index_
;
std
::
unique_ptr
<
DataSchema
>
data_schema_
;
std
::
shared_ptr
<
Sampler
>
sampler_
;
int64_t
row_cnt_
;
int64_t
buf_cnt_
;
int64_t
sampler_ind_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
浏览文件 @
b57d4ea2
...
...
@@ -64,7 +64,7 @@ Status ManifestOp::Builder::SanityCheck() {
ManifestOp
::
ManifestOp
(
int32_t
num_works
,
int32_t
rows_per_buffer
,
std
::
string
file
,
int32_t
queue_size
,
bool
decode
,
const
std
::
map
<
std
::
string
,
int32_t
>
&
class_index
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
,
std
::
string
usage
)
:
ParallelOp
(
num_works
,
queue_size
),
:
ParallelOp
(
num_works
,
queue_size
,
std
::
move
(
sampler
)
),
rows_per_buffer_
(
rows_per_buffer
),
io_block_pushed_
(
0
),
row_cnt_
(
0
),
...
...
@@ -72,7 +72,6 @@ ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string f
data_schema_
(
std
::
move
(
data_schema
)),
file_
(
file
),
class_index_
(
class_index
),
sampler_
(
std
::
move
(
sampler
)),
decode_
(
decode
),
usage_
(
usage
),
buf_cnt_
(
0
)
{
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h
浏览文件 @
b57d4ea2
...
...
@@ -230,7 +230,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
std
::
unique_ptr
<
DataSchema
>
data_schema_
;
std
::
string
file_
;
// file that store the information of images
std
::
map
<
std
::
string
,
int32_t
>
class_index_
;
std
::
shared_ptr
<
Sampler
>
sampler_
;
bool
decode_
;
std
::
string
usage_
;
int64_t
buf_cnt_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
浏览文件 @
b57d4ea2
...
...
@@ -66,12 +66,11 @@ Status MnistOp::Builder::SanityCheck() {
MnistOp
::
MnistOp
(
int32_t
num_workers
,
int32_t
rows_per_buffer
,
std
::
string
folder_path
,
int32_t
queue_size
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
ParallelOp
(
num_workers
,
queue_size
),
:
ParallelOp
(
num_workers
,
queue_size
,
std
::
move
(
sampler
)
),
buf_cnt_
(
0
),
row_cnt_
(
0
),
folder_path_
(
folder_path
),
rows_per_buffer_
(
rows_per_buffer
),
sampler_
(
std
::
move
(
sampler
)),
data_schema_
(
std
::
move
(
data_schema
))
{
io_block_queues_
.
Init
(
num_workers
,
queue_size
);
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h
浏览文件 @
b57d4ea2
...
...
@@ -235,7 +235,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
WaitPost
wp_
;
std
::
string
folder_path_
;
// directory of image folder
int32_t
rows_per_buffer_
;
std
::
shared_ptr
<
Sampler
>
sampler_
;
std
::
unique_ptr
<
DataSchema
>
data_schema_
;
std
::
vector
<
MnistLabelPair
>
image_label_pairs_
;
std
::
vector
<
std
::
string
>
image_names_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.cc
浏览文件 @
b57d4ea2
...
...
@@ -21,6 +21,7 @@
#include "dataset/core/config_manager.h"
#include "dataset/util/random.h"
#include "dataset/util/wait_post.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
namespace
mindspore
{
namespace
dataset
{
...
...
@@ -30,7 +31,8 @@ RandomDataOp::Builder::Builder()
builder_num_workers_
(
0
),
builder_op_connector_size_
(
0
),
builder_rows_per_buffer_
(
0
),
builder_total_rows_
(
0
)
{
builder_total_rows_
(
0
),
builder_sampler_
(
nullptr
)
{
// Some arguments to the RandomDataOp have a default argument that is taken from the config.
// The user may override these defaults by using the builder set methods.
std
::
shared_ptr
<
ConfigManager
>
cfg
=
GlobalContext
::
config_manager
();
...
...
@@ -43,8 +45,9 @@ RandomDataOp::Builder::Builder()
Status
RandomDataOp
::
Builder
::
Build
(
std
::
shared_ptr
<
RandomDataOp
>
*
out_op
)
{
RETURN_IF_NOT_OK
(
SanityCheck
());
*
out_op
=
std
::
make_shared
<
RandomDataOp
>
(
builder_num_workers_
,
builder_op_connector_size_
,
builder_rows_per_buffer_
,
builder_total_rows_
,
std
::
move
(
builder_data_schema_
));
*
out_op
=
std
::
make_shared
<
RandomDataOp
>
(
builder_num_workers_
,
builder_op_connector_size_
,
builder_rows_per_buffer_
,
builder_total_rows_
,
std
::
move
(
builder_data_schema_
),
std
::
move
(
builder_sampler_
));
// If the user did not provide a schema, then we will ask the op to generate a pseudo-random
// schema.
...
...
@@ -66,8 +69,8 @@ Status RandomDataOp::Builder::SanityCheck() const {
// Constructor for RandomDataOp
RandomDataOp
::
RandomDataOp
(
int32_t
num_workers
,
int32_t
op_connector_size
,
int64_t
rows_per_buffer
,
int64_t
total_rows
,
std
::
unique_ptr
<
DataSchema
>
data_schema
)
:
ParallelOp
(
num_workers
,
op_connector_size
),
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
ParallelOp
(
num_workers
,
op_connector_size
,
std
::
move
(
sampler
)
),
buffer_id_
(
0
),
rows_per_buffer_
(
rows_per_buffer
),
total_rows_
(
total_rows
),
...
...
@@ -124,7 +127,7 @@ Status RandomDataOp::GenerateSchema() {
// For each column:
// - choose a datatype
// - generate a shape that randomly chooses the number of dimensions and the dimension values.
DataType
::
Type
newType
=
static_cast
<
DataType
::
Type
>
(
GenRandomInt
(
0
,
DataType
::
NUM_OF_TYPES
-
2
));
DataType
::
Type
newType
=
static_cast
<
DataType
::
Type
>
(
GenRandomInt
(
1
,
DataType
::
NUM_OF_TYPES
-
2
));
int32_t
rank
=
GenRandomInt
(
1
,
kMaxRank
);
std
::
vector
<
dsize_t
>
dims
;
for
(
int32_t
d
=
0
;
d
<
rank
;
d
++
)
{
...
...
@@ -412,5 +415,15 @@ Status RandomDataOp::ComputeColMap() {
}
return
Status
::
OK
();
}
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
Status
RandomDataOp
::
PrepareNodePostAction
()
{
// Run common code from super class before adding RandomDataOp specific handling
RETURN_IF_NOT_OK
(
ParallelOp
::
PrepareNodePostAction
());
// Specific handling for this op, we need to do cache op work to assign the sampler to the cache.
RETURN_IF_NOT_OK
(
DatasetOp
::
SaveSamplerForCache
(
false
));
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.h
浏览文件 @
b57d4ea2
...
...
@@ -42,7 +42,7 @@ class RandomDataOp : public ParallelOp {
// Some constants to provide limits to random generation.
static
constexpr
int32_t
kMaxNumColumns
=
4
;
static
constexpr
int32_t
kMaxRank
=
4
;
static
constexpr
int32_t
kMaxDimValue
=
2048
;
static
constexpr
int32_t
kMaxDimValue
=
32
;
static
constexpr
int32_t
kMaxTotalRows
=
1024
;
// A nested builder class to aid in the construction of a RandomDataOp
...
...
@@ -117,6 +117,14 @@ class RandomDataOp : public ParallelOp {
return
*
this
;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder
&
SetSampler
(
std
::
shared_ptr
<
Sampler
>
sampler
)
{
builder_sampler_
=
std
::
move
(
sampler
);
return
*
this
;
}
private:
/**
* Check if the required parameters are set by the builder.
...
...
@@ -125,6 +133,7 @@ class RandomDataOp : public ParallelOp {
Status
SanityCheck
()
const
;
std
::
unique_ptr
<
DataSchema
>
builder_data_schema_
;
std
::
shared_ptr
<
Sampler
>
builder_sampler_
;
int32_t
builder_num_workers_
;
int32_t
builder_op_connector_size_
;
int64_t
builder_rows_per_buffer_
;
...
...
@@ -139,10 +148,11 @@ class RandomDataOp : public ParallelOp {
* @param rows_per_buffer - The number of rows in each DataBuffer
* @param data_schema - A user-provided schema
* @param total_rows - The total number of rows in the dataset
* @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes
* @return Builder - The modified builder by reference
*/
RandomDataOp
(
int32_t
num_workers
,
int32_t
op_connector_size
,
int64_t
rows_per_buffer
,
int64_t
total_rows
,
std
::
unique_ptr
<
DataSchema
>
data_schema
);
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
);
/**
* Destructor
...
...
@@ -193,6 +203,12 @@ class RandomDataOp : public ParallelOp {
// @return Name of the current Op
std
::
string
Name
()
const
override
{
return
"RandomDataOp"
;
}
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
Status
PrepareNodePostAction
()
override
;
private:
/**
* The entry point code for when workers are launched
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc
浏览文件 @
b57d4ea2
...
...
@@ -107,12 +107,11 @@ Status DistributedSampler::ResetSampler() {
}
void
DistributedSampler
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
out
<<
"
(sampler): DistributedSampler
\n
"
;
out
<<
"
\n
Sampler: DistributedSampler
"
;
if
(
show_all
)
{
out
<<
"seed_: "
<<
seed_
<<
'\n'
;
out
<<
"device_id_: "
<<
device_id_
<<
'\n'
;
out
<<
"num_devices_: "
<<
num_devices_
<<
'\n'
;
out
<<
"shuffle_: "
<<
shuffle_
<<
'\n'
;
Sampler
::
Print
(
out
,
show_all
);
out
<<
"
\n
seed: "
<<
seed_
<<
"
\n
device_id: "
<<
device_id_
<<
"
\n
num_devices: "
<<
num_devices_
<<
"
\n
shuffle: "
<<
shuffle_
;
}
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc
浏览文件 @
b57d4ea2
...
...
@@ -113,5 +113,13 @@ Status PKSampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
return
Status
::
OK
();
}
void
PKSampler
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
out
<<
"
\n
Sampler: PKSampler"
;
if
(
show_all
)
{
// Call the super class for displaying any common detailed info
Sampler
::
Print
(
out
,
show_all
);
// Then add our own info if any
}
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h
浏览文件 @
b57d4ea2
...
...
@@ -56,6 +56,11 @@ class PKSampler : public Sampler { // NOT YET FINISHED
// @return - The error code return
Status
ResetSampler
()
override
;
// Printer for debugging purposes.
// @param out - output stream to write to
// @param show_all - bool to show detailed vs summary
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
private:
bool
shuffle_
;
uint32_t
seed_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc
浏览文件 @
b57d4ea2
...
...
@@ -103,5 +103,14 @@ Status PythonSampler::ResetSampler() {
return
Status
::
OK
();
}
void
PythonSampler
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
out
<<
"
\n
Sampler: PythonSampler"
;
if
(
show_all
)
{
// Call the super class for displaying any common detailed info
Sampler
::
Print
(
out
,
show_all
);
// Then add our own info if any
}
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h
浏览文件 @
b57d4ea2
...
...
@@ -50,6 +50,11 @@ class PythonSampler : public Sampler {
// @return - The error code return
Status
GetNextSample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
// Printer for debugging purposes.
// @param out - output stream to write to
// @param show_all - bool to show detailed vs summary
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
private:
bool
need_to_reset_
;
// Whether Reset() should be called before calling GetNextBuffer()
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc
浏览文件 @
b57d4ea2
...
...
@@ -113,13 +113,12 @@ Status RandomSampler::ResetSampler() {
}
void
RandomSampler
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
out
<<
"(sampler): RandomSampler
\n
"
;
out
<<
"
\n
Sampler: RandomSampler"
;
if
(
show_all
)
{
out
<<
"num_samples_: "
<<
num_samples_
<<
'\n'
;
out
<<
"next_id_: "
<<
next_id_
<<
'\n'
;
// Call the super class for displaying any common detailed info
Sampler
::
Print
(
out
,
show_all
);
// Then add our own info if any
}
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
浏览文件 @
b57d4ea2
...
...
@@ -80,11 +80,12 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t
}
void
Sampler
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
out
<<
"(sampler): base
\n
"
;
// Sampler printing is usually only called in the show_all mode.
// Derived classes will display the name, then call back to this base
// for common info.
// No-op in the summary mode.
if
(
show_all
)
{
out
<<
"num_rows_: "
<<
num_rows_
<<
'\n'
;
out
<<
"num_samples_: "
<<
num_samples_
<<
'\n'
;
out
<<
"
\n
num_rows_: "
<<
num_rows_
<<
"
\n
num_samples_: "
<<
num_samples_
;
}
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc
浏览文件 @
b57d4ea2
...
...
@@ -89,7 +89,14 @@ Status SequentialSampler::ResetSampler() {
return
Status
::
OK
();
}
void
SequentialSampler
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
out
<<
"(sampler): SequentialSampler
\n
"
;
}
void
SequentialSampler
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
out
<<
"
\n
Sampler: SequentialSampler"
;
if
(
show_all
)
{
// Call the super class for displaying any common detailed info
Sampler
::
Print
(
out
,
show_all
);
// Then add our own info
out
<<
"
\n
Start index: "
<<
start_index_
;
}
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h
浏览文件 @
b57d4ea2
...
...
@@ -49,6 +49,9 @@ class SequentialSampler : public Sampler {
// @return - The error code return
Status
GetNextSample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
// Printer for debugging purposes.
// @param out - output stream to write to
// @param show_all - bool to show detailed vs summary
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
private:
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc
浏览文件 @
b57d4ea2
...
...
@@ -119,5 +119,14 @@ Status SubsetRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffe
return
Status
::
OK
();
}
void
SubsetRandomSampler
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
out
<<
"
\n
Sampler: SubsetRandomSampler"
;
if
(
show_all
)
{
// Call the super class for displaying any common detailed info
Sampler
::
Print
(
out
,
show_all
);
// Then add our own info if any
}
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h
浏览文件 @
b57d4ea2
...
...
@@ -51,6 +51,11 @@ class SubsetRandomSampler : public Sampler {
// @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer.
Status
GetNextSample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
// Printer for debugging purposes.
// @param out - output stream to write to
// @param show_all - bool to show detailed vs summary
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
private:
// A list of indices (already randomized in constructor).
std
::
vector
<
int64_t
>
indices_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc
浏览文件 @
b57d4ea2
...
...
@@ -156,5 +156,14 @@ Status WeightedRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buf
return
Status
::
OK
();
}
void
WeightedRandomSampler
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
out
<<
"
\n
Sampler: WeightedRandomSampler"
;
if
(
show_all
)
{
// Call the super class for displaying any common detailed info
Sampler
::
Print
(
out
,
show_all
);
// Then add our own info if any
}
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h
浏览文件 @
b57d4ea2
...
...
@@ -53,6 +53,11 @@ class WeightedRandomSampler : public Sampler {
// @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer.
Status
GetNextSample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
// Printer for debugging purposes.
// @param out - output stream to write to
// @param show_all - bool to show detailed vs summary
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
private:
// A list of weights for each sample.
std
::
vector
<
double
>
weights_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc
浏览文件 @
b57d4ea2
...
...
@@ -33,7 +33,11 @@
namespace
mindspore
{
namespace
dataset
{
TextFileOp
::
Builder
::
Builder
()
:
builder_device_id_
(
0
),
builder_num_devices_
(
1
),
builder_total_rows_
(
0
),
builder_shuffle_files_
(
false
)
{
:
builder_device_id_
(
0
),
builder_num_devices_
(
1
),
builder_total_rows_
(
0
),
builder_shuffle_files_
(
false
),
builder_sampler_
(
nullptr
)
{
std
::
shared_ptr
<
ConfigManager
>
config_manager
=
GlobalContext
::
config_manager
();
builder_num_workers_
=
config_manager
->
num_parallel_workers
();
builder_op_connector_size_
=
config_manager
->
op_connector_size
();
...
...
@@ -64,7 +68,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
std
::
shared_ptr
<
TextFileOp
>
text_file_op
=
std
::
make_shared
<
TextFileOp
>
(
builder_num_workers_
,
builder_rows_per_buffer_
,
builder_total_rows_
,
builder_worker_connector_size_
,
std
::
move
(
builder_schema_
),
builder_text_files_list_
,
builder_op_connector_size_
,
builder_shuffle_files_
,
builder_num_devices_
,
builder_device_id_
);
builder_num_devices_
,
builder_device_id_
,
std
::
move
(
builder_sampler_
)
);
RETURN_IF_NOT_OK
(
text_file_op
->
Init
());
*
op
=
std
::
move
(
text_file_op
);
...
...
@@ -73,8 +77,9 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
TextFileOp
::
TextFileOp
(
int32_t
num_workers
,
int64_t
rows_per_buffer
,
int64_t
total_rows
,
int32_t
worker_connector_size
,
std
::
unique_ptr
<
DataSchema
>
schema
,
std
::
vector
<
std
::
string
>
text_files_list
,
int32_t
op_connector_size
,
bool
shuffle_files
,
int32_t
num_device
,
int32_t
device_id
)
:
ParallelOp
(
num_workers
,
op_connector_size
),
int32_t
op_connector_size
,
bool
shuffle_files
,
int32_t
num_device
,
int32_t
device_id
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
ParallelOp
(
num_workers
,
op_connector_size
,
std
::
move
(
sampler
)),
device_id_
(
device_id
),
num_devices_
(
num_device
),
rows_per_buffer_
(
rows_per_buffer
),
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h
浏览文件 @
b57d4ea2
...
...
@@ -20,6 +20,7 @@
#include <map>
#include <mutex>
#include <string>
#include <utility>
#include <vector>
#include "dataset/util/status.h"
...
...
@@ -112,6 +113,14 @@ class TextFileOp : public ParallelOp {
return
*
this
;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder
&
SetSampler
(
std
::
shared_ptr
<
Sampler
>
sampler
)
{
builder_sampler_
=
std
::
move
(
sampler
);
return
*
this
;
}
private:
int32_t
builder_device_id_
;
int32_t
builder_num_devices_
;
...
...
@@ -123,6 +132,7 @@ class TextFileOp : public ParallelOp {
std
::
vector
<
std
::
string
>
builder_text_files_list_
;
bool
builder_shuffle_files_
;
std
::
unique_ptr
<
DataSchema
>
builder_schema_
;
std
::
shared_ptr
<
Sampler
>
builder_sampler_
;
};
// Constructor of TextFileOp
...
...
@@ -136,9 +146,10 @@ class TextFileOp : public ParallelOp {
// @param columns_to_load - the names of the columns to load data from.
// @param shuffle_files - whether or not to shuffle the files before reading data.
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
// @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes
TextFileOp
(
int32_t
num_workers
,
int64_t
rows_per_buffer
,
int64_t
total_rows
,
int32_t
worker_connector_size
,
std
::
unique_ptr
<
DataSchema
>
,
std
::
vector
<
std
::
string
>
text_files_list
,
int32_t
op_connector_size
,
bool
shuffle_files
,
int32_t
num_devices
,
int32_t
device_id
);
bool
shuffle_files
,
int32_t
num_devices
,
int32_t
device_id
,
std
::
shared_ptr
<
Sampler
>
sampler
);
// Default destructor
~
TextFileOp
()
=
default
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc
浏览文件 @
b57d4ea2
...
...
@@ -48,7 +48,11 @@
namespace
mindspore
{
namespace
dataset
{
TFReaderOp
::
Builder
::
Builder
()
:
builder_device_id_
(
0
),
builder_num_devices_
(
1
),
builder_total_rows_
(
0
),
builder_equal_rows_per_shard_
(
false
)
{
:
builder_device_id_
(
0
),
builder_num_devices_
(
1
),
builder_total_rows_
(
0
),
builder_equal_rows_per_shard_
(
false
),
builder_sampler_
(
nullptr
)
{
std
::
shared_ptr
<
ConfigManager
>
config_manager
=
GlobalContext
::
config_manager
();
builder_num_workers_
=
config_manager
->
num_parallel_workers
();
builder_worker_connector_size_
=
config_manager
->
worker_connector_size
();
...
...
@@ -87,11 +91,6 @@ Status TFReaderOp::Builder::ValidateInputs() const {
err_msg
+=
"Number of parallel workers is smaller or equal to 0
\n
"
;
}
if
(
!
builder_equal_rows_per_shard_
&&
builder_dataset_files_list_
.
size
()
<
static_cast
<
uint32_t
>
(
builder_num_devices_
))
{
err_msg
+=
"Not enough tfrecord files provided
\n
"
;
}
if
(
builder_device_id_
>=
builder_num_devices_
||
builder_num_devices_
<
1
)
{
err_msg
+=
"Wrong sharding configs
\n
"
;
}
...
...
@@ -125,7 +124,8 @@ Status TFReaderOp::Builder::Build(std::shared_ptr<TFReaderOp> *out_tf_reader_op)
std
::
shared_ptr
<
TFReaderOp
>
new_tf_reader_op
=
std
::
make_shared
<
TFReaderOp
>
(
builder_num_workers_
,
builder_worker_connector_size_
,
builder_rows_per_buffer_
,
builder_total_rows_
,
builder_dataset_files_list_
,
std
::
move
(
builder_data_schema_
),
builder_op_connector_size_
,
builder_columns_to_load_
,
builder_shuffle_files_
,
builder_num_devices_
,
builder_device_id_
,
builder_equal_rows_per_shard_
);
builder_shuffle_files_
,
builder_num_devices_
,
builder_device_id_
,
builder_equal_rows_per_shard_
,
std
::
move
(
builder_sampler_
));
RETURN_IF_NOT_OK
(
new_tf_reader_op
->
Init
());
*
out_tf_reader_op
=
std
::
move
(
new_tf_reader_op
);
...
...
@@ -136,8 +136,8 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64
int64_t
total_num_rows
,
std
::
vector
<
std
::
string
>
dataset_files_list
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
int32_t
op_connector_size
,
std
::
vector
<
std
::
string
>
columns_to_load
,
bool
shuffle_files
,
int32_t
num_device
,
int32_t
device_id
,
bool
equal_rows_per_shard
)
:
ParallelOp
(
num_workers
,
op_connector_size
),
int32_t
device_id
,
bool
equal_rows_per_shard
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
ParallelOp
(
num_workers
,
op_connector_size
,
std
::
move
(
sampler
)
),
device_id_
(
device_id
),
num_devices_
(
num_device
),
rows_per_buffer_
(
rows_per_buffer
),
...
...
@@ -1018,5 +1018,40 @@ Status TFReaderOp::ComputeColMap() {
}
return
Status
::
OK
();
}
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
Status
TFReaderOp
::
PrepareNodePostAction
()
{
// Run common code from super class before adding TFReaderOp specific handling
RETURN_IF_NOT_OK
(
ParallelOp
::
PrepareNodePostAction
());
// Specific handling for this op, we need to do cache op work so assign the sampler to the cache
// TF is a special case because it can support file-based sharding/shuffling, or, if there
// is a cache, then it can also do row-based sampler using the sampler on the cache.
// Thus, pass true for random access op flag when saving the sampler. This is a special case,
// since usually a non-mappable dataset would pass false here.
RETURN_IF_NOT_OK
(
DatasetOp
::
SaveSamplerForCache
(
true
));
// Now that the sampler has been saved for the cache, we need to adjust the TFReaderOp to turn it into
// a simpler producer of all data (no shuffling or sharding or anything)
if
(
BitTest
(
tree_
->
PrepareFlags
(),
ExecutionTree
::
kDePrepCache
))
{
device_id_
=
0
;
num_devices_
=
1
;
total_rows_
=
0
;
shuffle_files_
=
false
;
equal_rows_per_shard_
=
false
;
sampler_
.
reset
();
// Normally SaveSampler code did this for us, but we passed in true above (See comment)
}
else
{
// This sanity check had been delayed until now in the prepare loop.
// If we are not in a cache path, then we can validate the the file-based sharding config.
// If we are in a cache path, there is no file-based sharding so the check is not correct in that
// situation.
if
(
!
equal_rows_per_shard_
&&
dataset_files_list_
.
size
()
<
static_cast
<
uint32_t
>
(
num_devices_
))
{
RETURN_STATUS_UNEXPECTED
(
"Not enough tfrecord files provided
\n
"
);
}
}
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h
浏览文件 @
b57d4ea2
...
...
@@ -153,8 +153,17 @@ class TFReaderOp : public ParallelOp {
return
*
this
;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder
&
SetSampler
(
std
::
shared_ptr
<
Sampler
>
sampler
)
{
builder_sampler_
=
std
::
move
(
sampler
);
return
*
this
;
}
private:
std
::
unique_ptr
<
DataSchema
>
builder_data_schema_
;
std
::
shared_ptr
<
Sampler
>
builder_sampler_
;
int32_t
builder_device_id_
;
int32_t
builder_num_devices_
;
int32_t
builder_num_workers_
;
...
...
@@ -180,10 +189,11 @@ class TFReaderOp : public ParallelOp {
// @param columns_to_load - the names of the columns to load data from.
// @param shuffle_files - whether or not to shuffle the files before reading data.
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
// @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes
TFReaderOp
(
int32_t
num_workers
,
int32_t
worker_connector_size
,
int64_t
rows_per_buffer
,
int64_t
total_num_rows
,
std
::
vector
<
std
::
string
>
dataset_files_list
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
int32_t
op_connector_size
,
std
::
vector
<
std
::
string
>
columns_to_load
,
bool
shuffle_files
,
int32_t
num_devices
,
int32_t
device_id
,
bool
equal_rows_per_shard
);
int32_t
num_devices
,
int32_t
device_id
,
bool
equal_rows_per_shard
,
std
::
shared_ptr
<
Sampler
>
sampler
);
// Default destructor
~
TFReaderOp
()
=
default
;
...
...
@@ -236,6 +246,12 @@ class TFReaderOp : public ParallelOp {
// @return Vector of the input file names
std
::
vector
<
std
::
string
>
FileNames
()
{
return
dataset_files_list_
;
}
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
Status
PrepareNodePostAction
()
override
;
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
浏览文件 @
b57d4ea2
...
...
@@ -88,7 +88,7 @@ Status VOCOp::Builder::SanityCheck() {
VOCOp
::
VOCOp
(
const
TaskType
&
task_type
,
const
std
::
string
&
task_mode
,
const
std
::
string
&
folder_path
,
const
std
::
map
<
std
::
string
,
int32_t
>
&
class_index
,
int32_t
num_workers
,
int32_t
rows_per_buffer
,
int32_t
queue_size
,
bool
decode
,
std
::
unique_ptr
<
DataSchema
>
data_schema
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
ParallelOp
(
num_workers
,
queue_size
),
:
ParallelOp
(
num_workers
,
queue_size
,
std
::
move
(
sampler
)
),
decode_
(
decode
),
row_cnt_
(
0
),
buf_cnt_
(
0
),
...
...
@@ -97,7 +97,6 @@ VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std:
folder_path_
(
folder_path
),
class_index_
(
class_index
),
rows_per_buffer_
(
rows_per_buffer
),
sampler_
(
std
::
move
(
sampler
)),
data_schema_
(
std
::
move
(
data_schema
))
{
io_block_queues_
.
Init
(
num_workers_
,
queue_size
);
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h
浏览文件 @
b57d4ea2
...
...
@@ -274,7 +274,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
TaskType
task_type_
;
std
::
string
task_mode_
;
int32_t
rows_per_buffer_
;
std
::
shared_ptr
<
Sampler
>
sampler_
;
std
::
unique_ptr
<
DataSchema
>
data_schema_
;
WaitPost
wp_
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/take_op.cc
浏览文件 @
b57d4ea2
...
...
@@ -129,7 +129,7 @@ Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<D
Status
TakeOp
::
PrepareNodePostAction
()
{
RETURN_IF_NOT_OK
(
PipelineOp
::
PrepareNodePostAction
());
tree_
->
AddTo
Repeat
Stack
(
shared_from_this
());
tree_
->
AddTo
EOEOp
Stack
(
shared_from_this
());
return
Status
::
OK
();
}
...
...
mindspore/ccsrc/dataset/engine/execution_tree.cc
浏览文件 @
b57d4ea2
...
...
@@ -88,13 +88,13 @@ Status ExecutionTree::AssignRoot(const std::shared_ptr<DatasetOp> &op) {
}
// A print method typically used for debugging
void
ExecutionTree
::
Print
(
std
::
ostream
&
out
)
const
{
void
ExecutionTree
::
Print
(
std
::
ostream
&
out
,
const
std
::
shared_ptr
<
DatasetOp
>
&
op
)
const
{
out
<<
"Execution tree summary:
\n
"
<<
"-----------------------
\n
"
;
this
->
PrintNode
(
out
,
root_
,
""
,
true
,
false
);
this
->
PrintNode
(
out
,
op
==
nullptr
?
root_
:
op
,
""
,
true
,
false
);
out
<<
"
\n
Execution tree operator details:
\n
"
<<
"--------------------------------
\n
"
;
this
->
PrintNode
(
out
,
root_
,
""
,
true
,
true
);
this
->
PrintNode
(
out
,
op
==
nullptr
?
root_
:
op
,
""
,
true
,
true
);
}
// A helper functions for doing the recursive printing
...
...
@@ -269,27 +269,40 @@ Status ExecutionTree::PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op)
RETURN_IF_NOT_OK
(
this
->
PrepareNode
(
i
));
}
// Then clear the flags from this op now that we have prepared it.
BitClear
(
&
prepare_flags_
,
op_prep_flags
);
// No more children, now we execute any prepare actions before going back up the
// the tree on recursive function
RETURN_IF_NOT_OK
(
dataset_op
->
PrepareNodePostAction
());
// Then clear the flags from this op now that we have prepared it.
BitClear
(
&
prepare_flags_
,
op_prep_flags
);
return
Status
::
OK
();
}
// Adds an operator to the
repeat
stack during prepare phase.
void
ExecutionTree
::
AddTo
RepeatStack
(
std
::
shared_ptr
<
DatasetOp
>
dataset_op
)
{
repeat
_stack_
.
push
(
dataset_op
);
}
// Adds an operator to the
eoe operator
stack during prepare phase.
void
ExecutionTree
::
AddTo
EOEOpStack
(
std
::
shared_ptr
<
DatasetOp
>
dataset_op
)
{
eoe
_stack_
.
push
(
dataset_op
);
}
// Pops an operator from the
repeat
stack during prepare phase.
std
::
shared_ptr
<
DatasetOp
>
ExecutionTree
::
PopFrom
Repeat
Stack
()
{
// Pops an operator from the
eoe operator
stack during prepare phase.
std
::
shared_ptr
<
DatasetOp
>
ExecutionTree
::
PopFrom
EOEOp
Stack
()
{
std
::
shared_ptr
<
DatasetOp
>
top_op
=
nullptr
;
if
(
!
repeat
_stack_
.
empty
())
{
top_op
=
repeat
_stack_
.
top
();
repeat
_stack_
.
pop
();
if
(
!
eoe
_stack_
.
empty
())
{
top_op
=
eoe
_stack_
.
top
();
eoe
_stack_
.
pop
();
}
return
top_op
;
}
// Adds a sampler to the sampler stack during prepare phase.
void
ExecutionTree
::
AddToSamplerStack
(
std
::
shared_ptr
<
Sampler
>
sampler
)
{
sampler_stack_
.
push
(
sampler
);
}
// Pops an operator from the sampler stack during prepare phase.
std
::
shared_ptr
<
Sampler
>
ExecutionTree
::
PopFromSamplerStack
()
{
std
::
shared_ptr
<
Sampler
>
top_sampler
=
nullptr
;
if
(
!
sampler_stack_
.
empty
())
{
top_sampler
=
sampler_stack_
.
top
();
sampler_stack_
.
pop
();
}
return
top_sampler
;
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/execution_tree.h
浏览文件 @
b57d4ea2
...
...
@@ -37,7 +37,8 @@ class ExecutionTree {
// Prepare flags used during tree prepare phase
enum
PrepareFlags
{
kDePrepNone
=
0
,
kDePrepRepeat
=
1
// Processing a repeat operation
kDePrepRepeat
=
1
,
// Processing a repeat operation
kDePrepCache
=
2
// Processing a cache operation
};
// State flags for the lifecycle of the tree
...
...
@@ -118,9 +119,9 @@ class ExecutionTree {
// @return Status - The error code return
Status
Launch
();
// A print method typically used for debugging
//
@
param out - The output stream to write output to
void
Print
(
std
::
ostream
&
out
)
const
;
//
/
A print method typically used for debugging
//
/ \
param out - The output stream to write output to
void
Print
(
std
::
ostream
&
out
,
const
std
::
shared_ptr
<
DatasetOp
>
&
op
=
nullptr
)
const
;
// Returns an iterator positioned at the start
// @return Iterator - The iterator
...
...
@@ -199,14 +200,23 @@ class ExecutionTree {
// @return Status - The error code return
Status
PrepareNode
(
const
std
::
shared_ptr
<
DatasetOp
>
&
dataset_op
);
// Adds an operator to the repeat stack during prepare phase.
// @param op - The dataset op to work add to repeat stack
// @return Status - The error code return
void
AddToRepeatStack
(
std
::
shared_ptr
<
DatasetOp
>
dataset_op
);
/// Adds an operator to the eoe operator stack during prepare phase.
/// \param op - The dataset op to work add to eoe stack
/// \return Status - The error code return
void
AddToEOEOpStack
(
std
::
shared_ptr
<
DatasetOp
>
dataset_op
);
/// Pops an operator from the eoe operator stack during prepare phase.
/// \return shared_ptr to the popped operator
std
::
shared_ptr
<
DatasetOp
>
PopFromEOEOpStack
();
/// Adds a sampler to the sampler stack during prepare phase.
/// \param samplerop - The dataset op to work add to eoe stack
/// \return Status - The error code return
void
AddToSamplerStack
(
std
::
shared_ptr
<
Sampler
>
sampler
);
//
Pops an operator from the repeat
stack during prepare phase.
//
@
return shared_ptr to the popped operator
std
::
shared_ptr
<
DatasetOp
>
PopFromRepeat
Stack
();
//
/ Pops an operator from the sampler
stack during prepare phase.
//
/ \
return shared_ptr to the popped operator
std
::
shared_ptr
<
Sampler
>
PopFromSampler
Stack
();
// Return the pointer to the TaskGroup
// @return raw pointer to the TaskGroup
...
...
@@ -236,9 +246,10 @@ class ExecutionTree {
int32_t
id_count_
;
// Counter for generating operator id's
uint32_t
prepare_flags_
;
// Flags used during tree prepare
TreeState
tree_state_
;
// Tracking the current tree state
std
::
stack
<
std
::
shared_ptr
<
DatasetOp
>>
repeat_stack_
;
// A stack used during prepare phase
std
::
unique_ptr
<
Monitor
>
perf_monitor_
;
// Performance Monitor
std
::
unique_ptr
<
ProfilingManager
>
profiling_manager_
;
// Profiling manager
std
::
stack
<
std
::
shared_ptr
<
DatasetOp
>>
eoe_stack_
;
// A stack used during prepare phase
std
::
stack
<
std
::
shared_ptr
<
Sampler
>>
sampler_stack_
;
// A stack used during prepare phase
};
}
// namespace dataset
}
// namespace mindspore
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录