提交 4a19e6b8 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3114 add coo_tensor

Merge pull request !3114 from riemann_penn/coo_tensor
......@@ -17,7 +17,7 @@
"""Resources for ast tree parse."""
import ast
import math
from mindspore import IndexedSlices
from mindspore import IndexedSlices, SparseTensor
from mindspore.ops.composite import multitype_ops
from mindspore.ops import functional as F, composite as C
from . import standard_method as M
......@@ -140,4 +140,5 @@ convert_object_map = {
# user defined
IndexedSlices: F.make_indexed_slices,
SparseTensor: F.make_sparse_tensor,
}
......@@ -124,6 +124,8 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s
// Do Nothing
} else if (type->isa<UndeterminedType>()) {
// Do Nothing
} else if (type->isa<SparseTensorType>()) {
// Do Nothing
} else if (type->isa<Tuple>()) {
TuplePtr tuple_type = dyn_cast<Tuple>(type);
type_proto->set_data_type(irpb::DT_TUPLE);
......
......@@ -803,6 +803,18 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_li
abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a);
abstract::AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b);
if (a_tuple == nullptr || b_tuple == nullptr) {
TypePtrList types;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
[](const AbstractBasePtr &arg) -> TypePtr {
MS_EXCEPTION_IF_NULL(arg);
return arg->BuildType();
});
auto stub = GenerateStubFunc(types);
if (stub != nullptr) {
MS_LOG(DEBUG) << "GenerateStubFunc for TupleAdd "
<< ", function: " << stub->ToString();
return stub;
}
MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple,but " << args_spec_list[0]->ToString() << ", "
<< args_spec_list[1]->ToString();
}
......
......@@ -119,42 +119,6 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
return py::none();
}
FuncGraphPtr GenerateStubFunc(const TypePtrList &types) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->enable_sparse();
if (!enable_sparse) {
return nullptr;
}
std::vector<AnfNodePtr> parameters;
ParameterPtr undetermined_param = nullptr;
auto stub = std::make_shared<FuncGraph>();
for (size_t i = 0; i < types.size(); ++i) {
auto param = stub->add_parameter();
parameters.push_back(param);
if (types[i]->type_id() == kObjectTypeUndeterminedType) {
undetermined_param = param;
}
}
if (undetermined_param != nullptr) {
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
for (size_t i = 0; i < types.size(); ++i) {
if (types[i]->type_id() == kObjectTypeFunction) {
std::vector<AnfNodePtr> call_prim{parameters[i], undetermined_param};
inputs.push_back(stub->NewCNode(call_prim));
} else {
inputs.push_back(parameters[i]);
}
}
auto stub_output = stub->NewCNode(inputs);
stub->set_output(stub_output);
stub->set_stub(true);
return stub;
}
return nullptr;
}
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
auto py_fn = SignMatch(types);
std::ostringstream buffer;
......
......@@ -283,6 +283,11 @@ const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared<Primitive>("MakeInd
const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues");
const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices");
const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape");
const PrimitivePtr kPrimIsIndexedSlices = std::make_shared<Primitive>("IsIndexedSlices");
// SparseTensor
const PrimitivePtr kPrimMakeSparseTensor = std::make_shared<Primitive>("MakeSparseTensor");
const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitive>("SparseTensorGetValues");
const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices");
const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape");
} // namespace prim
} // namespace mindspore
......@@ -292,7 +292,12 @@ extern const PrimitivePtr kPrimMakeIndexedSlices;
extern const PrimitivePtr kPrimIndexedSlicesGetValues;
extern const PrimitivePtr kPrimIndexedSlicesGetIndices;
extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape;
extern const PrimitivePtr kPrimIsIndexedSlices;
// SparseTensor
extern const PrimitivePtr kPrimMakeSparseTensor;
extern const PrimitivePtr kPrimSparseTensorGetValues;
extern const PrimitivePtr kPrimSparseTensorGetIndices;
extern const PrimitivePtr kPrimSparseTensorGetDenseShape;
// attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll
const char SWITCH_UNROLL_FLAG[] = "unroll_flag";
......
......@@ -349,6 +349,26 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2);
auto indices_dtype = indices->element()->BuildType();
if (!indices_dtype->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString();
}
auto indices_shp = indices->shape()->shape();
if (indices_shp.size() != 1) {
MS_EXCEPTION(TypeError) << "Indices must be a 1 dimension tensor, but got a " << indices_shp.size()
<< " dimension tensor";
}
auto values_shp = values->shape()->shape();
if (indices_shp[0] != values_shp[0]) {
MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values "
<< values_shp[0] << ", but got " << indices_shp[0];
}
for (auto elem_type : dense_shape->ElementsType()) {
if (!elem_type->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString();
}
}
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(dense_shape_value);
auto shp = dense_shape_value->value();
......@@ -358,6 +378,12 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim
auto elem = GetValue<int>(e);
return elem;
});
for (auto dense_shape_elem : dense_shape_vec) {
if (dense_shape_elem < 0) {
MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got "
<< dense_shape_value->ToString();
}
}
auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec);
ret->set_indices(indices);
ret->set_values(values);
......@@ -395,16 +421,89 @@ AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, c
return indexed_slices->dense_shape();
}
AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
bool ret = false;
if (args_spec_list[0]->isa<AbstractIndexedSlices>()) {
ret = true;
CheckArgsSize(op_name, args_spec_list, 3);
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2);
auto indices_dtype = indices->element()->BuildType();
if (!indices_dtype->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString();
}
auto indices_shp = indices->shape()->shape();
if (indices_shp.size() != 2) {
MS_EXCEPTION(TypeError) << "Indices must be a 2 dimension tensor, but got a " << indices_shp.size()
<< " dimension tensor";
}
auto values_shp = values->shape()->shape();
if (values_shp.size() != 1) {
MS_EXCEPTION(TypeError) << "Values must be a 1 dimension tensor, but got a " << values_shp.size()
<< " dimension tensor";
}
if (indices_shp[0] != values_shp[0]) {
MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values "
<< values_shp[0] << ", but got " << indices_shp[0];
}
for (auto elem_type : dense_shape->ElementsType()) {
if (!elem_type->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString();
}
}
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(dense_shape_value);
auto shp = dense_shape_value->value();
std::vector<int> dense_shape_vec;
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
[](const ValuePtr &e) -> int {
auto elem = GetValue<int>(e);
return elem;
});
for (auto dense_shape_elem : dense_shape_vec) {
if (dense_shape_elem < 0) {
MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got "
<< dense_shape_value->ToString();
}
MS_LOG(DEBUG) << "IsIndexedSlices result: " << ret << ", input: " << args_spec_list[0]->ToString();
return std::make_shared<AbstractScalar>(ret);
}
auto ret = std::make_shared<AbstractSparseTensor>(values->element()->BuildType(), dense_shape_vec);
ret->set_indices(indices);
ret->set_values(values);
ret->set_dense_shape(dense_shape);
return ret;
}
AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(sparse_tensor->values());
return sparse_tensor->values();
}
AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(sparse_tensor->indices());
return sparse_tensor->indices();
}
AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(sparse_tensor->dense_shape());
return sparse_tensor->dense_shape();
}
} // namespace abstract
} // namespace mindspore
......@@ -264,7 +264,7 @@ FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::Re
return IsPrimitiveCNode(user.first, prim);
});
if (cnode == users.end()) {
MS_LOG(EXCEPTION) << "Fail to find cnode.";
MS_LOG(EXCEPTION) << "Fail to find user for " << prim->ToString();
}
auto inputs_num = cnode->first->cast<CNodePtr>()->inputs().size() - 1;
......
......@@ -43,6 +43,7 @@
#include "frontend/optimizer/irpass/transpose_eliminate.h"
#include "frontend/optimizer/opt.h"
#include "frontend/optimizer/irpass/indexed_slices_eliminate.h"
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
namespace mindspore {
namespace opt {
......@@ -159,6 +160,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
indexed_slices_eliminate_ = MakeSubstitution(
std::make_shared<IndexedSlicesEliminater>(), "indexed_slices_eliminate",
{prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape});
// SparseTensor Eliminate
sparse_tensor_eliminate_ = MakeSubstitution(
std::make_shared<SparseTensorEliminater>(), "sparse_tensor_eliminate",
{prim::kPrimSparseTensorGetIndices, prim::kPrimSparseTensorGetValues, prim::kPrimSparseTensorGetDenseShape});
}
ResolveIRPassLib::ResolveIRPassLib() {
......
......@@ -107,6 +107,9 @@ class OptimizeIRPassLib {
// IndexedSlices Eliminate
SubstitutionPtr indexed_slices_eliminate_;
// SparseTensor Eliminate
SubstitutionPtr sparse_tensor_eliminate_;
};
// the collection of irpass for resolve action
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_
#include <vector>
#include <algorithm>
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/visitor.h"
#include "frontend/operator/ops.h"
namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimSparseTensorGetIndices, {prim::kPrimMakeSparseTensor, Xs}}
// {prim::kPrimSparseTensorGetValues, {prim::kPrimMakeSparseTensor, Xs}}
// {prim::kPrimSparseTensorGetDenseShape, {prim::kPrimMakeSparseTensor, Xs}}
class SparseTensorEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimSparseTensorGetIndices, {IsCNode})(node);
if (is_match_) {
return tuple_->input(1);
}
AnfVisitor::Match(prim::kPrimSparseTensorGetValues, {IsCNode})(node);
if (is_match_) {
return tuple_->input(2);
}
AnfVisitor::Match(prim::kPrimSparseTensorGetDenseShape, {IsCNode})(node);
if (is_match_) {
return tuple_->input(3);
}
return nullptr;
}
void Visit(const CNodePtr &cnode) override {
if (IsPrimitiveCNode(cnode, prim::kPrimMakeSparseTensor)) {
tuple_ = cnode;
is_match_ = true;
}
}
void Reset() {
tuple_ = nullptr;
is_match_ = false;
}
private:
bool is_match_{false};
CNodePtr tuple_{nullptr};
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_
......@@ -157,6 +157,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.make_ref_eliminate_,
irpass.get_ref_param_eliminate_,
irpass.indexed_slices_eliminate_,
irpass.sparse_tensor_eliminate_,
});
OptPassGroupMap map({
{"b_1", b_1},
......
......@@ -179,6 +179,12 @@ MethodMap &GetMethodMap() {
{"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices
{"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape
}},
{kObjectTypeSparseTensorType,
{
{"values", prim::kPrimSparseTensorGetValues}, // F.sparse_tensor_get_values
{"indices", prim::kPrimSparseTensorGetIndices}, // F.sparse_tensor_get_indices
{"dense_shape", prim::kPrimSparseTensorGetDenseShape}, // F.sparse_tensor_get_dense_shape
}},
{kObjectTypeJTagged, {}},
{kObjectTypeSymbolicKeyType, {}},
{kObjectTypeEnvType, {}}};
......
......@@ -138,7 +138,11 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}},
{prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}},
{prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}},
{prim::kPrimIsIndexedSlices, {InferImplIsIndexedSlices, true}},
// SparseTensor
{prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}},
{prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}},
{prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, true}},
{prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, true}},
};
return prim_eval_implement_map;
}
......
......@@ -358,7 +358,13 @@ AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, cons
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
} // namespace abstract
} // namespace mindspore
......
......@@ -36,6 +36,7 @@ using mindspore::abstract::AbstractIndexedSlices;
using mindspore::abstract::AbstractJTagged;
using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractScalar;
using mindspore::abstract::AbstractSparseTensor;
using mindspore::abstract::AbstractTensor;
using mindspore::abstract::AbstractTuple;
using mindspore::abstract::AbstractType;
......@@ -95,7 +96,7 @@ void ValidateAbstract(const AnfNodePtr &node) {
if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() ||
ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractIndexedSlices>() ||
ptrBase->isa<abstract::AbstractRefKey>()) {
ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>()) {
return;
}
......
......@@ -17,10 +17,10 @@ from . import dtype
from .api import ms_function
from .dtype import *
from .parameter import Parameter, ParameterTuple
from .tensor import MetaTensor, Tensor, IndexedSlices
from .tensor import MetaTensor, Tensor, IndexedSlices, SparseTensor
__all__ = [
"MetaTensor", "Tensor", "IndexedSlices", # tensor
"MetaTensor", "Tensor", "IndexedSlices", "SparseTensor", # tensor
'ms_function', # api
'Parameter', 'ParameterTuple', # parameter
"dtype"
......
......@@ -21,7 +21,7 @@ from .._checkparam import check_type, check_typename
from . import dtype as mstype
from ._register_for_tensor import tensor_operator_registry
__all__ = ['Tensor', 'MetaTensor', 'IndexedSlices']
__all__ = ['Tensor', 'MetaTensor', 'IndexedSlices', 'SparseTensor']
np_types = (np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
np.float32, np.float64, np.bool_)
......@@ -211,3 +211,7 @@ class Tensor(Tensor_):
class IndexedSlices:
def __init__(self, indices, values, dense_shape):
raise NotImplementedError
class SparseTensor:
def __init__(self, indices, values, dense_shape):
raise NotImplementedError
......@@ -1093,5 +1093,64 @@ std::string AbstractIndexedSlices::ToString() const {
<< ", dense_shape: " << dense_shape_->ToString();
return buffer.str();
}
// SparseTensor
TypePtr AbstractSparseTensor::BuildType() const {
MS_EXCEPTION_IF_NULL(element());
TypePtr element_type = element()->BuildType();
return std::make_shared<SparseTensorType>(element_type);
}
AbstractBasePtr AbstractSparseTensor::Clone() const {
MS_EXCEPTION_IF_NULL(element());
auto clone = std::make_shared<AbstractSparseTensor>(element()->Clone());
ShapePtr shp = shape();
clone->set_shape(shp->Clone());
clone->set_value(GetValueTrack());
clone->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
clone->set_values(values_->Clone()->cast<AbstractTensorPtr>());
clone->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
return clone;
}
AbstractBasePtr AbstractSparseTensor::Broaden() const {
MS_EXCEPTION_IF_NULL(element());
auto broaden = std::make_shared<AbstractSparseTensor>(element()->Broaden());
auto shp = shape();
broaden->set_shape(shp->Clone());
broaden->set_value(kAnyValue);
broaden->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
broaden->set_values(values_->Clone()->cast<AbstractTensorPtr>());
broaden->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
return broaden;
}
AbstractBasePtr AbstractSparseTensor::BroadenWithShape() const {
MS_EXCEPTION_IF_NULL(element());
auto broaden = std::make_shared<AbstractSparseTensor>(element()->Broaden());
auto shp = shape()->Clone();
shp->Broaden();
broaden->set_shape(shp);
broaden->set_value(kAnyValue);
broaden->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
broaden->set_values(values_->Clone()->cast<AbstractTensorPtr>());
broaden->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
return broaden;
}
std::string AbstractSparseTensor::ToString() const {
std::ostringstream buffer;
BaseShapePtr shape_track = GetShapeTrack();
MS_EXCEPTION_IF_NULL(shape_track);
MS_EXCEPTION_IF_NULL(element());
auto value_track = GetValueTrack();
MS_EXCEPTION_IF_NULL(value_track);
buffer << type_name() << "("
<< "shape: " << shape_track->ToString() << ", element: " << element()->ToString()
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"
<< ", indices: " << indices_->ToString() << ", values" << values_->ToString()
<< ", dense_shape: " << dense_shape_->ToString();
return buffer.str();
}
} // namespace abstract
} // namespace mindspore
......@@ -604,10 +604,39 @@ class AbstractIndexedSlices : public AbstractUndetermined {
MS_DECLARE_PARENT(AbstractIndexedSlices, AbstractUndetermined)
const AbstractTensorPtr indices() const { return indices_; }
void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
const AbstractTensorPtr values() const { return values_; }
void set_values(const AbstractTensorPtr &values) { values_ = values; }
const AbstractTuplePtr dense_shape() const { return dense_shape_; }
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr BroadenWithShape() const;
std::string ToString() const override;
private:
AbstractTensorPtr indices_;
AbstractTensorPtr values_;
AbstractTuplePtr dense_shape_;
};
// SparseTensor
class AbstractSparseTensor : public AbstractUndetermined {
public:
explicit AbstractSparseTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
: AbstractUndetermined(element, shape) {}
AbstractSparseTensor(const TypePtr &element_type, const std::vector<int> &shape)
: AbstractUndetermined(element_type, shape) {}
~AbstractSparseTensor() override = default;
MS_DECLARE_PARENT(AbstractSparseTensor, AbstractUndetermined)
const AbstractTensorPtr indices() const { return indices_; }
void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
const AbstractTensorPtr values() const { return values_; }
void set_values(const AbstractTensorPtr &values) { values_ = values; }
const AbstractTuplePtr dense_shape() const { return dense_shape_; }
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override;
......
......@@ -67,6 +67,7 @@ ABSTRACT_REPORT_NAME_TRAITS(Type)
ABSTRACT_REPORT_NAME_TRAITS(KeywordArg)
ABSTRACT_REPORT_NAME_TRAITS(Class)
ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices)
ABSTRACT_REPORT_NAME_TRAITS(SparseTensor)
ABSTRACT_REPORT_NAME_TRAITS(Sequeue)
template <typename T>
......
......@@ -221,6 +221,48 @@ bool IndexedSlicesType::operator==(const Type &other) const {
return *element_type_ == *other_elem_type;
}
TypePtr SparseTensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<SparseTensorType>();
}
return std::make_shared<SparseTensorType>(element_type_->DeepCopy());
}
std::string SparseTensorType::ToReprString() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->ToReprString() + "]";
}
std::string SparseTensorType::ToString() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->ToString() + "]";
}
std::string SparseTensorType::DumpText() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->DumpText() + "]";
}
bool SparseTensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const SparseTensorType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
Function::Function() : Object(kObjectTypeFunction) {
args_ = std::vector<TypePtr>();
retval_ = nullptr;
......
......@@ -177,6 +177,29 @@ class IndexedSlicesType : public Object {
};
using IndexedSlicesTypePtr = std::shared_ptr<IndexedSlicesType>;
class SparseTensorType : public Object {
public:
SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {}
explicit SparseTensorType(const TypePtr &ele)
: Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
~SparseTensorType() override = default;
MS_DECLARE_PARENT(SparseTensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
};
using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>;
class Function : public Object {
public:
Function();
......
......@@ -117,6 +117,8 @@ const char *ObjectIdLabel(const TypeId &v) {
return "kObjectTypeTensorType";
case kObjectTypeIndexedSlicesType:
return "kObjectTypeIndexedSlicesType";
case kObjectTypeSparseTensorType:
return "kObjectTypeSparseTensorType";
case kObjectTypeUndeterminedType:
return "kObjectTypeUndeterminedType";
case kObjectTypeDictionary:
......
......@@ -51,6 +51,7 @@ enum TypeId : int {
kObjectTypeKeyword,
kObjectTypeTensorType,
kObjectTypeIndexedSlicesType,
kObjectTypeSparseTensorType,
kObjectTypeUndeterminedType,
kObjectTypeClass,
kObjectTypeDictionary,
......
......@@ -207,6 +207,23 @@ TypePtr IndexedSlicesStrToType(const std::string &type_name) {
return std::make_shared<IndexedSlicesType>(element_type);
}
TypePtr SparseTensorStrToType(const std::string &type_name) {
if (type_name == "SparseTensor") {
return std::make_shared<SparseTensorType>();
}
auto start = type_name.find_first_of('[') + 1;
auto end = type_name.find_last_of(']');
if (start >= type_name.size()) {
return nullptr;
}
auto element_str = type_name.substr(start, end - start);
auto element_type = StringToType(element_str);
if (element_type == nullptr) {
return nullptr;
}
return std::make_shared<SparseTensorType>(element_type);
}
TypePtr UndeterminedStrToType(const std::string &type_name) {
if (type_name == "Undetermined") {
return std::make_shared<UndeterminedType>();
......@@ -349,6 +366,8 @@ TypePtr StringToType(const std::string &type_name) {
type = UndeterminedStrToType(type_name);
} else if (type_name.compare(0, strlen("IndexedSlices"), "IndexedSlices") == 0) {
type = IndexedSlicesStrToType(type_name);
} else if (type_name.compare(0, strlen("SparseTensor"), "SparseTensor") == 0) {
type = SparseTensorStrToType(type_name);
} else if (type_name.compare(0, strlen("List"), "List") == 0) {
type = ListStrToType(type_name);
} else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) {
......@@ -428,6 +447,7 @@ const TypePtr kTypeEnv = std::make_shared<EnvType>();
const TypePtr kTypeType = std::make_shared<TypeType>();
const TypePtr kTensorType = std::make_shared<TensorType>();
const TypePtr kIndexedSlicesType = std::make_shared<IndexedSlicesType>();
const TypePtr kSparseTensorType = std::make_shared<SparseTensorType>();
const TypePtr kUndeterminedType = std::make_shared<UndeterminedType>();
const TypePtr kString = std::make_shared<String>();
const TypePtr kList = std::make_shared<List>();
......
......@@ -139,6 +139,8 @@ REGISTER_PYBIND_DEFINE(
}));
(void)py::class_<IndexedSlicesType, Type, std::shared_ptr<IndexedSlicesType>>(m_sub, "IndexedSlicesType")
.def(py::init());
(void)py::class_<SparseTensorType, Type, std::shared_ptr<SparseTensorType>>(m_sub, "SparseTensorType")
.def(py::init());
(void)py::class_<UndeterminedType, Type, std::shared_ptr<UndeterminedType>>(m_sub, "UndeterminedType")
.def(py::init());
(void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function")
......
......@@ -17,9 +17,49 @@
*/
#include "ir/meta_func_graph.h"
#include "pipeline/jit/static_analysis/static_analysis.h"
#include "pipeline/jit/static_analysis/abstract_function.h"
#include "utils/context/ms_context.h"
#include "frontend/operator/ops.h"
// namespace to support intermediate representation definition
namespace mindspore {
FuncGraphPtr MetaFuncGraph::GenerateStubFunc(const TypePtrList &types) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->enable_sparse();
if (!enable_sparse) {
return nullptr;
}
std::vector<AnfNodePtr> parameters;
ParameterPtr undetermined_param = nullptr;
auto stub = std::make_shared<FuncGraph>();
for (size_t i = 0; i < types.size(); ++i) {
auto param = stub->add_parameter();
parameters.push_back(param);
if (types[i]->type_id() == kObjectTypeUndeterminedType) {
undetermined_param = param;
}
}
if (undetermined_param != nullptr) {
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
for (size_t i = 0; i < types.size(); ++i) {
if (types[i]->type_id() == kObjectTypeFunction) {
std::vector<AnfNodePtr> call_prim{parameters[i], undetermined_param};
inputs.push_back(stub->NewCNode(call_prim));
} else {
inputs.push_back(parameters[i]);
}
}
auto stub_output = stub->NewCNode(inputs);
stub->set_output(stub_output);
stub->set_stub(true);
return stub;
}
return nullptr;
}
FuncGraphPtr MetaFuncGraph::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) {
TypePtrList types;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
......
......@@ -79,6 +79,7 @@ class MetaFuncGraph : public FuncGraphBase {
std::shared_ptr<Derived> shared_from_base() {
return std::static_pointer_cast<Derived>(shared_from_this());
}
FuncGraphPtr GenerateStubFunc(const TypePtrList &types);
std::string name_;
std::vector<Signature> signatures_;
std::unordered_map<TypePtrList, FuncGraphPtr, TypeListHasher, TypeListEqual> cache_;
......
......@@ -40,18 +40,12 @@ class ParamValue {
const std::string &name() const { return name_; }
void set_name(const std::string &name) { name_ = name; }
const std::string &sparse_grad() const { return sparse_grad_; }
void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; }
bool requires_grad() const { return requires_grad_; }
void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; }
bool layerwise_parallel() const { return layerwise_parallel_; }
void set_layerwise_parallel(bool layerwise_parallel) { layerwise_parallel_ = layerwise_parallel; }
bool has_indexed_slices_grad() const { return has_indexed_slices_grad_; }
void set_has_indexed_slices_grad(bool b) { has_indexed_slices_grad_ = b; }
// Whether the parameter clone from other parameter.
bool cloned() const { return cloned_; }
......@@ -81,10 +75,8 @@ class ParamValue {
private:
tensor::MetaTensorPtr value_;
std::string name_{"Parameter"};
std::string sparse_grad_;
bool requires_grad_{true};
bool layerwise_parallel_{false};
bool has_indexed_slices_grad_{false};
bool be_cloned_{false};
bool cloned_{false};
std::vector<int32_t> be_cloned_index_;
......
......@@ -29,14 +29,10 @@ REGISTER_PYBIND_DEFINE(ParamValue, ([](const py::module *m) {
.def_property("requires_grad", &ParamValue::requires_grad, &ParamValue::set_requires_grad)
.def_property("layerwise_parallel", &ParamValue::layerwise_parallel,
&ParamValue::set_layerwise_parallel)
.def_property("has_indexed_slices_grad", &ParamValue::has_indexed_slices_grad,
&ParamValue::set_has_indexed_slices_grad)
.def_property("sparse_grad", &ParamValue::sparse_grad, &ParamValue::set_sparse_grad)
.def(py::pickle(
[](const ParamValue &p) { // __getstate__
return py::make_tuple(py::cast(p.value()), p.name(), p.requires_grad(),
p.layerwise_parallel(), p.has_indexed_slices_grad(),
p.sparse_grad());
p.layerwise_parallel());
},
[](const py::tuple &t) { // __setstate__
if (t.size() != 6) {
......@@ -47,8 +43,6 @@ REGISTER_PYBIND_DEFINE(ParamValue, ([](const py::module *m) {
p->set_name(t[1].cast<std::string>());
p->set_requires_grad(t[2].cast<bool>());
p->set_layerwise_parallel(t[3].cast<bool>());
p->set_has_indexed_slices_grad(t[4].cast<bool>());
p->set_sparse_grad(t[5].cast<std::string>());
return p;
}));
}));
......
......@@ -159,6 +159,10 @@ indexed_slices_get_values = Primitive('IndexedSlicesGetValues')
indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices')
indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape')
make_sparse_tensor = Primitive('MakeSparseTensor')
sparse_tensor_get_values = Primitive('SparseTensorGetValues')
sparse_tensor_get_indices = Primitive('SparseTensorGetIndices')
sparse_tensor_get_dense_shape = Primitive('SparseTensorGetDenseShape')
tensor_operator_registry.register('__add__', tensor_add)
tensor_operator_registry.register('__sub__', tensor_sub)
......
......@@ -616,5 +616,18 @@ TEST_F(TestOptLib, test_indexed_slices) {
ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns));
ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns));
}
TEST_F(TestOptLib, test_sparse_tensor) {
FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_indices");
FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_indices");
FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_values");
FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_values");
FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_dense_shape");
FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_dense_shape");
auto patterns = std::vector<SubstitutionPtr>({irpass.sparse_tensor_eliminate_});
ASSERT_TRUE(CheckOpt(before_get_indices, after_get_indices, patterns));
ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns));
ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns));
}
} // namespace opt
} // namespace mindspore
......@@ -1163,3 +1163,38 @@ def test_indexed_slices(tag):
return z
return fns[tag]
def test_sparse_tensor(tag):
""" test_add_zero """
fns = FnDict()
make_sparse_tensor = Primitive('MakeSparseTensor')
sparse_tensor_get_values = Primitive('SparseTensorGetValues')
sparse_tensor_get_indices = Primitive('SparseTensorGetIndices')
sparse_tensor_get_dense_shape = Primitive('SparseTensorGetDenseShape')
@fns
def before_get_indices(x, y, z):
return sparse_tensor_get_indices(make_sparse_tensor(x, y, z))
@fns
def after_get_indices(x, y, z):
return x
@fns
def before_get_values(x, y, z):
return sparse_tensor_get_values(make_sparse_tensor(x, y, z))
@fns
def after_get_values(x, y, z):
return y
@fns
def before_get_dense_shape(x, y, z):
return sparse_tensor_get_dense_shape(make_sparse_tensor(x, y, z))
@fns
def after_get_dense_shape(x, y, z):
return z
return fns[tag]
......@@ -35,6 +35,9 @@ from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.nn import Optimizer
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Momentum
from mindspore.train import Model
from ....dataset_mock import MindData
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
......@@ -47,6 +50,40 @@ size_op = P.Size()
invert_permutation = P.InvertPermutation()
logical_and = P.LogicalAnd()
def get_axis(x):
shape = shape_op(x)
length = F.tuple_len(shape)
perm = F.make_range(0, length)
return perm
class MSELoss(nn.Cell):
def __init__(self):
super(MSELoss, self).__init__()
self.reduce_sum = P.ReduceSum()
self.square = P.Square()
self.reduce_mean = P.ReduceMean()
def construct(self, data, label):
diff = data - label
return self.reduce_mean(self.square(diff), get_axis(diff))
class MindDataSet(MindData):
def __init__(self, dataset_types, dataset_shapes):
super(MindDataSet, self).__init__(size=2, batch_size=32,
np_types=dataset_types,
output_shapes=dataset_shapes,
input_indexs=(0, 1))
def __next__(self):
if self._size < self._iter_num:
raise StopIteration
self._iter_num += 1
lst = []
for shape_, type_ in zip(self._output_shapes, self._np_types):
lst.append(Tensor(np.ones(shape_).astype(type_)))
return tuple(lst)
@constexpr
def _generate_shape_index(out_shape, indices_shape, axis):
out_rank = len(out_shape)
......@@ -189,8 +226,8 @@ def test_indexed_slices_make_indexed_slices():
def construct(self, indices, values):
ret = (IndexedSlices(indices, values, self.dense_shape),)
return ret[0]
indices = Tensor([[0, 0], [1, 2]])
values = Tensor([1, 2], dtype=ms.float32)
indices = Tensor([1, 2])
values = Tensor([[0, 0], [1, 2]], dtype=ms.float32)
MakeIndexedSlices()(indices, values)
......@@ -202,8 +239,8 @@ def test_indexed_slices_attr():
def construct(self, indices, values):
x = IndexedSlices(indices, values, self.dense_shape)
return x.values(), x.indices(), x.dense_shape()
indices = Tensor([[0, 0], [1, 2]])
values = Tensor([1, 2], dtype=ms.float32)
indices = Tensor([0])
values = Tensor([[1, 2]], dtype=ms.float32)
IndexedSlicesGetAttr()(indices, values)
......@@ -279,3 +316,29 @@ def test_indexed_slices_env_get():
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
train_network(inputs, label)
def test_indexed_slices_model_train():
class Net(nn.Cell):
def __init__(self, in_features, out_features):
super(Net, self).__init__()
self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight")
self.add = P.TensorAdd()
self.cast = P.Cast()
self.flag = True
def construct(self, inputs, label):
x = self.add(inputs, self.weight)
if self.flag:
x = self.cast(x, mstype.float32)
return x
dataset_types = (np.float32, np.float32)
dataset_shapes = ((16, 16), (16, 16))
dataset = MindDataSet(dataset_types, dataset_shapes)
net = Net(16, 16)
net.set_train()
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
@File : test_sparse_tensor.py
@Author:
@Date : 2020-07-16
@Desc : test mindspore sparse_tensor's operation
"""
import mindspore as ms
import mindspore.nn as nn
from mindspore.ops import composite as C
from mindspore import Tensor, SparseTensor, context
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
def test_sparse_tensor_make_sparse_tensor():
class MakeSparseTensor(nn.Cell):
def __init__(self):
super(MakeSparseTensor, self).__init__()
self.dense_shape = (3, 4)
def construct(self, indices, values):
ret = (SparseTensor(indices, values, self.dense_shape),)
return ret[0]
indices = Tensor([[0, 1], [1, 2]])
values = Tensor([1, 2], dtype=ms.float32)
MakeSparseTensor()(indices, values)
def test_sparse_tensor_attr():
grad_op = C.GradOperation('get_all', get_all=True)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, input1, input2):
gout = grad_op(self.network)(input1, input2)
return gout
class SparseTensorGetAttr(nn.Cell):
def __init__(self):
super(SparseTensorGetAttr, self).__init__()
self.dense_shape = (3, 4)
def construct(self, indices, values):
x = SparseTensor(indices, values, self.dense_shape)
return x.values(), x.indices(), x.dense_shape()
indices = Tensor([[0, 1], [1, 2]])
values = Tensor([1, 2], dtype=ms.float32)
SparseTensorGetAttr()(indices, values)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册