Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ea829e89
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ea829e89
编写于
6月 18, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 18, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2198 Add MaskOp to dataset
Merge pull request !2198 from h.farahat/mask_op
上级
e94ad524
f2462bb0
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
560 addition
and
25 deletion
+560
-25
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+15
-1
mindspore/ccsrc/dataset/core/tensor.cc
mindspore/ccsrc/dataset/core/tensor.cc
+6
-6
mindspore/ccsrc/dataset/core/tensor.h
mindspore/ccsrc/dataset/core/tensor.h
+1
-1
mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt
mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt
+7
-6
mindspore/ccsrc/dataset/kernels/data/data_utils.cc
mindspore/ccsrc/dataset/kernels/data/data_utils.cc
+100
-1
mindspore/ccsrc/dataset/kernels/data/data_utils.h
mindspore/ccsrc/dataset/kernels/data/data_utils.h
+29
-0
mindspore/ccsrc/dataset/kernels/data/mask_op.cc
mindspore/ccsrc/dataset/kernels/data/mask_op.cc
+49
-0
mindspore/ccsrc/dataset/kernels/data/mask_op.h
mindspore/ccsrc/dataset/kernels/data/mask_op.h
+54
-0
mindspore/ccsrc/dataset/kernels/data/slice_op.cc
mindspore/ccsrc/dataset/kernels/data/slice_op.cc
+1
-2
mindspore/ccsrc/dataset/kernels/data/slice_op.h
mindspore/ccsrc/dataset/kernels/data/slice_op.h
+3
-3
mindspore/dataset/transforms/c_transforms.py
mindspore/dataset/transforms/c_transforms.py
+53
-3
mindspore/dataset/transforms/validators.py
mindspore/dataset/transforms/validators.py
+37
-0
tests/ut/cpp/dataset/mask_test.cc
tests/ut/cpp/dataset/mask_test.cc
+63
-0
tests/ut/python/dataset/test_mask_op.py
tests/ut/python/dataset/test_mask_op.py
+132
-0
tests/ut/python/dataset/test_slice_op.py
tests/ut/python/dataset/test_slice_op.py
+10
-2
未找到文件。
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
ea829e89
...
...
@@ -38,6 +38,7 @@
#include "dataset/kernels/image/resize_op.h"
#include "dataset/kernels/image/uniform_aug_op.h"
#include "dataset/kernels/data/fill_op.h"
#include "dataset/kernels/data/mask_op.h"
#include "dataset/kernels/data/slice_op.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
...
...
@@ -383,7 +384,7 @@ void bindTensorOps2(py::module *m) {
*
m
,
"FillOp"
,
"Tensor operation to return tensor filled with same value as input fill value."
)
.
def
(
py
::
init
<
std
::
shared_ptr
<
Tensor
>>
());
(
void
)
py
::
class_
<
SliceOp
,
TensorOp
,
std
::
shared_ptr
<
SliceOp
>>
(
*
m
,
"SliceOp"
,
""
)
(
void
)
py
::
class_
<
SliceOp
,
TensorOp
,
std
::
shared_ptr
<
SliceOp
>>
(
*
m
,
"SliceOp"
,
"
Tensor Slice operation.
"
)
.
def
(
py
::
init
<
bool
>
())
.
def
(
py
::
init
([](
const
py
::
list
&
py_list
)
{
std
::
vector
<
dsize_t
>
c_list
;
...
...
@@ -414,6 +415,19 @@ void bindTensorOps2(py::module *m) {
return
std
::
make_shared
<
SliceOp
>
(
c_slice
);
}));
(
void
)
py
::
enum_
<
RelationalOp
>
(
*
m
,
"RelationalOp"
,
py
::
arithmetic
())
.
value
(
"EQ"
,
RelationalOp
::
kEqual
)
.
value
(
"NE"
,
RelationalOp
::
kNotEqual
)
.
value
(
"LT"
,
RelationalOp
::
kLess
)
.
value
(
"LE"
,
RelationalOp
::
kLessEqual
)
.
value
(
"GT"
,
RelationalOp
::
kGreater
)
.
value
(
"GE"
,
RelationalOp
::
kGreaterEqual
)
.
export_values
();
(
void
)
py
::
class_
<
MaskOp
,
TensorOp
,
std
::
shared_ptr
<
MaskOp
>>
(
*
m
,
"MaskOp"
,
"Tensor operation mask using relational comparator"
)
.
def
(
py
::
init
<
RelationalOp
,
std
::
shared_ptr
<
Tensor
>
,
DataType
>
());
(
void
)
py
::
class_
<
RandomRotationOp
,
TensorOp
,
std
::
shared_ptr
<
RandomRotationOp
>>
(
*
m
,
"RandomRotationOp"
,
"Tensor operation to apply RandomRotation."
...
...
mindspore/ccsrc/dataset/core/tensor.cc
浏览文件 @
ea829e89
...
...
@@ -699,7 +699,7 @@ Status Tensor::GetItemAt(T *o, const std::vector<dsize_t> &index) const {
Status
Tensor
::
GetItemAt
(
std
::
string_view
*
o
,
const
std
::
vector
<
dsize_t
>
&
index
)
const
{
RETURN_UNEXPECTED_IF_NULL
(
data_
);
RETURN_UNEXPECTED_IF_NULL
(
o
);
CHECK_FAIL_RETURN_UNEXPECTED
(
type_
==
DataType
::
DE_STRING
,
"T
ype is not DE_STRING
"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
type_
==
DataType
::
DE_STRING
,
"T
ensor type is not a string
"
);
uchar
*
start
=
nullptr
;
offset_t
length
=
0
;
...
...
@@ -932,17 +932,17 @@ Status Tensor::SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsiz
dsize_t
out_index
=
0
;
dsize_t
dim_length
=
shape_
[
0
];
dsize_t
type_size
=
type_
.
SizeInBytes
();
dsize_t
src_start
=
h
andleNeg
(
indices
[
0
],
dim_length
);
dsize_t
src_start
=
H
andleNeg
(
indices
[
0
],
dim_length
);
uchar
*
dst_addr
=
(
*
out
)
->
data_
;
dsize_t
count
=
1
;
for
(
dsize_t
i
=
0
;
i
<
indices
.
size
();
i
++
)
{
dsize_t
cur_index
=
h
andleNeg
(
indices
[
i
],
dim_length
);
dsize_t
cur_index
=
H
andleNeg
(
indices
[
i
],
dim_length
);
CHECK_FAIL_RETURN_UNEXPECTED
(
cur_index
>=
0
&&
cur_index
<
dim_length
,
"Index "
+
std
::
to_string
(
indices
[
i
])
+
" is out of bounds [0,"
+
std
::
to_string
(
dim_length
)
+
")"
);
if
(
i
<
indices
.
size
()
-
1
)
{
dsize_t
next_index
=
h
andleNeg
(
indices
[
i
+
1
],
dim_length
);
dsize_t
next_index
=
H
andleNeg
(
indices
[
i
+
1
],
dim_length
);
if
(
next_index
==
cur_index
+
1
)
{
count
++
;
continue
;
...
...
@@ -951,7 +951,7 @@ Status Tensor::SliceNumeric(std::shared_ptr<Tensor> *out, const std::vector<dsiz
memcpy_s
(
dst_addr
+
out_index
*
type_size
,
(
*
out
)
->
SizeInBytes
(),
data_
+
src_start
*
type_size
,
count
*
type_size
);
out_index
+=
count
;
if
(
i
<
indices
.
size
()
-
1
)
{
src_start
=
h
andleNeg
(
indices
[
i
+
1
],
dim_length
);
// next index
src_start
=
H
andleNeg
(
indices
[
i
+
1
],
dim_length
);
// next index
}
count
=
1
;
}
...
...
@@ -961,7 +961,7 @@ Status Tensor::SliceString(std::shared_ptr<Tensor> *out, const std::vector<dsize
dsize_t
dim_length
=
shape_
[
0
];
std
::
vector
<
std
::
string
>
strings
;
for
(
dsize_t
index
:
indices
)
{
dsize_t
cur_index
=
h
andleNeg
(
index
,
dim_length
);
dsize_t
cur_index
=
H
andleNeg
(
index
,
dim_length
);
CHECK_FAIL_RETURN_UNEXPECTED
(
cur_index
>=
0
&&
cur_index
<
dim_length
,
"Index "
+
std
::
to_string
(
index
)
+
" is out of bounds [0,"
+
std
::
to_string
(
dim_length
)
+
")"
);
...
...
mindspore/ccsrc/dataset/core/tensor.h
浏览文件 @
ea829e89
...
...
@@ -348,7 +348,7 @@ class Tensor {
}
// Handle negative indices.
static
inline
dsize_t
h
andleNeg
(
dsize_t
index
,
dsize_t
length
)
{
return
(
index
<
0
)
?
(
index
+
length
)
:
index
;
}
static
inline
dsize_t
H
andleNeg
(
dsize_t
index
,
dsize_t
length
)
{
return
(
index
<
0
)
?
(
index
+
length
)
:
index
;
}
// Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported.
// Based on the type of tensor, SliceNumeric or SliceString will be called
...
...
mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt
浏览文件 @
ea829e89
file
(
GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
"*.cc"
)
set_property
(
SOURCE
${
_CURRENT_SRC_FILES
}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD
)
add_library
(
kernels-data OBJECT
data_utils.cc
one_hot_op.cc
type_cast_op.cc
to_float16_op.cc
fill_op.cc
slice_op.cc
)
data_utils.cc
one_hot_op.cc
type_cast_op.cc
to_float16_op.cc
fill_op.cc
slice_op.cc
mask_op.cc
)
mindspore/ccsrc/dataset/kernels/data/data_utils.cc
浏览文件 @
ea829e89
...
...
@@ -120,7 +120,7 @@ Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output
std
::
unique_ptr
<
TypeCastOp
>
op
(
new
TypeCastOp
(
to
));
std
::
shared_ptr
<
Tensor
>
fill_output
;
op
->
Compute
(
fill_value
,
&
fill_output
);
RETURN_IF_NOT_OK
(
op
->
Compute
(
fill_value
,
&
fill_output
)
);
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
&
out
,
TensorImpl
::
kFlexible
,
input
->
shape
(),
input
->
type
()));
...
...
@@ -344,6 +344,8 @@ Status PadEnd(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor> *dst,
return
PadEndString
(
src
,
dst
,
pad_shape
,
""
);
}
}
CHECK_FAIL_RETURN_UNEXPECTED
(
src
->
type
().
IsNumeric
()
==
pad_val
->
type
().
IsNumeric
(),
"Source and pad_value tensors are not of the same type."
);
if
(
pad_val
->
type
().
IsNumeric
())
{
float
val
=
0
;
RETURN_IF_NOT_OK
(
pad_val
->
GetItemAt
<
float
>
(
&
val
,
{}));
...
...
@@ -454,5 +456,102 @@ Status PadEndStringHelper(const std::shared_ptr<Tensor> &src, std::vector<std::s
}
return
Status
::
OK
();
}
template
<
typename
T
>
Status
MaskHelper
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
const
std
::
shared_ptr
<
Tensor
>
&
output
,
const
std
::
shared_ptr
<
Tensor
>
&
value_tensor
,
RelationalOp
op
)
{
T
value
;
RETURN_IF_NOT_OK
(
value_tensor
->
GetItemAt
(
&
value
,
{}));
auto
in_itr
=
input
->
begin
<
T
>
();
auto
out_itr
=
output
->
begin
<
bool
>
();
for
(;
in_itr
!=
input
->
end
<
T
>
();
in_itr
++
,
out_itr
++
)
{
switch
(
op
)
{
case
RelationalOp
::
kEqual
:
*
out_itr
=
(
*
in_itr
==
value
);
break
;
case
RelationalOp
::
kNotEqual
:
*
out_itr
=
(
*
in_itr
!=
value
);
break
;
case
RelationalOp
::
kGreater
:
*
out_itr
=
(
*
in_itr
>
value
);
break
;
case
RelationalOp
::
kGreaterEqual
:
*
out_itr
=
(
*
in_itr
>=
value
);
break
;
case
RelationalOp
::
kLess
:
*
out_itr
=
(
*
in_itr
<
value
);
break
;
case
RelationalOp
::
kLessEqual
:
*
out_itr
=
(
*
in_itr
<=
value
);
break
;
default:
RETURN_STATUS_UNEXPECTED
(
"Unknown relational operator."
);
}
}
return
Status
::
OK
();
}
Status
Mask
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
std
::
shared_ptr
<
Tensor
>
*
output
,
const
std
::
shared_ptr
<
Tensor
>
&
value
,
RelationalOp
op
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
input
->
type
().
IsNumeric
()
==
value
->
type
().
IsNumeric
(),
"Cannot convert constant value to the type of the input tensor."
);
CHECK_FAIL_RETURN_UNEXPECTED
(
value
->
shape
()
==
TensorShape
::
CreateScalar
(),
"Value is not a scalar"
);
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
output
,
TensorImpl
::
kFlexible
,
input
->
shape
(),
DataType
(
DataType
::
DE_BOOL
)));
std
::
unique_ptr
<
TypeCastOp
>
value_cast_op
(
new
TypeCastOp
(
input
->
type
()));
std
::
shared_ptr
<
Tensor
>
casted_value
;
if
(
input
->
type
().
IsNumeric
())
{
RETURN_IF_NOT_OK
(
value_cast_op
->
Compute
(
value
,
&
casted_value
));
}
else
{
casted_value
=
value
;
}
switch
(
input
->
type
().
value
())
{
case
DataType
::
DE_BOOL
:
RETURN_IF_NOT_OK
(
MaskHelper
<
bool
>
(
input
,
*
output
,
casted_value
,
op
));
break
;
case
DataType
::
DE_INT8
:
RETURN_IF_NOT_OK
(
MaskHelper
<
int8_t
>
(
input
,
*
output
,
casted_value
,
op
));
break
;
case
DataType
::
DE_UINT8
:
RETURN_IF_NOT_OK
(
MaskHelper
<
uint8_t
>
(
input
,
*
output
,
casted_value
,
op
));
break
;
case
DataType
::
DE_UINT16
:
RETURN_IF_NOT_OK
(
MaskHelper
<
uint16_t
>
(
input
,
*
output
,
casted_value
,
op
));
break
;
case
DataType
::
DE_INT16
:
RETURN_IF_NOT_OK
(
MaskHelper
<
int16_t
>
(
input
,
*
output
,
casted_value
,
op
));
break
;
case
DataType
::
DE_UINT32
:
RETURN_IF_NOT_OK
(
MaskHelper
<
uint32_t
>
(
input
,
*
output
,
casted_value
,
op
));
break
;
case
DataType
::
DE_INT32
:
RETURN_IF_NOT_OK
(
MaskHelper
<
int32_t
>
(
input
,
*
output
,
casted_value
,
op
));
break
;
case
DataType
::
DE_UINT64
:
RETURN_IF_NOT_OK
(
MaskHelper
<
uint64_t
>
(
input
,
*
output
,
casted_value
,
op
));
break
;
case
DataType
::
DE_INT64
:
RETURN_IF_NOT_OK
(
MaskHelper
<
int64_t
>
(
input
,
*
output
,
casted_value
,
op
));
break
;
case
DataType
::
DE_FLOAT16
:
RETURN_IF_NOT_OK
(
MaskHelper
<
float16
>
(
input
,
*
output
,
casted_value
,
op
));
break
;
case
DataType
::
DE_FLOAT32
:
RETURN_IF_NOT_OK
(
MaskHelper
<
float
>
(
input
,
*
output
,
casted_value
,
op
));
break
;
case
DataType
::
DE_FLOAT64
:
RETURN_IF_NOT_OK
(
MaskHelper
<
double
>
(
input
,
*
output
,
casted_value
,
op
));
break
;
case
DataType
::
DE_STRING
:
RETURN_IF_NOT_OK
(
MaskHelper
<
std
::
string_view
>
(
input
,
*
output
,
casted_value
,
op
));
break
;
case
DataType
::
DE_UNKNOWN
:
RETURN_STATUS_UNEXPECTED
(
"Unsupported input type."
);
break
;
}
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/kernels/data/data_utils.h
浏览文件 @
ea829e89
...
...
@@ -119,6 +119,35 @@ Status PadEndString(const std::shared_ptr<Tensor> &src, std::shared_ptr<Tensor>
Status
PadEndStringHelper
(
const
std
::
shared_ptr
<
Tensor
>
&
src
,
std
::
vector
<
std
::
string
>
*
dst
,
const
TensorShape
&
dst_shape
,
std
::
vector
<
dsize_t
>
cur_ind
,
size_t
cur_dim
,
const
std
::
string
&
pad_value
);
enum
class
RelationalOp
{
kEqual
=
0
,
// ==
kNotEqual
,
// !=
kLess
,
// <
kLessEqual
,
// <=
kGreater
,
// >
kGreaterEqual
,
// >=
};
/// Helper method that masks the input tensor
/// @tparam T type of the tensor
/// @param input[in] input tensor
/// @param output[out] output tensor
/// @param value_tensor[in] scalar tensor value to compared with
/// @param op[in] RelationalOp enum
/// @return Status ok/error
template
<
typename
T
>
Status
MaskHelper
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
const
std
::
shared_ptr
<
Tensor
>
&
output
,
const
std
::
shared_ptr
<
Tensor
>
&
value_tensor
,
RelationalOp
op
);
/// Mask the input tensor
/// @param input[in] input tensor
/// @param output[out] output tensor
/// @param value[in] scalar tensor value to compared with
/// @param op[in] RelationalOp enum
/// @return Status ok/error
Status
Mask
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
std
::
shared_ptr
<
Tensor
>
*
output
,
const
std
::
shared_ptr
<
Tensor
>
&
value
,
RelationalOp
op
);
}
// namespace dataset
}
// namespace mindspore
...
...
mindspore/ccsrc/dataset/kernels/data/mask_op.cc
0 → 100644
浏览文件 @
ea829e89
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/kernels/data/mask_op.h"
#include "dataset/core/tensor.h"
#include "dataset/kernels/tensor_op.h"
namespace
mindspore
{
namespace
dataset
{
Status
MaskOp
::
Compute
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
std
::
shared_ptr
<
Tensor
>
*
output
)
{
IO_CHECK
(
input
,
output
);
std
::
shared_ptr
<
Tensor
>
temp_output
;
CHECK_FAIL_RETURN_UNEXPECTED
(
type_
.
IsNumeric
(),
"Cannot generate a string mask. Type should be numeric."
);
RETURN_IF_NOT_OK
(
Mask
(
input
,
&
temp_output
,
value_
,
op_
));
// cast the output to the the required type. Skip casting if type_ is bool.
if
(
type_
!=
DataType
::
DE_BOOL
)
{
RETURN_IF_NOT_OK
(
cast_
->
Compute
(
temp_output
,
output
));
}
else
{
*
output
=
temp_output
;
}
return
Status
::
OK
();
}
Status
MaskOp
::
OutputType
(
const
std
::
vector
<
DataType
>
&
inputs
,
std
::
vector
<
DataType
>
&
outputs
)
{
RETURN_IF_NOT_OK
(
TensorOp
::
OutputType
(
inputs
,
outputs
));
outputs
[
0
]
=
type_
;
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/kernels/data/mask_op.h
0 → 100644
浏览文件 @
ea829e89
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_KERNELS_DATA_MASK_OP_H_
#define DATASET_KERNELS_DATA_MASK_OP_H_
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dataset/core/tensor.h"
#include "dataset/kernels/tensor_op.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/kernels/data/data_utils.h"
namespace
mindspore
{
namespace
dataset
{
class
MaskOp
:
public
TensorOp
{
public:
MaskOp
(
RelationalOp
op
,
std
::
shared_ptr
<
Tensor
>
value
,
DataType
type
=
DataType
(
DataType
::
DE_BOOL
))
:
op_
(
op
),
value_
(
std
::
move
(
value
)),
type_
(
type
),
cast_
(
new
TypeCastOp
(
type
))
{}
~
MaskOp
()
override
=
default
;
void
Print
(
std
::
ostream
&
out
)
const
override
{
out
<<
"MaskOp"
;
}
Status
Compute
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
std
::
shared_ptr
<
Tensor
>
*
output
)
override
;
Status
OutputType
(
const
std
::
vector
<
DataType
>
&
inputs
,
std
::
vector
<
DataType
>
&
outputs
)
override
;
private:
RelationalOp
op_
;
std
::
shared_ptr
<
Tensor
>
value_
;
DataType
type_
;
std
::
unique_ptr
<
TypeCastOp
>
cast_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_KERNELS_DATA_MASK_OP_H_
mindspore/ccsrc/dataset/kernels/data/slice_op.cc
浏览文件 @
ea829e89
/**
* Copyright 20
19
Huawei Technologies Co., Ltd
* Copyright 20
20
Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
...
...
@@ -16,7 +16,6 @@
#include "dataset/kernels/data/slice_op.h"
#include "dataset/core/tensor.h"
#include "dataset/kernels/data/data_utils.h"
#include "dataset/kernels/tensor_op.h"
namespace
mindspore
{
...
...
mindspore/ccsrc/dataset/kernels/data/slice_op.h
浏览文件 @
ea829e89
...
...
@@ -36,8 +36,8 @@ class Slice {
std
::
vector
<
dsize_t
>
Indices
(
dsize_t
length
)
{
std
::
vector
<
dsize_t
>
indices
;
dsize_t
index
=
std
::
min
(
Tensor
::
h
andleNeg
(
start_
,
length
),
length
);
dsize_t
end_index
=
std
::
min
(
Tensor
::
h
andleNeg
(
stop_
,
length
),
length
);
dsize_t
index
=
std
::
min
(
Tensor
::
H
andleNeg
(
start_
,
length
),
length
);
dsize_t
end_index
=
std
::
min
(
Tensor
::
H
andleNeg
(
stop_
,
length
),
length
);
if
(
step_
>
0
)
{
for
(;
index
<
end_index
;
index
+=
step_
)
{
indices
.
push_back
(
index
);
...
...
@@ -80,4 +80,4 @@ class SliceOp : public TensorOp {
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_KERNELS_DATA_
ONE_HOT
_OP_H_
#endif // DATASET_KERNELS_DATA_
SLICE
_OP_H_
mindspore/dataset/transforms/c_transforms.py
浏览文件 @
ea829e89
...
...
@@ -15,10 +15,14 @@
"""
This module c_transforms provides common operations, including OneHotOp and TypeCast.
"""
import
numpy
as
np
from
enum
import
IntEnum
import
mindspore.common.dtype
as
mstype
import
mindspore._c_dataengine
as
cde
from
.validators
import
check_num_classes
,
check_de_type
,
check_fill_value
,
check_slice_op
import
numpy
as
np
from
.validators
import
check_num_classes
,
check_de_type
,
check_fill_value
,
check_slice_op
,
check_mask_op
from
..core.datatypes
import
mstype_to_detype
...
...
@@ -48,7 +52,6 @@ class Fill(cde.FillOp):
@
check_fill_value
def
__init__
(
self
,
fill_value
):
print
(
fill_value
)
super
().
__init__
(
cde
.
Tensor
(
np
.
array
(
fill_value
)))
...
...
@@ -108,3 +111,50 @@ class Slice(cde.SliceOp):
elif
dim0
is
Ellipsis
:
dim0
=
True
super
().
__init__
(
dim0
)
class
Relational
(
IntEnum
):
EQ
=
0
NE
=
1
GT
=
2
GE
=
3
LT
=
4
LE
=
5
DE_C_RELATIONAL
=
{
Relational
.
EQ
:
cde
.
RelationalOp
.
EQ
,
Relational
.
NE
:
cde
.
RelationalOp
.
NE
,
Relational
.
GT
:
cde
.
RelationalOp
.
GT
,
Relational
.
GE
:
cde
.
RelationalOp
.
GE
,
Relational
.
LT
:
cde
.
RelationalOp
.
LT
,
Relational
.
LE
:
cde
.
RelationalOp
.
LE
}
class
Mask
(
cde
.
MaskOp
):
"""
Mask content of the input tensor with the given predicate.
Any element of the tensor that matches the predicate will be evaluated to True, otherwise False.
Args:
operator (Relational): One of the relational operator EQ, NE LT, GT, LE or GE
constant (python types (str, int, float, or bool): constant to be compared to.
Constant will be casted to the type of the input tensor
dtype (optional, mindspore.dtype): type of the generated mask. Default to bool
Examples:
>>> # Data before
>>> # | col1 |
>>> # +---------+
>>> # | [1,2,3] |
>>> # +---------+
>>> data = data.map(operations=Mask(Relational.EQ, 2))
>>> # Data after
>>> # | col1 |
>>> # +--------------------+
>>> # | [False,True,False] |
>>> # +--------------------+
"""
@
check_mask_op
def
__init__
(
self
,
operator
,
constant
,
dtype
=
mstype
.
bool_
):
dtype
=
mstype_to_detype
(
dtype
)
constant
=
cde
.
Tensor
(
np
.
array
(
constant
))
super
().
__init__
(
DE_C_RELATIONAL
[
operator
],
constant
,
dtype
)
mindspore/dataset/transforms/validators.py
浏览文件 @
ea829e89
...
...
@@ -213,3 +213,40 @@ def check_slice_op(method):
return
method
(
self
,
*
args
)
return
new_method
def
check_mask_op
(
method
):
"""Wrapper method to check the parameters of slice."""
@
wraps
(
method
)
def
new_method
(
self
,
*
args
,
**
kwargs
):
operator
,
constant
,
dtype
=
(
list
(
args
)
+
3
*
[
None
])[:
3
]
if
"operator"
in
kwargs
:
operator
=
kwargs
.
get
(
"operator"
)
if
"constant"
in
kwargs
:
constant
=
kwargs
.
get
(
"constant"
)
if
"dtype"
in
kwargs
:
dtype
=
kwargs
.
get
(
"dtype"
)
if
operator
is
None
:
raise
ValueError
(
"operator is not provided."
)
if
constant
is
None
:
raise
ValueError
(
"constant is not provided."
)
from
.c_transforms
import
Relational
if
not
isinstance
(
operator
,
Relational
):
raise
TypeError
(
"operator is not a Relational operator enum."
)
if
not
isinstance
(
constant
,
(
str
,
float
,
bool
,
int
)):
raise
TypeError
(
"constant must be either a primitive python str, float, bool, or int"
)
if
not
isinstance
(
dtype
,
typing
.
Type
):
raise
TypeError
(
"dtype is not a MindSpore data type."
)
kwargs
[
"operator"
]
=
operator
kwargs
[
"constant"
]
=
constant
kwargs
[
"dtype"
]
=
dtype
return
method
(
self
,
**
kwargs
)
return
new_method
tests/ut/cpp/dataset/mask_test.cc
0 → 100644
浏览文件 @
ea829e89
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <string>
#include "dataset/core/client.h"
#include "common/common.h"
#include "gtest/gtest.h"
#include "securec.h"
#include "dataset/core/tensor.h"
#include "dataset/core/cv_tensor.h"
#include "dataset/core/data_type.h"
#include "dataset/util/de_error.h"
#include "dataset/kernels/data/mask_op.h"
#include "dataset/kernels/data/data_utils.h"
using
namespace
mindspore
::
dataset
;
namespace
py
=
pybind11
;
class
MindDataTestMaskOp
:
public
UT
::
Common
{
public:
MindDataTestMaskOp
()
{}
void
SetUp
()
{
GlobalInit
();
}
};
TEST_F
(
MindDataTestMaskOp
,
Basics
)
{
std
::
shared_ptr
<
Tensor
>
t
;
Tensor
::
CreateTensor
(
&
t
,
std
::
vector
<
uint32_t
>
({
1
,
2
,
3
,
4
,
5
,
6
}));
std
::
shared_ptr
<
Tensor
>
v
;
Tensor
::
CreateTensor
(
&
v
,
std
::
vector
<
uint32_t
>
({
3
}),
TensorShape
::
CreateScalar
());
std
::
shared_ptr
<
MaskOp
>
op
=
std
::
make_shared
<
MaskOp
>
(
RelationalOp
::
kEqual
,
v
,
DataType
(
DataType
::
DE_UINT16
));
std
::
shared_ptr
<
Tensor
>
out
;
ASSERT_TRUE
(
op
->
Compute
(
t
,
&
out
).
IsOk
());
op
=
std
::
make_shared
<
MaskOp
>
(
RelationalOp
::
kNotEqual
,
v
,
DataType
(
DataType
::
DE_UINT16
));
ASSERT_TRUE
(
op
->
Compute
(
t
,
&
out
).
IsOk
());
op
=
std
::
make_shared
<
MaskOp
>
(
RelationalOp
::
kLessEqual
,
v
,
DataType
(
DataType
::
DE_UINT16
));
ASSERT_TRUE
(
op
->
Compute
(
t
,
&
out
).
IsOk
());
op
=
std
::
make_shared
<
MaskOp
>
(
RelationalOp
::
kLess
,
v
,
DataType
(
DataType
::
DE_UINT16
));
ASSERT_TRUE
(
op
->
Compute
(
t
,
&
out
).
IsOk
());
op
=
std
::
make_shared
<
MaskOp
>
(
RelationalOp
::
kGreaterEqual
,
v
,
DataType
(
DataType
::
DE_UINT16
));
ASSERT_TRUE
(
op
->
Compute
(
t
,
&
out
).
IsOk
());
op
=
std
::
make_shared
<
MaskOp
>
(
RelationalOp
::
kGreater
,
v
,
DataType
(
DataType
::
DE_UINT16
));
ASSERT_TRUE
(
op
->
Compute
(
t
,
&
out
).
IsOk
());
}
tests/ut/python/dataset/test_mask_op.py
0 → 100644
浏览文件 @
ea829e89
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing Mask op in DE
"""
import
numpy
as
np
import
pytest
import
mindspore.common.dtype
as
mstype
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.c_transforms
as
ops
mstype_to_np_type
=
{
mstype
.
bool_
:
np
.
bool
,
mstype
.
int8
:
np
.
int8
,
mstype
.
uint8
:
np
.
uint8
,
mstype
.
int16
:
np
.
int16
,
mstype
.
uint16
:
np
.
uint16
,
mstype
.
int32
:
np
.
int32
,
mstype
.
uint32
:
np
.
uint32
,
mstype
.
int64
:
np
.
int64
,
mstype
.
uint64
:
np
.
uint64
,
mstype
.
float16
:
np
.
float16
,
mstype
.
float32
:
np
.
float32
,
mstype
.
float64
:
np
.
float64
,
mstype
.
string
:
np
.
str
}
def
mask_compare
(
array
,
op
,
constant
,
dtype
=
mstype
.
bool_
):
data
=
ds
.
NumpySlicesDataset
([
array
])
array
=
np
.
array
(
array
)
data
=
data
.
map
(
operations
=
ops
.
Mask
(
op
,
constant
,
dtype
))
for
d
in
data
:
if
op
==
ops
.
Relational
.
EQ
:
array
=
array
==
np
.
array
(
constant
,
dtype
=
array
.
dtype
)
elif
op
==
ops
.
Relational
.
NE
:
array
=
array
!=
np
.
array
(
constant
,
dtype
=
array
.
dtype
)
elif
op
==
ops
.
Relational
.
GT
:
array
=
array
>
np
.
array
(
constant
,
dtype
=
array
.
dtype
)
elif
op
==
ops
.
Relational
.
GE
:
array
=
array
>=
np
.
array
(
constant
,
dtype
=
array
.
dtype
)
elif
op
==
ops
.
Relational
.
LT
:
array
=
array
<
np
.
array
(
constant
,
dtype
=
array
.
dtype
)
elif
op
==
ops
.
Relational
.
LE
:
array
=
array
<=
np
.
array
(
constant
,
dtype
=
array
.
dtype
)
array
=
array
.
astype
(
dtype
=
mstype_to_np_type
[
dtype
])
np
.
testing
.
assert_array_equal
(
array
,
d
[
0
])
def
test_int_comparison
():
for
k
in
mstype_to_np_type
:
if
k
==
mstype
.
string
:
continue
mask_compare
([
1
,
2
,
3
,
4
,
5
],
ops
.
Relational
.
EQ
,
3
,
k
)
mask_compare
([
1
,
2
,
3
,
4
,
5
],
ops
.
Relational
.
NE
,
3
,
k
)
mask_compare
([
1
,
2
,
3
,
4
,
5
],
ops
.
Relational
.
LT
,
3
,
k
)
mask_compare
([
1
,
2
,
3
,
4
,
5
],
ops
.
Relational
.
LE
,
3
,
k
)
mask_compare
([
1
,
2
,
3
,
4
,
5
],
ops
.
Relational
.
GT
,
3
,
k
)
mask_compare
([
1
,
2
,
3
,
4
,
5
],
ops
.
Relational
.
GE
,
3
,
k
)
def
test_float_comparison
():
for
k
in
mstype_to_np_type
:
if
k
==
mstype
.
string
:
continue
mask_compare
([
1.5
,
2.5
,
3.
,
4.5
,
5.5
],
ops
.
Relational
.
EQ
,
3
,
k
)
mask_compare
([
1.5
,
2.5
,
3.
,
4.5
,
5.5
],
ops
.
Relational
.
NE
,
3
,
k
)
mask_compare
([
1.5
,
2.5
,
3.
,
4.5
,
5.5
],
ops
.
Relational
.
LT
,
3
,
k
)
mask_compare
([
1.5
,
2.5
,
3.
,
4.5
,
5.5
],
ops
.
Relational
.
LE
,
3
,
k
)
mask_compare
([
1.5
,
2.5
,
3.
,
4.5
,
5.5
],
ops
.
Relational
.
GT
,
3
,
k
)
mask_compare
([
1.5
,
2.5
,
3.
,
4.5
,
5.5
],
ops
.
Relational
.
GE
,
3
,
k
)
def
test_float_comparison2
():
for
k
in
mstype_to_np_type
:
if
k
==
mstype
.
string
:
continue
mask_compare
([
1
,
2
,
3
,
4
,
5
],
ops
.
Relational
.
EQ
,
3.5
,
k
)
mask_compare
([
1
,
2
,
3
,
4
,
5
],
ops
.
Relational
.
NE
,
3.5
,
k
)
mask_compare
([
1
,
2
,
3
,
4
,
5
],
ops
.
Relational
.
LT
,
3.5
,
k
)
mask_compare
([
1
,
2
,
3
,
4
,
5
],
ops
.
Relational
.
LE
,
3.5
,
k
)
mask_compare
([
1
,
2
,
3
,
4
,
5
],
ops
.
Relational
.
GT
,
3.5
,
k
)
mask_compare
([
1
,
2
,
3
,
4
,
5
],
ops
.
Relational
.
GE
,
3.5
,
k
)
def
test_string_comparison
():
for
k
in
mstype_to_np_type
:
if
k
==
mstype
.
string
:
continue
mask_compare
([
"1.5"
,
"2.5"
,
"3."
,
"4.5"
,
"5.5"
],
ops
.
Relational
.
EQ
,
"3."
,
k
)
mask_compare
([
"1.5"
,
"2.5"
,
"3."
,
"4.5"
,
"5.5"
],
ops
.
Relational
.
NE
,
"3."
,
k
)
mask_compare
([
"1.5"
,
"2.5"
,
"3."
,
"4.5"
,
"5.5"
],
ops
.
Relational
.
LT
,
"3."
,
k
)
mask_compare
([
"1.5"
,
"2.5"
,
"3."
,
"4.5"
,
"5.5"
],
ops
.
Relational
.
LE
,
"3."
,
k
)
mask_compare
([
"1.5"
,
"2.5"
,
"3."
,
"4.5"
,
"5.5"
],
ops
.
Relational
.
GT
,
"3."
,
k
)
mask_compare
([
"1.5"
,
"2.5"
,
"3."
,
"4.5"
,
"5.5"
],
ops
.
Relational
.
GE
,
"3."
,
k
)
def
test_mask_exceptions_str
():
with
pytest
.
raises
(
RuntimeError
)
as
info
:
mask_compare
([
1
,
2
,
3
,
4
,
5
],
ops
.
Relational
.
EQ
,
"3.5"
)
assert
"Cannot convert constant value to the type of the input tensor."
in
str
(
info
.
value
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
mask_compare
([
"1"
,
"2"
,
"3"
,
"4"
,
"5"
],
ops
.
Relational
.
EQ
,
3.5
)
assert
"Cannot convert constant value to the type of the input tensor."
in
str
(
info
.
value
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
mask_compare
([
"1"
,
"2"
,
"3"
,
"4"
,
"5"
],
ops
.
Relational
.
EQ
,
"3.5"
,
mstype
.
string
)
assert
"Cannot generate a string mask. Type should be numeric."
in
str
(
info
.
value
)
if
__name__
==
"__main__"
:
test_int_comparison
()
test_float_comparison
()
test_float_comparison2
()
test_string_comparison
()
test_mask_exceptions_str
()
tests/ut/python/dataset/test_slice_op.py
浏览文件 @
ea829e89
# Copyright 20
19
Huawei Technologies Co., Ltd
# Copyright 20
20
Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""
Testing
TypeCast
op in DE
Testing
Slice
op in DE
"""
import
numpy
as
np
import
pytest
...
...
@@ -109,6 +109,10 @@ def test_slice_exceptions():
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
0
))
assert
"Indices are empty, generated tensor would be empty."
in
str
(
info
.
value
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
3
,
1
,
1
))
assert
"Indices are empty, generated tensor would be empty."
in
str
(
info
.
value
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
slice_compare
([
1
,
2
,
3
,
4
,
5
],
slice
(
5
,
10
,
1
))
assert
"Indices are empty, generated tensor would be empty."
in
str
(
info
.
value
)
...
...
@@ -182,6 +186,10 @@ def test_slice_exceptions_str():
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
0
))
assert
"Indices are empty, generated tensor would be empty."
in
str
(
info
.
value
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
3
,
1
,
1
))
assert
"Indices are empty, generated tensor would be empty."
in
str
(
info
.
value
)
with
pytest
.
raises
(
RuntimeError
)
as
info
:
slice_compare
([
b
"1"
,
b
"2"
,
b
"3"
,
b
"4"
,
b
"5"
],
slice
(
5
,
10
,
1
))
assert
"Indices are empty, generated tensor would be empty."
in
str
(
info
.
value
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录