Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
2301_77200941
mindspore
提交
6ae88c39
M
mindspore
项目概览
2301_77200941
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
6ae88c39
编写于
6月 10, 2020
作者:
L
Lixia Chen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Create TensorRow class that supports a row id.
上级
4aed2bf9
变更
30
隐藏空白更改
内联
并排
Showing
30 changed file
with
277 addition
and
60 deletion
+277
-60
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+1
-1
mindspore/ccsrc/dataset/core/CMakeLists.txt
mindspore/ccsrc/dataset/core/CMakeLists.txt
+1
-0
mindspore/ccsrc/dataset/core/constants.h
mindspore/ccsrc/dataset/core/constants.h
+3
-0
mindspore/ccsrc/dataset/core/tensor.h
mindspore/ccsrc/dataset/core/tensor.h
+0
-3
mindspore/ccsrc/dataset/core/tensor_row.cc
mindspore/ccsrc/dataset/core/tensor_row.cc
+75
-0
mindspore/ccsrc/dataset/core/tensor_row.h
mindspore/ccsrc/dataset/core/tensor_row.h
+131
-0
mindspore/ccsrc/dataset/engine/data_buffer.h
mindspore/ccsrc/dataset/engine/data_buffer.h
+1
-0
mindspore/ccsrc/dataset/engine/datasetops/batch_op.h
mindspore/ccsrc/dataset/engine/datasetops/batch_op.h
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
...spore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
+4
-3
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h
+3
-1
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
+2
-2
mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.cc
+13
-12
mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.h
mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.h
+8
-4
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
...ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
+3
-3
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h
.../ccsrc/dataset/engine/datasetops/source/image_folder_op.h
+2
-1
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
...ore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
+4
-3
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h
...pore/ccsrc/dataset/engine/datasetops/source/manifest_op.h
+3
-1
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
+3
-3
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h
+2
-1
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
+4
-4
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h
+2
-1
mindspore/ccsrc/dataset/engine/gnn/graph.h
mindspore/ccsrc/dataset/engine/gnn/graph.h
+1
-0
mindspore/ccsrc/dataset/engine/tdt/tdt_plugin.h
mindspore/ccsrc/dataset/engine/tdt/tdt_plugin.h
+1
-0
mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc
mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc
+1
-2
mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h
mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h
+1
-2
mindspore/ccsrc/dataset/kernels/py_func_op.cc
mindspore/ccsrc/dataset/kernels/py_func_op.cc
+1
-2
mindspore/ccsrc/dataset/kernels/py_func_op.h
mindspore/ccsrc/dataset/kernels/py_func_op.h
+1
-2
mindspore/ccsrc/dataset/kernels/tensor_op.cc
mindspore/ccsrc/dataset/kernels/tensor_op.cc
+1
-2
mindspore/ccsrc/dataset/kernels/tensor_op.h
mindspore/ccsrc/dataset/kernels/tensor_op.h
+2
-2
tests/ut/cpp/dataset/map_op_test.cc
tests/ut/cpp/dataset/map_op_test.cc
+2
-4
未找到文件。
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
6ae88c39
...
...
@@ -584,7 +584,7 @@ void bindGraphData(py::module *m) {
[](
gnn
::
Graph
&
g
,
std
::
shared_ptr
<
Tensor
>
node_list
,
std
::
vector
<
gnn
::
FeatureType
>
feature_types
)
{
TensorRow
out
;
THROW_IF_ERROR
(
g
.
GetNodeFeature
(
node_list
,
feature_types
,
&
out
));
return
out
;
return
out
.
getRow
()
;
})
.
def
(
"graph_info"
,
[](
gnn
::
Graph
&
g
)
{
py
::
dict
out
;
...
...
mindspore/ccsrc/dataset/core/CMakeLists.txt
浏览文件 @
6ae88c39
...
...
@@ -11,6 +11,7 @@ add_library(core OBJECT
data_type.cc
global_context.cc
tensor.cc
tensor_row.cc
tensor_shape.cc
)
add_dependencies
(
core mindspore::protobuf
)
...
...
mindspore/ccsrc/dataset/core/constants.h
浏览文件 @
6ae88c39
...
...
@@ -51,6 +51,9 @@ constexpr uint32_t kCfgMonitorSamplingInterval = 10;
// Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h)
constexpr
uint8_t
kCVInvalidType
=
255
;
using
connection_id_type
=
int64_t
;
using
row_id_type
=
int64_t
;
}
// namespace dataset
}
// namespace mindspore
...
...
mindspore/ccsrc/dataset/core/tensor.h
浏览文件 @
6ae88c39
...
...
@@ -44,9 +44,6 @@ class Tensor;
using
CharAllocPtr
=
std
::
unique_ptr
<
Allocator
<
unsigned
char
>>
;
using
TensorAllocPtr
=
std
::
shared_ptr
<
Allocator
<
Tensor
>>
;
// An allocator shared_ptr for Tensors
using
TensorRow
=
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
;
// A row is a set of Tensor pointers
using
TensorTable
=
std
::
vector
<
TensorRow
>
;
// The table of tensors is a vector of rows
using
TensorQTable
=
std
::
deque
<
TensorRow
>
;
// A different flavour of tensor table, this one has queue functionality
class
Tensor
{
public:
...
...
mindspore/ccsrc/dataset/core/tensor_row.cc
0 → 100644
浏览文件 @
6ae88c39
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <utility>
#include "dataset/core/tensor_row.h"
namespace
py
=
pybind11
;
namespace
mindspore
{
namespace
dataset
{
TensorRow
::
TensorRow
()
noexcept
:
id_
(
kDefaultRowId
)
{}
TensorRow
::
TensorRow
(
size_type
n
,
TensorRow
::
value_type
t
)
noexcept
:
id_
(
kDefaultRowId
),
row_
(
n
,
t
)
{}
TensorRow
::
TensorRow
(
const
TensorRow
::
vector_type
&
v
)
:
id_
(
kDefaultRowId
),
row_
(
v
)
{}
TensorRow
::
TensorRow
(
row_id_type
id
,
const
std
::
initializer_list
<
value_type
>
&
lst
)
:
id_
(
id
),
row_
(
lst
)
{}
TensorRow
::
TensorRow
(
const
TensorRow
&
tr
)
:
id_
(
tr
.
id_
),
row_
(
tr
.
row_
)
{}
TensorRow
&
TensorRow
::
operator
=
(
const
TensorRow
&
tr
)
{
if
(
this
==
&
tr
)
{
return
*
this
;
}
row_
=
tr
.
row_
;
id_
=
tr
.
id_
;
return
*
this
;
}
TensorRow
&
TensorRow
::
operator
=
(
const
std
::
initializer_list
<
TensorRow
::
value_type
>
&
lst
)
{
row_
=
lst
;
return
*
this
;
}
TensorRow
::
TensorRow
(
TensorRow
::
vector_type
&&
v
)
noexcept
:
id_
(
kDefaultRowId
),
row_
(
std
::
move
(
v
))
{}
TensorRow
::
TensorRow
(
row_id_type
id
,
std
::
initializer_list
<
value_type
>
&&
lst
)
noexcept
:
id_
(
id
),
row_
(
std
::
move
(
lst
))
{}
TensorRow
::
TensorRow
(
TensorRow
&&
tr
)
noexcept
{
id_
=
tr
.
id_
;
row_
=
std
::
move
(
tr
.
row_
);
}
TensorRow
&
TensorRow
::
operator
=
(
TensorRow
&&
tr
)
noexcept
{
if
(
this
==
&
tr
)
{
return
*
this
;
}
row_
=
std
::
move
(
tr
.
row_
);
id_
=
tr
.
id_
;
tr
.
id_
=
kDefaultRowId
;
return
*
this
;
}
TensorRow
&
TensorRow
::
operator
=
(
std
::
initializer_list
<
TensorRow
::
value_type
>
&&
lst
)
noexcept
{
row_
=
std
::
move
(
lst
);
return
*
this
;
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/core/tensor_row.h
0 → 100644
浏览文件 @
6ae88c39
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_CORE_TENSOR_ROW_H_
#define DATASET_CORE_TENSOR_ROW_H_
#include <deque>
#include <memory>
#include <vector>
#include "dataset/core/tensor.h"
namespace
mindspore
{
namespace
dataset
{
class
TensorRow
;
// A set of Tensor pointers with an id
using
TensorTable
=
std
::
vector
<
TensorRow
>
;
// The table of tensors is a vector of rows
using
TensorQTable
=
std
::
deque
<
TensorRow
>
;
// A different flavour of tensor table, this one has queue functionality
class
TensorRow
{
public:
static
constexpr
row_id_type
kDefaultRowId
=
-
1
;
// Default row id
// Type definitions
typedef
dsize_t
size_type
;
typedef
std
::
shared_ptr
<
Tensor
>
value_type
;
typedef
std
::
shared_ptr
<
Tensor
>
&
reference
;
typedef
const
std
::
shared_ptr
<
Tensor
>
&
const_reference
;
typedef
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
vector_type
;
typedef
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>::
iterator
iterator
;
typedef
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>::
const_iterator
const_iterator
;
TensorRow
()
noexcept
;
TensorRow
(
size_type
n
,
value_type
t
)
noexcept
;
// Copy Constructors
explicit
TensorRow
(
const
vector_type
&
v
);
TensorRow
(
row_id_type
id
,
const
std
::
initializer_list
<
value_type
>
&
lst
);
TensorRow
(
const
TensorRow
&
tr
);
TensorRow
&
operator
=
(
const
TensorRow
&
tr
);
TensorRow
&
operator
=
(
const
std
::
initializer_list
<
value_type
>
&
lst
);
// Move Constructors
explicit
TensorRow
(
vector_type
&&
v
)
noexcept
;
TensorRow
(
row_id_type
id
,
std
::
initializer_list
<
value_type
>
&&
lst
)
noexcept
;
TensorRow
(
TensorRow
&&
tr
)
noexcept
;
TensorRow
&
operator
=
(
TensorRow
&&
tr
)
noexcept
;
TensorRow
&
operator
=
(
std
::
initializer_list
<
value_type
>
&&
lst
)
noexcept
;
// Destructor
~
TensorRow
()
=
default
;
// Functions to fetch/set id/vector
row_id_type
getId
()
const
{
return
id_
;
}
void
setId
(
row_id_type
id
)
{
id_
=
id
;
}
const
vector_type
&
getRow
()
const
{
return
row_
;
}
// Wrapper functions to support vector operations
void
emplace_back
(
value_type
t
)
{
row_
.
emplace_back
(
t
);
}
void
push_back
(
value_type
t
)
{
row_
.
push_back
(
t
);
}
void
clear
()
noexcept
{
row_
.
clear
();
}
size_type
size
()
const
noexcept
{
return
row_
.
size
();
}
void
reserve
(
size_type
size
)
{
row_
.
reserve
(
size
);
}
void
resize
(
size_type
size
)
{
row_
.
resize
(
size
);
}
bool
empty
()
{
return
row_
.
empty
();
}
void
insert
(
iterator
position
,
iterator
first
,
iterator
last
)
{
row_
.
insert
(
position
,
first
,
last
);
}
// Wrapper functions to support vector element access
reference
at
(
size_type
index
)
{
return
row_
.
at
(
index
);
}
const_reference
at
(
size_type
index
)
const
{
return
row_
.
at
(
index
);
}
reference
front
()
{
return
row_
.
front
();
}
const_reference
front
()
const
{
return
row_
.
front
();
}
reference
back
()
{
return
row_
.
back
();
}
const_reference
back
()
const
{
return
row_
.
back
();
}
reference
operator
[](
size_type
index
)
{
return
row_
[
index
];
}
const_reference
operator
[](
size_type
index
)
const
{
return
row_
[
index
];
}
// Wrapper functions to support vector iteration
iterator
begin
()
{
return
row_
.
begin
();
}
const_iterator
begin
()
const
{
return
row_
.
begin
();
}
iterator
end
()
{
return
row_
.
end
();
}
const_iterator
end
()
const
{
return
row_
.
end
();
}
protected:
row_id_type
id_
;
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
row_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_CORE_TENSOR_ROW_H_
mindspore/ccsrc/dataset/engine/data_buffer.h
浏览文件 @
6ae88c39
...
...
@@ -25,6 +25,7 @@
#include "dataset/util/status.h"
#include "dataset/core/constants.h"
#include "dataset/core/tensor.h"
#include "dataset/core/tensor_row.h"
namespace
mindspore
{
namespace
dataset
{
...
...
mindspore/ccsrc/dataset/engine/datasetops/batch_op.h
浏览文件 @
6ae88c39
...
...
@@ -36,7 +36,7 @@ namespace mindspore {
namespace
dataset
{
class
DataBuffer
;
using
TensorBatch
=
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
;
using
TensorBatch
=
TensorRow
;
using
TensorBatchTable
=
std
::
vector
<
TensorBatch
>
;
using
PadInfo
=
std
::
map
<
std
::
string
,
std
::
pair
<
TensorShape
,
std
::
shared_ptr
<
Tensor
>>>
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
浏览文件 @
6ae88c39
...
...
@@ -349,7 +349,7 @@ Status CelebAOp::LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<Da
std
::
unique_ptr
<
TensorQTable
>
deq
=
std
::
make_unique
<
TensorQTable
>
();
for
(
const
auto
&
key
:
keys
)
{
TensorRow
row
;
RETURN_IF_NOT_OK
(
LoadTensorRow
(
image_labels_vec_
[
key
],
&
row
));
RETURN_IF_NOT_OK
(
LoadTensorRow
(
key
,
image_labels_vec_
[
key
],
&
row
));
deq
->
push_back
(
std
::
move
(
row
));
}
...
...
@@ -357,7 +357,8 @@ Status CelebAOp::LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<Da
return
Status
::
OK
();
}
Status
CelebAOp
::
LoadTensorRow
(
const
std
::
pair
<
std
::
string
,
std
::
vector
<
int32_t
>>
&
image_label
,
TensorRow
*
row
)
{
Status
CelebAOp
::
LoadTensorRow
(
row_id_type
row_id
,
const
std
::
pair
<
std
::
string
,
std
::
vector
<
int32_t
>>
&
image_label
,
TensorRow
*
row
)
{
std
::
shared_ptr
<
Tensor
>
image
;
std
::
shared_ptr
<
Tensor
>
label
;
...
...
@@ -398,7 +399,7 @@ Status CelebAOp::LoadTensorRow(const std::pair<std::string, std::vector<int32_t>
}
label
->
Squeeze
();
(
*
row
)
=
{
std
::
move
(
image
),
std
::
move
(
label
)}
;
(
*
row
)
=
TensorRow
(
row_id
,
{
std
::
move
(
image
),
std
::
move
(
label
)})
;
return
Status
::
OK
();
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h
浏览文件 @
6ae88c39
...
...
@@ -197,10 +197,12 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
Status
LoadBuffer
(
const
std
::
vector
<
int64_t
>
&
keys
,
std
::
unique_ptr
<
DataBuffer
>
*
db
);
// Load a tensor row according to a pair
// @param row_id_type row_id - id for this tensor row
// @param std::pair - <image_file,<label>>
// @param TensorRow row - image & label read into this tensor row
// @return Status - The error code return
Status
LoadTensorRow
(
const
std
::
pair
<
std
::
string
,
std
::
vector
<
int32_t
>>
&
image_label
,
TensorRow
*
row
);
Status
LoadTensorRow
(
row_id_type
row_id
,
const
std
::
pair
<
std
::
string
,
std
::
vector
<
int32_t
>>
&
image_label
,
TensorRow
*
row
);
// Check if need read according to dataset type
// @return bool - if need read
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
浏览文件 @
6ae88c39
...
...
@@ -203,9 +203,9 @@ Status CifarOp::LoadTensorRow(uint64_t index, TensorRow *trow) {
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
&
fine_label
,
data_schema_
->
column
(
2
).
tensorImpl
(),
data_schema_
->
column
(
2
).
shape
(),
data_schema_
->
column
(
2
).
type
(),
reinterpret_cast
<
unsigned
char
*>
(
&
cifar_image_label_pairs_
[
index
].
second
[
1
])));
(
*
trow
)
=
{
copy_image
,
std
::
move
(
label
),
std
::
move
(
fine_label
)}
;
(
*
trow
)
=
TensorRow
(
index
,
{
copy_image
,
std
::
move
(
label
),
std
::
move
(
fine_label
)})
;
}
else
{
(
*
trow
)
=
{
copy_image
,
std
::
move
(
label
)}
;
(
*
trow
)
=
TensorRow
(
index
,
{
copy_image
,
std
::
move
(
label
)})
;
}
return
Status
::
OK
();
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.cc
浏览文件 @
6ae88c39
...
...
@@ -213,7 +213,7 @@ Status CocoOp::Reset() {
return
Status
::
OK
();
}
Status
CocoOp
::
LoadTensorRow
(
const
std
::
string
&
image_id
,
TensorRow
*
trow
)
{
Status
CocoOp
::
LoadTensorRow
(
row_id_type
row_id
,
const
std
::
string
&
image_id
,
TensorRow
*
trow
)
{
std
::
shared_ptr
<
Tensor
>
image
,
coordinate
;
auto
itr
=
coordinate_map_
.
find
(
image_id
);
if
(
itr
==
coordinate_map_
.
end
())
RETURN_STATUS_UNEXPECTED
(
"Invalid image_id found :"
+
image_id
);
...
...
@@ -246,11 +246,11 @@ Status CocoOp::LoadTensorRow(const std::string &image_id, TensorRow *trow) {
data_schema_
->
column
(
1
).
type
(),
reinterpret_cast
<
unsigned
char
*>
(
&
bbox_row
[
0
])));
if
(
task_type_
==
TaskType
::
Detection
)
{
RETURN_IF_NOT_OK
(
LoadDetectionTensorRow
(
image_id
,
image
,
coordinate
,
trow
));
RETURN_IF_NOT_OK
(
LoadDetectionTensorRow
(
row_id
,
image_id
,
image
,
coordinate
,
trow
));
}
else
if
(
task_type_
==
TaskType
::
Stuff
||
task_type_
==
TaskType
::
Keypoint
)
{
RETURN_IF_NOT_OK
(
LoadSimpleTensorRow
(
image_id
,
image
,
coordinate
,
trow
));
RETURN_IF_NOT_OK
(
LoadSimpleTensorRow
(
row_id
,
image_id
,
image
,
coordinate
,
trow
));
}
else
if
(
task_type_
==
TaskType
::
Panoptic
)
{
RETURN_IF_NOT_OK
(
LoadMixTensorRow
(
image_id
,
image
,
coordinate
,
trow
));
RETURN_IF_NOT_OK
(
LoadMixTensorRow
(
row_id
,
image_id
,
image
,
coordinate
,
trow
));
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Invalid task type."
);
}
...
...
@@ -265,7 +265,7 @@ Status CocoOp::LoadTensorRow(const std::string &image_id, TensorRow *trow) {
// column ["iscrowd"] with datatype=uint32
// By the way, column ["iscrowd"] is used for some testcases, like fasterRcnn.
// If "iscrowd" is not existed, user will get default value 0.
Status
CocoOp
::
LoadDetectionTensorRow
(
const
std
::
string
&
image_id
,
std
::
shared_ptr
<
Tensor
>
image
,
Status
CocoOp
::
LoadDetectionTensorRow
(
row_id_type
row_id
,
const
std
::
string
&
image_id
,
std
::
shared_ptr
<
Tensor
>
image
,
std
::
shared_ptr
<
Tensor
>
coordinate
,
TensorRow
*
trow
)
{
std
::
shared_ptr
<
Tensor
>
category_id
,
iscrowd
;
std
::
vector
<
uint32_t
>
category_id_row
;
...
...
@@ -288,7 +288,7 @@ Status CocoOp::LoadDetectionTensorRow(const std::string &image_id, std::shared_p
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
&
iscrowd
,
data_schema_
->
column
(
3
).
tensorImpl
(),
TensorShape
({
static_cast
<
dsize_t
>
(
iscrowd_row
.
size
()),
1
}),
data_schema_
->
column
(
3
).
type
(),
reinterpret_cast
<
unsigned
char
*>
(
&
iscrowd_row
[
0
])));
(
*
trow
)
=
{
std
::
move
(
image
),
std
::
move
(
coordinate
),
std
::
move
(
category_id
),
std
::
move
(
iscrowd
)}
;
(
*
trow
)
=
TensorRow
(
row_id
,
{
std
::
move
(
image
),
std
::
move
(
coordinate
),
std
::
move
(
category_id
),
std
::
move
(
iscrowd
)})
;
return
Status
::
OK
();
}
...
...
@@ -296,7 +296,7 @@ Status CocoOp::LoadDetectionTensorRow(const std::string &image_id, std::shared_p
// column ["image"] with datatype=uint8
// column ["segmentation"]/["keypoints"] with datatype=float32
// column ["iscrowd"]/["num_keypoints"] with datatype=uint32
Status
CocoOp
::
LoadSimpleTensorRow
(
const
std
::
string
&
image_id
,
std
::
shared_ptr
<
Tensor
>
image
,
Status
CocoOp
::
LoadSimpleTensorRow
(
row_id_type
row_id
,
const
std
::
string
&
image_id
,
std
::
shared_ptr
<
Tensor
>
image
,
std
::
shared_ptr
<
Tensor
>
coordinate
,
TensorRow
*
trow
)
{
std
::
shared_ptr
<
Tensor
>
item
;
std
::
vector
<
uint32_t
>
item_queue
;
...
...
@@ -308,7 +308,7 @@ Status CocoOp::LoadSimpleTensorRow(const std::string &image_id, std::shared_ptr<
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
&
item
,
data_schema_
->
column
(
2
).
tensorImpl
(),
TensorShape
(
bbox_dim
),
data_schema_
->
column
(
2
).
type
(),
reinterpret_cast
<
unsigned
char
*>
(
&
item_queue
[
0
])));
(
*
trow
)
=
{
std
::
move
(
image
),
std
::
move
(
coordinate
),
std
::
move
(
item
)}
;
(
*
trow
)
=
TensorRow
(
row_id
,
{
std
::
move
(
image
),
std
::
move
(
coordinate
),
std
::
move
(
item
)})
;
return
Status
::
OK
();
}
...
...
@@ -318,7 +318,7 @@ Status CocoOp::LoadSimpleTensorRow(const std::string &image_id, std::shared_ptr<
// column ["category_id"] with datatype=uint32
// column ["iscrowd"] with datatype=uint32
// column ["area"] with datattype=uint32
Status
CocoOp
::
LoadMixTensorRow
(
const
std
::
string
&
image_id
,
std
::
shared_ptr
<
Tensor
>
image
,
Status
CocoOp
::
LoadMixTensorRow
(
row_id_type
row_id
,
const
std
::
string
&
image_id
,
std
::
shared_ptr
<
Tensor
>
image
,
std
::
shared_ptr
<
Tensor
>
coordinate
,
TensorRow
*
trow
)
{
std
::
shared_ptr
<
Tensor
>
category_id
,
iscrowd
,
area
;
std
::
vector
<
uint32_t
>
category_id_row
;
...
...
@@ -349,15 +349,16 @@ Status CocoOp::LoadMixTensorRow(const std::string &image_id, std::shared_ptr<Ten
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
&
area
,
data_schema_
->
column
(
4
).
tensorImpl
(),
TensorShape
({
static_cast
<
dsize_t
>
(
area_row
.
size
()),
1
}),
data_schema_
->
column
(
4
).
type
(),
reinterpret_cast
<
unsigned
char
*>
(
&
area_row
[
0
])));
(
*
trow
)
=
{
std
::
move
(
image
),
std
::
move
(
coordinate
),
std
::
move
(
category_id
),
std
::
move
(
iscrowd
),
std
::
move
(
area
)};
(
*
trow
)
=
TensorRow
(
row_id
,
{
std
::
move
(
image
),
std
::
move
(
coordinate
),
std
::
move
(
category_id
),
std
::
move
(
iscrowd
),
std
::
move
(
area
)});
return
Status
::
OK
();
}
Status
CocoOp
::
LoadBuffer
(
const
std
::
vector
<
int64_t
>
&
keys
,
std
::
unique_ptr
<
DataBuffer
>
*
db
)
{
std
::
unique_ptr
<
TensorQTable
>
deq
=
std
::
make_unique
<
TensorQTable
>
();
TensorRow
trow
;
for
(
const
u
int64_t
&
key
:
keys
)
{
RETURN_IF_NOT_OK
(
this
->
LoadTensorRow
(
image_ids_
[
key
],
&
trow
));
for
(
const
int64_t
&
key
:
keys
)
{
RETURN_IF_NOT_OK
(
this
->
LoadTensorRow
(
key
,
image_ids_
[
key
],
&
trow
));
deq
->
push_back
(
std
::
move
(
trow
));
}
(
*
db
)
->
set_tensor_table
(
std
::
move
(
deq
));
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.h
浏览文件 @
6ae88c39
...
...
@@ -205,36 +205,40 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
Status
InitSampler
();
// Load a tensor row according to image id
// @param row_id_type row_id - id for this tensor row
// @param std::string image_id - image id
// @param TensorRow row - image & target read into this tensor row
// @return Status - The error code return
Status
LoadTensorRow
(
const
std
::
string
&
image_id
,
TensorRow
*
row
);
Status
LoadTensorRow
(
row_id_type
row_id
,
const
std
::
string
&
image_id
,
TensorRow
*
row
);
// Load a tensor row with vector which a vector to a tensor
// @param row_id_type row_id - id for this tensor row
// @param const std::string &image_id - image is
// @param std::shared_ptr<Tensor> image - image tensor
// @param std::shared_ptr<Tensor> coordinate - coordinate tensor
// @param TensorRow row - image & target read into this tensor row
// @return Status - The error code return
Status
LoadDetectionTensorRow
(
const
std
::
string
&
image_id
,
std
::
shared_ptr
<
Tensor
>
image
,
Status
LoadDetectionTensorRow
(
row_id_type
row_id
,
const
std
::
string
&
image_id
,
std
::
shared_ptr
<
Tensor
>
image
,
std
::
shared_ptr
<
Tensor
>
coordinate
,
TensorRow
*
trow
);
// Load a tensor row with vector which a vector to a tensor
// @param row_id_type row_id - id for this tensor row
// @param const std::string &image_id - image is
// @param std::shared_ptr<Tensor> image - image tensor
// @param std::shared_ptr<Tensor> coordinate - coordinate tensor
// @param TensorRow row - image & target read into this tensor row
// @return Status - The error code return
Status
LoadSimpleTensorRow
(
const
std
::
string
&
image_id
,
std
::
shared_ptr
<
Tensor
>
image
,
Status
LoadSimpleTensorRow
(
row_id_type
row_id
,
const
std
::
string
&
image_id
,
std
::
shared_ptr
<
Tensor
>
image
,
std
::
shared_ptr
<
Tensor
>
coordinate
,
TensorRow
*
trow
);
// Load a tensor row with vector which a vector to multi-tensor
// @param row_id_type row_id - id for this tensor row
// @param const std::string &image_id - image is
// @param std::shared_ptr<Tensor> image - image tensor
// @param std::shared_ptr<Tensor> coordinate - coordinate tensor
// @param TensorRow row - image & target read into this tensor row
// @return Status - The error code return
Status
LoadMixTensorRow
(
const
std
::
string
&
image_id
,
std
::
shared_ptr
<
Tensor
>
image
,
Status
LoadMixTensorRow
(
row_id_type
row_id
,
const
std
::
string
&
image_id
,
std
::
shared_ptr
<
Tensor
>
image
,
std
::
shared_ptr
<
Tensor
>
coordinate
,
TensorRow
*
trow
);
// @param const std::string &path - path to the image file
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc
浏览文件 @
6ae88c39
...
...
@@ -199,7 +199,7 @@ Status ImageFolderOp::WorkerEntry(int32_t worker_id) {
}
// Load 1 TensorRow (image,label) using 1 ImageLabelPair. 1 function call produces 1 TensorTow in a DataBuffer
Status
ImageFolderOp
::
LoadTensorRow
(
ImageLabelPair
pairPtr
,
TensorRow
*
trow
)
{
Status
ImageFolderOp
::
LoadTensorRow
(
row_id_type
row_id
,
ImageLabelPair
pairPtr
,
TensorRow
*
trow
)
{
std
::
shared_ptr
<
Tensor
>
image
,
label
;
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
&
label
,
data_schema_
->
column
(
1
).
tensorImpl
(),
data_schema_
->
column
(
1
).
shape
(),
data_schema_
->
column
(
1
).
type
(),
...
...
@@ -223,7 +223,7 @@ Status ImageFolderOp::LoadTensorRow(ImageLabelPair pairPtr, TensorRow *trow) {
RETURN_STATUS_UNEXPECTED
(
err
);
}
}
(
*
trow
)
=
{
std
::
move
(
image
),
std
::
move
(
label
)}
;
(
*
trow
)
=
TensorRow
(
row_id
,
{
std
::
move
(
image
),
std
::
move
(
label
)})
;
return
Status
::
OK
();
}
...
...
@@ -232,7 +232,7 @@ Status ImageFolderOp::LoadBuffer(const std::vector<int64_t> &keys, std::unique_p
std
::
unique_ptr
<
TensorQTable
>
deq
=
std
::
make_unique
<
TensorQTable
>
();
TensorRow
trow
;
for
(
const
int64_t
&
key
:
keys
)
{
RETURN_IF_NOT_OK
(
this
->
LoadTensorRow
(
image_label_pairs_
[
key
],
&
trow
));
RETURN_IF_NOT_OK
(
this
->
LoadTensorRow
(
key
,
image_label_pairs_
[
key
],
&
trow
));
deq
->
push_back
(
std
::
move
(
trow
));
}
(
*
db
)
->
set_tensor_table
(
std
::
move
(
deq
));
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h
浏览文件 @
6ae88c39
...
...
@@ -220,10 +220,11 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
Status
InitSampler
();
// Load a tensor row according to a pair
// @param row_id_type row_id - id for this tensor row
// @param ImageLabelPair pair - <imagefile,label>
// @param TensorRow row - image & label read into this tensor row
// @return Status - The error code return
Status
LoadTensorRow
(
ImageLabelPair
pair
,
TensorRow
*
row
);
Status
LoadTensorRow
(
row_id_type
row_id
,
ImageLabelPair
pair
,
TensorRow
*
row
);
// @param const std::vector<int64_t> &keys - keys in ioblock
// @param std::unique_ptr<DataBuffer> db
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
浏览文件 @
6ae88c39
...
...
@@ -182,7 +182,8 @@ Status ManifestOp::WorkerEntry(int32_t worker_id) {
}
// Load 1 TensorRow (image,label) using 1 ImageLabelPair. 1 function call produces 1 TensorTow in a DataBuffer
Status
ManifestOp
::
LoadTensorRow
(
const
std
::
pair
<
std
::
string
,
std
::
vector
<
std
::
string
>>
&
data
,
TensorRow
*
trow
)
{
Status
ManifestOp
::
LoadTensorRow
(
row_id_type
row_id
,
const
std
::
pair
<
std
::
string
,
std
::
vector
<
std
::
string
>>
&
data
,
TensorRow
*
trow
)
{
std
::
shared_ptr
<
Tensor
>
image
;
std
::
shared_ptr
<
Tensor
>
label
;
std
::
vector
<
int32_t
>
label_index
(
data
.
second
.
size
());
...
...
@@ -222,7 +223,7 @@ Status ManifestOp::LoadTensorRow(const std::pair<std::string, std::vector<std::s
RETURN_STATUS_UNEXPECTED
(
err
);
}
}
(
*
trow
)
=
{
std
::
move
(
image
),
std
::
move
(
label
)}
;
(
*
trow
)
=
TensorRow
(
row_id
,
{
std
::
move
(
image
),
std
::
move
(
label
)})
;
return
Status
::
OK
();
}
...
...
@@ -231,7 +232,7 @@ Status ManifestOp::LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<
std
::
unique_ptr
<
TensorQTable
>
deq
=
std
::
make_unique
<
TensorQTable
>
();
for
(
const
auto
&
key
:
keys
)
{
TensorRow
trow
;
RETURN_IF_NOT_OK
(
LoadTensorRow
(
image_labelname_
[
static_cast
<
size_t
>
(
key
)],
&
trow
));
RETURN_IF_NOT_OK
(
LoadTensorRow
(
key
,
image_labelname_
[
static_cast
<
size_t
>
(
key
)],
&
trow
));
deq
->
push_back
(
std
::
move
(
trow
));
}
(
*
db
)
->
set_tensor_table
(
std
::
move
(
deq
));
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h
浏览文件 @
6ae88c39
...
...
@@ -187,10 +187,12 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
Status
AddIoBlock
(
std
::
unique_ptr
<
DataBuffer
>
*
sampler_buffer
);
// Load a tensor row according to a pair
// @param row_id_type row_id - id for this tensor row
// @param std::pair<std::string, std::vector<std::string>> - <imagefile, <label1, label2...>>
// @param TensorRow row - image & label read into this tensor row
// @return Status - The error code return
Status
LoadTensorRow
(
const
std
::
pair
<
std
::
string
,
std
::
vector
<
std
::
string
>>
&
data
,
TensorRow
*
row
);
Status
LoadTensorRow
(
row_id_type
row_id
,
const
std
::
pair
<
std
::
string
,
std
::
vector
<
std
::
string
>>
&
data
,
TensorRow
*
row
);
// @param const std::vector<int64_t> &keys - keys in ioblock
// @param std::unique_ptr<DataBuffer> db
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
浏览文件 @
6ae88c39
...
...
@@ -162,7 +162,7 @@ Status MnistOp::WorkerEntry(int32_t worker_id) {
}
// Load 1 TensorRow (image,label) using 1 MnistLabelPair.
Status
MnistOp
::
LoadTensorRow
(
const
MnistLabelPair
&
mnist_pair
,
TensorRow
*
trow
)
{
Status
MnistOp
::
LoadTensorRow
(
row_id_type
row_id
,
const
MnistLabelPair
&
mnist_pair
,
TensorRow
*
trow
)
{
std
::
shared_ptr
<
Tensor
>
image
,
label
;
int32_t
l
=
mnist_pair
.
second
;
// make a copy of cached tensor
...
...
@@ -170,7 +170,7 @@ Status MnistOp::LoadTensorRow(const MnistLabelPair &mnist_pair, TensorRow *trow)
mnist_pair
.
first
->
type
(),
mnist_pair
.
first
->
GetMutableBuffer
()));
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
&
label
,
data_schema_
->
column
(
1
).
tensorImpl
(),
data_schema_
->
column
(
1
).
shape
(),
data_schema_
->
column
(
1
).
type
(),
reinterpret_cast
<
unsigned
char
*>
(
&
l
)));
(
*
trow
)
=
{
std
::
move
(
image
),
std
::
move
(
label
)}
;
(
*
trow
)
=
TensorRow
(
row_id
,
{
std
::
move
(
image
),
std
::
move
(
label
)})
;
return
Status
::
OK
();
}
...
...
@@ -179,7 +179,7 @@ Status MnistOp::LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<Dat
std
::
unique_ptr
<
TensorQTable
>
deq
=
std
::
make_unique
<
TensorQTable
>
();
TensorRow
trow
;
for
(
const
int64_t
&
key
:
keys
)
{
RETURN_IF_NOT_OK
(
this
->
LoadTensorRow
(
image_label_pairs_
[
key
],
&
trow
));
RETURN_IF_NOT_OK
(
this
->
LoadTensorRow
(
key
,
image_label_pairs_
[
key
],
&
trow
));
deq
->
push_back
(
std
::
move
(
trow
));
}
(
*
db
)
->
set_tensor_table
(
std
::
move
(
deq
));
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h
浏览文件 @
6ae88c39
...
...
@@ -162,10 +162,11 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
Status
InitSampler
();
// Load a tensor row according to a pair
// @param row_id_type row_id - id for this tensor row
// @param ImageLabelPair pair - <imagefile,label>
// @param TensorRow row - image & label read into this tensor row
// @return Status - The error code return
Status
LoadTensorRow
(
const
MnistLabelPair
&
mnist_pair
,
TensorRow
*
row
);
Status
LoadTensorRow
(
row_id_type
row_id
,
const
MnistLabelPair
&
mnist_pair
,
TensorRow
*
row
);
// @param const std::vector<int64_t> &keys - keys in ioblock
// @param std::unique_ptr<DataBuffer> db
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
浏览文件 @
6ae88c39
...
...
@@ -183,7 +183,7 @@ Status VOCOp::Reset() {
return
Status
::
OK
();
}
Status
VOCOp
::
LoadTensorRow
(
const
std
::
string
&
image_id
,
TensorRow
*
trow
)
{
Status
VOCOp
::
LoadTensorRow
(
row_id_type
row_id
,
const
std
::
string
&
image_id
,
TensorRow
*
trow
)
{
if
(
task_type_
==
TaskType
::
Segmentation
)
{
std
::
shared_ptr
<
Tensor
>
image
,
target
;
const
std
::
string
kImageFile
=
...
...
@@ -192,7 +192,7 @@ Status VOCOp::LoadTensorRow(const std::string &image_id, TensorRow *trow) {
folder_path_
+
std
::
string
(
kSegmentationClassFolder
)
+
image_id
+
std
::
string
(
kSegmentationExtension
);
RETURN_IF_NOT_OK
(
ReadImageToTensor
(
kImageFile
,
data_schema_
->
column
(
0
),
&
image
));
RETURN_IF_NOT_OK
(
ReadImageToTensor
(
kTargetFile
,
data_schema_
->
column
(
1
),
&
target
));
(
*
trow
)
=
{
std
::
move
(
image
),
std
::
move
(
target
)}
;
(
*
trow
)
=
TensorRow
(
row_id
,
{
std
::
move
(
image
),
std
::
move
(
target
)})
;
}
else
if
(
task_type_
==
TaskType
::
Detection
)
{
std
::
shared_ptr
<
Tensor
>
image
,
annotation
;
const
std
::
string
kImageFile
=
...
...
@@ -201,7 +201,7 @@ Status VOCOp::LoadTensorRow(const std::string &image_id, TensorRow *trow) {
folder_path_
+
std
::
string
(
kAnnotationsFolder
)
+
image_id
+
std
::
string
(
kAnnotationExtension
);
RETURN_IF_NOT_OK
(
ReadImageToTensor
(
kImageFile
,
data_schema_
->
column
(
0
),
&
image
));
RETURN_IF_NOT_OK
(
ReadAnnotationToTensor
(
kAnnotationFile
,
data_schema_
->
column
(
1
),
&
annotation
));
(
*
trow
)
=
{
std
::
move
(
image
),
std
::
move
(
annotation
)}
;
(
*
trow
)
=
TensorRow
(
row_id
,
{
std
::
move
(
image
),
std
::
move
(
annotation
)})
;
}
return
Status
::
OK
();
}
...
...
@@ -210,7 +210,7 @@ Status VOCOp::LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataB
std
::
unique_ptr
<
TensorQTable
>
deq
=
std
::
make_unique
<
TensorQTable
>
();
TensorRow
trow
;
for
(
const
uint64_t
&
key
:
keys
)
{
RETURN_IF_NOT_OK
(
this
->
LoadTensorRow
(
image_ids_
[
key
],
&
trow
));
RETURN_IF_NOT_OK
(
this
->
LoadTensorRow
(
key
,
image_ids_
[
key
],
&
trow
));
deq
->
push_back
(
std
::
move
(
trow
));
}
(
*
db
)
->
set_tensor_table
(
std
::
move
(
deq
));
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h
浏览文件 @
6ae88c39
...
...
@@ -215,10 +215,11 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
Status
InitSampler
();
// Load a tensor row according to image id
// @param row_id_type row_id - id for this tensor row
// @param std::string image_id - image id
// @param TensorRow row - image & target read into this tensor row
// @return Status - The error code return
Status
LoadTensorRow
(
const
std
::
string
&
image_id
,
TensorRow
*
row
);
Status
LoadTensorRow
(
row_id_type
row_id
,
const
std
::
string
&
image_id
,
TensorRow
*
row
);
// @param const std::string &path - path to the image file
// @param const ColDescriptor &col - contains tensor implementation and datatype
...
...
mindspore/ccsrc/dataset/engine/gnn/graph.h
浏览文件 @
6ae88c39
...
...
@@ -24,6 +24,7 @@
#include <vector>
#include "dataset/core/tensor.h"
#include "dataset/core/tensor_row.h"
#include "dataset/engine/gnn/graph_loader.h"
#include "dataset/engine/gnn/feature.h"
#include "dataset/engine/gnn/node.h"
...
...
mindspore/ccsrc/dataset/engine/tdt/tdt_plugin.h
浏览文件 @
6ae88c39
...
...
@@ -26,6 +26,7 @@
#include "dataset/core/data_type.h"
#include "dataset/core/tensor.h"
#include "dataset/core/tensor_row.h"
namespace
mindspore
{
namespace
dataset
{
...
...
mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc
浏览文件 @
6ae88c39
...
...
@@ -27,8 +27,7 @@ UniformAugOp::UniformAugOp(std::vector<std::shared_ptr<TensorOp>> op_list, int32
}
// compute method to apply uniformly random selected augmentations from a list
Status
UniformAugOp
::
Compute
(
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
&
input
,
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
*
output
)
{
Status
UniformAugOp
::
Compute
(
const
TensorRow
&
input
,
TensorRow
*
output
)
{
IO_CHECK_VECTOR
(
input
,
output
);
// randomly select ops to be applied
...
...
mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h
浏览文件 @
6ae88c39
...
...
@@ -44,8 +44,7 @@ class UniformAugOp : public TensorOp {
// Overrides the base class compute function
// @return Status - The error code return
Status
Compute
(
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
&
input
,
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
*
output
)
override
;
Status
Compute
(
const
TensorRow
&
input
,
TensorRow
*
output
)
override
;
private:
int32_t
num_ops_
;
...
...
mindspore/ccsrc/dataset/kernels/py_func_op.cc
浏览文件 @
6ae88c39
...
...
@@ -24,8 +24,7 @@
namespace
mindspore
{
namespace
dataset
{
Status
PyFuncOp
::
Compute
(
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
&
input
,
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
*
output
)
{
Status
PyFuncOp
::
Compute
(
const
TensorRow
&
input
,
TensorRow
*
output
)
{
IO_CHECK_VECTOR
(
input
,
output
);
Status
ret
=
Status
(
StatusCode
::
kOK
,
"PyFunc Call Succeed"
);
{
...
...
mindspore/ccsrc/dataset/kernels/py_func_op.h
浏览文件 @
6ae88c39
...
...
@@ -36,8 +36,7 @@ class __attribute__((visibility("hidden"))) PyFuncOp : public TensorOp {
uint32_t
NumOutput
()
override
{
return
0
;
}
// Compute function for n-n mapping.
Status
Compute
(
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
&
input
,
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
*
output
)
override
;
Status
Compute
(
const
TensorRow
&
input
,
TensorRow
*
output
)
override
;
private:
py
::
function
py_func_ptr_
;
...
...
mindspore/ccsrc/dataset/kernels/tensor_op.cc
浏览文件 @
6ae88c39
...
...
@@ -37,8 +37,7 @@ Status TensorOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T
// Name: Compute()
// Description: This Compute() take multiple Tensors from different columns and produce multiple Tensors too.
// The derived class should override this function otherwise error.
Status
TensorOp
::
Compute
(
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
&
input
,
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
*
output
)
{
Status
TensorOp
::
Compute
(
const
TensorRow
&
input
,
TensorRow
*
output
)
{
IO_CHECK_VECTOR
(
input
,
output
);
if
(
OneToOne
())
{
output
->
resize
(
1
);
...
...
mindspore/ccsrc/dataset/kernels/tensor_op.h
浏览文件 @
6ae88c39
...
...
@@ -21,6 +21,7 @@
#include <vector>
#include "dataset/core/tensor.h"
#include "dataset/core/tensor_row.h"
#include "dataset/util/status.h"
#define IO_CHECK(input, output) \
...
...
@@ -75,8 +76,7 @@ class TensorOp {
// @param input is a vector of shared_ptr to Tensor (pass by const reference).
// @param output is the address to an empty vector of shared_ptr to Tensor.
// @return Status
virtual
Status
Compute
(
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
&
input
,
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
*
output
);
virtual
Status
Compute
(
const
TensorRow
&
input
,
TensorRow
*
output
);
// Returns true oif the TensorOp takes one input and returns one output.
// @return true/false
...
...
tests/ut/cpp/dataset/map_op_test.cc
浏览文件 @
6ae88c39
...
...
@@ -55,8 +55,7 @@ class ThreeToOneOp : public TensorOp {
uint32_t
NumInput
()
override
{
return
3
;
}
// Compute function that holds the actual implementation of the operation.
Status
Compute
(
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
&
input
,
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
*
output
)
override
{
Status
Compute
(
const
TensorRow
&
input
,
TensorRow
*
output
)
override
{
output
->
push_back
(
input
[
0
]);
return
Status
::
OK
();
};
...
...
@@ -74,8 +73,7 @@ class OneToThreeOp : public TensorOp {
// Compute function that holds the actual implementation of the operation.
// Simply pushing the same shared pointer of the first element of input vector three times.
Status
Compute
(
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
&
input
,
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
*
output
)
override
{
Status
Compute
(
const
TensorRow
&
input
,
TensorRow
*
output
)
override
{
output
->
push_back
(
input
[
0
]);
output
->
push_back
(
input
[
0
]);
output
->
push_back
(
input
[
0
]);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录