提交 fd7d75ae 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!143 Adapting ops Stack and Unsatck in ME

Merge pull request !143 from liuxiao/temp
......@@ -135,6 +135,7 @@ extern const PrimitivePtr kPrimGatherV2;
extern const PrimitivePtr kPrimSize;
extern const PrimitivePtr kPrimArgMax;
extern const PrimitivePtr kPrimPack;
extern const PrimitivePtr kPrimUnpack;
extern const PrimitivePtr kPrimUnsortedSegmentSum;
extern const PrimitivePtr kPrimConcatOffset;
extern const PrimitivePtr kPrimReshape;
......
......@@ -148,7 +148,8 @@ const char kNameSlice[] = "Slice";
const char kNameAddN[] = "AddN";
const char kNameLess[] = "Less";
const char kNameGreater[] = "Greater";
const char kNamePack[] = "Stack";
const char kNameStack[] = "Stack";
const char kNameUnstack[] = "Unstack";
const char kNameMerge[] = "Merge";
const char kNameGeSwitch[] = "GeSwitch";
......@@ -199,7 +200,8 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameMaxPool), ADPT_DESC(MaxPool)},
{string(kNameAvgPool), ADPT_DESC(AvgPool)},
{string(kNameTopK), ADPT_DESC(TopKV2)},
{string(kNamePack), ADPT_DESC(Pack)},
{string(kNameStack), ADPT_DESC(Pack)},
{string(kNameUnstack), ADPT_DESC(Unpack)},
{string(kNameSplitD), ADPT_DESC(SplitD)},
{string(kNameAllReduce), ADPT_DESC(HcomAllReduce)},
{string(kNameBroadcast), ADPT_DESC(HcomBroadcast)},
......
......@@ -266,6 +266,30 @@ def get_bprop_gather_v2(self):
return bprop
@bprop_getters.register(P.Stack)
def get_bprop_stack(self):
"""Generate bprop for Stack"""
axis = self.axis
def bprop(x, out, dout):
stack_grad = P.Unstack(axis)
out = stack_grad(dout)
return (out,)
return bprop
@bprop_getters.register(P.Unstack)
def get_bprop_unstack(self):
"""Generate bprop for Unstack"""
axis = self.axis
def bprop(x, out, dout):
unstack_grad = P.Stack(axis)
out = unstack_grad(dout)
return (out,)
return bprop
@bprop_getters.register(P.StridedSlice)
def get_bprop_strided_slice(self):
"""Generate bprop for StridedSlice"""
......
......@@ -19,7 +19,7 @@ Primitive operator classes.
A collection of operators to build nerual networks or computing functions.
"""
from .array_ops import (Argmax, Argmin, Cast, ConcatOffset, Concat,
from .array_ops import (Argmax, Argmin, Cast, ConcatOffset, Concat, Stack, Unstack,
Diag, DiagPart, DType, ExpandDims, Eye,
Fill, GatherNd, GatherV2, InvertPermutation,
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
......@@ -112,6 +112,8 @@ __all__ = [
'OneHot',
'GatherV2',
'Concat',
'Stack',
'Unstack',
'Tile',
'BiasAdd',
'Gelu',
......
......@@ -1350,6 +1350,150 @@ class Concat(PrimitiveWithInfer):
return out
def _get_stack_shape(x_shape, x_type, axis):
"""for satck output shape"""
validator.check_type("shape", x_shape, [tuple])
validator.check_integer("len of input_x shape", len(x_shape), 0, Rel.GT)
validator.check_subclass("shape0", x_type[0], mstype.tensor)
validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT)
rank_base = len(x_shape[0])
N = len(x_shape)
out_shape = x_shape[0]
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH)
if axis < 0:
axis = axis + rank_base + 1
for i in range(1, N):
v = x_shape[i]
validator.check('len of x_shape[%d]' % i, len(v), 'len of rank_base', rank_base)
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0])
for j in range(rank_base):
if v[j] != x_shape[0][j]:
raise ValueError("Stack evaluator element %d shape in input can not stack with first element" % i)
out_shape.insert(axis, N)
return out_shape
class Stack(PrimitiveWithInfer):
r"""
Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor.
Packs the list of tensors in `input_x` into a tensor with rank one higher than
each tensor in `input_x`, by packing them along the `axis` dimension.
Given a list of length `N` of tensors of shape `(A, B, C)`;
If `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`.
If `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`. Etc.
Args:
axis (int): The axis to stack along. Negative values wrap around,
so the valid range is [-(R+1), R+1). Default: 0.
Inputs:
- **input_x** (Union[tuple, list]) - A Tuple or list of Tensor objects with the same shape and type.
Outputs:
Tensor. A stacked Tensor with the same type as values.
Examples:
>>> data1 = Tensor(np.array([0, 1]).astype(np.float32))
>>> data2 = Tensor(np.array([2, 3]).astype(np.float32))
>>> op = P.Stack()
>>> output = op([data1, data2])
[[0, 1], [2, 3]]
"""
@prim_attr_register
def __init__(self, axis=0):
"""init Stack"""
self.__setattr_flag__ = True
validator.check_type("axis", axis, [int])
self.axis = axis
def __infer__(self, value):
x_shape = value['shape']
x_type = value['dtype']
self.add_prim_attr('num', len(x_shape))
all_shape = _get_stack_shape(x_shape, x_type, self.axis)
out = {'shape': all_shape,
'dtype': x_type[0],
'value': None}
return out
class Unstack(PrimitiveWithInfer):
r"""
Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
Unpacks num tensors from value by chipping it along the axis dimension.
If num is not specified (the default), it is inferred from value's shape.
If value.shape[axis] is not known, ValueError is raised.
For example, given a tensor of shape (A, B, C, D);
If axis == 0 then the i'th tensor in output is the slice value[i, :, :, :] and
each tensor in output will have shape (B, C, D). (Note that the dimension unpacked along is gone, unlike split).
If axis == 1 then the i'th tensor in output is the slice value[:, i, :, :] and
each tensor in output will have shape (A, C, D). Etc.
This is the opposite of stack.
Args:
axis (int): The axis to unstack along. Defaults to the first dimension.
Negative values wrap around, so the valid range is [-R, R).
Inputs:
- **input_x** (Tensor) - The shape is :math:`(x_1, x_2, ..., x_R)`.
A rank R > 0 Tensor to be unstacked.
Outputs:
A tuple of Tensors, the shape of each objects is same.
Raises:
ValueError: If axis is out of the range [-len(input_x.shape()), len(input_x.shape())),
or if len(input_x.shape[axis]) not equal to num.
Examples:
>>> unstack = P.Unstack()
>>> x = Tensor(np.array([[1, 1, 1, 1], [2, 2, 2, 2]]))
>>> output = unstack(x)
([1, 1, 1, 1], [2, 2, 2, 2])
"""
@prim_attr_register
def __init__(self, axis=0):
"""init Unstack"""
self.__setattr_flag__ = True
validator.check_type("axis", axis, [int])
self.axis = axis
def __infer__(self, x):
validator.check_subclass("x", x['dtype'], mstype.tensor)
x_shape = list(x['shape'])
dim = len(x_shape)
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT)
if self.axis < 0:
self.axis = self.axis + dim
output_num = x_shape[self.axis]
validator.check_type("num", output_num, [int])
validator.check_integer("output_num", output_num, 0, Rel.GT)
self.add_prim_attr('num', output_num)
output_valid_check = x_shape[self.axis] - output_num
validator.check_integer("the dimension which to unstack divides output_num", output_valid_check, 0, Rel.EQ)
out_shapes = []
out_dtypes = []
out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:]
for _ in range(output_num):
out_shapes.append(tuple(out_shape))
out_dtypes.append(x['dtype'])
out_shapes = tuple(out_shapes)
out_dtypes = tuple(out_dtypes)
out = {'shape': out_shapes,
'dtype': out_dtypes,
'value': None}
return out
class Slice(PrimitiveWithInfer):
"""
Slice a tensor in specified shape.
......
......@@ -80,6 +80,29 @@ class NetForConcat1(nn.Cell):
return self.concat((x1, x2))
class NetForStackInput(nn.Cell):
def __init__(self, op):
super(NetForStackInput, self).__init__()
self.op = op
self.mul = P.Mul()
def construct(self, *args):
t = ()
for i in range(len(args)):
t = t + (self.mul(args[i], args[i]),)
return self.op(t)
class NetForUnstackInput(nn.Cell):
def __init__(self, op):
super(NetForUnstackInput, self).__init__()
self.op = op
self.mul = P.Mul()
def construct(self, x1):
return self.op((self.mul(x1, x1)))
class NetForFlatten(nn.Cell):
def __init__(self):
super(NetForFlatten, self).__init__()
......@@ -968,6 +991,36 @@ test_case_array_ops = [
Tensor(np.array([1], np.float32)),
Tensor(np.array([1], np.float32)))],
'desc_bprop': [[3,]]}),
('StackV2_0', {
'block': NetForStackInput(P.Stack()),
'desc_inputs':[[2, 2], [2, 2], [2, 2]],
'desc_bprop':[[3, 2, 2]],
}),
('StackV2_1', {
'block': NetForStackInput(P.Stack(axis=-2)),
'desc_inputs':[[3, 2, 3], [3, 2, 3], [3, 2, 3]],
'desc_bprop':[[3, 2, 3, 3]],
}),
('StackV2_2', {
'block': NetForStackInput(P.Stack()),
'desc_inputs':[[2, 2]],
'desc_bprop':[[2, 2, 2]],
}),
('StackV2_3', {
'block': NetForStackInput(P.Stack()),
'desc_inputs':[[128, 128], [128, 128]],
'desc_bprop':[[2, 128, 128]],
}),
('UnstackV2_0', {
'block': NetForUnstackInput(P.Unstack(axis=0)),
'desc_inputs':[[2, 4]],
'desc_bprop':[[4], [4]],
}),
('UnstackV2_1', {
'block': NetForUnstackInput(P.Unstack(axis=-1)),
'desc_inputs':[Tensor(np.array([[1, 1, 1]], np.float32))],
'desc_bprop':[[1], [1], [1]],
}),
('Diag', {
'block': P.Diag(),
'desc_inputs': [[4]],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册