提交 44e74ad5 编写于 作者: P panyifeng

Apply indexed_slices

上级 e03bd975
...@@ -45,7 +45,8 @@ FuncGraph::FuncGraph() ...@@ -45,7 +45,8 @@ FuncGraph::FuncGraph()
hyper_param_count_(0), hyper_param_count_(0),
is_generated_(false), is_generated_(false),
return_(nullptr), return_(nullptr),
manager_(std::weak_ptr<FuncGraphManager>()) { manager_(std::weak_ptr<FuncGraphManager>()),
stub_(false) {
debug_info_ = std::make_shared<GraphDebugInfo>(); debug_info_ = std::make_shared<GraphDebugInfo>();
} }
......
...@@ -344,6 +344,9 @@ class FuncGraph : public FuncGraphBase { ...@@ -344,6 +344,9 @@ class FuncGraph : public FuncGraphBase {
void SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs); void SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs);
bool HasEffect(const CNodePtr &cnode); bool HasEffect(const CNodePtr &cnode);
bool stub() const { return stub_; }
void set_stub(bool stub) { stub_ = stub; }
private: private:
// graph is manipulated by manager and others // graph is manipulated by manager and others
friend FuncGraphManager; friend FuncGraphManager;
...@@ -402,6 +405,7 @@ class FuncGraph : public FuncGraphBase { ...@@ -402,6 +405,7 @@ class FuncGraph : public FuncGraphBase {
// CNode order which relates to origin code order // CNode order which relates to origin code order
std::list<CNodePtr> order_; std::list<CNodePtr> order_;
bool stub_;
}; };
inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) { inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {
......
...@@ -218,6 +218,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons ...@@ -218,6 +218,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
(*target_func_graph)->set_kwonlyargs_count(func_graph->kwonlyargs_count()); (*target_func_graph)->set_kwonlyargs_count(func_graph->kwonlyargs_count());
(*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count()); (*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count());
(*target_func_graph)->set_is_generate(func_graph->is_generated()); (*target_func_graph)->set_is_generate(func_graph->is_generated());
(*target_func_graph)->set_stub(func_graph->stub());
TraceManager::EndTrace(); TraceManager::EndTrace();
} }
...@@ -629,6 +630,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP ...@@ -629,6 +630,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count()); new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
new_func_graph->set_is_generate(func_graph->is_generated()); new_func_graph->set_is_generate(func_graph->is_generated());
new_func_graph->set_stub(func_graph->stub());
for (auto &item : func_graph->parameter_default_value()) { for (auto &item : func_graph->parameter_default_value()) {
new_func_graph->set_param_default_value(item.first, cloner[item.second]); new_func_graph->set_param_default_value(item.first, cloner[item.second]);
} }
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "pipeline/static_analysis/param_validator.h" #include "pipeline/static_analysis/param_validator.h"
#include "operator/cc_implementations.h" #include "operator/cc_implementations.h"
#include "optimizer/opt.h" #include "optimizer/opt.h"
#include "utils/context/ms_context.h"
#include "utils/symbolic.h" #include "utils/symbolic.h"
#include "pybind_api/api_register.h" #include "pybind_api/api_register.h"
#include "./common.h" #include "./common.h"
...@@ -115,36 +116,43 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) { ...@@ -115,36 +116,43 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
} }
return item.second; return item.second;
} }
// Try best match return py::none();
py::function py_fn_subclass; }
size_t subclass_match_cnt = 0;
for (auto &item : fn_cache_py_) { FuncGraphPtr GenerateStubFunc(const TypePtrList &types) {
TypePtrList sign = item.first; auto context = MsContext::GetInstance();
if (sign.size() != types.size()) { MS_EXCEPTION_IF_NULL(context);
continue; bool enable_sparse = context->enable_sparse();
if (!enable_sparse) {
return nullptr;
}
std::vector<AnfNodePtr> parameters;
ParameterPtr undetermined_param = nullptr;
auto stub = std::make_shared<FuncGraph>();
for (size_t i = 0; i < types.size(); ++i) {
auto param = stub->add_parameter();
parameters.push_back(param);
if (types[i]->type_id() == kObjectTypeUndeterminedType) {
undetermined_param = param;
} }
auto match = true; }
for (size_t i = 0; i < sign.size(); ++i) { if (undetermined_param != nullptr) {
if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i]) && std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
!IsParentOrChildrenType(UnwrapRef(types[i]), sign[i])) { for (size_t i = 0; i < types.size(); ++i) {
match = false; if (types[i]->type_id() == kObjectTypeFunction) {
break; std::vector<AnfNodePtr> call_prim{parameters[i], undetermined_param};
inputs.push_back(stub->NewCNode(call_prim));
} else {
inputs.push_back(parameters[i]);
} }
} }
if (!match) { auto stub_output = stub->NewCNode(inputs);
continue; stub->set_output(stub_output);
} stub->set_stub(true);
py_fn_subclass = item.second; return stub;
subclass_match_cnt++;
}
if (subclass_match_cnt > 1) {
MS_LOG(EXCEPTION) << "There are more than one prototypes for overload function match by subclass";
}
if (subclass_match_cnt == 1) {
MS_LOG(DEBUG) << "Found one subclass match";
return py_fn_subclass;
} }
return py::none(); return nullptr;
} }
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
...@@ -159,6 +167,11 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { ...@@ -159,6 +167,11 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString(); MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString();
return func_graph; return func_graph;
} }
auto stub = GenerateStubFunc(types);
if (stub != nullptr) {
MS_LOG(DEBUG) << "GenerateStubFunc " << buffer.str() << ", function: " << stub->ToString();
return stub;
}
std::ostringstream oss; std::ostringstream oss;
oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_ oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_
<< "`, corresponding location info:\n"; << "`, corresponding location info:\n";
......
...@@ -23,8 +23,8 @@ ...@@ -23,8 +23,8 @@
#include "pipeline/static_analysis/param_validator.h" #include "pipeline/static_analysis/param_validator.h"
#include "pipeline/static_analysis/prim.h" #include "pipeline/static_analysis/prim.h"
#include "pipeline/static_analysis/utils.h" #include "pipeline/static_analysis/utils.h"
#include "utils/symbolic.h"
#include "utils/context/ms_context.h" #include "utils/context/ms_context.h"
#include "utils/symbolic.h"
namespace mindspore { namespace mindspore {
namespace abstract { namespace abstract {
...@@ -56,79 +56,6 @@ AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primit ...@@ -56,79 +56,6 @@ AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primit
return AbstractFunction::MakeAbstractFunction(jv); return AbstractFunction::MakeAbstractFunction(jv);
} }
class UndeterminedShapeType {
public:
explicit UndeterminedShapeType(const std::string &env_str) {
// param_name indices_shape indices_type values_shape values_type dense_shape
// export UNDETERMINED_SPARSE_SHAPE_TYPES="sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1
// 2:Float32:3 1 2"
std::vector<string> fields;
string tmp;
std::stringstream input(env_str);
while (std::getline(input, tmp, ':')) {
fields.push_back(tmp);
}
if (fields.size() != fields_num) {
MS_LOG(EXCEPTION) << "Expect " << fields_num << " fields, but got " << fields.size();
}
param_name_ = fields[0];
indices_shape_ = GetShape(fields[1]);
indices_type_ = StringToType(fields[2]);
values_shape_ = GetShape(fields[3]);
values_type_ = StringToType(fields[4]);
auto dense_shape_vec = GetShape(fields[5]);
AbstractBasePtrList dense_shape_list;
(void)std::transform(dense_shape_vec.begin(), dense_shape_vec.end(), std::back_inserter(dense_shape_list),
[](const auto &elem) { return FromValue(elem, false); });
dense_shape_ = dense_shape_list;
}
~UndeterminedShapeType() = default;
const std::string &param_name() { return param_name_; }
const std::vector<int> &indices_shape() { return indices_shape_; }
const TypePtr &indices_type() { return indices_type_; }
const std::vector<int> &values_shape() { return values_shape_; }
const TypePtr &values_type() { return values_type_; }
const AbstractBasePtrList &dense_shape() { return dense_shape_; }
private:
std::string param_name_;
std::vector<int> indices_shape_;
TypePtr indices_type_;
std::vector<int> values_shape_;
TypePtr values_type_;
AbstractBasePtrList dense_shape_;
static const size_t fields_num;
std::vector<int> GetShape(const std::string &shape_str);
};
std::vector<int> UndeterminedShapeType::GetShape(const std::string &shape_str) {
std::vector<int> ret;
std::istringstream iss(shape_str);
int elem;
while (iss.good()) {
iss >> elem;
ret.emplace_back(elem);
}
return ret;
}
const size_t UndeterminedShapeType::fields_num = 6;
std::unordered_map<std::string, UndeterminedShapeType> g_undetermined_configs;
void InitUndeterminedFromEnv(const std::string &sparse_shape_types) {
std::string tmp;
std::stringstream input(sparse_shape_types);
g_undetermined_configs.clear();
while (std::getline(input, tmp, ';')) {
auto config = UndeterminedShapeType(tmp);
g_undetermined_configs.insert(std::make_pair(config.param_name(), config));
MS_LOG(DEBUG) << "Undetermined config from env: " << tmp;
}
}
AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
...@@ -142,45 +69,14 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt ...@@ -142,45 +69,14 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString(); MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString();
} }
if (!key->sparse_grad().empty()) {
// Will be fixed once undetermined type ready
if (g_undetermined_configs.empty()) {
auto sparse_shape_types = common::GetEnv("UNDETERMINED_SPARSE_SHAPE_TYPES");
MS_LOG(INFO) << "Undetermind sparse shape:" << sparse_shape_types;
if (sparse_shape_types.empty()) {
sparse_shape_types = "sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1 2:Float32:3 1 2";
}
InitUndeterminedFromEnv(sparse_shape_types);
}
auto shape_types = g_undetermined_configs.find(key->sparse_grad());
if (shape_types == g_undetermined_configs.end()) {
MS_LOG(EXCEPTION) << "Param " << key->ToString()
<< " has sparse_grad, but shape/type is not configured in env UNDETERMINED_SPARSE_SHAPE_TYPES";
}
MS_LOG(DEBUG) << "EnvGetItem is sparse_grad " << key->ToString();
AbstractBasePtrList sparse_list;
// indices
auto indices_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types->second.indices_type());
auto indices =
std::make_shared<AbstractTensor>(indices_ele, std::make_shared<Shape>(shape_types->second.indices_shape()));
sparse_list.emplace_back(indices);
// values
auto dout_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types->second.values_type());
auto dout = std::make_shared<AbstractTensor>(dout_ele, std::make_shared<Shape>(shape_types->second.values_shape()));
sparse_list.emplace_back(dout);
// dense_shape
sparse_list.emplace_back(std::make_shared<AbstractTuple>(shape_types->second.dense_shape()));
return std::make_shared<AbstractTuple>(sparse_list);
}
auto context = MsContext::GetInstance(); auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
bool enable_sparse_flag = context->enable_sparse_flag(); bool enable_sparse = context->enable_sparse();
if (enable_sparse_flag && key->has_indexed_slices_grad() && dflt->isa<AbstractTensor>()) { if (enable_sparse && dflt->isa<AbstractTensor>()) {
auto dflt_tensor = dflt->cast<AbstractTensorPtr>(); auto dflt_tensor = dflt->cast<AbstractTensorPtr>();
return std::make_shared<AbstractUndetermined>(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone()); return std::make_shared<AbstractUndetermined>(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone());
} }
if (!key->GetValueTrack()->isa<SymbolicKeyInstance>()) { if (!key->GetValueTrack()->isa<SymbolicKeyInstance>()) {
return dflt; return dflt;
} }
...@@ -242,10 +138,7 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr & ...@@ -242,10 +138,7 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &
if (type->type_id() != kObjectTypeRefKey) { if (type->type_id() != kObjectTypeRefKey) {
MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString(); MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString();
} }
auto ret = std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]); return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
ret->set_sparse_grad(args_spec_list[2]->sparse_grad());
ret->set_has_indexed_slices_grad(args_spec_list[2]->has_indexed_slices_grad());
return ret;
} }
AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &,
......
...@@ -39,7 +39,7 @@ class ReplaceApplicator : public AnfVisitor { ...@@ -39,7 +39,7 @@ class ReplaceApplicator : public AnfVisitor {
} }
auto fg = GetValueNode<FuncGraphPtr>(node); auto fg = GetValueNode<FuncGraphPtr>(node);
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) { if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) {
return nullptr; return nullptr;
} }
...@@ -110,7 +110,7 @@ class InlinerBase : public AnfVisitor { ...@@ -110,7 +110,7 @@ class InlinerBase : public AnfVisitor {
// G // G
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]); auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) { if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) {
return nullptr; return nullptr;
} }
// Do not inline GraphKernel to Cell. // Do not inline GraphKernel to Cell.
......
...@@ -1367,7 +1367,6 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { ...@@ -1367,7 +1367,6 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
std::string env = common::GetEnv("SLICE_ENV"); std::string env = common::GetEnv("SLICE_ENV");
if (!env.empty()) { if (!env.empty()) {
MS_LOG(INFO) << "Slice tensors shape will be configured from env:" << env; MS_LOG(INFO) << "Slice tensors shape will be configured from env:" << env;
abstract::InitUndeterminedFromEnv(env);
} }
} }
......
...@@ -232,8 +232,6 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { ...@@ -232,8 +232,6 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
ValuePtr value = param_value->value(); ValuePtr value = param_value->value();
constexpr bool broaden = true; constexpr bool broaden = true;
AbstractBasePtr ptr = abstract::FromValue(value, broaden); AbstractBasePtr ptr = abstract::FromValue(value, broaden);
ptr->set_sparse_grad(param_value->sparse_grad());
ptr->set_has_indexed_slices_grad(param_value->has_indexed_slices_grad());
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr);
args_spec.push_back(ptr); args_spec.push_back(ptr);
......
...@@ -155,8 +155,8 @@ PYBIND11_MODULE(_c_expression, m) { ...@@ -155,8 +155,8 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel, .def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel,
"Set the GraphKernel switch to on or off.") "Set the GraphKernel switch to on or off.")
.def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.") .def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.")
.def("get_enable_sparse_flag", &mindspore::MsContext::enable_sparse_flag, "Get whether to enable sparse.") .def("get_enable_sparse", &mindspore::MsContext::enable_sparse, "Get whether to enable sparsity.")
.def("set_enable_sparse_flag", &mindspore::MsContext::set_enable_sparse_flag, "Set whether to enable sparse."); .def("set_enable_sparse", &mindspore::MsContext::set_enable_sparse, "Set whether to enable sparsity.");
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig") (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
......
...@@ -321,21 +321,19 @@ bool InferenceOptPreparePass(const ResourcePtr &res) { ...@@ -321,21 +321,19 @@ bool InferenceOptPreparePass(const ResourcePtr &res) {
return true; return true;
} }
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, std::vector<PassItem> kVmPasses = {{"opt_a", OptPassAGroup},
{"opt_a", OptPassAGroup}, {"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_b", OptPassBGroup}, {"opt_b", OptPassBGroup},
{"cconv", CconvPass}, {"cconv", CconvPass},
{"opt_graph_kernel_a", OptPassGraphKernelGroupA}, {"opt_graph_kernel_a", OptPassGraphKernelGroupA},
{"opt_graph_kernel_b", OptPassGraphKernelGroupB}, {"opt_graph_kernel_b", OptPassGraphKernelGroupB},
{"add_control_depend", AddControlDependPass}}; {"add_control_depend", AddControlDependPass}};
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, std::vector<PassItem> kGePasses = {
{"opt_a", OptPassAGroup}, {"opt_a", OptPassAGroup}, {"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_b", OptPassBGroup}, {"opt_b", OptPassBGroup}, {"add_control_depend", AddControlDependPass},
{"add_control_depend", AddControlDependPass}, {"opt_control", ControlGroup}, {"opt_prepare", PrepareGroup},
{"opt_control", ControlGroup}, {"cconv", CconvPass}};
{"opt_prepare", PrepareGroup},
{"cconv", CconvPass}};
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}}; std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}};
} // namespace pipeline } // namespace pipeline
......
...@@ -146,37 +146,35 @@ MethodMap &GetMethodMap() { ...@@ -146,37 +146,35 @@ MethodMap &GetMethodMap() {
}}, }},
{kObjectTypeTensorType, {kObjectTypeTensorType,
{ {
{"__add__", std::string("add")}, // C.add {"__add__", std::string("add")}, // C.add
{"__sub__", std::string("sub")}, // C.sub {"__sub__", std::string("sub")}, // C.sub
{"__mul__", std::string("mul")}, // C.mul {"__mul__", std::string("mul")}, // C.mul
{"__truediv__", std::string("truediv")}, // C.truediv {"__truediv__", std::string("truediv")}, // C.truediv
{"__floordiv__", std::string("floordiv")}, // C.floordiv {"__floordiv__", std::string("floordiv")}, // C.floordiv
{"__mod__", std::string("mod")}, // C.mod {"__mod__", std::string("mod")}, // C.mod
{"__pow__", std::string("pow_")}, // C.pow {"__pow__", std::string("pow_")}, // C.pow
{"__floor__", std::string("array_floor")}, // C.array_floor {"__floor__", std::string("array_floor")}, // C.array_floor
{"__trunc__", std::string("array_trunc")}, // C.array_trunc {"__trunc__", std::string("array_trunc")}, // C.array_trunc
{"__pos__", std::string("array_uadd")}, // C.array_uadd {"__pos__", std::string("array_uadd")}, // C.array_uadd
{"__neg__", std::string("array_usub")}, // C.array_usub {"__neg__", std::string("array_usub")}, // C.array_usub
{"__eq__", std::string("eq")}, // C.eq {"__eq__", std::string("eq")}, // C.eq
{"__ne__", std::string("ne")}, // C.ne {"__ne__", std::string("ne")}, // C.ne
{"__lt__", std::string("lt")}, // C.lt {"__lt__", std::string("lt")}, // C.lt
{"__gt__", std::string("gt")}, // C.gt {"__gt__", std::string("gt")}, // C.gt
{"__le__", std::string("le")}, // C.le {"__le__", std::string("le")}, // C.le
{"__ge__", std::string("ge")}, // C.ge {"__ge__", std::string("ge")}, // C.ge
{"__matmul__", prim::kPrimDot}, // P.dot, {"__matmul__", prim::kPrimDot}, // P.dot,
{"__len__", prim::kPrimArrayLen}, // P.array_len, {"__len__", prim::kPrimArrayLen}, // P.array_len,
{"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem, {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
{"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem, {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
{"__ms_iter__", std::string("array_iter")}, // C.array_iter {"__ms_iter__", std::string("array_iter")}, // C.array_iter
{"__ms_to_array__", prim::kPrimIdentity}, // P.identity, {"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
{"item", prim::kPrimArrayToScalar}, // P.array_to_scalar, {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar,
{"transpose", std::string("transpose")}, // P.transpose {"transpose", std::string("transpose")}, // P.transpose
{"__bool__", std::string("tensor_bool")}, // C.tensor_bool {"__bool__", std::string("tensor_bool")}, // C.tensor_bool
{"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices
}}, }},
{kObjectTypeIndexedSlicesType, {kObjectTypeIndexedSlicesType,
{ {
{"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices
{"values", prim::kPrimIndexedSlicesGetValues}, // F.indexed_slices_get_values {"values", prim::kPrimIndexedSlicesGetValues}, // F.indexed_slices_get_values
{"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices {"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices
{"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape {"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape
......
...@@ -55,7 +55,6 @@ ValuePtr AbstractBase::BuildValue() const { ...@@ -55,7 +55,6 @@ ValuePtr AbstractBase::BuildValue() const {
AbstractBasePtr AbstractBase::Broaden() const { AbstractBasePtr AbstractBase::Broaden() const {
AbstractBasePtr clone = Clone(); AbstractBasePtr clone = Clone();
clone->set_value(kAnyValue); clone->set_value(kAnyValue);
clone->set_sparse_grad(sparse_grad_);
return clone; return clone;
} }
...@@ -68,8 +67,7 @@ std::string AbstractBase::ToString() const { ...@@ -68,8 +67,7 @@ std::string AbstractBase::ToString() const {
MS_EXCEPTION_IF_NULL(type_); MS_EXCEPTION_IF_NULL(type_);
MS_EXCEPTION_IF_NULL(shape_); MS_EXCEPTION_IF_NULL(shape_);
buffer << type_name() << "(" buffer << type_name() << "("
<< "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() << "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() << ")";
<< " sparse_grad: " << sparse_grad_ << " has_indexed_slices_grad: " << has_indexed_slices_grad_ << ")";
return buffer.str(); return buffer.str();
} }
...@@ -78,25 +76,16 @@ AbstractBasePtr AbstractScalar::Broaden() const { return AbstractBase::Broaden() ...@@ -78,25 +76,16 @@ AbstractBasePtr AbstractScalar::Broaden() const { return AbstractBase::Broaden()
AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
MS_EXCEPTION_IF_NULL(other); MS_EXCEPTION_IF_NULL(other);
if (*this == *other) { if (*this == *other) {
auto ret = shared_from_base<AbstractBase>(); return shared_from_base<AbstractBase>();
ret->set_sparse_grad(sparse_grad());
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
return ret;
} }
auto value_self = GetValueTrack(); auto value_self = GetValueTrack();
MS_EXCEPTION_IF_NULL(value_self); MS_EXCEPTION_IF_NULL(value_self);
ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack()); ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack());
TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack()); TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack());
if (res_value == value_self) { if (res_value == value_self) {
auto ret = shared_from_base<AbstractBase>(); return shared_from_base<AbstractBase>();
ret->set_sparse_grad(sparse_grad());
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
return ret;
} }
auto ret = std::make_shared<AbstractScalar>(res_value, res_type); return std::make_shared<AbstractScalar>(res_value, res_type);
ret->set_sparse_grad(sparse_grad());
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
return ret;
} }
AbstractBasePtr AbstractType::Clone() const { AbstractBasePtr AbstractType::Clone() const {
...@@ -452,16 +441,11 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { ...@@ -452,16 +441,11 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
} }
if (*this == *other) { if (*this == *other) {
if (sparse_grad() == other->sparse_grad()) { return shared_from_base<AbstractBase>();
return shared_from_base<AbstractBase>();
}
} }
auto element = element_->Join(other_tensor->element_); auto element = element_->Join(other_tensor->element_);
auto shape = ShapeJoin(this->shape(), other_tensor->shape()); auto shape = ShapeJoin(this->shape(), other_tensor->shape());
auto ret = std::make_shared<AbstractTensor>(element, shape); return std::make_shared<AbstractTensor>(element, shape);
ret->set_sparse_grad(sparse_grad());
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
return ret;
} }
bool AbstractTensor::operator==(const AbstractTensor &other) const { bool AbstractTensor::operator==(const AbstractTensor &other) const {
...@@ -501,8 +485,6 @@ AbstractBasePtr AbstractTensor::Clone() const { ...@@ -501,8 +485,6 @@ AbstractBasePtr AbstractTensor::Clone() const {
ShapePtr shp = shape(); ShapePtr shp = shape();
clone->set_shape(shp->Clone()); clone->set_shape(shp->Clone());
clone->set_value(GetValueTrack()); clone->set_value(GetValueTrack());
clone->set_sparse_grad(sparse_grad());
clone->set_has_indexed_slices_grad(has_indexed_slices_grad());
return clone; return clone;
} }
...@@ -512,8 +494,6 @@ AbstractBasePtr AbstractTensor::Broaden() const { ...@@ -512,8 +494,6 @@ AbstractBasePtr AbstractTensor::Broaden() const {
auto shp = shape(); auto shp = shape();
broaden->set_shape(shp->Clone()); broaden->set_shape(shp->Clone());
broaden->set_value(kAnyValue); broaden->set_value(kAnyValue);
broaden->set_sparse_grad(sparse_grad());
broaden->set_has_indexed_slices_grad(has_indexed_slices_grad());
return broaden; return broaden;
} }
...@@ -524,8 +504,6 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const { ...@@ -524,8 +504,6 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const {
shp->Broaden(); shp->Broaden();
broaden->set_shape(shp); broaden->set_shape(shp);
broaden->set_value(kAnyValue); broaden->set_value(kAnyValue);
broaden->set_sparse_grad(sparse_grad());
broaden->set_has_indexed_slices_grad(has_indexed_slices_grad());
return broaden; return broaden;
} }
...@@ -538,8 +516,7 @@ std::string AbstractTensor::ToString() const { ...@@ -538,8 +516,7 @@ std::string AbstractTensor::ToString() const {
MS_EXCEPTION_IF_NULL(value_track); MS_EXCEPTION_IF_NULL(value_track);
buffer << type_name() << "(" buffer << type_name() << "("
<< "shape: " << shape_track->ToString() << ", element: " << element_->ToString() << "shape: " << shape_track->ToString() << ", element: " << element_->ToString()
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << " sparse_grad " << sparse_grad() << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")";
<< " has_indexed_slices_grad " << has_indexed_slices_grad() << ")";
return buffer.str(); return buffer.str();
} }
......
...@@ -44,7 +44,7 @@ class AbstractBase : public Base { ...@@ -44,7 +44,7 @@ class AbstractBase : public Base {
public: public:
explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType, explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType,
const BaseShapePtr &shape = kNoShape) const BaseShapePtr &shape = kNoShape)
: value_(value), type_(type), shape_(shape), sparse_grad_(""), has_indexed_slices_grad_(false) {} : value_(value), type_(type), shape_(shape) {}
~AbstractBase() override = default; ~AbstractBase() override = default;
MS_DECLARE_PARENT(AbstractBase, Base) MS_DECLARE_PARENT(AbstractBase, Base)
...@@ -53,17 +53,11 @@ class AbstractBase : public Base { ...@@ -53,17 +53,11 @@ class AbstractBase : public Base {
virtual bool operator==(const AbstractBase &other) const; virtual bool operator==(const AbstractBase &other) const;
void set_value(const ValuePtr &value) { value_ = value; } void set_value(const ValuePtr &value) { value_ = value; }
void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; }
void set_has_indexed_slices_grad(const bool &has_indexed_slices_grad) {
has_indexed_slices_grad_ = has_indexed_slices_grad;
}
void set_type(const TypePtr &type) { type_ = type; } void set_type(const TypePtr &type) { type_ = type; }
void set_shape(const BaseShapePtr &shape) { shape_ = shape; } void set_shape(const BaseShapePtr &shape) { shape_ = shape; }
void set_value_desc(const std::string &desc) { value_desc_ = desc; } void set_value_desc(const std::string &desc) { value_desc_ = desc; }
const std::string &value_desc() const { return value_desc_; } const std::string &value_desc() const { return value_desc_; }
ValuePtr GetValueTrack() const { return value_; } ValuePtr GetValueTrack() const { return value_; }
const std::string &sparse_grad() const { return sparse_grad_; }
const bool &has_indexed_slices_grad() const { return has_indexed_slices_grad_; }
TypePtr GetTypeTrack() const { return type_; } TypePtr GetTypeTrack() const { return type_; }
BaseShapePtr GetShapeTrack() const { return shape_; } BaseShapePtr GetShapeTrack() const { return shape_; }
...@@ -91,8 +85,6 @@ class AbstractBase : public Base { ...@@ -91,8 +85,6 @@ class AbstractBase : public Base {
TypePtr type_; TypePtr type_;
BaseShapePtr shape_; BaseShapePtr shape_;
std::string value_desc_; // store initial value description for error report std::string value_desc_; // store initial value description for error report
std::string sparse_grad_;
bool has_indexed_slices_grad_;
}; };
class AbstractScalar : public AbstractBase { class AbstractScalar : public AbstractBase {
......
...@@ -126,7 +126,11 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr ...@@ -126,7 +126,11 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
} }
MS_EXCEPTION_IF_NULL(ret_base); MS_EXCEPTION_IF_NULL(ret_base);
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString(); MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString()
<< ", is stub: " << fg->stub();
if (fg->stub()) {
return std::make_shared<EvalResult>(std::make_shared<AbstractUndetermined>(), nullptr);
}
return std::make_shared<EvalResult>(ret_base, nullptr); return std::make_shared<EvalResult>(ret_base, nullptr);
} }
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <vector> #include <vector>
#include "pipeline/static_analysis/static_analysis.h" #include "pipeline/static_analysis/static_analysis.h"
#include "utils/context/ms_context.h"
namespace mindspore { namespace mindspore {
namespace abstract { namespace abstract {
...@@ -59,6 +60,13 @@ class Evaluator : public Base { ...@@ -59,6 +60,13 @@ class Evaluator : public Base {
} }
virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) { virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->enable_sparse();
if (!enable_sparse) {
return nullptr;
}
auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) { auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) {
if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) { if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) {
return true; return true;
......
...@@ -146,10 +146,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { ...@@ -146,10 +146,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
using mindspore::parse::PyObjectWrapper; using mindspore::parse::PyObjectWrapper;
EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
auto context = MsContext::GetInstance(); if (prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch) {
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse_flag = context->enable_sparse_flag();
if (enable_sparse_flag && prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch) {
auto ret_abstract = AbstractEval(args); auto ret_abstract = AbstractEval(args);
if (ret_abstract != nullptr) { if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
...@@ -167,6 +164,14 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c ...@@ -167,6 +164,14 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) { AnfNodeConfigPtr out_conf) {
AbstractBasePtrList args_spec_list; AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
auto ret_abstract = AbstractEval(args_spec_list);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
return ret_abstract;
}
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) { if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
} }
...@@ -181,9 +186,6 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt ...@@ -181,9 +186,6 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
} }
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
ScopePtr scope = kDefaultScope; ScopePtr scope = kDefaultScope;
if (out_conf != nullptr) { if (out_conf != nullptr) {
scope = out_conf->node()->scope(); scope = out_conf->node()->scope();
...@@ -509,15 +511,10 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic ...@@ -509,15 +511,10 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
} // end anonymous namespace } // end anonymous namespace
EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
auto context = MsContext::GetInstance(); auto ret_abstract = AbstractEval(args);
MS_EXCEPTION_IF_NULL(context); if (ret_abstract != nullptr) {
bool enable_sparse_flag = context->enable_sparse_flag(); MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
if (enable_sparse_flag) { return ret_abstract;
auto ret_abstract = AbstractEval(args);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
return ret_abstract;
}
} }
MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
...@@ -546,15 +543,10 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs ...@@ -546,15 +543,10 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs
} }
EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
auto context = MsContext::GetInstance(); auto ret_abstract = AbstractEval(args);
MS_EXCEPTION_IF_NULL(context); if (ret_abstract != nullptr) {
bool enable_sparse_flag = context->enable_sparse_flag(); MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined";
if (enable_sparse_flag) { return ret_abstract;
auto ret_abstract = AbstractEval(args);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined";
return ret_abstract;
}
} }
// if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type. // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
if (nargs_ != args.size()) { if (nargs_ != args.size()) {
...@@ -914,8 +906,6 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { ...@@ -914,8 +906,6 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
auto ret = std::make_shared<AbstractScalar>(type); auto ret = std::make_shared<AbstractScalar>(type);
auto ref_value = ref_abs->ref(); auto ref_value = ref_abs->ref();
MS_EXCEPTION_IF_NULL(ref_value); MS_EXCEPTION_IF_NULL(ref_value);
ret->set_sparse_grad(ref_value->sparse_grad());
ret->set_has_indexed_slices_grad(ref_value->has_indexed_slices_grad());
return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>()); return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
} }
...@@ -930,8 +920,6 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { ...@@ -930,8 +920,6 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
x = SensitivityTransform(x); x = SensitivityTransform(x);
std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x); std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type); std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
abs_scalar->set_sparse_grad(x->sparse_grad());
abs_scalar->set_has_indexed_slices_grad(x->has_indexed_slices_grad());
return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>()); return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
} }
}; };
...@@ -943,15 +931,10 @@ class GetAttrEvaluator : public TransitionPrimEvaluator { ...@@ -943,15 +931,10 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator); MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
auto context = MsContext::GetInstance(); auto ret_abstract = AbstractEval(args_spec_list);
MS_EXCEPTION_IF_NULL(context); if (ret_abstract != nullptr) {
bool enable_sparse_flag = context->enable_sparse_flag(); MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined";
if (enable_sparse_flag) { return ret_abstract;
auto ret_abstract = AbstractEval(args_spec_list);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined";
return ret_abstract;
}
} }
// Inputs: data, item // Inputs: data, item
if (args_spec_list.size() != 2) { if (args_spec_list.size() != 2) {
......
...@@ -349,7 +349,6 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv ...@@ -349,7 +349,6 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv
AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
void InitUndeterminedFromEnv(const std::string &sparse_shape_types);
AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
......
...@@ -321,7 +321,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, co ...@@ -321,7 +321,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, co
AbstractFunctionPtr func = real_a->GetUnique(); AbstractFunctionPtr func = real_a->GetUnique();
SpecializeStatusCode errcode; SpecializeStatusCode errcode;
ScopeGuard scope_guard(node->scope()); ScopeGuard scope_guard(node->scope());
AnfNodePtr repl = BuildSpecializedNodeInner(abs, func, argvals, &errcode); AnfNodePtr repl = BuildSpecializedNodeInner(node, abs, func, argvals, &errcode);
if (repl == nullptr) { if (repl == nullptr) {
if (errcode == kSpecializeFindUniqueArgvalDead) { if (errcode == kSpecializeFindUniqueArgvalDead) {
const auto error_dead_node = std::make_shared<AbstractError>(kDeadNode, node); const auto error_dead_node = std::make_shared<AbstractError>(kDeadNode, node);
...@@ -340,7 +340,8 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, co ...@@ -340,7 +340,8 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, co
return repl; return repl;
} }
AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr &abs, const AbstractFunctionPtr &func, AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs,
const AbstractFunctionPtr &func,
const AbstractBasePtrList &args, const AbstractBasePtrList &args,
SpecializeStatusCode *errcode) { SpecializeStatusCode *errcode) {
MS_EXCEPTION_IF_NULL(abs); MS_EXCEPTION_IF_NULL(abs);
...@@ -384,7 +385,14 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr ...@@ -384,7 +385,14 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr
AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals); AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals);
MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size() MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size()
<< ", graph: " << context->func_graph()->get_return()->DebugString(); << ", graph: " << context->func_graph()->get_return()->DebugString();
if (context->func_graph()->stub()) {
MS_LOG(DEBUG) << "Specialize stub function graph, return the original node: " << context->func_graph()->ToString()
<< ", args: " << argvals.size() << ", graph: " << context->func_graph()->get_return()->DebugString()
<< ", " << node->ToString();
return node;
}
FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context); FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context);
v->set_flag(kFuncGraphFlagUndetermined, false);
return BuildValueNode(v, abs); return BuildValueNode(v, abs);
} }
...@@ -613,7 +621,8 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct ...@@ -613,7 +621,8 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
*result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract()); *result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract());
return kSpecializeSuccess; return kSpecializeSuccess;
} else if (choices->empty()) { } else if (choices->empty()) {
MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase."; MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase " << func->ToString() << " | "
<< func->type_name();
return kSpecializeFindUniqueArgvalDead; return kSpecializeFindUniqueArgvalDead;
} else { } else {
if (IsPolyFunc(func, argvals)) { if (IsPolyFunc(func, argvals)) {
......
...@@ -118,8 +118,9 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia ...@@ -118,8 +118,9 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
// Build a specialized node from given argvals; // Build a specialized node from given argvals;
AnfNodePtr BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs, AnfNodePtr BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs,
const AbstractBasePtrList &argvals); const AbstractBasePtrList &argvals);
AnfNodePtr BuildSpecializedNodeInner(const AbstractBasePtr &abs, const AbstractFunctionPtr &func, AnfNodePtr BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs,
const AbstractBasePtrList &args, SpecializeStatusCode *errcode); const AbstractFunctionPtr &func, const AbstractBasePtrList &args,
SpecializeStatusCode *errcode);
// Find the unique argument values which can be used to specialize a primitive or graph function. // Find the unique argument values which can be used to specialize a primitive or graph function.
SpecializeStatusCode FindUniqueArgvals(const AbstractFunctionPtr &fn, const EvaluatorPtr &eval, SpecializeStatusCode FindUniqueArgvals(const AbstractFunctionPtr &fn, const EvaluatorPtr &eval,
......
...@@ -89,7 +89,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { ...@@ -89,7 +89,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
max_device_memory_ = kDefaultMaxDeviceMemory; max_device_memory_ = kDefaultMaxDeviceMemory;
print_file_path_ = ""; print_file_path_ = "";
enable_graph_kernel_ = false; enable_graph_kernel_ = false;
enable_sparse_flag_ = false; enable_sparse_ = false;
} }
std::shared_ptr<MsContext> MsContext::GetInstance() { std::shared_ptr<MsContext> MsContext::GetInstance() {
......
...@@ -161,8 +161,8 @@ class MsContext { ...@@ -161,8 +161,8 @@ class MsContext {
void set_enable_graph_kernel(bool enable_graph_kernel) { enable_graph_kernel_ = enable_graph_kernel; } void set_enable_graph_kernel(bool enable_graph_kernel) { enable_graph_kernel_ = enable_graph_kernel; }
bool enable_graph_kernel() const { return enable_graph_kernel_; } bool enable_graph_kernel() const { return enable_graph_kernel_; }
bool enable_sparse_flag() const { return enable_sparse_flag_; } bool enable_sparse() const { return enable_sparse_; }
void set_enable_sparse_flag(bool enable_sparse_flag) { enable_sparse_flag_ = enable_sparse_flag; } void set_enable_sparse(bool enable_sparse) { enable_sparse_ = enable_sparse; }
private: private:
MsContext(const std::string &backend_policy, const std::string &target); MsContext(const std::string &backend_policy, const std::string &target);
...@@ -207,7 +207,7 @@ class MsContext { ...@@ -207,7 +207,7 @@ class MsContext {
float max_device_memory_; float max_device_memory_;
std::string print_file_path_; std::string print_file_path_;
bool enable_graph_kernel_; bool enable_graph_kernel_;
bool enable_sparse_flag_; bool enable_sparse_;
}; };
} // namespace mindspore } // namespace mindspore
......
...@@ -51,18 +51,13 @@ class Parameter: ...@@ -51,18 +51,13 @@ class Parameter:
requires_grad (bool): True if the parameter requires gradient. Default: True. requires_grad (bool): True if the parameter requires gradient. Default: True.
layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode, layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode,
broadcast and gradients communication would not be applied on parameters. Default: False. broadcast and gradients communication would not be applied on parameters. Default: False.
sparse_grad (str): Set if the parameter's gradient is sparse. Default: empty.
has_indexed_slices (bool): Set if the parameter's gradient is indexed_slices. Default: false.
""" """
def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False, def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False):
sparse_grad="", has_indexed_slices_grad=False):
self._value = ParamValue() self._value = ParamValue()
self.set_parameter_data(default_input) self.set_parameter_data(default_input)
self.name = name self.name = name
self.requires_grad = requires_grad self.requires_grad = requires_grad
self.layerwise_parallel = layerwise_parallel self.layerwise_parallel = layerwise_parallel
self.sparse_grad = sparse_grad
self.has_indexed_slices_grad = has_indexed_slices_grad
self._is_init = False self._is_init = False
self._sliced = False self._sliced = False
if context.get_context("mode") == context.PYNATIVE_MODE: if context.get_context("mode") == context.PYNATIVE_MODE:
...@@ -177,28 +172,6 @@ class Parameter: ...@@ -177,28 +172,6 @@ class Parameter:
raise TypeError("`requires_grad` parameter must be bool type") raise TypeError("`requires_grad` parameter must be bool type")
self._value.requires_grad = value self._value.requires_grad = value
@property
def sparse_grad(self):
"""Return whether the parameter's gradient is sparse."""
return self._value.sparse_grad
@sparse_grad.setter
def sparse_grad(self, value=""):
if not isinstance(value, str):
raise TypeError("`sparse_grad` parameter must be str type")
self._value.sparse_grad = value
@property
def has_indexed_slices_grad(self):
"""Return whether the parameter's gradient is indexed_slices."""
return self._value.has_indexed_slices_grad
@has_indexed_slices_grad.setter
def has_indexed_slices_grad(self, value=False):
if not isinstance(value, bool):
raise TypeError("`has_indexed_slices_grad` parameter must be bool type")
self._value.has_indexed_slices_grad = value
@property @property
def data(self): def data(self):
return self.default_input return self.default_input
......
...@@ -367,14 +367,6 @@ class _Context: ...@@ -367,14 +367,6 @@ class _Context:
def check_bprop(self, check_bprop_flag): def check_bprop(self, check_bprop_flag):
self._context_handle.set_check_bprop_flag(check_bprop_flag) self._context_handle.set_check_bprop_flag(check_bprop_flag)
@property
def enable_sparse(self):
return self._context_handle.get_enable_sparse_flag()
@enable_sparse.setter
def enable_sparse(self, enable_sparse_flag):
self._context_handle.set_enable_sparse_flag(enable_sparse_flag)
@property @property
def max_device_memory(self): def max_device_memory(self):
return self._context_handle.get_max_device_memory() return self._context_handle.get_max_device_memory()
...@@ -408,6 +400,13 @@ class _Context: ...@@ -408,6 +400,13 @@ class _Context:
full_file_name = print_file_path full_file_name = print_file_path
self._context_handle.set_print_file_path(full_file_name) self._context_handle.set_print_file_path(full_file_name)
@property
def enable_sparse(self):
return self._context_handle.get_enable_sparse()
@enable_sparse.setter
def enable_sparse(self, enable_sparse):
self._context_handle.set_enable_sparse(enable_sparse)
def check_input_format(x): def check_input_format(x):
import re import re
...@@ -601,7 +600,7 @@ def set_context(**kwargs): ...@@ -601,7 +600,7 @@ def set_context(**kwargs):
print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to
a file by default, and turn off printing to the screen. If the file already exists, add a timestamp a file by default, and turn off printing to the screen. If the file already exists, add a timestamp
suffix to the file. suffix to the file.
enable_sparse (bool): Whether to enable sparse feature. Default: False. enable_sparse (bool): Whether to enable sparsity feature. Default: False.
Raises: Raises:
ValueError: If input key is not an attribute in context. ValueError: If input key is not an attribute in context.
......
...@@ -162,8 +162,8 @@ class Adam(Optimizer): ...@@ -162,8 +162,8 @@ class Adam(Optimizer):
To improve parameter groups performance, the customized order of parameters can be supported. To improve parameter groups performance, the customized order of parameters can be supported.
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
`sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU. behavior is currently performed on the CPU.
Args: Args:
......
...@@ -72,8 +72,8 @@ class FTRL(Optimizer): ...@@ -72,8 +72,8 @@ class FTRL(Optimizer):
<https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf>`_ for engineering document. <https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf>`_ for engineering document.
Note: Note:
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
`sparse_grad` of `Parameter` being set. The sparse feature is under continuous development. The sparse The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU. behavior is currently performed on the CPU.
Args: Args:
......
...@@ -91,8 +91,8 @@ class LazyAdam(Optimizer): ...@@ -91,8 +91,8 @@ class LazyAdam(Optimizer):
value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be
applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters.
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
`sparse_grad` of `Parameter` being set. The sparse behavior, to be notice, is not equivalent to the The sparse behavior, to be notice, is not equivalent to the
original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under
continuous development. The sparse behavior is currently performed on the CPU. continuous development. The sparse behavior is currently performed on the CPU.
......
...@@ -59,8 +59,8 @@ class ProximalAdagrad(Optimizer): ...@@ -59,8 +59,8 @@ class ProximalAdagrad(Optimizer):
<http://papers.nips.cc//paper/3793-efficient-learning-using-forward-backward-splitting.pdf>`_. <http://papers.nips.cc//paper/3793-efficient-learning-using-forward-backward-splitting.pdf>`_.
Note: Note:
The sparse strategy is applied while the SparseGatherV2 operator being used for forward network and the The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
`sparse_grad` of `Parameter` being set as True. The sparse feature is under continuous development. The sparse The sparse feature is under continuous development. The sparse
behavior is currently performed on the CPU. behavior is currently performed on the CPU.
Args: Args:
......
...@@ -158,7 +158,6 @@ make_indexed_slices = Primitive('MakeIndexedSlices') ...@@ -158,7 +158,6 @@ make_indexed_slices = Primitive('MakeIndexedSlices')
indexed_slices_get_values = Primitive('IndexedSlicesGetValues') indexed_slices_get_values = Primitive('IndexedSlicesGetValues')
indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices') indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices')
indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape') indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape')
is_indexed_slices = Primitive('IsIndexedSlices')
tensor_operator_registry.register('__add__', tensor_add) tensor_operator_registry.register('__add__', tensor_add)
......
...@@ -36,6 +36,8 @@ from mindspore._checkparam import Rel ...@@ -36,6 +36,8 @@ from mindspore._checkparam import Rel
from mindspore.nn import Optimizer from mindspore.nn import Optimizer
from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import TrainOneStepCell, WithLossCell
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
reduce_sum = P.ReduceSum() reduce_sum = P.ReduceSum()
unsorted_segment_sum = P.UnsortedSegmentSum() unsorted_segment_sum = P.UnsortedSegmentSum()
transpose = P.Transpose() transpose = P.Transpose()
...@@ -44,7 +46,6 @@ reshape = P.Reshape() ...@@ -44,7 +46,6 @@ reshape = P.Reshape()
size_op = P.Size() size_op = P.Size()
invert_permutation = P.InvertPermutation() invert_permutation = P.InvertPermutation()
logical_and = P.LogicalAnd() logical_and = P.LogicalAnd()
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
@constexpr @constexpr
def _generate_shape_index(out_shape, indices_shape, axis): def _generate_shape_index(out_shape, indices_shape, axis):
...@@ -103,10 +104,15 @@ def get_bprop_sparse_gather_v2(self): ...@@ -103,10 +104,15 @@ def get_bprop_sparse_gather_v2(self):
adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map")
@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Undetermined", "Bool") "Tensor", "Tensor", "Tensor", "IndexedSlices", "Bool")
def _update_run_op_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): def _update_run_op_for_map_indexed_slices(beta1, beta2, eps, lr, weight_decay_tensor, param,
if gradient.is_indexed_slices(): m, v, gradient, decay_flag):
return gradient.values() return gradient.values()
@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "Bool")
def _update_run_op_for_map_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param,
m, v, gradient, decay_flag):
op_mul = P.Mul() op_mul = P.Mul()
op_square = P.Square() op_square = P.Square()
op_sqrt = P.Sqrt() op_sqrt = P.Sqrt()
...@@ -182,7 +188,7 @@ def test_indexed_slices_make_indexed_slices(): ...@@ -182,7 +188,7 @@ def test_indexed_slices_make_indexed_slices():
self.dense_shape = (3, 4) self.dense_shape = (3, 4)
def construct(self, indices, values): def construct(self, indices, values):
ret = (IndexedSlices(indices, values, self.dense_shape),) ret = (IndexedSlices(indices, values, self.dense_shape),)
return ret[0].is_indexed_slices() return ret[0]
indices = Tensor([[0, 0], [1, 2]]) indices = Tensor([[0, 0], [1, 2]])
values = Tensor([1, 2], dtype=ms.float32) values = Tensor([1, 2], dtype=ms.float32)
MakeIndexedSlices()(indices, values) MakeIndexedSlices()(indices, values)
...@@ -209,7 +215,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all(): ...@@ -209,7 +215,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all():
self.network = network self.network = network
def construct(self, x, y): def construct(self, x, y):
grad = grad_all(self.network)(x, y) grad = grad_all(self.network)(x, y)
return grad, grad[0].is_indexed_slices(), grad[1].is_indexed_slices() return grad, grad[0], grad[1]
class SparseGatherV2(nn.Cell): class SparseGatherV2(nn.Cell):
def __init__(self): def __init__(self):
super(SparseGatherV2, self).__init__() super(SparseGatherV2, self).__init__()
...@@ -233,14 +239,13 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram(): ...@@ -233,14 +239,13 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram():
weights = self.weights weights = self.weights
grad = grad_by_list(self.network, weights)(x) grad = grad_by_list(self.network, weights)(x)
x = grad[0] x = grad[0]
return x.is_indexed_slices(), x.values(), x.indices(), x.dense_shape() return x, x.values(), x.indices(), x.dense_shape()
class SparseGatherV2(nn.Cell): class SparseGatherV2(nn.Cell):
def __init__(self): def __init__(self):
super(SparseGatherV2, self).__init__() super(SparseGatherV2, self).__init__()
self.sparse_gatherv2 = MySparseGatherV2() self.sparse_gatherv2 = MySparseGatherV2()
self.axis = 0 self.axis = 0
self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)), self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)), name="params")
name="params", has_indexed_slices_grad=True)
def construct(self, indices): def construct(self, indices):
return self.sparse_gatherv2(self.params, indices, self.axis) return self.sparse_gatherv2(self.params, indices, self.axis)
indices = Tensor(np.array([0, 1]).astype(np.int32)) indices = Tensor(np.array([0, 1]).astype(np.int32))
...@@ -248,20 +253,6 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram(): ...@@ -248,20 +253,6 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram():
network(indices) network(indices)
def test_indexed_slices_is_indexed_slices():
class MakeIndexedSlices(nn.Cell):
def __init__(self):
super(MakeIndexedSlices, self).__init__()
self.dense_shape = (3, 4)
def construct(self, indices, values):
indexed_slices = IndexedSlices(indices, values, self.dense_shape)
ret = indexed_slices.is_indexed_slices()
return ret
indices = Tensor([[0, 0], [1, 2]])
values = Tensor([1, 2], dtype=ms.float32)
MakeIndexedSlices()(indices, values)
def test_indexed_slices_env_get(): def test_indexed_slices_env_get():
class Loss(nn.Cell): class Loss(nn.Cell):
def __init__(self): def __init__(self):
...@@ -271,7 +262,7 @@ def test_indexed_slices_env_get(): ...@@ -271,7 +262,7 @@ def test_indexed_slices_env_get():
class NetWithSparseGatherV2(nn.Cell): class NetWithSparseGatherV2(nn.Cell):
def __init__(self): def __init__(self):
super(NetWithSparseGatherV2, self).__init__() super(NetWithSparseGatherV2, self).__init__()
self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", has_indexed_slices_grad=True) self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1")
self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2") self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2")
self.gatherv2 = MySparseGatherV2() self.gatherv2 = MySparseGatherV2()
self.axis = 0 self.axis = 0
......
...@@ -17,12 +17,13 @@ import numpy as np ...@@ -17,12 +17,13 @@ import numpy as np
import pytest import pytest
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor, Parameter from mindspore import Tensor, Parameter, context
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR
from mindspore.ops import operations as P from mindspore.ops import operations as P
context.set_context(enable_sparse=True)
class Net(nn.Cell): class Net(nn.Cell):
""" Net definition """ """ Net definition """
...@@ -53,8 +54,7 @@ class NetWithSparseGatherV2(nn.Cell): ...@@ -53,8 +54,7 @@ class NetWithSparseGatherV2(nn.Cell):
""" NetWithSparseGatherV2 definition """ """ NetWithSparseGatherV2 definition """
def __init__(self): def __init__(self):
super(NetWithSparseGatherV2, self).__init__() super(NetWithSparseGatherV2, self).__init__()
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1")
name="weight1", sparse_grad="sparse_key_w1")
self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2") self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2")
self.axis = 0 self.axis = 0
self.gather = P.SparseGatherV2() self.gather = P.SparseGatherV2()
......
...@@ -27,6 +27,7 @@ from mindspore.ops import functional as F ...@@ -27,6 +27,7 @@ from mindspore.ops import functional as F
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
context.set_context(enable_sparse=True)
adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map")
@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
...@@ -154,7 +155,7 @@ def test_AdamWeightDecaySparse(): ...@@ -154,7 +155,7 @@ def test_AdamWeightDecaySparse():
class NetWithSparseGatherV2(nn.Cell): class NetWithSparseGatherV2(nn.Cell):
def __init__(self): def __init__(self):
super(NetWithSparseGatherV2, self).__init__() super(NetWithSparseGatherV2, self).__init__()
self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", sparse_grad="sparse_key_w1") self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1")
self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2") self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2")
self.gatherv2 = P.SparseGatherV2() self.gatherv2 = P.SparseGatherV2()
self.axis = 0 self.axis = 0
......
...@@ -17,12 +17,13 @@ ...@@ -17,12 +17,13 @@
import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor, Parameter from mindspore import Tensor, Parameter, context
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import FTRL from mindspore.nn.optim import FTRL
from mindspore.ops import operations as P from mindspore.ops import operations as P
context.set_context(enable_sparse=True)
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
...@@ -41,8 +42,7 @@ class NetWithSparseGatherV2(nn.Cell): ...@@ -41,8 +42,7 @@ class NetWithSparseGatherV2(nn.Cell):
""" NetWithSparseGatherV2 definition """ """ NetWithSparseGatherV2 definition """
def __init__(self): def __init__(self):
super(NetWithSparseGatherV2, self).__init__() super(NetWithSparseGatherV2, self).__init__()
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1")
name="weight1", sparse_grad="sparse_key_w1")
self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2") self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2")
self.axis = 0 self.axis = 0
self.gather = P.SparseGatherV2() self.gather = P.SparseGatherV2()
......
...@@ -17,12 +17,13 @@ import numpy as np ...@@ -17,12 +17,13 @@ import numpy as np
import pytest import pytest
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor, Parameter from mindspore import Tensor, Parameter, context
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import LazyAdam from mindspore.nn.optim import LazyAdam
from mindspore.ops import operations as P from mindspore.ops import operations as P
context.set_context(enable_sparse=True)
class Net(nn.Cell): class Net(nn.Cell):
""" Net definition """ """ Net definition """
...@@ -43,8 +44,7 @@ class NetWithSparseGatherV2(nn.Cell): ...@@ -43,8 +44,7 @@ class NetWithSparseGatherV2(nn.Cell):
""" NetWithSparseGatherV2 definition """ """ NetWithSparseGatherV2 definition """
def __init__(self): def __init__(self):
super(NetWithSparseGatherV2, self).__init__() super(NetWithSparseGatherV2, self).__init__()
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1")
name="weight1", sparse_grad="sparse_key_w1")
self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2") self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype((np.float32))), name="weight2")
self.axis = 0 self.axis = 0
self.gather = P.SparseGatherV2() self.gather = P.SparseGatherV2()
......
...@@ -17,12 +17,13 @@ ...@@ -17,12 +17,13 @@
import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor, Parameter from mindspore import Tensor, Parameter, context
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import ProximalAdagrad from mindspore.nn.optim import ProximalAdagrad
from mindspore.ops import operations as P from mindspore.ops import operations as P
context.set_context(enable_sparse=True)
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
...@@ -40,8 +41,7 @@ class NetWithSparseGatherV2(nn.Cell): ...@@ -40,8 +41,7 @@ class NetWithSparseGatherV2(nn.Cell):
""" NetWithSparseGatherV2 definition """ """ NetWithSparseGatherV2 definition """
def __init__(self): def __init__(self):
super(NetWithSparseGatherV2, self).__init__() super(NetWithSparseGatherV2, self).__init__()
self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1", self.weight1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="weight1")
sparse_grad="sparse_key_w1")
self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="weight2") self.weight2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="weight2")
self.axis = 0 self.axis = 0
self.gather = P.SparseGatherV2() self.gather = P.SparseGatherV2()
......
...@@ -53,4 +53,4 @@ def test_hypermap_specialize_param(): ...@@ -53,4 +53,4 @@ def test_hypermap_specialize_param():
expected_ret = (Tensor(np.full(1, 5).astype(np.int32)), Tensor(np.full(2, 5).astype(np.int32))) expected_ret = (Tensor(np.full(1, 5).astype(np.int32)), Tensor(np.full(2, 5).astype(np.int32)))
ret = hypermap_specialize_param() ret = hypermap_specialize_param()
assert ret == (expected_ret, expected_ret) assert ret == (expected_ret, list(expected_ret))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册