提交 d137cefa 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2372 Cleanup work for BERT special ops

Merge pull request !2372 from h.farahat/cleanup_0619
......@@ -403,7 +403,7 @@ def check_to_number(method):
if not isinstance(data_type, typing.Type):
raise TypeError("data_type is not a MindSpore data type.")
if not data_type in mstype.number_type:
if data_type not in mstype.number_type:
raise TypeError("data_type is not numeric data type.")
kwargs["data_type"] = data_type
......
......@@ -79,12 +79,13 @@ class Slice(cde.SliceOp):
(Currently only rank 1 Tensors are supported)
Args:
*slices: Maximum n number of objects to slice a tensor of rank n.
One object in slices can be one of:
*slices(Variable length argument list): Maximum `n` number of arguments to slice a tensor of rank `n`.
One object in slices can be one of:
1. int: slice this index only. Negative index is supported.
2. slice object: slice the generated indices from the slice object. Similar to `start:stop:step`.
3. None: slice the whole dimension. Similar to `:` in python indexing.
4. Ellipses ...: slice all dimensions between the two slices.
Examples:
>>> # Data before
>>> # | col |
......@@ -134,11 +135,13 @@ class Mask(cde.MaskOp):
"""
Mask content of the input tensor with the given predicate.
Any element of the tensor that matches the predicate will be evaluated to True, otherwise False.
Args:
operator (Relational): One of the relational operator EQ, NE LT, GT, LE or GE
constant (python types (str, int, float, or bool): constant to be compared to.
Constant will be casted to the type of the input tensor
dtype (optional, mindspore.dtype): type of the generated mask. Default to bool
Examples:
>>> # Data before
>>> # | col1 |
......@@ -163,11 +166,13 @@ class Mask(cde.MaskOp):
class PadEnd(cde.PadEndOp):
"""
Pad input tensor according to `pad_shape`, need to have same rank.
Args:
pad_shape (list of `int`): list on integers representing the shape needed. Dimensions that set to `None` will
not be padded (i.e., original dim will be used). Shorter dimensions will truncate the values.
pad_value (python types (str, bytes, int, float, or bool), optional): value used to pad. Default to 0 or empty
string in case of Tensors of strings.
Examples:
>>> # Data before
>>> # | col |
......@@ -201,21 +206,25 @@ class Concatenate(cde.ConcatenateOp):
@check_concat_type
def __init__(self, axis=0, prepend=None, append=None):
# add some validations here later
if prepend is not None:
prepend = cde.Tensor(np.array(prepend))
if append is not None:
append = cde.Tensor(np.array(append))
super().__init__(axis, prepend, append)
class Duplicate(cde.DuplicateOp):
"""
Duplicate the input tensor to a new output tensor. The input tensor is carried over to the output list.
Examples:
Examples:
>>> # Data before
>>> # | x |
>>> # +---------+
>>> # | [1,2,3] |
>>> # +---------+
>>> data = data.map(input_columns=["x"], operations=Duplicate(),
>>> output_columns=["x", "y"], output_order=["x", "y"])
>>> output_columns=["x", "y"], columns_order=["x", "y"])
>>> # Data after
>>> # | x | y |
>>> # +---------+---------+
......
......@@ -17,7 +17,6 @@
from functools import wraps
import numpy as np
import mindspore._c_dataengine as cde
from mindspore._c_expression import typing
# POS_INT_MIN is used to limit values from starting from 0
......@@ -243,12 +242,13 @@ def check_mask_op(method):
if not isinstance(constant, (str, float, bool, int, bytes)):
raise TypeError("constant must be either a primitive python str, float, bool, bytes or int")
if not isinstance(dtype, typing.Type):
raise TypeError("dtype is not a MindSpore data type.")
if dtype is not None:
if not isinstance(dtype, typing.Type):
raise TypeError("dtype is not a MindSpore data type.")
kwargs["dtype"] = dtype
kwargs["operator"] = operator
kwargs["constant"] = constant
kwargs["dtype"] = dtype
return method(self, **kwargs)
......@@ -269,8 +269,10 @@ def check_pad_end(method):
if pad_shape is None:
raise ValueError("pad_shape is not provided.")
if pad_value is not None and not isinstance(pad_value, (str, float, bool, int, bytes)):
raise TypeError("pad_value must be either a primitive python str, float, bool, int or bytes.")
if pad_value is not None:
if not isinstance(pad_value, (str, float, bool, int, bytes)):
raise TypeError("pad_value must be either a primitive python str, float, bool, int or bytes")
kwargs["pad_value"] = pad_value
if not isinstance(pad_shape, list):
raise TypeError("pad_shape must be a list")
......@@ -283,7 +285,6 @@ def check_pad_end(method):
raise TypeError("a value in the list is not an integer.")
kwargs["pad_shape"] = pad_shape
kwargs["pad_value"] = pad_value
return method(self, **kwargs)
......@@ -303,30 +304,22 @@ def check_concat_type(method):
if "axis" in kwargs:
axis = kwargs.get("axis")
if not isinstance(axis, (type(None), int)):
raise TypeError("axis type is not valid, must be None or an integer.")
if isinstance(axis, type(None)):
axis = 0
if axis not in (None, 0, -1):
raise ValueError("only 1D concatenation supported.")
if not isinstance(prepend, (type(None), np.ndarray)):
raise ValueError("prepend type is not valid, must be None for no prepend tensor or a numpy array.")
if not isinstance(append, (type(None), np.ndarray)):
raise ValueError("append type is not valid, must be None for no append tensor or a numpy array.")
if isinstance(prepend, np.ndarray):
prepend = cde.Tensor(prepend)
if isinstance(append, np.ndarray):
append = cde.Tensor(append)
kwargs["axis"] = axis
kwargs["prepend"] = prepend
kwargs["append"] = append
if axis is not None:
if not isinstance(axis, int):
raise TypeError("axis type is not valid, must be an integer.")
if axis not in (0, -1):
raise ValueError("only 1D concatenation supported.")
kwargs["axis"] = axis
if prepend is not None:
if not isinstance(prepend, (type(None), np.ndarray)):
raise ValueError("prepend type is not valid, must be None for no prepend tensor or a numpy array.")
kwargs["prepend"] = prepend
if append is not None:
if not isinstance(append, (type(None), np.ndarray)):
raise ValueError("append type is not valid, must be None for no append tensor or a numpy array.")
kwargs["append"] = append
return method(self, **kwargs)
......
......@@ -62,7 +62,7 @@ def mask_compare(array, op, constant, dtype=mstype.bool_):
np.testing.assert_array_equal(array, d[0])
def test_int_comparison():
def test_mask_int_comparison():
for k in mstype_to_np_type:
if k == mstype.string:
continue
......@@ -74,7 +74,7 @@ def test_int_comparison():
mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3, k)
def test_float_comparison():
def test_mask_float_comparison():
for k in mstype_to_np_type:
if k == mstype.string:
continue
......@@ -86,7 +86,7 @@ def test_float_comparison():
mask_compare([1.5, 2.5, 3., 4.5, 5.5], ops.Relational.GE, 3, k)
def test_float_comparison2():
def test_mask_float_comparison2():
for k in mstype_to_np_type:
if k == mstype.string:
continue
......@@ -98,7 +98,7 @@ def test_float_comparison2():
mask_compare([1, 2, 3, 4, 5], ops.Relational.GE, 3.5, k)
def test_string_comparison():
def test_mask_string_comparison():
for k in mstype_to_np_type:
if k == mstype.string:
continue
......@@ -125,8 +125,8 @@ def test_mask_exceptions_str():
if __name__ == "__main__":
test_int_comparison()
test_float_comparison()
test_float_comparison2()
test_string_comparison()
test_mask_int_comparison()
test_mask_float_comparison()
test_mask_float_comparison2()
test_mask_string_comparison()
test_mask_exceptions_str()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册