Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
dd9bf09f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dd9bf09f
编写于
6月 11, 2020
作者:
N
nhussain
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
added FillOp for #119 - special Ops
上级
2005ecc2
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
540 addition
and
30 deletion
+540
-30
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+5
-0
mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt
mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt
+1
-1
mindspore/ccsrc/dataset/kernels/data/data_utils.cc
mindspore/ccsrc/dataset/kernels/data/data_utils.cc
+112
-0
mindspore/ccsrc/dataset/kernels/data/data_utils.h
mindspore/ccsrc/dataset/kernels/data/data_utils.h
+7
-0
mindspore/ccsrc/dataset/kernels/data/fill_op.cc
mindspore/ccsrc/dataset/kernels/data/fill_op.cc
+31
-0
mindspore/ccsrc/dataset/kernels/data/fill_op.h
mindspore/ccsrc/dataset/kernels/data/fill_op.h
+47
-0
mindspore/dataset/transforms/c_transforms.py
mindspore/dataset/transforms/c_transforms.py
+18
-2
mindspore/dataset/transforms/validators.py
mindspore/dataset/transforms/validators.py
+19
-1
tests/ut/cpp/dataset/CMakeLists.txt
tests/ut/cpp/dataset/CMakeLists.txt
+1
-0
tests/ut/cpp/dataset/fill_op_test.cc
tests/ut/cpp/dataset/fill_op_test.cc
+183
-0
tests/ut/cpp/dataset/queue_test.cc
tests/ut/cpp/dataset/queue_test.cc
+21
-26
tests/ut/python/dataset/test_fill_op.py
tests/ut/python/dataset/test_fill_op.py
+95
-0
未找到文件。
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
dd9bf09f
...
...
@@ -38,6 +38,7 @@
#include "dataset/kernels/image/resize_op.h"
#include "dataset/kernels/image/uniform_aug_op.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/kernels/data/fill_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/source/io_block.h"
...
...
@@ -350,6 +351,10 @@ void bindTensorOps2(py::module *m) {
*
m
,
"OneHotOp"
,
"Tensor operation to apply one hot encoding. Takes number of classes."
)
.
def
(
py
::
init
<
int32_t
>
());
(
void
)
py
::
class_
<
FillOp
,
TensorOp
,
std
::
shared_ptr
<
FillOp
>>
(
*
m
,
"FillOp"
,
"Tensor operation to return tensor filled with same value as input fill value."
)
.
def
(
py
::
init
<
std
::
shared_ptr
<
Tensor
>>
());
(
void
)
py
::
class_
<
RandomRotationOp
,
TensorOp
,
std
::
shared_ptr
<
RandomRotationOp
>>
(
*
m
,
"RandomRotationOp"
,
"Tensor operation to apply RandomRotation."
...
...
mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt
浏览文件 @
dd9bf09f
...
...
@@ -5,4 +5,4 @@ add_library(kernels-data OBJECT
one_hot_op.cc
type_cast_op.cc
to_float16_op.cc
)
fill_op.cc
)
mindspore/ccsrc/dataset/kernels/data/data_utils.cc
浏览文件 @
dd9bf09f
...
...
@@ -23,6 +23,7 @@
#include "dataset/core/tensor_shape.h"
#include "dataset/core/data_type.h"
#include "dataset/core/pybind_support.h"
#include "dataset/kernels/data/type_cast_op.h"
namespace
mindspore
{
namespace
dataset
{
...
...
@@ -78,6 +79,7 @@ Status OneHotEncodingSigned(const std::shared_ptr<Tensor> &input, std::shared_pt
Status
OneHotEncoding
(
std
::
shared_ptr
<
Tensor
>
input
,
std
::
shared_ptr
<
Tensor
>
*
output
,
dsize_t
num_classes
)
{
input
->
Squeeze
();
if
(
input
->
Rank
()
>
1
)
{
// We expect the input to be int he first dimension
RETURN_STATUS_UNEXPECTED
(
"One hot only supports scalars or 1D shape Tensors."
);
}
...
...
@@ -106,11 +108,121 @@ Status OneHotEncoding(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *ou
}
}
Status
Fill
(
const
std
::
shared_ptr
<
Tensor
>
input
,
std
::
shared_ptr
<
Tensor
>
*
output
,
std
::
shared_ptr
<
Tensor
>
fill_value
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
!
((
fill_value
->
type
()
==
DataType
::
DE_STRING
)
&&
(
input
->
type
()
!=
DataType
::
DE_STRING
)),
"Types do not match"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
fill_value
->
shape
()
==
TensorShape
({}),
"fill_value is not a scalar"
);
std
::
shared_ptr
<
Tensor
>
out
;
const
DataType
&
to
=
input
->
type
();
std
::
unique_ptr
<
TypeCastOp
>
op
(
new
TypeCastOp
(
to
));
std
::
shared_ptr
<
Tensor
>
fill_output
;
op
->
Compute
(
fill_value
,
&
fill_output
);
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
&
out
,
TensorImpl
::
kFlexible
,
input
->
shape
(),
input
->
type
()));
switch
(
input
->
type
().
value
())
{
case
DataType
::
DE_BOOL
:
{
bool
value
=
0
;
RETURN_IF_NOT_OK
(
fill_output
->
GetItemAt
(
&
value
,
{}));
out
->
Fill
<
bool
>
(
value
);
break
;
}
case
DataType
::
DE_INT8
:
{
int8_t
value
=
0
;
RETURN_IF_NOT_OK
(
fill_output
->
GetItemAt
(
&
value
,
{}));
out
->
Fill
<
int8_t
>
(
value
);
break
;
}
case
DataType
::
DE_UINT8
:
{
uint8_t
value
=
0
;
RETURN_IF_NOT_OK
(
fill_output
->
GetItemAt
(
&
value
,
{}));
out
->
Fill
<
uint8_t
>
(
value
);
break
;
}
case
DataType
::
DE_UINT16
:
{
uint16_t
value
=
0
;
RETURN_IF_NOT_OK
(
fill_output
->
GetItemAt
(
&
value
,
{}));
out
->
Fill
<
uint16_t
>
(
value
);
break
;
}
case
DataType
::
DE_INT16
:
{
int16_t
value
=
0
;
RETURN_IF_NOT_OK
(
fill_output
->
GetItemAt
(
&
value
,
{}));
out
->
Fill
<
int16_t
>
(
value
);
break
;
}
case
DataType
::
DE_UINT32
:
{
uint32_t
value
=
0
;
RETURN_IF_NOT_OK
(
fill_output
->
GetItemAt
(
&
value
,
{}));
out
->
Fill
<
uint32_t
>
(
value
);
break
;
}
case
DataType
::
DE_INT32
:
{
int32_t
value
=
0
;
RETURN_IF_NOT_OK
(
fill_output
->
GetItemAt
(
&
value
,
{}));
out
->
Fill
<
int32_t
>
(
value
);
break
;
}
case
DataType
::
DE_UINT64
:
{
uint64_t
value
=
0
;
RETURN_IF_NOT_OK
(
fill_output
->
GetItemAt
(
&
value
,
{}));
out
->
Fill
<
uint64_t
>
(
value
);
break
;
}
case
DataType
::
DE_INT64
:
{
int64_t
value
=
0
;
RETURN_IF_NOT_OK
(
fill_output
->
GetItemAt
(
&
value
,
{}));
out
->
Fill
<
int64_t
>
(
value
);
break
;
}
case
DataType
::
DE_FLOAT16
:
{
int64_t
value
=
0
;
RETURN_IF_NOT_OK
(
fill_output
->
GetItemAt
(
&
value
,
{}));
out
->
Fill
<
float
>
(
value
);
break
;
}
case
DataType
::
DE_FLOAT32
:
{
float
value
=
0
;
RETURN_IF_NOT_OK
(
fill_output
->
GetItemAt
(
&
value
,
{}));
out
->
Fill
<
float
>
(
value
);
break
;
}
case
DataType
::
DE_FLOAT64
:
{
double
value
=
0
;
RETURN_IF_NOT_OK
(
fill_output
->
GetItemAt
(
&
value
,
{}));
out
->
Fill
<
double
>
(
value
);
break
;
}
case
DataType
::
DE_STRING
:
{
std
::
vector
<
std
::
string
>
strings
;
std
::
string_view
fill_string_view
;
RETURN_IF_NOT_OK
(
fill_value
->
GetItemAt
(
&
fill_string_view
,
{}));
std
::
string
fill_string
=
std
::
string
(
fill_string_view
);
for
(
int
i
=
0
;
i
<
input
->
shape
().
NumOfElements
();
i
++
)
{
strings
.
emplace_back
(
fill_string
);
}
RETURN_IF_NOT_OK
(
Tensor
::
CreateTensor
(
&
out
,
strings
,
input
->
shape
()));
break
;
}
case
DataType
::
DE_UNKNOWN
:
{
RETURN_STATUS_UNEXPECTED
(
"FillOp does not support input of this type."
);
break
;
}
}
*
output
=
out
;
return
Status
::
OK
();
}
template
<
typename
FROM
,
typename
TO
>
void
Cast
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
std
::
shared_ptr
<
Tensor
>
*
output
)
{
auto
in_itr
=
input
->
begin
<
FROM
>
();
auto
out_itr
=
(
*
output
)
->
begin
<
TO
>
();
auto
out_end
=
(
*
output
)
->
end
<
TO
>
();
for
(;
out_itr
!=
out_end
;
static_cast
<
void
>
(
in_itr
++
),
static_cast
<
void
>
(
out_itr
++
))
*
out_itr
=
static_cast
<
TO
>
(
*
in_itr
);
}
...
...
mindspore/ccsrc/dataset/kernels/data/data_utils.h
浏览文件 @
dd9bf09f
...
...
@@ -43,6 +43,13 @@ Status OneHotEncodingUnsigned(const std::shared_ptr<Tensor> &input, std::shared_
Status
OneHotEncodingSigned
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
std
::
shared_ptr
<
Tensor
>
*
output
,
dsize_t
num_classes
,
int64_t
index
);
// Returns a tensor of shape input filled with the passed fill_value
// @param input Tensor
// @param output Tensor. The shape and type of the output tensor is same as input
// @param fill_value Tensor. A scalar tensor used to fill the output tensor
Status
Fill
(
const
std
::
shared_ptr
<
Tensor
>
input
,
std
::
shared_ptr
<
Tensor
>
*
output
,
std
::
shared_ptr
<
Tensor
>
fill_value
);
// Returns a type changed input tensor.
// Example: if input tensor is float64, the output will the specified dataType. See DataTypes.cpp
// @param input Tensor
...
...
mindspore/ccsrc/dataset/kernels/data/fill_op.cc
0 → 100644
浏览文件 @
dd9bf09f
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/kernels/data/fill_op.h"
#include "dataset/core/tensor.h"
#include "dataset/kernels/data/data_utils.h"
#include "dataset/kernels/tensor_op.h"
namespace
mindspore
{
namespace
dataset
{
Status
FillOp
::
Compute
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
std
::
shared_ptr
<
Tensor
>
*
output
)
{
IO_CHECK
(
input
,
output
);
Status
s
=
Fill
(
input
,
output
,
fill_value_
);
return
s
;
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/kernels/data/fill_op.h
0 → 100644
浏览文件 @
dd9bf09f
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_KERNELS_DATA_FILL_OP_H_
#define DATASET_KERNELS_DATA_FILL_OP_H_
#include <string>
#include <vector>
#include <memory>
#include "dataset/core/tensor.h"
#include "dataset/kernels/tensor_op.h"
namespace
mindspore
{
namespace
dataset
{
class
FillOp
:
public
TensorOp
{
public:
explicit
FillOp
(
std
::
shared_ptr
<
Tensor
>
value
)
:
fill_value_
(
value
)
{}
~
FillOp
()
override
=
default
;
void
Print
(
std
::
ostream
&
out
)
const
override
{
out
<<
"FillOp"
;
}
Status
Compute
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
std
::
shared_ptr
<
Tensor
>
*
output
)
override
;
private:
std
::
shared_ptr
<
Tensor
>
fill_value_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_FILL_OP_H
mindspore/dataset/transforms/c_transforms.py
浏览文件 @
dd9bf09f
...
...
@@ -15,9 +15,9 @@
"""
This module c_transforms provides common operations, including OneHotOp and TypeCast.
"""
import
numpy
as
np
import
mindspore._c_dataengine
as
cde
from
.validators
import
check_num_classes
,
check_de_type
from
.validators
import
check_num_classes
,
check_de_type
,
check_fill_value
from
..core.datatypes
import
mstype_to_detype
...
...
@@ -35,6 +35,22 @@ class OneHot(cde.OneHotOp):
super
().
__init__
(
num_classes
)
class
Fill
(
cde
.
FillOp
):
"""
Tensor operation to create a tensor filled with passed scalar value.
The output tensor will have the same shape and type as the input tensor.
Args:
fill_value (python types (str, int, float, or bool)) : scalar value
to fill created tensor with.
"""
@
check_fill_value
def
__init__
(
self
,
fill_value
):
print
(
fill_value
)
super
().
__init__
(
cde
.
Tensor
(
np
.
array
(
fill_value
)))
class
TypeCast
(
cde
.
TypeCastOp
):
"""
Tensor operation to cast to a given MindSpore data type.
...
...
mindspore/dataset/transforms/validators.py
浏览文件 @
dd9bf09f
...
...
@@ -17,7 +17,6 @@
from
functools
import
wraps
from
mindspore._c_expression
import
typing
# POS_INT_MIN is used to limit values from starting from 0
POS_INT_MIN
=
1
UINT8_MAX
=
255
...
...
@@ -159,6 +158,25 @@ def check_num_classes(method):
return
new_method
def
check_fill_value
(
method
):
"""Wrapper method to check the parameters of fill value."""
@
wraps
(
method
)
def
new_method
(
self
,
*
args
,
**
kwargs
):
fill_value
=
(
list
(
args
)
+
[
None
])[
0
]
if
"fill_value"
in
kwargs
:
fill_value
=
kwargs
.
get
(
"fill_value"
)
if
fill_value
is
None
:
raise
ValueError
(
"fill_value is not provided."
)
if
not
isinstance
(
fill_value
,
(
str
,
float
,
bool
,
int
)):
raise
TypeError
(
"fill_value must be either a primitive python str, float, bool, or int"
)
kwargs
[
"fill_value"
]
=
fill_value
return
method
(
self
,
**
kwargs
)
return
new_method
def
check_de_type
(
method
):
"""Wrapper method to check the parameters of data type."""
...
...
tests/ut/cpp/dataset/CMakeLists.txt
浏览文件 @
dd9bf09f
...
...
@@ -72,6 +72,7 @@ SET(DE_UT_SRCS
tokenizer_op_test.cc
gnn_graph_test.cc
coco_op_test.cc
fill_op_test.cc
)
add_executable
(
de_ut_tests
${
DE_UT_SRCS
}
)
...
...
tests/ut/cpp/dataset/fill_op_test.cc
0 → 100644
浏览文件 @
dd9bf09f
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "dataset/kernels/data/fill_op.h"
#include "utils/log_adapter.h"
using
namespace
mindspore
::
dataset
;
using
mindspore
::
LogStream
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
using
mindspore
::
MsLogLevel
::
INFO
;
class
MindDataTestFillOp
:
public
UT
::
Common
{
protected:
MindDataTestFillOp
()
{}
};
TEST_F
(
MindDataTestFillOp
,
TestOp
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestFillOp-TestOp."
;
uint64_t
labels
[
3
]
=
{
1
,
1
,
2
};
TensorShape
shape
({
3
});
std
::
shared_ptr
<
Tensor
>
input
=
std
::
make_shared
<
Tensor
>
(
shape
,
DataType
(
DataType
::
DE_UINT64
),
reinterpret_cast
<
unsigned
char
*>
(
labels
));
TensorShape
fill_shape
({});
std
::
shared_ptr
<
Tensor
>
fill_tensor
=
std
::
make_shared
<
Tensor
>
(
fill_shape
,
DataType
(
DataType
::
DE_UINT64
));
fill_tensor
->
SetItemAt
<
uint64_t
>
({},
4
);
std
::
shared_ptr
<
Tensor
>
output
;
std
::
unique_ptr
<
FillOp
>
op
(
new
FillOp
(
fill_tensor
));
Status
s
=
op
->
Compute
(
input
,
&
output
);
uint64_t
out
[
3
]
=
{
4
,
4
,
4
};
std
::
shared_ptr
<
Tensor
>
expected
=
std
::
make_shared
<
Tensor
>
(
TensorShape
{
3
},
DataType
(
DataType
::
DE_UINT64
),
reinterpret_cast
<
unsigned
char
*>
(
out
));
EXPECT_TRUE
(
s
.
IsOk
());
ASSERT_TRUE
(
output
->
shape
()
==
expected
->
shape
());
ASSERT_TRUE
(
output
->
type
()
==
expected
->
type
());
MS_LOG
(
DEBUG
)
<<
*
output
<<
std
::
endl
;
MS_LOG
(
DEBUG
)
<<
*
expected
<<
std
::
endl
;
ASSERT_TRUE
(
*
output
==
*
expected
);
MS_LOG
(
INFO
)
<<
"MindDataTestFillOp-TestOp end."
;
}
TEST_F
(
MindDataTestFillOp
,
TestCasting
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestFillOp-TestCasting."
;
uint64_t
labels
[
3
]
=
{
0
,
1
,
2
};
TensorShape
shape
({
3
});
std
::
shared_ptr
<
Tensor
>
input
=
std
::
make_shared
<
Tensor
>
(
shape
,
DataType
(
DataType
::
DE_UINT64
),
reinterpret_cast
<
unsigned
char
*>
(
labels
));
TensorShape
fill_shape
({});
std
::
shared_ptr
<
Tensor
>
fill_tensor
=
std
::
make_shared
<
Tensor
>
(
fill_shape
,
DataType
(
DataType
::
DE_FLOAT32
));
fill_tensor
->
SetItemAt
<
float
>
({},
2.0
);
std
::
shared_ptr
<
Tensor
>
output
;
std
::
unique_ptr
<
FillOp
>
op
(
new
FillOp
(
fill_tensor
));
Status
s
=
op
->
Compute
(
input
,
&
output
);
uint64_t
out
[
3
]
=
{
2
,
2
,
2
};
std
::
shared_ptr
<
Tensor
>
expected
=
std
::
make_shared
<
Tensor
>
(
TensorShape
{
3
},
DataType
(
DataType
::
DE_UINT64
),
reinterpret_cast
<
unsigned
char
*>
(
out
));
ASSERT_TRUE
(
output
->
shape
()
==
expected
->
shape
());
ASSERT_TRUE
(
output
->
type
()
==
expected
->
type
());
EXPECT_TRUE
(
s
.
IsOk
());
MS_LOG
(
DEBUG
)
<<
*
output
<<
std
::
endl
;
MS_LOG
(
DEBUG
)
<<
*
expected
<<
std
::
endl
;
ASSERT_TRUE
(
*
output
==
*
expected
);
MS_LOG
(
INFO
)
<<
"MindDataTestFillOp-TestCasting end."
;
}
TEST_F
(
MindDataTestFillOp
,
ScalarFill
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestFillOp-ScalarFill."
;
uint64_t
labels
[
3
]
=
{
0
,
1
,
2
};
TensorShape
shape
({
3
});
std
::
shared_ptr
<
Tensor
>
input
=
std
::
make_shared
<
Tensor
>
(
shape
,
DataType
(
DataType
::
DE_UINT64
),
reinterpret_cast
<
unsigned
char
*>
(
labels
));
TensorShape
fill_shape
({
2
});
uint64_t
fill_labels
[
3
]
=
{
0
,
1
};
std
::
shared_ptr
<
Tensor
>
fill_tensor
=
std
::
make_shared
<
Tensor
>
(
fill_shape
,
DataType
(
DataType
::
DE_UINT64
),
reinterpret_cast
<
unsigned
char
*>
(
fill_labels
));
std
::
shared_ptr
<
Tensor
>
output
;
std
::
unique_ptr
<
FillOp
>
op
(
new
FillOp
(
fill_tensor
));
Status
s
=
op
->
Compute
(
input
,
&
output
);
EXPECT_TRUE
(
s
.
IsError
());
ASSERT_TRUE
(
s
.
get_code
()
==
StatusCode
::
kUnexpectedError
);
MS_LOG
(
INFO
)
<<
"MindDataTestFillOp-ScalarFill end."
;
}
TEST_F
(
MindDataTestFillOp
,
StringFill
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestFillOp-StringFill."
;
std
::
vector
<
std
::
string
>
strings
=
{
"xyzzy"
,
"plugh"
,
"abracadabra"
};
TensorShape
shape
({
3
});
std
::
shared_ptr
<
Tensor
>
input
=
std
::
make_shared
<
Tensor
>
(
strings
,
shape
);
TensorShape
fill_shape
({});
std
::
string
fill_string
=
"hello"
;
std
::
shared_ptr
<
Tensor
>
fill_tensor
=
std
::
make_shared
<
Tensor
>
(
fill_string
);
std
::
shared_ptr
<
Tensor
>
output
;
std
::
unique_ptr
<
FillOp
>
op
(
new
FillOp
(
fill_tensor
));
Status
s
=
op
->
Compute
(
input
,
&
output
);
std
::
vector
<
std
::
string
>
expected_strings
=
{
"hello"
,
"hello"
,
"hello"
};
TensorShape
expected_shape
({
3
});
std
::
shared_ptr
<
Tensor
>
expected
=
std
::
make_shared
<
Tensor
>
(
expected_strings
,
expected_shape
);
EXPECT_TRUE
(
s
.
IsOk
());
ASSERT_TRUE
(
output
->
shape
()
==
expected
->
shape
());
ASSERT_TRUE
(
output
->
type
()
==
expected
->
type
());
MS_LOG
(
DEBUG
)
<<
*
output
<<
std
::
endl
;
MS_LOG
(
DEBUG
)
<<
*
expected
<<
std
::
endl
;
ASSERT_TRUE
(
*
output
==
*
expected
);
MS_LOG
(
INFO
)
<<
"MindDataTestFillOp-StringFill end."
;
}
TEST_F
(
MindDataTestFillOp
,
NumericToString
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestFillOp-NumericToString."
;
std
::
vector
<
std
::
string
>
strings
=
{
"xyzzy"
,
"plugh"
,
"abracadabra"
};
TensorShape
shape
({
3
});
std
::
shared_ptr
<
Tensor
>
input
=
std
::
make_shared
<
Tensor
>
(
strings
,
shape
);
TensorShape
fill_shape
({});
std
::
shared_ptr
<
Tensor
>
fill_tensor
=
std
::
make_shared
<
Tensor
>
(
fill_shape
,
DataType
(
DataType
::
DE_FLOAT32
));
fill_tensor
->
SetItemAt
<
float
>
({},
2.0
);
std
::
shared_ptr
<
Tensor
>
output
;
std
::
unique_ptr
<
FillOp
>
op
(
new
FillOp
(
fill_tensor
));
Status
s
=
op
->
Compute
(
input
,
&
output
);
EXPECT_TRUE
(
s
.
IsError
());
ASSERT_TRUE
(
s
.
get_code
()
==
StatusCode
::
kUnexpectedError
);
MS_LOG
(
INFO
)
<<
"MindDataTestFillOp-NumericToString end."
;
}
TEST_F
(
MindDataTestFillOp
,
StringToNumeric
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestFillOp-StringToNumeric."
;
uint64_t
labels
[
3
]
=
{
0
,
1
,
2
};
TensorShape
shape
({
3
});
std
::
shared_ptr
<
Tensor
>
input
=
std
::
make_shared
<
Tensor
>
(
shape
,
DataType
(
DataType
::
DE_UINT64
),
reinterpret_cast
<
unsigned
char
*>
(
labels
));
TensorShape
fill_shape
({});
std
::
string
fill_string
=
"hello"
;
std
::
shared_ptr
<
Tensor
>
fill_tensor
=
std
::
make_shared
<
Tensor
>
(
fill_string
);
std
::
shared_ptr
<
Tensor
>
output
;
std
::
unique_ptr
<
FillOp
>
op
(
new
FillOp
(
fill_tensor
));
Status
s
=
op
->
Compute
(
input
,
&
output
);
EXPECT_TRUE
(
s
.
IsError
());
ASSERT_TRUE
(
s
.
get_code
()
==
StatusCode
::
kUnexpectedError
);
MS_LOG
(
INFO
)
<<
"MindDataTestFillOp-StringToNumeric end."
;
}
\ No newline at end of file
tests/ut/cpp/dataset/queue_test.cc
浏览文件 @
dd9bf09f
...
...
@@ -13,9 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//
// Created by jesse on 10/3/19.
//
#include "common/common.h"
#include "gtest/gtest.h"
...
...
@@ -25,32 +22,32 @@
#include "utils/log_adapter.h"
using
namespace
mindspore
::
dataset
;
using
mindspore
::
MsLogLevel
::
INFO
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
using
mindspore
::
LogStream
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
using
mindspore
::
MsLogLevel
::
INFO
;
class
MindDataTestQueue
:
public
UT
::
Common
{
public:
MindDataTestQueue
()
{}
MindDataTestQueue
()
{}
void
SetUp
()
{}
void
SetUp
()
{}
};
int
gRefCountDestructorCalled
;
class
RefCount
{
public:
RefCount
()
:
v_
(
nullptr
)
{}
explicit
RefCount
(
int
x
)
:
v_
(
std
::
make_shared
<
int
>
(
x
))
{}
explicit
RefCount
(
const
RefCount
&
o
)
:
v_
(
o
.
v_
)
{}
~
RefCount
()
{
MS_LOG
(
DEBUG
)
<<
"Destructor of RefCount called"
<<
std
::
endl
;
gRefCountDestructorCalled
++
;
}
RefCount
&
operator
=
(
const
RefCount
&
o
)
{
v_
=
o
.
v_
;
return
*
this
;
}
RefCount
()
:
v_
(
nullptr
)
{}
explicit
RefCount
(
int
x
)
:
v_
(
std
::
make_shared
<
int
>
(
x
))
{}
explicit
RefCount
(
const
RefCount
&
o
)
:
v_
(
o
.
v_
)
{}
~
RefCount
()
{
MS_LOG
(
DEBUG
)
<<
"Destructor of RefCount called"
<<
std
::
endl
;
gRefCountDestructorCalled
++
;
}
RefCount
&
operator
=
(
const
RefCount
&
o
)
{
v_
=
o
.
v_
;
return
*
this
;
}
std
::
shared_ptr
<
int
>
v_
;
};
...
...
@@ -70,22 +67,22 @@ TEST_F(MindDataTestQueue, Test1) {
// Use count should remain 2. a and b. No copy in the queue.
ASSERT_EQ
(
a
.
use_count
(),
2
);
a
.
reset
(
new
int
(
5
));
ASSERT_EQ
(
a
.
use_count
(),
1
);
ASSERT_EQ
(
a
.
use_count
(),
1
);
// Push again but expect a is nullptr after push
rc
=
que
.
Add
(
std
::
move
(
a
));
ASSERT_TRUE
(
rc
.
IsOk
());
ASSERT_EQ
(
a
.
use_count
(),
0
);
ASSERT_EQ
(
a
.
use_count
(),
0
);
rc
=
que
.
PopFront
(
&
b
);
ASSERT_TRUE
(
rc
.
IsOk
());
ASSERT_EQ
(
*
b
,
5
);
ASSERT_EQ
(
b
.
use_count
(),
1
);
ASSERT_EQ
(
b
.
use_count
(),
1
);
// Test construct in place
rc
=
que
.
EmplaceBack
(
std
::
make_shared
<
int
>
(
100
));
ASSERT_TRUE
(
rc
.
IsOk
());
rc
=
que
.
PopFront
(
&
b
);
ASSERT_TRUE
(
rc
.
IsOk
());
ASSERT_EQ
(
*
b
,
100
);
ASSERT_EQ
(
b
.
use_count
(),
1
);
ASSERT_EQ
(
b
.
use_count
(),
1
);
// Test the destructor of the Queue by add an element in the queue without popping it and let the queue go
// out of scope.
rc
=
que
.
EmplaceBack
(
std
::
make_shared
<
int
>
(
2000
));
...
...
@@ -127,7 +124,7 @@ TEST_F(MindDataTestQueue, Test3) {
ASSERT_EQ
(
*
b
,
40
);
}
void
test4
(){
void
test4
()
{
gRefCountDestructorCalled
=
0
;
// Pass a structure along the queue.
Queue
<
RefCount
>
que
(
3
);
...
...
@@ -144,9 +141,7 @@ void test4(){
ASSERT_TRUE
(
rc
.
IsOk
());
}
TEST_F
(
MindDataTestQueue
,
Test4
)
{
test4
();
}
TEST_F
(
MindDataTestQueue
,
Test4
)
{
test4
();
}
TEST_F
(
MindDataTestQueue
,
Test5
)
{
test4
();
...
...
tests/ut/python/dataset/test_fill_op.py
0 → 100644
浏览文件 @
dd9bf09f
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing fill op
"""
import
numpy
as
np
import
pytest
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.c_transforms
as
data_trans
def
test_fillop_basic
():
def
gen
():
yield
(
np
.
array
([
4
,
5
,
6
,
7
],
dtype
=
np
.
uint8
),)
data
=
ds
.
GeneratorDataset
(
gen
,
column_names
=
[
"col"
])
fill_op
=
data_trans
.
Fill
(
3
)
data
=
data
.
map
(
input_columns
=
[
"col"
],
operations
=
fill_op
)
expected
=
np
.
array
([
3
,
3
,
3
,
3
],
dtype
=
np
.
uint8
)
for
data_row
in
data
:
np
.
testing
.
assert_array_equal
(
data_row
[
0
],
expected
)
def
test_fillop_down_type_cast
():
def
gen
():
yield
(
np
.
array
([
4
,
5
,
6
,
7
],
dtype
=
np
.
uint8
),)
data
=
ds
.
GeneratorDataset
(
gen
,
column_names
=
[
"col"
])
fill_op
=
data_trans
.
Fill
(
-
3
)
data
=
data
.
map
(
input_columns
=
[
"col"
],
operations
=
fill_op
)
expected
=
np
.
array
([
253
,
253
,
253
,
253
],
dtype
=
np
.
uint8
)
for
data_row
in
data
:
np
.
testing
.
assert_array_equal
(
data_row
[
0
],
expected
)
def
test_fillop_up_type_cast
():
def
gen
():
yield
(
np
.
array
([
4
,
5
,
6
,
7
],
dtype
=
np
.
float
),)
data
=
ds
.
GeneratorDataset
(
gen
,
column_names
=
[
"col"
])
fill_op
=
data_trans
.
Fill
(
3
)
data
=
data
.
map
(
input_columns
=
[
"col"
],
operations
=
fill_op
)
expected
=
np
.
array
([
3.
,
3.
,
3.
,
3.
],
dtype
=
np
.
float
)
for
data_row
in
data
:
np
.
testing
.
assert_array_equal
(
data_row
[
0
],
expected
)
def
test_fillop_string
():
def
gen
():
yield
(
np
.
array
([
"45555"
,
"45555"
],
dtype
=
'S'
),)
data
=
ds
.
GeneratorDataset
(
gen
,
column_names
=
[
"col"
])
fill_op
=
data_trans
.
Fill
(
"error"
)
data
=
data
.
map
(
input_columns
=
[
"col"
],
operations
=
fill_op
)
expected
=
np
.
array
([
'error'
,
'error'
],
dtype
=
'S'
)
for
data_row
in
data
:
np
.
testing
.
assert_array_equal
(
data_row
[
0
],
expected
)
def
test_fillop_error_handling
():
def
gen
():
yield
(
np
.
array
([
4
,
4
,
4
,
4
]),)
data
=
ds
.
GeneratorDataset
(
gen
,
column_names
=
[
"col"
])
fill_op
=
data_trans
.
Fill
(
"words"
)
data
=
data
.
map
(
input_columns
=
[
"col"
],
operations
=
fill_op
)
with
pytest
.
raises
(
RuntimeError
)
as
error_info
:
for
data_row
in
data
:
print
(
data_row
)
assert
"Types do not match"
in
repr
(
error_info
.
value
)
if
__name__
==
"__main__"
:
test_fillop_basic
()
test_fillop_up_type_cast
()
test_fillop_down_type_cast
()
test_fillop_string
()
test_fillop_error_handling
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录