diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index 727d66dfb36358983e373730a01f9ca28d5b38d6..5fbf2b706792705f79fe620d50938378c7464bc2 100644 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -135,6 +135,7 @@ extern const PrimitivePtr kPrimGatherV2; extern const PrimitivePtr kPrimSize; extern const PrimitivePtr kPrimArgMax; extern const PrimitivePtr kPrimPack; +extern const PrimitivePtr kPrimUnpack; extern const PrimitivePtr kPrimUnsortedSegmentSum; extern const PrimitivePtr kPrimConcatOffset; extern const PrimitivePtr kPrimReshape; diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index c400d1c5733e102be684e3751c7b027d108cac46..d1c4a3d42e86626d610f0853fd49377b32734bf6 100755 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -148,7 +148,8 @@ const char kNameSlice[] = "Slice"; const char kNameAddN[] = "AddN"; const char kNameLess[] = "Less"; const char kNameGreater[] = "Greater"; -const char kNamePack[] = "Stack"; +const char kNameStack[] = "Stack"; +const char kNameUnstack[] = "Unstack"; const char kNameMerge[] = "Merge"; const char kNameGeSwitch[] = "GeSwitch"; @@ -199,7 +200,8 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameMaxPool), ADPT_DESC(MaxPool)}, {string(kNameAvgPool), ADPT_DESC(AvgPool)}, {string(kNameTopK), ADPT_DESC(TopKV2)}, - {string(kNamePack), ADPT_DESC(Pack)}, + {string(kNameStack), ADPT_DESC(Pack)}, + {string(kNameUnstack), ADPT_DESC(Unpack)}, {string(kNameSplitD), ADPT_DESC(SplitD)}, {string(kNameAllReduce), ADPT_DESC(HcomAllReduce)}, {string(kNameBroadcast), ADPT_DESC(HcomBroadcast)}, diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index 81d38a1e1e38778e9a81daafd7f715539a4ec2db..0a0caf471ee375fb46d000816206de119ad7d88d 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -266,6 +266,30 @@ def get_bprop_gather_v2(self): return bprop +@bprop_getters.register(P.Stack) +def get_bprop_stack(self): + """Generate bprop for Stack""" + axis = self.axis + + def bprop(x, out, dout): + stack_grad = P.Unstack(axis) + out = stack_grad(dout) + return (out,) + return bprop + + +@bprop_getters.register(P.Unstack) +def get_bprop_unstack(self): + """Generate bprop for Unstack""" + axis = self.axis + + def bprop(x, out, dout): + unstack_grad = P.Stack(axis) + out = unstack_grad(dout) + return (out,) + return bprop + + @bprop_getters.register(P.StridedSlice) def get_bprop_strided_slice(self): """Generate bprop for StridedSlice""" diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 89a5ea02495c8431b6c84a5aa75953110bfb7b54..7a8655b46c565950a8f875c60944a15479e1ff66 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -19,7 +19,7 @@ Primitive operator classes. A collection of operators to build nerual networks or computing functions. """ -from .array_ops import (Argmax, Argmin, Cast, ConcatOffset, Concat, +from .array_ops import (Argmax, Argmin, Cast, ConcatOffset, Concat, Stack, Unstack, Diag, DiagPart, DType, ExpandDims, Eye, Fill, GatherNd, GatherV2, InvertPermutation, IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, @@ -112,6 +112,8 @@ __all__ = [ 'OneHot', 'GatherV2', 'Concat', + 'Stack', + 'Unstack', 'Tile', 'BiasAdd', 'Gelu', diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index b91c2cbc7d5c8ad206a3945743c8d81346e8ba13..59d3083c5d2b640a3a28358782c043e87d7da0a3 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1350,6 +1350,150 @@ class Concat(PrimitiveWithInfer): return out +def _get_stack_shape(x_shape, x_type, axis): + """for satck output shape""" + validator.check_type("shape", x_shape, [tuple]) + validator.check_integer("len of input_x shape", len(x_shape), 0, Rel.GT) + validator.check_subclass("shape0", x_type[0], mstype.tensor) + validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT) + rank_base = len(x_shape[0]) + N = len(x_shape) + out_shape = x_shape[0] + validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH) + if axis < 0: + axis = axis + rank_base + 1 + for i in range(1, N): + v = x_shape[i] + validator.check('len of x_shape[%d]' % i, len(v), 'len of rank_base', rank_base) + validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0]) + for j in range(rank_base): + if v[j] != x_shape[0][j]: + raise ValueError("Stack evaluator element %d shape in input can not stack with first element" % i) + out_shape.insert(axis, N) + return out_shape + +class Stack(PrimitiveWithInfer): + r""" + Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor. + + Packs the list of tensors in `input_x` into a tensor with rank one higher than + each tensor in `input_x`, by packing them along the `axis` dimension. + Given a list of length `N` of tensors of shape `(A, B, C)`; + + If `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`. + + If `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`. Etc. + + Args: + axis (int): The axis to stack along. Negative values wrap around, + so the valid range is [-(R+1), R+1). Default: 0. + + Inputs: + - **input_x** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type. + + Outputs: + Tensor. A stacked Tensor with the same type as values. + + Examples: + >>> data1 = Tensor(np.array([0, 1]).astype(np.float32)) + >>> data2 = Tensor(np.array([2, 3]).astype(np.float32)) + >>> op = P.Stack() + >>> output = op([data1, data2]) + [[0, 1], [2, 3]] + """ + + @prim_attr_register + def __init__(self, axis=0): + """init Stack""" + self.__setattr_flag__ = True + validator.check_type("axis", axis, [int]) + self.axis = axis + + def __infer__(self, value): + x_shape = value['shape'] + x_type = value['dtype'] + self.add_prim_attr('num', len(x_shape)) + all_shape = _get_stack_shape(x_shape, x_type, self.axis) + out = {'shape': all_shape, + 'dtype': x_type[0], + 'value': None} + return out + + +class Unstack(PrimitiveWithInfer): + r""" + Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors. + + Unpacks num tensors from value by chipping it along the axis dimension. + If num is not specified (the default), it is inferred from value's shape. + If value.shape[axis] is not known, ValueError is raised. + + For example, given a tensor of shape (A, B, C, D); + + If axis == 0 then the i'th tensor in output is the slice value[i, :, :, :] and + each tensor in output will have shape (B, C, D). (Note that the dimension unpacked along is gone, unlike split). + + If axis == 1 then the i'th tensor in output is the slice value[:, i, :, :] and + each tensor in output will have shape (A, C, D). Etc. + + This is the opposite of stack. + + Args: + axis (int): The axis to unstack along. Defaults to the first dimension. + Negative values wrap around, so the valid range is [-R, R). + + Inputs: + - **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`. + A rank R > 0 Tensor to be unstacked. + + Outputs: + A tuple of Tensors, the shape of each objects is same. + + Raises: + ValueError: If axis is out of the range [-len(input_x.shape()), len(input_x.shape())), + or if len(input_x.shape[axis]) not equal to num. + + Examples: + >>> unstack = P.Unstack() + >>> x = Tensor(np.array([[1, 1, 1, 1], [2, 2, 2, 2]])) + >>> output = unstack(x) + ([1, 1, 1, 1], [2, 2, 2, 2]) + """ + + @prim_attr_register + def __init__(self, axis=0): + """init Unstack""" + self.__setattr_flag__ = True + validator.check_type("axis", axis, [int]) + self.axis = axis + + def __infer__(self, x): + validator.check_subclass("x", x['dtype'], mstype.tensor) + x_shape = list(x['shape']) + dim = len(x_shape) + validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT) + if self.axis < 0: + self.axis = self.axis + dim + output_num = x_shape[self.axis] + validator.check_type("num", output_num, [int]) + validator.check_integer("output_num", output_num, 0, Rel.GT) + self.add_prim_attr('num', output_num) + output_valid_check = x_shape[self.axis] - output_num + validator.check_integer("the dimension which to unstack divides output_num", output_valid_check, 0, Rel.EQ) + out_shapes = [] + out_dtypes = [] + out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:] + for _ in range(output_num): + out_shapes.append(tuple(out_shape)) + out_dtypes.append(x['dtype']) + out_shapes = tuple(out_shapes) + out_dtypes = tuple(out_dtypes) + out = {'shape': out_shapes, + 'dtype': out_dtypes, + 'value': None} + return out + + class Slice(PrimitiveWithInfer): """ Slice a tensor in specified shape. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 0f5b716e390c5df31c6beba7daccc38ff8f7adf0..5dcd2d553afe8bf19bbdc3b278468027003c16f3 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -80,6 +80,29 @@ class NetForConcat1(nn.Cell): return self.concat((x1, x2)) +class NetForStackInput(nn.Cell): + def __init__(self, op): + super(NetForStackInput, self).__init__() + self.op = op + self.mul = P.Mul() + + def construct(self, *args): + t = () + for i in range(len(args)): + t = t + (self.mul(args[i], args[i]),) + return self.op(t) + + +class NetForUnstackInput(nn.Cell): + def __init__(self, op): + super(NetForUnstackInput, self).__init__() + self.op = op + self.mul = P.Mul() + + def construct(self, x1): + return self.op((self.mul(x1, x1))) + + class NetForFlatten(nn.Cell): def __init__(self): super(NetForFlatten, self).__init__() @@ -968,6 +991,36 @@ test_case_array_ops = [ Tensor(np.array([1], np.float32)), Tensor(np.array([1], np.float32)))], 'desc_bprop': [[3,]]}), + ('StackV2_0', { + 'block': NetForStackInput(P.Stack()), + 'desc_inputs':[[2, 2], [2, 2], [2, 2]], + 'desc_bprop':[[3, 2, 2]], + }), + ('StackV2_1', { + 'block': NetForStackInput(P.Stack(axis=-2)), + 'desc_inputs':[[3, 2, 3], [3, 2, 3], [3, 2, 3]], + 'desc_bprop':[[3, 2, 3, 3]], + }), + ('StackV2_2', { + 'block': NetForStackInput(P.Stack()), + 'desc_inputs':[[2, 2]], + 'desc_bprop':[[2, 2, 2]], + }), + ('StackV2_3', { + 'block': NetForStackInput(P.Stack()), + 'desc_inputs':[[128, 128], [128, 128]], + 'desc_bprop':[[2, 128, 128]], + }), + ('UnstackV2_0', { + 'block': NetForUnstackInput(P.Unstack(axis=0)), + 'desc_inputs':[[2, 4]], + 'desc_bprop':[[4], [4]], + }), + ('UnstackV2_1', { + 'block': NetForUnstackInput(P.Unstack(axis=-1)), + 'desc_inputs':[Tensor(np.array([[1, 1, 1]], np.float32))], + 'desc_bprop':[[1], [1], [1]], + }), ('Diag', { 'block': P.Diag(), 'desc_inputs': [[4]],