diff --git a/mindspore/ccsrc/debug/anf_ir_utils.cc b/mindspore/ccsrc/debug/anf_ir_utils.cc index 274cd43914f5f9c437b0bd9f86f0f2523ebf8a1e..36fdd5a300f851eec52678db3f90c35dbc7a161b 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.cc +++ b/mindspore/ccsrc/debug/anf_ir_utils.cc @@ -30,6 +30,7 @@ #include "pipeline/parse/python_adapter.h" #include "pipeline/parse/resolve.h" #include "operator/composite/composite.h" +#include "operator/composite/map.h" #include "utils/ordered_map.h" #include "utils/ordered_set.h" #include "utils/utils.h" @@ -190,6 +191,8 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap * ├── MultitypeGraph * ├── HyperMap * │ └── HyperMapPy + * ├── Map + * │ └── MapPy * ├── Tail * ├── MakeTupleGradient * ├── GradOperation @@ -208,17 +211,25 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_ oss << GetMultitypeFuncGraphText(mt_func_graph); } else if (meta_func_graph ->isa()) { // this statement must before 'meta_graph->isa()' - prim::HyperMapPyPtr hyper_map = meta_func_graph->cast(); - MS_EXCEPTION_IF_NULL(hyper_map); + auto hyper_map = meta_func_graph->cast(); if (hyper_map->GetFnLeaf() != nullptr) { oss << "{fn_leaf=" << GetMetaFuncGraphText(hyper_map->GetFnLeaf()) << "}"; } } else if (meta_func_graph->isa()) { - prim::HyperMapPtr hyper_map = meta_func_graph->cast(); - MS_EXCEPTION_IF_NULL(hyper_map); + auto hyper_map = meta_func_graph->cast(); if (hyper_map->GetFnLeaf() != nullptr) { oss << "{fn_leaf=" << GetMetaFuncGraphText(hyper_map->GetFnLeaf()) << "}"; } + } else if (meta_func_graph->isa()) { // this statement must before 'meta_graph->isa()' + auto map = meta_func_graph->cast(); + if (map->GetFnLeaf() != nullptr) { + oss << "{fn_leaf=" << GetMetaFuncGraphText(map->GetFnLeaf()) << "}"; + } + } else if (meta_func_graph->isa()) { + auto map = meta_func_graph->cast(); + if (map->GetFnLeaf() != nullptr) { + oss << "{fn_leaf=" << GetMetaFuncGraphText(map->GetFnLeaf()) << "}"; + } } else if (meta_func_graph->isa()) { prim::GradOperationPtr grad_op = meta_func_graph->cast(); oss << "{get_all=" << grad_op->get_all_ << ", get_by_list=" << grad_op->get_by_list_ diff --git a/mindspore/ccsrc/operator/composite/map.cc b/mindspore/ccsrc/operator/composite/map.cc new file mode 100644 index 0000000000000000000000000000000000000000..6752cfe0789a3fc02d39f2e54e14ae254f3e6977 --- /dev/null +++ b/mindspore/ccsrc/operator/composite/map.cc @@ -0,0 +1,289 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "operator/composite/map.h" +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "pipeline/static_analysis/abstract_value.h" +#include "pipeline/static_analysis/abstract_function.h" +#include "pipeline/static_analysis/dshape.h" +#include "pybind_api/api_register.h" +#include "debug/trace.h" +#include "operator/ops.h" +#include "./common.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure; + +AnfNodePtr Map::FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args) { + MS_LOG(DEBUG) << "Map FullMakeLeaf non recursive.\n"; + MS_EXCEPTION_IF_NULL(func_graph); + std::vector inputs; + if (fn_arg != nullptr) { + inputs.emplace_back(fn_arg); + } else { + inputs.emplace_back(NewValueNode(fn_leaf_)); + } + inputs.insert(inputs.end(), args.begin(), args.end()); + return func_graph->NewCNode(inputs); +} + +FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) { + // Generate func for leaf nodes + FuncGraphPtr ptrGraph = std::make_shared(); + ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); + ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); + ptrGraph->debug_info()->set_name("map"); + AnfNodePtr ptrFnArg = nullptr; + if (fn_leaf_ == nullptr) { + ptrFnArg = ptrGraph->add_parameter(); + } + AnfNodePtrList args; + for (size_t i = 0; i < args_size; ++i) { + args.emplace_back(ptrGraph->add_parameter()); + } + ptrGraph->set_output(FullMakeLeaf(ptrGraph, ptrFnArg, args)); + return ptrGraph; +} + +AnfNodePtr Map::FullMakeList(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(type); + + std::size_t size = type->elements().size(); + bool is_not_same = + std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair &item) { + auto lhs = std::dynamic_pointer_cast(item.second); + MS_EXCEPTION_IF_NULL(lhs); + return lhs->elements().size() != size; + }); + if (is_not_same) { + MS_LOG(EXCEPTION) << "List in Map should have same length"; + } + + std::vector inputs; + inputs.push_back(NewValueNode(prim::kPrimMakeList)); + + for (int i = 0; i < SizeToInt(size); ++i) { + MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the target"; + auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); + auto fn = NewValueNode(ptrGraph); + + std::vector inputs2; + inputs2.push_back(fn); + if (fn_arg != nullptr) { + inputs2.push_back(fn_arg); + } + + (void)std::transform( + arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2), + [&func_graph, i](const std::pair &item) { + return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); + }); + + inputs.push_back(func_graph->NewCNode(inputs2)); + } + return func_graph->NewCNode(inputs); +} + +AnfNodePtr Map::FullMakeTuple(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(type); + + std::size_t size = type->elements().size(); + bool is_not_same = + std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair &item) { + auto lhs = std::dynamic_pointer_cast(item.second); + MS_EXCEPTION_IF_NULL(lhs); + return lhs->elements().size() != size; + }); + if (is_not_same) { + MS_LOG(EXCEPTION) << "tuple in Map should have same length"; + } + + std::vector inputs; + inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + + for (int i = 0; i < SizeToInt(size); ++i) { + MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the tuple inputs"; + auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); + auto fn = NewValueNode(ptrGraph); + + std::vector inputs2; + inputs2.push_back(fn); + if (fn_arg != nullptr) { + inputs2.push_back(fn_arg); + } + + (void)std::transform( + arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2), + [&func_graph, &i](std::pair item) { + return func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)}); + }); + + inputs.push_back(func_graph->NewCNode(inputs2)); + } + return func_graph->NewCNode(inputs); +} + +AnfNodePtr Map::FullMakeClass(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { + MS_EXCEPTION_IF_NULL(type); + MS_EXCEPTION_IF_NULL(func_graph); + + std::vector inputs; + inputs.push_back(NewValueNode(prim::kPrimMakeRecord)); + inputs.push_back(NewValueNode(type)); + + std::size_t attrSize = type->GetAttributes().size(); + for (std::size_t i = 0; i < attrSize; ++i) { + MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th element of the inputs"; + auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); + auto fn = NewValueNode(ptrGraph); + + std::vector inputs2; + inputs2.push_back(fn); + if (fn_arg != nullptr) { + inputs2.push_back(fn_arg); + } + + int j = 0; + for (auto item : arg_pairs) { + inputs2.push_back(func_graph->NewCNode({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)})); + j++; + } + + inputs.push_back(func_graph->NewCNode(inputs2)); + } + return func_graph->NewCNode(inputs); +} + +AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { + bool found = false; + TypeId id = kObjectTypeEnd; + std::pair pair; + for (auto &item : arg_pairs) { + pair = item; + MS_LOG(DEBUG) << "Map " << pair.second->ToString(); + id = item.second->type_id(); + if (nonleaf_.count(id)) { + found = true; + break; + } + } + + if (found) { + // In a nonleaf situation, all arguments must have the same generic. + bool is_not_same = + std::any_of(arg_pairs.begin(), arg_pairs.end(), [pair](const std::pair &item) { + if (item.first != pair.first) { + return item.second->type_id() != pair.second->type_id(); + } + return false; + }); + if (is_not_same) { + std::ostringstream oss; + oss << "There are " << arg_pairs.size() << " inputs of `" << name_ << "`, corresponding type info:\n" + << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; + int idx = 0; + for (auto &item : arg_pairs) { + oss << ++idx << ": " << item.second->ToString() << "\n"; + } + MS_LOG(EXCEPTION) << "Map cannot match up all input types of arguments.\n" + << oss.str() << pair.second->ToString() << "\n"; + } + } + + switch (id) { + case kObjectTypeList: { + auto type = std::static_pointer_cast(pair.second); + return FullMakeList(type, func_graph, fn_arg, arg_pairs); + } + case kObjectTypeTuple: { + auto type = std::static_pointer_cast(pair.second); + return FullMakeTuple(type, func_graph, fn_arg, arg_pairs); + } + case kObjectTypeClass: { + auto type = std::static_pointer_cast(pair.second); + return FullMakeClass(type, func_graph, fn_arg, arg_pairs); + } + default: + MS_LOG(EXCEPTION) << "Map can only be applied to list, tuple and class " + << ", but got " << pair.second->ToString(); + } +} + +FuncGraphPtr Map::GenerateFromTypes(const TypePtrList &args_spec_list) { + FuncGraphPtr ptrGraph = std::make_shared(); + ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); + ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); + ptrGraph->debug_info()->set_name("map"); + + AnfNodePtr ptrFnArg = nullptr; + std::size_t i = 0; + if (fn_leaf_ == nullptr) { + ptrFnArg = ptrGraph->add_parameter(); + i = 1; + } + ArgsPairList arg_pairs; + std::size_t size = args_spec_list.size(); + for (; i < size; ++i) { + MS_LOG(DEBUG) << "GenerateFromTypes for elements from " << args_spec_list[i]->ToString(); + arg_pairs.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i])); + } + + ptrGraph->set_output(Make(ptrGraph, ptrFnArg, arg_pairs)); + return ptrGraph; +} + +abstract::AbstractBasePtrList Map::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { + if (fn_leaf_ == nullptr) { + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + // Assert that map's function param does not contain free variables + if (args_spec_list[0]->isa()) { + auto graph_func = dyn_cast(args_spec_list[0]); + auto func_graph = graph_func->func_graph(); + if (func_graph->parent() != nullptr) { + MS_LOG(EXCEPTION) << "Map don't support Closure with free variable yet."; + } + } + } + + AbstractBasePtrList broadened; + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened), + [](const AbstractBasePtr &arg) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(arg); + return arg->Broaden(); + }); + return broadened; +} + +REGISTER_PYBIND_DEFINE(Map_, ([](const py::module *m) { + (void)py::class_>(*m, "Map_") + .def(py::init>(), py::arg("leaf")) + .def(py::init<>()); + })); +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/map.h b/mindspore/ccsrc/operator/composite/map.h new file mode 100644 index 0000000000000000000000000000000000000000..02d374214adac83dc79f6c62ff81793852868a70 --- /dev/null +++ b/mindspore/ccsrc/operator/composite/map.h @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ +#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ + +#include +#include +#include +#include + +#include "ir/dtype.h" +#include "ir/meta_func_graph.h" +#include "operator/composite/multitype_funcgraph.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +using ArgsPairList = std::vector>; + +class Map : public MetaFuncGraph { + public: + explicit Map(const std::shared_ptr &fn_leaf = nullptr) + : MetaFuncGraph("map"), + fn_leaf_(fn_leaf), + broadcast_(false), + nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) { + Init(); + } + Map(const Map &h) : MetaFuncGraph("map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) { + Init(); + } + Map &operator=(const Map &h) { + if (this != &h) { + fn_leaf_ = h.fn_leaf_; + broadcast_ = h.broadcast_; + nonleaf_ = h.nonleaf_; + if (fn_leaf_) { + name_ = "map[" + fn_leaf_->name() + "]"; + } + } + return *this; + } + ~Map() override = default; + MS_DECLARE_PARENT(Map, MetaFuncGraph) + abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override; + FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override; + MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } + + private: + FuncGraphPtr GenerateLeafFunc(const size_t &args_size); + AnfNodePtr FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args); + AnfNodePtr FullMakeList(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_pairs); + AnfNodePtr FullMakeTuple(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_pairs); + AnfNodePtr FullMakeClass(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_pairs); + AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs); + void Init() { + if (fn_leaf_ != nullptr) { + name_ = "map[" + fn_leaf_->name() + "]"; + } + signatures_ = + // def map(func:read, *args:ref): + std::vector({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}, + {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); + } + + MultitypeFuncGraphPtr fn_leaf_; + bool broadcast_; + std::set nonleaf_; +}; +using MapPtr = std::shared_ptr; +class MapPy : public Map { + public: + explicit MapPy(const std::shared_ptr &fn_leaf = nullptr) : Map(fn_leaf) {} + ~MapPy() override = default; + MS_DECLARE_PARENT(MapPy, Map) +}; +using MapPyPtr = std::shared_ptr; +} // namespace prim +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ diff --git a/mindspore/ccsrc/operator/prim_others.cc b/mindspore/ccsrc/operator/prim_others.cc index b8e89378e6a387f329a33e31a834e7559159846f..9d05ecef971bc2e711ccfbe6e9912cc4ba2d8f8d 100644 --- a/mindspore/ccsrc/operator/prim_others.cc +++ b/mindspore/ccsrc/operator/prim_others.cc @@ -14,9 +14,14 @@ * limitations under the License. */ +#include +#include + +#include "ir/dtype.h" +#include "common/utils.h" +#include "operator/ops.h" #include "pipeline/static_analysis/param_validator.h" #include "pipeline/static_analysis/prim.h" -#include "operator/ops.h" #include "pipeline/static_analysis/utils.h" #include "utils/symbolic.h" @@ -50,6 +55,65 @@ AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primit return AbstractFunction::MakeAbstractFunction(jv); } +class UndeterminedShapeType { + public: + explicit UndeterminedShapeType(const std::string &env_str) { + // param_name indices_shape indices_type values_shape values_type dense_shape + // export UNDETERMINED_SPARSE_SHAPE_TYPES="w1:2:Int32:2 1 2:Float32:3 1 2" + std::vector fields; + string tmp; + std::stringstream input(env_str); + while (std::getline(input, tmp, ':')) { + fields.push_back(tmp); + } + if (fields.size() != fields_num) { + MS_LOG(EXCEPTION) << "Expect " << fields_num << " fields, but got " << fields.size(); + } + + param_name_ = fields[0]; + + indices_shape_ = GetShape(fields[1]); + indices_type_ = StringToType(fields[2]); + + values_shape_ = GetShape(fields[3]); + values_type_ = StringToType(fields[4]); + + auto dense_shape_vec = GetShape(fields[5]); + AbstractBasePtrList dense_shape_list; + (void)std::transform(dense_shape_vec.begin(), dense_shape_vec.end(), std::back_inserter(dense_shape_list), + [](const auto &elem) { return FromValue(elem, false); }); + dense_shape_ = dense_shape_list; + } + const std::string ¶m_name() { return param_name_; } + const std::vector &indices_shape() { return indices_shape_; } + const TypePtr &indices_type() { return indices_type_; } + const std::vector &values_shape() { return values_shape_; } + const TypePtr &values_type() { return values_type_; } + const AbstractBasePtrList &dense_shape() { return dense_shape_; } + + private: + std::string param_name_; + std::vector indices_shape_; + TypePtr indices_type_; + std::vector values_shape_; + TypePtr values_type_; + AbstractBasePtrList dense_shape_; + static const size_t fields_num; + + std::vector GetShape(const std::string &shape_str); +}; +std::vector UndeterminedShapeType::GetShape(const std::string &shape_str) { + std::vector ret; + std::istringstream iss(shape_str); + int elem; + while (iss.good()) { + iss >> elem; + ret.emplace_back(elem); + } + return ret; +} +const size_t UndeterminedShapeType::fields_num = 6; + AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { MS_EXCEPTION_IF_NULL(primitive); @@ -62,6 +126,31 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt if (type->type_id() != kObjectTypeSymbolicKeyType) { MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString(); } + + if (key->sparse_grad()) { + // Will be fixed once undetermined type ready + auto sparse_shape_types = common::GetEnv("UNDETERMINED_SPARSE_SHAPE_TYPES"); + if (sparse_shape_types.empty()) { + sparse_shape_types = "w1:2:Int32:2 1 2:Float32:3 1 2"; + } + MS_LOG(DEBUG) << "EnvGetItem is sparse_grad " << key->ToString() << ", Undetermined shape is " + << sparse_shape_types; + + auto shape_types = UndeterminedShapeType(sparse_shape_types); + AbstractBasePtrList sparse_list; + // indices + auto indices_ele = std::make_shared(kAnyValue, shape_types.indices_type()); + auto indices = std::make_shared(indices_ele, std::make_shared(shape_types.indices_shape())); + sparse_list.emplace_back(indices); + // values + auto dout_ele = std::make_shared(kAnyValue, shape_types.values_type()); + auto dout = std::make_shared(dout_ele, std::make_shared(shape_types.values_shape())); + sparse_list.emplace_back(dout); + // dense_shape + sparse_list.emplace_back(std::make_shared(shape_types.dense_shape())); + return std::make_shared(sparse_list); + } + if (!key->GetValueTrack()->isa()) { return dflt; } @@ -80,8 +169,6 @@ AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePt CheckArgsSize(primitive->name(), args_spec_list, 3); auto key = args_spec_list[1]; - auto value = args_spec_list[2]; - ValuePtr key_value_ptr = key->GetValueTrack(); MS_EXCEPTION_IF_NULL(key_value_ptr); auto key_value_track = key_value_ptr->cast(); @@ -91,7 +178,6 @@ AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePt } auto expected = key_value_track->abstract(); MS_EXCEPTION_IF_NULL(expected); - (void)expected->Join(value); return std::make_shared(kAnyValue, std::make_shared()); } @@ -126,7 +212,9 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr & if (type->type_id() != kObjectTypeRefKey) { MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString(); } - return std::make_shared(args_spec_list[0], args_spec_list[1], args_spec_list[2]); + auto ret = std::make_shared(args_spec_list[0], args_spec_list[1], args_spec_list[2]); + ret->set_sparse_grad(args_spec_list[2]->sparse_grad()); + return ret; } AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index b2a7a75703d7c4163f3f8d3929b341ad61dfbff2..f127305d1baeec30d0b88a1d311704a74b0b77d0 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -38,6 +38,7 @@ #include "pipeline/remove_value_node_dup.h" #include "optimizer/optimizer.h" #include "vm/transform.h" +#include "parse/python_adapter.h" namespace mindspore { namespace pipeline { @@ -228,6 +229,8 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { if (param_node->has_default()) { auto param_value = std::dynamic_pointer_cast(param_node->default_param()); AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_value->value()), true); + auto sparse_grad = py::cast(parse::python_adapter::GetPyObjAttr(param_value->value(), "sparse_grad")); + ptr->set_sparse_grad(sparse_grad); parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); args_spec.push_back(ptr); diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc index d4f0c6f8d4dbd80937d6483bfeb0e0965f39fdf9..f23c6e31c4b7f88ed6b69af3c84df9d0ad7257dd 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc @@ -51,6 +51,7 @@ ValuePtr AbstractBase::BuildValue() const { AbstractBasePtr AbstractBase::Broaden() const { AbstractBasePtr clone = Clone(); clone->set_value(kAnyValue); + clone->set_sparse_grad(sparse_grad_); return clone; } @@ -63,7 +64,8 @@ std::string AbstractBase::ToString() const { MS_EXCEPTION_IF_NULL(type_); MS_EXCEPTION_IF_NULL(shape_); buffer << type_name() << "(" - << "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() << ")"; + << "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() + << " sparse_grad: " << sparse_grad_ << ")"; return buffer.str(); } @@ -72,16 +74,22 @@ AbstractBasePtr AbstractScalar::Broaden() const { return AbstractBase::Broaden() AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { MS_EXCEPTION_IF_NULL(other); if (*this == *other) { - return shared_from_base(); + auto ret = shared_from_base(); + ret->set_sparse_grad(sparse_grad()); + return ret; } auto value_self = GetValueTrack(); MS_EXCEPTION_IF_NULL(value_self); ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack()); TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack()); if (res_value == value_self) { - return shared_from_base(); + auto ret = shared_from_base(); + ret->set_sparse_grad(sparse_grad()); + return ret; } - return std::make_shared(res_value, res_type); + auto ret = std::make_shared(res_value, res_type); + ret->set_sparse_grad(sparse_grad()); + return ret; } AbstractBasePtr AbstractType::Clone() const { @@ -423,7 +431,9 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { } auto element = element_->Join(other_tensor->element_); auto shape = ShapeJoin(this->shape(), other_tensor->shape()); - return std::make_shared(element, shape); + auto ret = std::make_shared(element, shape); + ret->set_sparse_grad(sparse_grad()); + return ret; } bool AbstractTensor::operator==(const AbstractTensor &other) const { @@ -463,6 +473,7 @@ AbstractBasePtr AbstractTensor::Clone() const { ShapePtr shp = shape(); clone->set_shape(shp->Clone()); clone->set_value(GetValueTrack()); + clone->set_sparse_grad(sparse_grad()); return clone; } @@ -472,6 +483,7 @@ AbstractBasePtr AbstractTensor::Broaden() const { auto shp = shape(); broaden->set_shape(shp->Clone()); broaden->set_value(kAnyValue); + broaden->set_sparse_grad(sparse_grad()); return broaden; } @@ -482,6 +494,7 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const { shp->Broaden(); broaden->set_shape(shp); broaden->set_value(kAnyValue); + broaden->set_sparse_grad(sparse_grad()); return broaden; } @@ -502,7 +515,8 @@ std::string AbstractTensor::ToString() const { MS_EXCEPTION_IF_NULL(value_track); buffer << type_name() << "(" << "shape: " << shape_track->ToString() << ", element: " << element_->ToString() - << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"; + << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << " sparse_grad " << sparse_grad() + << ")"; return buffer.str(); } diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h index 939976bb95be15e184a84d6410b93ce4a45df85e..dcd6f8f951c75af8d4a5f921c476dc8ba246d5d0 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h +++ b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h @@ -44,7 +44,7 @@ class AbstractBase : public Base { public: explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType, const BaseShapePtr &shape = kNoShape) - : value_(value), type_(type), shape_(shape) {} + : value_(value), type_(type), shape_(shape), sparse_grad_(false) {} ~AbstractBase() override = default; MS_DECLARE_PARENT(AbstractBase, Base) @@ -53,11 +53,13 @@ class AbstractBase : public Base { virtual bool operator==(const AbstractBase &other) const; void set_value(const ValuePtr &value) { value_ = value; } + void set_sparse_grad(const bool &sparse_grad) { sparse_grad_ = sparse_grad; } void set_type(const TypePtr &type) { type_ = type; } void set_shape(const BaseShapePtr &shape) { shape_ = shape; } void set_value_desc(const std::string &desc) { value_desc_ = desc; } const std::string &value_desc() const { return value_desc_; } ValuePtr GetValueTrack() const { return value_; } + bool sparse_grad() const { return sparse_grad_; } TypePtr GetTypeTrack() const { return type_; } BaseShapePtr GetShapeTrack() const { return shape_; } @@ -85,6 +87,7 @@ class AbstractBase : public Base { TypePtr type_; BaseShapePtr shape_; std::string value_desc_; // store initial value description for error report + bool sparse_grad_; }; class AbstractScalar : public AbstractBase { diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index 4cd52f7f478d96b5ffb97911ef3fbb6f6efc41fb..f6c78f0cd21dbe6cd14cdffb86a11a4d52baaf82 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -851,7 +851,11 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { } auto refkey = key_value->cast(); if (refkey == nullptr) { - return std::make_shared(std::make_shared(type), std::make_shared()); + auto ret = std::make_shared(type); + auto ref_value = ref_abs->ref(); + MS_EXCEPTION_IF_NULL(ref_value); + ret->set_sparse_grad(ref_value->sparse_grad()); + return std::make_shared(ret, std::make_shared()); } std::string name = refkey->tag(); @@ -865,6 +869,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { x = SensitivityTransform(x); std::shared_ptr key = std::make_shared(node, x); std::shared_ptr abs_scalar = std::make_shared(key, type); + abs_scalar->set_sparse_grad(x->sparse_grad()); return std::make_shared(abs_scalar, std::make_shared()); } }; diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 788c2d03073a17a89aa2421fcc92a6a7f5ea9c81..e760e23536dbc09c68b7b3c5c7d9d2385139ecf8 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -50,12 +50,14 @@ class Parameter: requires_grad (bool): True if the parameter requires gradient. Default: True. layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode, broadcast and gradients communication would not be applied on parameters. Default: False. + sparse_grad (bool): True if the parameter's gradient is sparse. Default: False. """ - def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False): + def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False, sparse_grad=False): self.set_parameter_data(default_input) self.name = name self.requires_grad = requires_grad self.layerwise_parallel = layerwise_parallel + self.sparse_grad = sparse_grad self._is_init = False self._sliced = False self.clone_info = _CloneInfo() @@ -168,6 +170,17 @@ class Parameter: raise TypeError("`requires_grad` parameter must be bool type") self._requires_grad = value + @property + def sparse_grad(self): + """Return whether the parameter's gradient is sparse.""" + return self._sparse_grad + + @sparse_grad.setter + def sparse_grad(self, value=True): + if not isinstance(value, bool): + raise TypeError("`sparse_grad` parameter must be bool type") + self._sparse_grad = value + @property def data(self): return self.default_input diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index 9ec9b0f08049343adc08ebcd5de5241fec02ea8e..43a35f99e0970be7a81261e39a48752ef88afeae 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -30,6 +30,7 @@ unsorted_segment_sum = P.UnsortedSegmentSum() transpose = P.Transpose() shape_op = P.Shape() reshape = P.Reshape() +size_op = P.Size() invert_permutation = P.InvertPermutation() logical_and = P.LogicalAnd() @@ -284,6 +285,37 @@ def get_bprop_gather_v2(self): return bprop +@bprop_getters.register(P.SparseGatherV2) +def get_bprop_sparse_gather_v2(self): + """Generate bprop for SparseGatherV2""" + + def bprop(x, indices, axis, out, dout): + x_shp = shape_op(x) + if axis == 0: + indices_size = (size_op(indices),) + x_tail_shp = x_shp[1:] + values_shape = indices_size + x_tail_shp + values = reshape(dout, values_shape) + indices = reshape(indices, indices_size) + return (indices, values, x_shp), zeros_like(indices), zeros_like(axis) + if F.rank(dout) == 0: + dout = P.ExpandDims()(dout, -1) + if F.rank(indices) == 0: + indices = P.ExpandDims()(indices, -1) + out_shp = shape_op(dout) + ind_shp = shape_op(indices) + # Example: out_shape:(3,2,3) axis 1 -> (1,0,2) + perm_1 = _generate_shape_index(out_shp, ind_shp, axis) + values_transpose = transpose(dout, perm_1) + params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis]) + # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) + perm_2 = _generate_inverse_index(x_shp, axis) + params_grad = transpose(params_grad, perm_2) + return params_grad, zeros_like(indices), zeros_like(axis) + + return bprop + + @bprop_getters.register(P.Range) def get_bprop_range(self): """Generate bprop for Range""" diff --git a/mindspore/ops/composite/__init__.py b/mindspore/ops/composite/__init__.py index e4c6e35d3a34a0f736eaac849a63f4f66ef81b5a..a531503d940384dd595004c5905e5f0183aba53d 100644 --- a/mindspore/ops/composite/__init__.py +++ b/mindspore/ops/composite/__init__.py @@ -20,7 +20,7 @@ Pre-defined combination of operators. """ -from .base import GradOperation, HyperMap, MultitypeFuncGraph, add_flags, \ +from .base import GradOperation, HyperMap, Map, MultitypeFuncGraph, add_flags, \ grad, grad_all, grad_all_with_sens, grad_by_list, grad_by_list_with_sens, grad_with_sens, \ core, env_get, tail, zip_operation from .clip_ops import clip_by_value diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index 007d147bf08e3d2a92ac102eff1c36da2c0d3ce5..79e4fe76e39cc0b78450fa8f1be7c9309e31f924 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -19,7 +19,7 @@ from functools import partial from mindspore import context -from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeFuncGraph_, Tail_, TensorSlice_, \ +from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, TensorSlice_, \ TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_ from ...common import dtype as mstype from ...common.api import ms_function, _pynative_exec @@ -240,6 +240,69 @@ class HyperMap(HyperMap_): return func(*args_list) return tuple(map(hypermap, *args_list)) +class Map(Map_): + """ + Map will apply the set operation on input sequences. + + Which will apply the operations of every elements of the sequence. + + Args: + ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`, + the operations should be putted in the first input of the instance. + + Inputs: + - **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences, + and each row of the sequences. e.g. If args length is 2, and for `i` in length of each sequence + `(args[0][i], args[1][i])` will be the input of the operation. + + If `ops` is not `None`, the first input is the operation, and the other is inputs. + + Outputs: + sequence, the output will be same type and same length of sequence from input and the value of each element + is the result of operation apply each row of element. e.g. `operation(args[0][i], args[1][i])`. + """ + + def __init__(self, ops=None): + self.ops = ops + if ops: + Map_.__init__(self, ops) + else: + Map_.__init__(self) + + def __call__(self, *args): + func = args[0] + count = 0 + count_max = 1 + args_list = args[1:] + if self.ops is not None: + func = self.ops + args_list = args + for item in args_list: + if isinstance(item, (tuple, list)): + count_max = len(item) + break + + def get_item(x): + nonlocal count + if isinstance(x, (tuple, list)): + return x[count] + return x + + for i in range(count_max): + true_args = tuple(map(get_item, args_list)) + func(*true_args) + count = i + 1 + return True + + def register(self, *type_names): + """Register a function for the given type string.""" + + def deco(fn): + self.register_fn(type_names, fn) + return fn + return deco + + class _ListAppend(ListAppend_): """ A metafuncgraph class that append one element to list. diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index f71a8614737b2fe5ad5b8e12f3668178f8d6c600..4a30d3b7c6a82e0e4389f4614aa42d49b769d245 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -21,7 +21,7 @@ A collection of operators to build nerual networks or computing functions. from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Diag, DiagPart, DType, ExpandDims, Eye, - Fill, GatherNd, GatherV2, InvertPermutation, + Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation, IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Range, SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate, @@ -122,6 +122,7 @@ __all__ = [ 'Transpose', 'OneHot', 'GatherV2', + 'SparseGatherV2', 'Concat', 'Pack', 'Unpack', diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 68e0c1b7aeb8e6a908aff0a716a2cda53fc982e6..a21cb9d955e0b71ef5cec6429474dae31ec6eecd 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -526,6 +526,29 @@ class GatherV2(PrimitiveWithInfer): return out +class SparseGatherV2(GatherV2): + """ + Returns a slice of input tensor based on the specified indices and axis. + + Inputs: + - **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. + The original Tensor. + - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. + Specifies the indices of elements of the original Tensor. Must be in the range + `[0, input_param.shape()[axis])`. + - **axis** (int) - Specifies the dimension index to gather indices. + + Outputs: + Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`. + + Examples: + >>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32) + >>> input_indices = Tensor(np.array([1, 2]), mindspore.int32) + >>> axis = 1 + >>> out = P.GatherV2()(input_params, input_indices, axis) + """ + + class Range(PrimitiveWithInfer): r""" Creates a sequence of numbers. diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index 6205f6b275addcd0f3436c332ee9f9529aa28862..44c7a5635a9e5a0c0bb4726a628bb04edf3263a8 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -332,6 +332,8 @@ class CheckBprop(PrimitiveWithInfer): def infer_shape(self, xshapes, yshapes): tips = f'Bprop of {self.prim_to_check}' + validator.check_value_type('grads', xshapes, (tuple,), tips) + validator.check_value_type('params', yshapes, (tuple,), tips) if len(xshapes) < len(yshapes): raise TypeError(f"{tips}, the size of output should be {len(yshapes)}," f" but got {len(xshapes)}.") @@ -348,6 +350,8 @@ class CheckBprop(PrimitiveWithInfer): def infer_dtype(self, xdtypes, ydtypes): tips = f'Bprop of {self.prim_to_check}' + validator.check_value_type('grads', xdtypes, (tuple,), tips) + validator.check_value_type('params', ydtypes, (tuple,), tips) if len(xdtypes) < len(ydtypes): raise TypeError(f"{tips}, the size of output should be {len(ydtypes)}," f" but got {len(xdtypes)}.") diff --git a/tests/ut/python/nn/optim/test_adam_with_tuple_grad.py b/tests/ut/python/nn/optim/test_adam_with_tuple_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..86ea99b1ae3fc6264167a6a3325b44160468bbd9 --- /dev/null +++ b/tests/ut/python/nn/optim/test_adam_with_tuple_grad.py @@ -0,0 +1,173 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test adam """ +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor, Parameter, context +from mindspore.common.api import _executor +from mindspore.common import dtype as mstype +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim import Optimizer +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore._checkparam import Validator as validator +from mindspore._checkparam import Rel + + +adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") +@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "Tensor", "Tensor", "Bool") +def _update_run_op_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): + op_mul = P.Mul() + op_square = P.Square() + op_sqrt = P.Sqrt() + op_cast = P.Cast() + op_reshape = P.Reshape() + op_shape = P.Shape() + + param_fp32 = op_cast(param, mstype.float32) + m_fp32 = op_cast(m, mstype.float32) + v_fp32 = op_cast(v, mstype.float32) + gradient_fp32 = op_cast(gradient, mstype.float32) + + next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32) + + next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) + - beta2, op_square(gradient_fp32)) + + update = next_m / (op_sqrt(next_v) + eps) + if decay_flag: + update = update + op_mul(weight_decay_tensor, param_fp32) + + update_with_lr = op_mul(lr, update) + next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) + + next_v = F.depend(next_v, F.assign(param, next_param)) + next_v = F.depend(next_v, F.assign(m, next_m)) + next_v = F.depend(next_v, F.assign(v, next_v)) + return next_v + + +@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "Tensor", "Tuple", "Bool") +def _update_run_op_sparse_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): + return gradient[2][2] + +def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): + """Check the type of inputs.""" + validator.check_value_type("beta1", beta1, [float], prim_name) + validator.check_value_type("beta2", beta2, [float], prim_name) + validator.check_value_type("eps", eps, [float], prim_name) + validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) + validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) + validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) + + +class AdamWeightDecaySparse(Optimizer): + """ + Implements Adam algorithm weight decay fix. + + Args: + params (list[Parameter]): A list of parameter, which will be updated. The element in `params` + should be class mindspore.Parameter. + learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is + Iterable or a Tensor and the dims of the Tensor is 1, + use dynamic learning rate, then the i-th step will + take the i-th value as the learning rate. + When the learning_rate is float or learning_rate is a Tensor + but the dims of the Tensor is 0, use fixed learning rate. + Other cases are not supported. Default: 1e-3. + beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9. + Should be in range (0.0, 1.0). + beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999. + Should be in range (0.0, 1.0). + eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6. + Should be greater than 0. + weight_decay (float): Weight decay (L2 penalty). Default: 0.0. + decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: + lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name. + + Inputs: + - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`, + and might be in sparse format. + + Outputs: + tuple[Parameter], the updated velocity value, the shape is the same as `params`. + + Examples: + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits() + >>> optim = nn.AdamWeightDecay(params=net.trainable_params()) + >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) + """ + def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0, + decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): + super(AdamWeightDecaySparse, self).__init__(learning_rate, params) + if self.is_group: + raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") + _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) + self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) + self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) + self.eps = Tensor(np.array([eps]).astype(np.float32)) + self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32)) + + self.params = self.parameters + self.moments1 = self.params.clone(prefix="adam_m", init='zeros') + self.moments2 = self.params.clone(prefix="adam_v", init='zeros') + self.decay_flag = tuple(decay_filter(x) for x in self.params) + + self.map = C.Map() + + def construct(self, gradients): + lr = self.get_lr() + updated_velocity = self.map(F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr, + self.weight_decay_tensor), + self.params, self.moments1, self.moments2, gradients, self.decay_flag) + + return updated_velocity + + +def test_AdamWeightDecaySparse(): + """ test_AdamWeightDecaySparse """ + context.set_context(mode=context.GRAPH_MODE) + class Loss(nn.Cell): + def __init__(self): + super(Loss, self).__init__() + def construct(self, base, target): + return base + class NetWithSparseGatherV2(nn.Cell): + def __init__(self): + super(NetWithSparseGatherV2, self).__init__() + self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", sparse_grad=True) + self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2") + self.gatherv2 = P.SparseGatherV2() + self.axis = 0 + def construct(self, indices): + return self.gatherv2(self.w1, indices, self.axis) * self.w2 + + inputs = Tensor(np.array([0, 1]).astype(np.int32)) + label = Tensor(np.zeros([2, 1, 2]).astype(np.float32)) + net = NetWithSparseGatherV2() + net.set_train() + loss = Loss() + optimizer = AdamWeightDecaySparse(net.trainable_params()) + + net_with_loss = WithLossCell(net, loss) + train_network = TrainOneStepCell(net_with_loss, optimizer) + _executor.compile(train_network, inputs, label)