提交 8ac5672a 编写于 作者: F fary86

Add support for dynamic shape

上级 779c668a
......@@ -113,6 +113,8 @@ inline const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransDat
inline const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask");
inline const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad");
inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared<Primitive>("ArgMaxWithValue");
inline const PrimitivePtr kPrimUnique = std::make_shared<Primitive>("Unique");
inline const PrimitivePtr kPrimUniqueGrad = std::make_shared<Primitive>("UniqueGrad");
// NN
inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
......
......@@ -148,5 +148,47 @@ AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &pri
ret->set_shape(std::make_shared<Shape>(shape));
return ret;
}
AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// inputs: a 1-d Tensor
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto shape = input->shape();
if (shape->shape().size() != 1) {
MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 1.";
}
std::vector<int> ids_shape = {Shape::SHP_ANY};
std::vector<int> min_shape = {1};
std::vector<int> max_shape = shape->shape();
auto ids =
std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(ids_shape, min_shape, max_shape));
auto ids_idx = std::make_shared<AbstractTensor>(std::make_shared<Int>(32), shape->shape());
// outputs: ids, ids_idx
AbstractBasePtrList elements = {ids, ids_idx};
return std::make_shared<AbstractTuple>(elements);
}
AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// inputs: a 1-d Tensor
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractTuplePtr dout = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
CheckArgsSize(op_name + " dout", dout->elements(), 2);
auto ids = CheckArg<AbstractTensor>(op_name, dout->elements(), 0);
auto ids_idx = CheckArg<AbstractTensor>(op_name, dout->elements(), 1);
if (ids->shape()->shape().size() != 1) {
MS_LOG(EXCEPTION) << "Dims of dout[0] of " << op_name << "' input must be 1.";
}
if (ids_idx->shape()->shape().size() != 1) {
MS_LOG(EXCEPTION) << "Dims of dout[1] of " << op_name << "' input must be 1.";
}
// outputs: dx
return std::make_shared<AbstractTensor>(ids->element(), ids_idx->shape());
}
} // namespace abstract
} // namespace mindspore
......@@ -23,6 +23,7 @@
#include <mutex>
#include <string>
#include <utility>
#include <unordered_set>
#include "frontend/operator/cc_implementations.h"
#include "frontend/operator/ops.h"
......@@ -62,6 +63,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
{prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}},
{prim::kPrimPack, {InferImplPack, true}},
{prim::kPrimUnique, {InferImplUnique, true}},
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
// Structure
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
{prim::kPrimMakeList, {InferImplMakeList, true}},
......@@ -389,6 +392,14 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
if (abs_base->isa<AbstractTensor>()) {
auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
dic["shape"] = arg_tensor->shape()->shape();
if (MsContext::GetInstance()->execution_mode() == kGraphMode) {
const auto &min_shape = arg_tensor->shape()->min_shape();
const auto &max_shape = arg_tensor->shape()->max_shape();
if (!min_shape.empty() && !max_shape.empty()) {
dic["min_shape"] = min_shape;
dic["max_shape"] = max_shape;
}
}
dic["dtype"] = arg_tensor->BuildType();
dic["value"] = BuildValue(arg_tensor->BuildValue());
} else if (abs_base->isa<AbstractIndexedSlices>()) {
......@@ -503,7 +514,10 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
if (output["value"].is_none()) {
auto out_shape = output["shape"];
auto out_dtype = output["dtype"];
return PyListDtype2AbstractTensor(out_shape, out_dtype);
py::object min_shape = output.contains("min_shape") ? (py::object)output["min_shape"] : (py::object)py::none();
py::object max_shape = output.contains("max_shape") ? (py::object)output["max_shape"] : (py::object)py::none();
return PyListDtype2AbstractTensor(out_shape, out_dtype, min_shape, max_shape);
}
// Convert pyobject to Value, then to AbstractValue
ValuePtr converted_ret = nullptr;
......
......@@ -244,6 +244,10 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
......
......@@ -371,7 +371,8 @@ py::object VectorRefToPyData(const VectorRef &value_list) {
return ret;
}
AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj) {
AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj,
const py::object &min_shape, const py::object &max_shape) {
if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) && py::isinstance<Type>(type_obj)) {
auto ret_vec = shape_obj.cast<std::vector<int>>();
auto ret_dtype = type_obj.cast<TypePtr>();
......@@ -382,12 +383,23 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py
return abs_scalar;
}
AbstractBasePtr tensor = nullptr;
std::vector<int> min_shape_vec;
std::vector<int> max_shape_vec;
if (!min_shape.is_none()) {
min_shape_vec = min_shape.cast<std::vector<int>>();
}
if (!max_shape.is_none()) {
max_shape_vec = max_shape.cast<std::vector<int>>();
}
auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec);
if (ret_dtype->isa<TensorType>()) {
auto tensor_type = type_obj.cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
tensor = std::make_shared<abstract::AbstractTensor>(tensor_type->element(), ret_vec);
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, tensor_type->element());
tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
} else {
tensor = std::make_shared<abstract::AbstractTensor>(ret_dtype, ret_vec);
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
}
return tensor;
} else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) {
......
......@@ -47,7 +47,9 @@ bool BaseRefToInt(const ValuePtr &v, int *value);
bool ValueToBool(const ValuePtr &in, bool *out);
py::object ValuePtrToPyData(const ValuePtr &value);
AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj);
AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj,
const py::object &min_shape = py::none(),
const py::object &max_shape = py::none());
bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
const std::shared_ptr<py::object> &ret_val);
......
......@@ -67,6 +67,9 @@ std::string Shape::DumpText() const {
buffer << "[";
for (size_t i = 0; i < shape_.size(); i++) {
buffer << (i > 0 ? ", " : "") << shape_[i];
if (shape_[i] == SHP_ANY && min_shape_.size() == shape_.size() && max_shape_.size() == shape_.size()) {
buffer << "_" << min_shape_[i] << "^" << max_shape_[i];
}
}
buffer << "]";
return buffer.str();
......
......@@ -74,16 +74,22 @@ class Shape : public BaseShape {
(void)std::transform(list.begin(), list.end(), std::back_inserter(shape_),
[](const int64_t &value) { return static_cast<int>(value); });
}
Shape(const std::vector<int> &list, const std::vector<int> &min_shape, const std::vector<int> &max_shape)
: shape_(list), min_shape_(min_shape), max_shape_(max_shape) {}
~Shape() override = default;
MS_DECLARE_PARENT(Shape, BaseShape)
std::string ToString() const override;
std::string DumpText() const override;
bool operator==(const BaseShape &other) const override;
BaseShapePtr Clone() const override { return std::make_shared<Shape>(shape_); }
BaseShapePtr Clone() const override { return std::make_shared<Shape>(shape_, min_shape_, max_shape_); }
void Broaden() override;
std::vector<int> &shape() { return shape_; }
std::vector<int> &min_shape() { return min_shape_; }
std::vector<int> &max_shape() { return max_shape_; }
std::vector<int> shape_; // use SHP_ANY to implement the any shape in python
std::vector<int> shape_; // use SHP_ANY to implement the any shape in python
std::vector<int> min_shape_; // record mininum length for each dynamic dimention
std::vector<int> max_shape_; // record maximum length for each dynamic dimention
};
using ShapePtr = std::shared_ptr<Shape>;
using ShapePtrList = std::vector<ShapePtr>;
......
......@@ -55,15 +55,66 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) {
return shape1;
}
std::vector<int> dims;
bool has_dynamic_shape = false;
dims.resize(shape1->shape().size());
for (std::size_t i = 0; i < shape1->shape().size(); i++) {
if (shape1->shape()[i] == shape2->shape()[i]) {
dims[i] = shape1->shape()[i];
if (shape1->shape()[i] == Shape::SHP_ANY) {
has_dynamic_shape = true;
}
} else {
dims[i] = Shape::SHP_ANY;
has_dynamic_shape = true;
}
}
return std::make_shared<Shape>(dims);
if (!has_dynamic_shape) {
return std::make_shared<Shape>(dims);
}
// calculate dynamic shape
std::vector<int> min_dims(dims.size());
std::vector<int> max_dims(dims.size());
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] != Shape::SHP_ANY) {
min_dims[i] = max_dims[i] = dims[i];
continue;
}
if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) {
min_dims[i] = std::min(shape1->shape()[i], shape2->shape()[i]);
max_dims[i] = std::max(shape1->shape()[i], shape2->shape()[i]);
continue;
}
if (shape1->shape()[i] == Shape::SHP_ANY && shape2->shape()[i] != Shape::SHP_ANY) {
if (shape1->min_shape().empty() || shape1->max_shape().empty()) {
MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString()
<< " has dynamic shape, but does not have min/max shape info.";
}
min_dims[i] = std::min(shape1->min_shape()[i], shape2->shape()[i]);
max_dims[i] = std::max(shape1->max_shape()[i], shape2->shape()[i]);
continue;
}
if (shape1->shape()[i] != Shape::SHP_ANY && shape2->shape()[i] == Shape::SHP_ANY) {
if (shape2->min_shape().empty() || shape2->max_shape().empty()) {
MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString()
<< " has dynamic shape, but does not have min/max shape info.";
}
min_dims[i] = std::min(shape1->shape()[i], shape2->min_shape()[i]);
max_dims[i] = std::max(shape1->shape()[i], shape2->max_shape()[i]);
continue;
}
// both shapes contains dynamic shape
if (shape1->min_shape().empty() || shape1->max_shape().empty()) {
MS_EXCEPTION(ValueError) << "Shape " << shape1->ToString()
<< " has dynamic shape, but does not have min/max shape info.";
}
if (shape2->min_shape().empty() || shape2->max_shape().empty()) {
MS_EXCEPTION(ValueError) << "Shape " << shape2->ToString()
<< " has dynamic shape, but does not have min/max shape info.";
}
min_dims[i] = std::min(shape1->min_shape()[i], shape2->min_shape()[i]);
max_dims[i] = std::max(shape1->max_shape()[i], shape2->max_shape()[i]);
}
return std::make_shared<Shape>(dims, min_dims, max_dims);
}
AbstractBasePtr AbstractJoin(const AbstractBasePtrList &args_spec_list) {
......
......@@ -807,3 +807,23 @@ def get_bprop_trans_shape(self):
dx = op(dout, shape_op(x))
return (dx, zeros_like(shape))
return bprop
@bprop_getters.register(P.Unique)
def get_bprop_unique(self):
"""Generate bprop for Unique"""
op = G.UniqueGrad()
def bprop(x, out, dout):
dx = op(dout, out)
return (dx,)
return bprop
@bprop_getters.register(P.UnsortedSegmentSum)
def get_bprop_unsorted_segment_sum(self):
"""Generate bprop for UnsortedSegmentSum"""
op = G.UnsortedSegmentSumGrad()
def bprop(x, segment_ids, num_segments, out, dout):
dx = op(dout, segment_ids)
return (dx, zeros_like(segment_ids), zeros_like(num_segments))
return bprop
......@@ -82,5 +82,8 @@ def get_concat_offset(x_shp, x_type, axis, prim_name):
if j != axis and v[j] != x_shp[0][j]:
raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not concat with first element")
offset.append(all_shp)
all_shp += v[axis]
if all_shp == -1 or v[axis] == -1:
all_shp = -1
else:
all_shp += v[axis]
return offset, all_shp, axis
......@@ -32,7 +32,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd,
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup)
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup,
Unique)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice,
......
......@@ -491,6 +491,31 @@ class FusedBatchNormGrad(Primitive):
raise NotImplementedError
class UniqueGrad(Primitive):
"""Gradients of Unique operation."""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['dy', 'y'], outputs=['dx'])
def __call__(self, dy, x, scale, save_mean, save_inv_variance):
raise NotImplementedError
class UnsortedSegmentSumGrad(PrimitiveWithInfer):
"""Gradients of UnsortedSegmentSum operation."""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['grads', 'ids'], outputs=['y'])
def infer_shape(self, grads, ids):
return ids + grads[len(ids):]
def infer_dtype(self, grads, ids):
return grads
class BNTrainingReduceGrad(PrimitiveWithInfer):
"""Gradients of FusedBatchNorm operation."""
......
......@@ -27,7 +27,7 @@ import numpy as np
from .._utils import get_concat_offset
from ..operations.math_ops import _infer_shape_reduce
from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op
from ..._c_expression import signature_dtype as sig_dtype
from ..._c_expression import signature_kind as sig_kind
from ..._c_expression import signature_rw as sig_rw
......@@ -556,6 +556,28 @@ class Transpose(PrimitiveWithInfer):
return out
class Unique(Primitive):
"""
Returns the unique elements of input tensor and also return a tensor containing the index of each value of input
tensor corresponding to the output unique tensor.
Inputs:
- **x** (Tensor) - The input tensor.
Outputs:
Tuple, containing tensor objects `(y, idx)`, `y` is a tensor has the same type as `x`, `idx` is a tensor
containing indices of elements in the input coressponding to the output tensor.
Examples:
>>> x = Tensor(np.array([1, 2, 5, 2]), mindspore.float32)
>>> out = P.Unique()(x)
(Tensor([1, 2, 5], mindspore.int32), Tensor([0, 1, 2, 1], mindspore.float32))
"""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['x'], outputs=['output'])
class GatherV2(PrimitiveWithInfer):
"""
Returns a slice of input tensor based on the specified indices and axis.
......
......@@ -20,6 +20,7 @@ import copy
from mindspore.common.api import _wrap_func
from mindspore.common import Parameter
from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore import context
from .._c_expression import Primitive_, real_run_op, prim_type
from .._c_expression import signature_rw as sig_rw
from .._c_expression import signature_kind as sig_kind
......@@ -138,6 +139,8 @@ class Primitive(Primitive_):
return self
def __getattr__(self, item):
if item == 'infer_dynamic_shape':
return None
if item in super().get_attr_dict():
return super().get_attr_dict()[item]
if item in self.attrs:
......@@ -282,13 +285,49 @@ class PrimitiveWithInfer(Primitive):
def __infer__(self, *args):
"""Infer shape, type, and value at the same time by using dictionary as arguments."""
is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
fn_infer_dynamic_shape = getattr(self, 'infer_dynamic_shape', None)
if is_graph_mode and fn_infer_dynamic_shape is not None:
out = fn_infer_dynamic_shape(*args)
tracks = ['dtype', 'value']
for track in tracks:
fn = getattr(self, 'infer_' + track)
# fn may return None
out[track] = fn(*(x[track] for x in args))
return out
tracks = ['dtype', 'shape', 'value']
out = {}
for track in tracks:
fn = getattr(self, 'infer_' + track)
# fn may return None
out[track] = fn(*(x[track] for x in args))
return out
# in non-graph_mode, it is not necessary to infer min/max shape
if not is_graph_mode:
return out
def get_specified_shape(elems, attr):
has_specified_shape = False
ret_vals = []
for elem in elems:
if attr in elem:
has_specified_shape = True
ret_vals.append(elem[attr])
else:
ret_vals.append(elem['shape'])
return has_specified_shape, tuple(ret_vals)
has_min_shape, min_shapes = get_specified_shape(args, 'min_shape')
has_max_shape, max_shapes = get_specified_shape(args, 'max_shape')
if not (has_min_shape or has_max_shape):
return out
if has_min_shape and has_max_shape:
fn_infer_shape = getattr(self, 'infer_shape')
out['min_shape'] = fn_infer_shape(*min_shapes)
out['max_shape'] = fn_infer_shape(*max_shapes)
return out
raise ValueError('Input args has invalid dynamic shape, args info: {args}')
def prim_attr_register(fn):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册