Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
d137cefa
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d137cefa
编写于
6月 20, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 20, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2372 Cleanup work for BERT special ops
Merge pull request !2372 from h.farahat/cleanup_0619
上级
cb3bbf3c
674415f7
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
47 addition
and
45 deletion
+47
-45
mindspore/dataset/text/validators.py
mindspore/dataset/text/validators.py
+1
-1
mindspore/dataset/transforms/c_transforms.py
mindspore/dataset/transforms/c_transforms.py
+14
-5
mindspore/dataset/transforms/validators.py
mindspore/dataset/transforms/validators.py
+24
-31
tests/ut/python/dataset/test_mask_op.py
tests/ut/python/dataset/test_mask_op.py
+8
-8
未找到文件。
mindspore/dataset/text/validators.py
浏览文件 @
d137cefa
...
...
@@ -403,7 +403,7 @@ def check_to_number(method):
if
not
isinstance
(
data_type
,
typing
.
Type
):
raise
TypeError
(
"data_type is not a MindSpore data type."
)
if
not
data_type
in
mstype
.
number_type
:
if
data_type
not
in
mstype
.
number_type
:
raise
TypeError
(
"data_type is not numeric data type."
)
kwargs
[
"data_type"
]
=
data_type
...
...
mindspore/dataset/transforms/c_transforms.py
浏览文件 @
d137cefa
...
...
@@ -79,12 +79,13 @@ class Slice(cde.SliceOp):
(Currently only rank 1 Tensors are supported)
Args:
*slices: Maximum n number of objects to slice a tensor of rank n
.
*slices(Variable length argument list): Maximum `n` number of arguments to slice a tensor of rank `n`
.
One object in slices can be one of:
1. int: slice this index only. Negative index is supported.
2. slice object: slice the generated indices from the slice object. Similar to `start:stop:step`.
3. None: slice the whole dimension. Similar to `:` in python indexing.
4. Ellipses ...: slice all dimensions between the two slices.
Examples:
>>> # Data before
>>> # | col |
...
...
@@ -134,11 +135,13 @@ class Mask(cde.MaskOp):
"""
Mask content of the input tensor with the given predicate.
Any element of the tensor that matches the predicate will be evaluated to True, otherwise False.
Args:
operator (Relational): One of the relational operator EQ, NE LT, GT, LE or GE
constant (python types (str, int, float, or bool): constant to be compared to.
Constant will be casted to the type of the input tensor
dtype (optional, mindspore.dtype): type of the generated mask. Default to bool
Examples:
>>> # Data before
>>> # | col1 |
...
...
@@ -163,11 +166,13 @@ class Mask(cde.MaskOp):
class
PadEnd
(
cde
.
PadEndOp
):
"""
Pad input tensor according to `pad_shape`, need to have same rank.
Args:
pad_shape (list of `int`): list on integers representing the shape needed. Dimensions that set to `None` will
not be padded (i.e., original dim will be used). Shorter dimensions will truncate the values.
pad_value (python types (str, bytes, int, float, or bool), optional): value used to pad. Default to 0 or empty
string in case of Tensors of strings.
Examples:
>>> # Data before
>>> # | col |
...
...
@@ -201,13 +206,17 @@ class Concatenate(cde.ConcatenateOp):
@
check_concat_type
def
__init__
(
self
,
axis
=
0
,
prepend
=
None
,
append
=
None
):
# add some validations here later
if
prepend
is
not
None
:
prepend
=
cde
.
Tensor
(
np
.
array
(
prepend
))
if
append
is
not
None
:
append
=
cde
.
Tensor
(
np
.
array
(
append
))
super
().
__init__
(
axis
,
prepend
,
append
)
class
Duplicate
(
cde
.
DuplicateOp
):
"""
Duplicate the input tensor to a new output tensor. The input tensor is carried over to the output list.
Examples:
>>> # Data before
>>> # | x |
...
...
@@ -215,7 +224,7 @@ class Duplicate(cde.DuplicateOp):
>>> # | [1,2,3] |
>>> # +---------+
>>> data = data.map(input_columns=["x"], operations=Duplicate(),
>>> output_columns=["x", "y"],
output
_order=["x", "y"])
>>> output_columns=["x", "y"],
columns
_order=["x", "y"])
>>> # Data after
>>> # | x | y |
>>> # +---------+---------+
...
...
mindspore/dataset/transforms/validators.py
浏览文件 @
d137cefa
...
...
@@ -17,7 +17,6 @@
from
functools
import
wraps
import
numpy
as
np
import
mindspore._c_dataengine
as
cde
from
mindspore._c_expression
import
typing
# POS_INT_MIN is used to limit values from starting from 0
...
...
@@ -243,12 +242,13 @@ def check_mask_op(method):
if
not
isinstance
(
constant
,
(
str
,
float
,
bool
,
int
,
bytes
)):
raise
TypeError
(
"constant must be either a primitive python str, float, bool, bytes or int"
)
if
dtype
is
not
None
:
if
not
isinstance
(
dtype
,
typing
.
Type
):
raise
TypeError
(
"dtype is not a MindSpore data type."
)
kwargs
[
"dtype"
]
=
dtype
kwargs
[
"operator"
]
=
operator
kwargs
[
"constant"
]
=
constant
kwargs
[
"dtype"
]
=
dtype
return
method
(
self
,
**
kwargs
)
...
...
@@ -269,8 +269,10 @@ def check_pad_end(method):
if
pad_shape
is
None
:
raise
ValueError
(
"pad_shape is not provided."
)
if
pad_value
is
not
None
and
not
isinstance
(
pad_value
,
(
str
,
float
,
bool
,
int
,
bytes
)):
raise
TypeError
(
"pad_value must be either a primitive python str, float, bool, int or bytes."
)
if
pad_value
is
not
None
:
if
not
isinstance
(
pad_value
,
(
str
,
float
,
bool
,
int
,
bytes
)):
raise
TypeError
(
"pad_value must be either a primitive python str, float, bool, int or bytes"
)
kwargs
[
"pad_value"
]
=
pad_value
if
not
isinstance
(
pad_shape
,
list
):
raise
TypeError
(
"pad_shape must be a list"
)
...
...
@@ -283,7 +285,6 @@ def check_pad_end(method):
raise
TypeError
(
"a value in the list is not an integer."
)
kwargs
[
"pad_shape"
]
=
pad_shape
kwargs
[
"pad_value"
]
=
pad_value
return
method
(
self
,
**
kwargs
)
...
...
@@ -303,29 +304,21 @@ def check_concat_type(method):
if
"axis"
in
kwargs
:
axis
=
kwargs
.
get
(
"axis"
)
if
not
isinstance
(
axis
,
(
type
(
None
),
int
)):
raise
TypeError
(
"axis type is not valid, must be None or an integer."
)
if
isinstance
(
axis
,
type
(
None
)):
axis
=
0
if
axis
not
in
(
None
,
0
,
-
1
):
if
axis
is
not
None
:
if
not
isinstance
(
axis
,
int
):
raise
TypeError
(
"axis type is not valid, must be an integer."
)
if
axis
not
in
(
0
,
-
1
):
raise
ValueError
(
"only 1D concatenation supported."
)
kwargs
[
"axis"
]
=
axis
if
prepend
is
not
None
:
if
not
isinstance
(
prepend
,
(
type
(
None
),
np
.
ndarray
)):
raise
ValueError
(
"prepend type is not valid, must be None for no prepend tensor or a numpy array."
)
kwargs
[
"prepend"
]
=
prepend
if
append
is
not
None
:
if
not
isinstance
(
append
,
(
type
(
None
),
np
.
ndarray
)):
raise
ValueError
(
"append type is not valid, must be None for no append tensor or a numpy array."
)
if
isinstance
(
prepend
,
np
.
ndarray
):
prepend
=
cde
.
Tensor
(
prepend
)
if
isinstance
(
append
,
np
.
ndarray
):
append
=
cde
.
Tensor
(
append
)
kwargs
[
"axis"
]
=
axis
kwargs
[
"prepend"
]
=
prepend
kwargs
[
"append"
]
=
append
return
method
(
self
,
**
kwargs
)
...
...
tests/ut/python/dataset/test_mask_op.py
浏览文件 @
d137cefa
...
...
@@ -62,7 +62,7 @@ def mask_compare(array, op, constant, dtype=mstype.bool_):
np
.
testing
.
assert_array_equal
(
array
,
d
[
0
])
def
test_int_comparison
():
def
test_
mask_
int_comparison
():
for
k
in
mstype_to_np_type
:
if
k
==
mstype
.
string
:
continue
...
...
@@ -74,7 +74,7 @@ def test_int_comparison():
mask_compare
([
1
,
2
,
3
,
4
,
5
],
ops
.
Relational
.
GE
,
3
,
k
)
def
test_float_comparison
():
def
test_
mask_
float_comparison
():
for
k
in
mstype_to_np_type
:
if
k
==
mstype
.
string
:
continue
...
...
@@ -86,7 +86,7 @@ def test_float_comparison():
mask_compare
([
1.5
,
2.5
,
3.
,
4.5
,
5.5
],
ops
.
Relational
.
GE
,
3
,
k
)
def
test_float_comparison2
():
def
test_
mask_
float_comparison2
():
for
k
in
mstype_to_np_type
:
if
k
==
mstype
.
string
:
continue
...
...
@@ -98,7 +98,7 @@ def test_float_comparison2():
mask_compare
([
1
,
2
,
3
,
4
,
5
],
ops
.
Relational
.
GE
,
3.5
,
k
)
def
test_string_comparison
():
def
test_
mask_
string_comparison
():
for
k
in
mstype_to_np_type
:
if
k
==
mstype
.
string
:
continue
...
...
@@ -125,8 +125,8 @@ def test_mask_exceptions_str():
if
__name__
==
"__main__"
:
test_int_comparison
()
test_float_comparison
()
test_float_comparison2
()
test_string_comparison
()
test_
mask_
int_comparison
()
test_
mask_
float_comparison
()
test_
mask_
float_comparison2
()
test_
mask_
string_comparison
()
test_mask_exceptions_str
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录