提交 5cccfbc6 编写于 作者: L leilei_snow 提交者: leilei_snow

Add TransShape Operator.

上级 f42e36b6
......@@ -134,6 +134,7 @@ const char kNameAssignSub[] = "AssignSub";
const char kNameNPUAllocFloatStatus[] = "NPUAllocFloatStatus";
const char kNameNPUClearFloatStatus[] = "NPUClearFloatStatus";
const char kNameReshape[] = "Reshape";
const char kNameTransShape[] = "TransShape";
const char kNameRealDiv[] = "RealDiv";
const char kNameTile[] = "Tile";
const char kNameCos[] = "Cos";
......@@ -242,6 +243,7 @@ std::unordered_map<std::string, OpAdapterDescPtr> &DfGraphConvertor::get_adpt_ma
{string(kNameBatchNorm), ADPT_DESC(BatchNorm)},
{string(kNameBatchNormGrad), ADPT_DESC(BatchNormGrad)},
{string(kNameReshape), ADPT_DESC(Reshape)},
{string(kNameTransShape), ADPT_DESC(TransShape)},
{string(kNameFlattenGrad), ADPT_DESC(Reshape)},
{prim::kPrimFlatten->name(), ADPT_DESC(Flatten)},
{string(kNameAddN), ADPT_DESC(AddN)},
......
......@@ -442,6 +442,12 @@ INPUT_MAP(Reshape) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(shape)}};
ATTR_MAP(Reshape) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Reshape) = {{0, OUTPUT_DESC(y)}};
// TransShape
INPUT_MAP(TransShape) = {{1, INPUT_DESC(x)}};
INPUT_ATTR_MAP(TransShape) = {{2, ATTR_DESC(outShape, AnyTraits<int>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(TransShape) = EMPTY_ATTR_MAP;
OUTPUT_MAP(TransShape) = {{0, OUTPUT_DESC(y)}};
// BiasAdd
INPUT_MAP(BiasAdd) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(bias)}};
ATTR_MAP(BiasAdd) = {{"data_format", ATTR_DESC(data_format, AnyTraits<std::string>())}};
......
......@@ -112,6 +112,9 @@ DECLARE_OP_USE_INPUT_ATTR(DepthwiseConv2DBackpropInputD)
DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropInputD)
DECLARE_OP_ADAPTER(Reshape)
DECLARE_OP_USE_OUTPUT(Reshape)
DECLARE_OP_ADAPTER(TransShape)
DECLARE_OP_USE_INPUT_ATTR(TransShape)
DECLARE_OP_USE_OUTPUT(TransShape)
DECLARE_OP_ADAPTER(Iou)
DECLARE_OP_USE_OUTPUT(Iou)
DECLARE_OP_ADAPTER(ResizeNearestNeighborV2D)
......
......@@ -696,3 +696,13 @@ def get_bprop_reverse_sequence(self):
dx = reverse_sequence_grad(dout, seq_lengths)
return dx, zeros_like(seq_lengths)
return bprop
@bprop_getters.register(P.TransShape)
def get_bprop_trans_shape(self):
"""Generate bprop for TransShape"""
op = P.TransShape()
def bprop(x, shape, out, dout):
dx = op(dout, shape_op(x))
return (dx, zeros_like(shape))
return bprop
......@@ -27,7 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split,
Shape, Size, Slice, Split, TransShape,
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
......
......@@ -3106,3 +3106,28 @@ class ReverseSequence(PrimitiveWithInfer):
validator.check_tensor_type_same({"x_dtype": x}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_tensor_type_same({"seq_lengths_dtype": seq_lengths}, [mstype.int32, mstype.int64], self.name)
return x
class TransShape(PrimitiveWithInfer):
"""
Transform the shape of input tensor to target shape.
Inputs:
- **input_x** (Tensor) - A input tensor.
- **out_shape** (tuple[int]) - The shape of output data.
Outputs:
Tensor, a tensor whose data type is same as 'input_x', and the shape is same as the `out_shape`.
"""
@prim_attr_register
def __init__(self):
self.__setattr_flag__ = True
def __infer__(self, x, shape):
shp = shape['value']
dtype = x['dtype']
validator.check_tensor_type_same({'x': dtype}, mstype.number_type + (mstype.bool_,), self.name)
self.add_prim_attr('out_shape', tuple(shp))
return {'shape': shp,
'dtype': dtype,
'value': None}
......@@ -1865,6 +1865,12 @@ test_case_array_ops = [
Tensor(np.arange(-12, 0).reshape(3, 2, 2), mstype.float32)],
'skip': ['backward'],
}),
('TransShape', {
'block': P.TransShape(),
'desc_const': [(1, 12, 24, 24)],
'desc_inputs': [[1, 3, 24, 24]],
'desc_bprop': [[1, 12, 24, 24]],
}),
]
test_case_other_ops = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册