提交 d89101b9 编写于 作者: N nhussain

add missing test

上级 722eafca
......@@ -576,7 +576,7 @@ Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::
CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Rank() == 1, "Only 1D tensors supported");
CHECK_FAIL_RETURN_UNEXPECTED(axis == 0 || axis == -1, "Only concatenation along the last dimension supported");
Tensor::HandleNeg(axis, input[0]->shape().Rank());
axis = Tensor::HandleNeg(axis, input[0]->shape().Rank());
CHECK_FAIL_RETURN_UNEXPECTED(axis == 0, "Only axis=0 is supported");
std::shared_ptr<Tensor> out;
......
......@@ -166,8 +166,8 @@ class PadEnd(cde.PadEndOp):
Args:
pad_shape (list of `int`): list on integers representing the shape needed. Dimensions that set to `None` will
not be padded (i.e., original dim will be used). Shorter dimensions will truncate the values.
pad_value (str, bytes, int, float, or bool, optional): value used to pad. Default to 0 or empty string in case
of Tensors of strings.
pad_value (python types (str, bytes, int, float, or bool), optional): value used to pad. Default to 0 or empty
string in case of Tensors of strings.
Examples:
>>> # Data before
>>> # | col |
......
......@@ -233,10 +233,10 @@ def check_mask_op(method):
if operator is None:
raise ValueError("operator is not provided.")
from .c_transforms import Relational
if constant is None:
raise ValueError("constant is not provided.")
from .c_transforms import Relational
if not isinstance(operator, Relational):
raise TypeError("operator is not a Relational operator enum.")
......@@ -270,14 +270,17 @@ def check_pad_end(method):
raise ValueError("pad_shape is not provided.")
if pad_value is not None and not isinstance(pad_value, (str, float, bool, int, bytes)):
raise TypeError("pad_value must be either a primitive python str, float, bool, bytes or int")
raise TypeError("pad_value must be either a primitive python str, float, bool, int or bytes.")
if not isinstance(pad_shape, list):
raise TypeError("pad_shape must be a list")
for dim in pad_shape:
if dim is not None:
check_pos_int64(dim)
if isinstance(dim, int):
check_pos_int64(dim)
else:
raise TypeError("a value in the list is not an integer.")
kwargs["pad_shape"] = pad_shape
kwargs["pad_value"] = pad_value
......
......@@ -147,6 +147,21 @@ def test_concatenate_op_wrong_axis():
assert "only 1D concatenation supported." in repr(error_info.value)
def test_concatenate_op_negative_axis():
def gen():
yield (np.array([5., 6., 7., 8.], dtype=np.float),)
prepend_tensor = np.array([1.4, 2., 3., 4., 4.5], dtype=np.float)
append_tensor = np.array([9., 10.3, 11., 12.], dtype=np.float)
data = ds.GeneratorDataset(gen, column_names=["col"])
concatenate_op = data_trans.Concatenate(-1, prepend_tensor, append_tensor)
data = data.map(input_columns=["col"], operations=concatenate_op)
expected = np.array([1.4, 2., 3., 4., 4.5, 5., 6., 7., 8., 9., 10.3,
11., 12.])
for data_row in data:
np.testing.assert_array_equal(data_row[0], expected)
def test_concatenate_op_incorrect_input_dim():
def gen():
yield (np.array(["ss", "ad"], dtype='S'),)
......@@ -166,10 +181,11 @@ if __name__ == "__main__":
test_concatenate_op_all()
test_concatenate_op_none()
test_concatenate_op_string()
test_concatenate_op_multi_input_string()
test_concatenate_op_multi_input_numeric()
test_concatenate_op_type_mismatch()
test_concatenate_op_type_mismatch2()
test_concatenate_op_incorrect_dim()
test_concatenate_op_incorrect_input_dim()
test_concatenate_op_multi_input_numeric()
test_concatenate_op_multi_input_string()
test_concatenate_op_negative_axis()
test_concatenate_op_wrong_axis()
test_concatenate_op_incorrect_input_dim()
......@@ -22,6 +22,8 @@ import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as ops
# Extensive testing of PadEnd is already done in batch with Pad test cases
def pad_compare(array, pad_shape, pad_value, res):
data = ds.NumpySlicesDataset([array])
if pad_value is not None:
......@@ -32,8 +34,6 @@ def pad_compare(array, pad_shape, pad_value, res):
np.testing.assert_array_equal(res, d[0])
# Extensive testing of PadEnd is already done in batch with Pad test cases
def test_pad_end_basics():
pad_compare([1, 2], [3], -1, [1, 2, -1])
pad_compare([1, 2, 3], [3], -1, [1, 2, 3])
......@@ -57,6 +57,10 @@ def test_pad_end_exceptions():
pad_compare([b"1", b"2", b"3", b"4", b"5"], [2], 1, [])
assert "Source and pad_value tensors are not of the same type." in str(info.value)
with pytest.raises(TypeError) as info:
pad_compare([3, 4, 5], ["2"], 1, [])
assert "a value in the list is not an integer." in str(info.value)
if __name__ == "__main__":
test_pad_end_basics()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册