提交 acaa66a7 编写于 作者: P panyifeng

sparse grad for gatherv2

上级 54991615
......@@ -30,6 +30,7 @@
#include "pipeline/parse/python_adapter.h"
#include "pipeline/parse/resolve.h"
#include "operator/composite/composite.h"
#include "operator/composite/map.h"
#include "utils/ordered_map.h"
#include "utils/ordered_set.h"
#include "utils/utils.h"
......@@ -190,6 +191,8 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap
* ├── MultitypeGraph
* ├── HyperMap
* │ └── HyperMapPy
* ├── Map
* │ └── MapPy
* ├── Tail
* ├── MakeTupleGradient
* ├── GradOperation
......@@ -208,17 +211,25 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_
oss << GetMultitypeFuncGraphText(mt_func_graph);
} else if (meta_func_graph
->isa<prim::HyperMapPy>()) { // this statement must before 'meta_graph->isa<prim::HyperMap>()'
prim::HyperMapPyPtr hyper_map = meta_func_graph->cast<prim::HyperMapPyPtr>();
MS_EXCEPTION_IF_NULL(hyper_map);
auto hyper_map = meta_func_graph->cast<prim::HyperMapPyPtr>();
if (hyper_map->GetFnLeaf() != nullptr) {
oss << "{fn_leaf=" << GetMetaFuncGraphText(hyper_map->GetFnLeaf()) << "}";
}
} else if (meta_func_graph->isa<prim::HyperMap>()) {
prim::HyperMapPtr hyper_map = meta_func_graph->cast<prim::HyperMapPtr>();
MS_EXCEPTION_IF_NULL(hyper_map);
auto hyper_map = meta_func_graph->cast<prim::HyperMapPtr>();
if (hyper_map->GetFnLeaf() != nullptr) {
oss << "{fn_leaf=" << GetMetaFuncGraphText(hyper_map->GetFnLeaf()) << "}";
}
} else if (meta_func_graph->isa<prim::MapPy>()) { // this statement must before 'meta_graph->isa<prim::Map>()'
auto map = meta_func_graph->cast<prim::MapPyPtr>();
if (map->GetFnLeaf() != nullptr) {
oss << "{fn_leaf=" << GetMetaFuncGraphText(map->GetFnLeaf()) << "}";
}
} else if (meta_func_graph->isa<prim::Map>()) {
auto map = meta_func_graph->cast<prim::MapPtr>();
if (map->GetFnLeaf() != nullptr) {
oss << "{fn_leaf=" << GetMetaFuncGraphText(map->GetFnLeaf()) << "}";
}
} else if (meta_func_graph->isa<prim::GradOperation>()) {
prim::GradOperationPtr grad_op = meta_func_graph->cast<prim::GradOperationPtr>();
oss << "{get_all=" << grad_op->get_all_ << ", get_by_list=" << grad_op->get_by_list_
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "operator/composite/map.h"
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "pipeline/static_analysis/abstract_function.h"
#include "pipeline/static_analysis/dshape.h"
#include "pybind_api/api_register.h"
#include "debug/trace.h"
#include "operator/ops.h"
#include "./common.h"
namespace mindspore {
// namespace to support composite operators definition
namespace prim {
using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure;
AnfNodePtr Map::FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args) {
MS_LOG(DEBUG) << "Map FullMakeLeaf non recursive.\n";
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> inputs;
if (fn_arg != nullptr) {
inputs.emplace_back(fn_arg);
} else {
inputs.emplace_back(NewValueNode(fn_leaf_));
}
inputs.insert(inputs.end(), args.begin(), args.end());
return func_graph->NewCNode(inputs);
}
FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) {
// Generate func for leaf nodes
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
ptrGraph->debug_info()->set_name("map");
AnfNodePtr ptrFnArg = nullptr;
if (fn_leaf_ == nullptr) {
ptrFnArg = ptrGraph->add_parameter();
}
AnfNodePtrList args;
for (size_t i = 0; i < args_size; ++i) {
args.emplace_back(ptrGraph->add_parameter());
}
ptrGraph->set_output(FullMakeLeaf(ptrGraph, ptrFnArg, args));
return ptrGraph;
}
AnfNodePtr Map::FullMakeList(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(type);
std::size_t size = type->elements().size();
bool is_not_same =
std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
auto lhs = std::dynamic_pointer_cast<List>(item.second);
MS_EXCEPTION_IF_NULL(lhs);
return lhs->elements().size() != size;
});
if (is_not_same) {
MS_LOG(EXCEPTION) << "List in Map should have same length";
}
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim::kPrimMakeList));
for (int i = 0; i < SizeToInt(size); ++i) {
MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the target";
auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
auto fn = NewValueNode(ptrGraph);
std::vector<AnfNodePtr> inputs2;
inputs2.push_back(fn);
if (fn_arg != nullptr) {
inputs2.push_back(fn_arg);
}
(void)std::transform(
arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
[&func_graph, i](const std::pair<AnfNodePtr, Any> &item) {
return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)});
});
inputs.push_back(func_graph->NewCNode(inputs2));
}
return func_graph->NewCNode(inputs);
}
AnfNodePtr Map::FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph,
const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(type);
std::size_t size = type->elements().size();
bool is_not_same =
std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair<AnfNodePtr, TypePtr> &item) {
auto lhs = std::dynamic_pointer_cast<Tuple>(item.second);
MS_EXCEPTION_IF_NULL(lhs);
return lhs->elements().size() != size;
});
if (is_not_same) {
MS_LOG(EXCEPTION) << "tuple in Map should have same length";
}
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
for (int i = 0; i < SizeToInt(size); ++i) {
MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the tuple inputs";
auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
auto fn = NewValueNode(ptrGraph);
std::vector<AnfNodePtr> inputs2;
inputs2.push_back(fn);
if (fn_arg != nullptr) {
inputs2.push_back(fn_arg);
}
(void)std::transform(
arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2),
[&func_graph, &i](std::pair<AnfNodePtr, Any> item) {
return func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)});
});
inputs.push_back(func_graph->NewCNode(inputs2));
}
return func_graph->NewCNode(inputs);
}
AnfNodePtr Map::FullMakeClass(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph,
const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
MS_EXCEPTION_IF_NULL(type);
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim::kPrimMakeRecord));
inputs.push_back(NewValueNode(type));
std::size_t attrSize = type->GetAttributes().size();
for (std::size_t i = 0; i < attrSize; ++i) {
MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th element of the inputs";
auto ptrGraph = GenerateLeafFunc(arg_pairs.size());
auto fn = NewValueNode(ptrGraph);
std::vector<AnfNodePtr> inputs2;
inputs2.push_back(fn);
if (fn_arg != nullptr) {
inputs2.push_back(fn_arg);
}
int j = 0;
for (auto item : arg_pairs) {
inputs2.push_back(func_graph->NewCNode({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)}));
j++;
}
inputs.push_back(func_graph->NewCNode(inputs2));
}
return func_graph->NewCNode(inputs);
}
AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) {
bool found = false;
TypeId id = kObjectTypeEnd;
std::pair<AnfNodePtr, TypePtr> pair;
for (auto &item : arg_pairs) {
pair = item;
MS_LOG(DEBUG) << "Map " << pair.second->ToString();
id = item.second->type_id();
if (nonleaf_.count(id)) {
found = true;
break;
}
}
if (found) {
// In a nonleaf situation, all arguments must have the same generic.
bool is_not_same =
std::any_of(arg_pairs.begin(), arg_pairs.end(), [pair](const std::pair<AnfNodePtr, TypePtr> &item) {
if (item.first != pair.first) {
return item.second->type_id() != pair.second->type_id();
}
return false;
});
if (is_not_same) {
std::ostringstream oss;
oss << "There are " << arg_pairs.size() << " inputs of `" << name_ << "`, corresponding type info:\n"
<< trace::GetDebugInfo(func_graph->debug_info()) << "\n";
int idx = 0;
for (auto &item : arg_pairs) {
oss << ++idx << ": " << item.second->ToString() << "\n";
}
MS_LOG(EXCEPTION) << "Map cannot match up all input types of arguments.\n"
<< oss.str() << pair.second->ToString() << "\n";
}
}
switch (id) {
case kObjectTypeList: {
auto type = std::static_pointer_cast<List>(pair.second);
return FullMakeList(type, func_graph, fn_arg, arg_pairs);
}
case kObjectTypeTuple: {
auto type = std::static_pointer_cast<Tuple>(pair.second);
return FullMakeTuple(type, func_graph, fn_arg, arg_pairs);
}
case kObjectTypeClass: {
auto type = std::static_pointer_cast<Class>(pair.second);
return FullMakeClass(type, func_graph, fn_arg, arg_pairs);
}
default:
MS_LOG(EXCEPTION) << "Map can only be applied to list, tuple and class "
<< ", but got " << pair.second->ToString();
}
}
FuncGraphPtr Map::GenerateFromTypes(const TypePtrList &args_spec_list) {
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
ptrGraph->debug_info()->set_name("map");
AnfNodePtr ptrFnArg = nullptr;
std::size_t i = 0;
if (fn_leaf_ == nullptr) {
ptrFnArg = ptrGraph->add_parameter();
i = 1;
}
ArgsPairList arg_pairs;
std::size_t size = args_spec_list.size();
for (; i < size; ++i) {
MS_LOG(DEBUG) << "GenerateFromTypes for elements from " << args_spec_list[i]->ToString();
arg_pairs.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i]));
}
ptrGraph->set_output(Make(ptrGraph, ptrFnArg, arg_pairs));
return ptrGraph;
}
abstract::AbstractBasePtrList Map::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
if (fn_leaf_ == nullptr) {
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
// Assert that map's function param does not contain free variables
if (args_spec_list[0]->isa<FuncGraphAbstractClosure>()) {
auto graph_func = dyn_cast<FuncGraphAbstractClosure>(args_spec_list[0]);
auto func_graph = graph_func->func_graph();
if (func_graph->parent() != nullptr) {
MS_LOG(EXCEPTION) << "Map don't support Closure with free variable yet.";
}
}
}
AbstractBasePtrList broadened;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened),
[](const AbstractBasePtr &arg) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(arg);
return arg->Broaden();
});
return broadened;
}
REGISTER_PYBIND_DEFINE(Map_, ([](const py::module *m) {
(void)py::class_<MapPy, MetaFuncGraph, std::shared_ptr<MapPy>>(*m, "Map_")
.def(py::init<std::shared_ptr<MultitypeFuncGraph>>(), py::arg("leaf"))
.def(py::init<>());
}));
} // namespace prim
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_
#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_
#include <memory>
#include <set>
#include <utility>
#include <vector>
#include "ir/dtype.h"
#include "ir/meta_func_graph.h"
#include "operator/composite/multitype_funcgraph.h"
namespace mindspore {
// namespace to support composite operators definition
namespace prim {
using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>;
class Map : public MetaFuncGraph {
public:
explicit Map(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr)
: MetaFuncGraph("map"),
fn_leaf_(fn_leaf),
broadcast_(false),
nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) {
Init();
}
Map(const Map &h) : MetaFuncGraph("map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) {
Init();
}
Map &operator=(const Map &h) {
if (this != &h) {
fn_leaf_ = h.fn_leaf_;
broadcast_ = h.broadcast_;
nonleaf_ = h.nonleaf_;
if (fn_leaf_) {
name_ = "map[" + fn_leaf_->name() + "]";
}
}
return *this;
}
~Map() override = default;
MS_DECLARE_PARENT(Map, MetaFuncGraph)
abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override;
FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override;
MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; }
private:
FuncGraphPtr GenerateLeafFunc(const size_t &args_size);
AnfNodePtr FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args);
AnfNodePtr FullMakeList(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
const ArgsPairList &arg_pairs);
AnfNodePtr FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
const ArgsPairList &arg_pairs);
AnfNodePtr FullMakeClass(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
const ArgsPairList &arg_pairs);
AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs);
void Init() {
if (fn_leaf_ != nullptr) {
name_ = "map[" + fn_leaf_->name() + "]";
}
signatures_ =
// def map(func:read, *args:ref):
std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
{"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
}
MultitypeFuncGraphPtr fn_leaf_;
bool broadcast_;
std::set<TypeId> nonleaf_;
};
using MapPtr = std::shared_ptr<Map>;
class MapPy : public Map {
public:
explicit MapPy(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) : Map(fn_leaf) {}
~MapPy() override = default;
MS_DECLARE_PARENT(MapPy, Map)
};
using MapPyPtr = std::shared_ptr<MapPy>;
} // namespace prim
} // namespace mindspore
#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_
......@@ -14,9 +14,14 @@
* limitations under the License.
*/
#include <string>
#include <sstream>
#include "ir/dtype.h"
#include "common/utils.h"
#include "operator/ops.h"
#include "pipeline/static_analysis/param_validator.h"
#include "pipeline/static_analysis/prim.h"
#include "operator/ops.h"
#include "pipeline/static_analysis/utils.h"
#include "utils/symbolic.h"
......@@ -50,6 +55,65 @@ AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primit
return AbstractFunction::MakeAbstractFunction(jv);
}
class UndeterminedShapeType {
public:
explicit UndeterminedShapeType(const std::string &env_str) {
// param_name indices_shape indices_type values_shape values_type dense_shape
// export UNDETERMINED_SPARSE_SHAPE_TYPES="w1:2:Int32:2 1 2:Float32:3 1 2"
std::vector<string> fields;
string tmp;
std::stringstream input(env_str);
while (std::getline(input, tmp, ':')) {
fields.push_back(tmp);
}
if (fields.size() != fields_num) {
MS_LOG(EXCEPTION) << "Expect " << fields_num << " fields, but got " << fields.size();
}
param_name_ = fields[0];
indices_shape_ = GetShape(fields[1]);
indices_type_ = StringToType(fields[2]);
values_shape_ = GetShape(fields[3]);
values_type_ = StringToType(fields[4]);
auto dense_shape_vec = GetShape(fields[5]);
AbstractBasePtrList dense_shape_list;
(void)std::transform(dense_shape_vec.begin(), dense_shape_vec.end(), std::back_inserter(dense_shape_list),
[](const auto &elem) { return FromValue(elem, false); });
dense_shape_ = dense_shape_list;
}
const std::string &param_name() { return param_name_; }
const std::vector<int> &indices_shape() { return indices_shape_; }
const TypePtr &indices_type() { return indices_type_; }
const std::vector<int> &values_shape() { return values_shape_; }
const TypePtr &values_type() { return values_type_; }
const AbstractBasePtrList &dense_shape() { return dense_shape_; }
private:
std::string param_name_;
std::vector<int> indices_shape_;
TypePtr indices_type_;
std::vector<int> values_shape_;
TypePtr values_type_;
AbstractBasePtrList dense_shape_;
static const size_t fields_num;
std::vector<int> GetShape(const std::string &shape_str);
};
std::vector<int> UndeterminedShapeType::GetShape(const std::string &shape_str) {
std::vector<int> ret;
std::istringstream iss(shape_str);
int elem;
while (iss.good()) {
iss >> elem;
ret.emplace_back(elem);
}
return ret;
}
const size_t UndeterminedShapeType::fields_num = 6;
AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(primitive);
......@@ -62,6 +126,31 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
if (type->type_id() != kObjectTypeSymbolicKeyType) {
MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString();
}
if (key->sparse_grad()) {
// Will be fixed once undetermined type ready
auto sparse_shape_types = common::GetEnv("UNDETERMINED_SPARSE_SHAPE_TYPES");
if (sparse_shape_types.empty()) {
sparse_shape_types = "w1:2:Int32:2 1 2:Float32:3 1 2";
}
MS_LOG(DEBUG) << "EnvGetItem is sparse_grad " << key->ToString() << ", Undetermined shape is "
<< sparse_shape_types;
auto shape_types = UndeterminedShapeType(sparse_shape_types);
AbstractBasePtrList sparse_list;
// indices
auto indices_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types.indices_type());
auto indices = std::make_shared<AbstractTensor>(indices_ele, std::make_shared<Shape>(shape_types.indices_shape()));
sparse_list.emplace_back(indices);
// values
auto dout_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types.values_type());
auto dout = std::make_shared<AbstractTensor>(dout_ele, std::make_shared<Shape>(shape_types.values_shape()));
sparse_list.emplace_back(dout);
// dense_shape
sparse_list.emplace_back(std::make_shared<AbstractTuple>(shape_types.dense_shape()));
return std::make_shared<AbstractTuple>(sparse_list);
}
if (!key->GetValueTrack()->isa<SymbolicKeyInstance>()) {
return dflt;
}
......@@ -80,8 +169,6 @@ AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePt
CheckArgsSize(primitive->name(), args_spec_list, 3);
auto key = args_spec_list[1];
auto value = args_spec_list[2];
ValuePtr key_value_ptr = key->GetValueTrack();
MS_EXCEPTION_IF_NULL(key_value_ptr);
auto key_value_track = key_value_ptr->cast<SymbolicKeyInstancePtr>();
......@@ -91,7 +178,6 @@ AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePt
}
auto expected = key_value_track->abstract();
MS_EXCEPTION_IF_NULL(expected);
(void)expected->Join(value);
return std::make_shared<AbstractScalar>(kAnyValue, std::make_shared<EnvType>());
}
......@@ -126,7 +212,9 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &
if (type->type_id() != kObjectTypeRefKey) {
MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString();
}
return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
auto ret = std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
ret->set_sparse_grad(args_spec_list[2]->sparse_grad());
return ret;
}
AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &,
......
......@@ -38,6 +38,7 @@
#include "pipeline/remove_value_node_dup.h"
#include "optimizer/optimizer.h"
#include "vm/transform.h"
#include "parse/python_adapter.h"
namespace mindspore {
namespace pipeline {
......@@ -228,6 +229,8 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
if (param_node->has_default()) {
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param());
AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_value->value()), true);
auto sparse_grad = py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "sparse_grad"));
ptr->set_sparse_grad(sparse_grad);
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr);
args_spec.push_back(ptr);
......
......@@ -51,6 +51,7 @@ ValuePtr AbstractBase::BuildValue() const {
AbstractBasePtr AbstractBase::Broaden() const {
AbstractBasePtr clone = Clone();
clone->set_value(kAnyValue);
clone->set_sparse_grad(sparse_grad_);
return clone;
}
......@@ -63,7 +64,8 @@ std::string AbstractBase::ToString() const {
MS_EXCEPTION_IF_NULL(type_);
MS_EXCEPTION_IF_NULL(shape_);
buffer << type_name() << "("
<< "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() << ")";
<< "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString()
<< " sparse_grad: " << sparse_grad_ << ")";
return buffer.str();
}
......@@ -72,16 +74,22 @@ AbstractBasePtr AbstractScalar::Broaden() const { return AbstractBase::Broaden()
AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
MS_EXCEPTION_IF_NULL(other);
if (*this == *other) {
return shared_from_base<AbstractBase>();
auto ret = shared_from_base<AbstractBase>();
ret->set_sparse_grad(sparse_grad());
return ret;
}
auto value_self = GetValueTrack();
MS_EXCEPTION_IF_NULL(value_self);
ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack());
TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack());
if (res_value == value_self) {
return shared_from_base<AbstractBase>();
auto ret = shared_from_base<AbstractBase>();
ret->set_sparse_grad(sparse_grad());
return ret;
}
return std::make_shared<AbstractScalar>(res_value, res_type);
auto ret = std::make_shared<AbstractScalar>(res_value, res_type);
ret->set_sparse_grad(sparse_grad());
return ret;
}
AbstractBasePtr AbstractType::Clone() const {
......@@ -423,7 +431,9 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
}
auto element = element_->Join(other_tensor->element_);
auto shape = ShapeJoin(this->shape(), other_tensor->shape());
return std::make_shared<AbstractTensor>(element, shape);
auto ret = std::make_shared<AbstractTensor>(element, shape);
ret->set_sparse_grad(sparse_grad());
return ret;
}
bool AbstractTensor::operator==(const AbstractTensor &other) const {
......@@ -463,6 +473,7 @@ AbstractBasePtr AbstractTensor::Clone() const {
ShapePtr shp = shape();
clone->set_shape(shp->Clone());
clone->set_value(GetValueTrack());
clone->set_sparse_grad(sparse_grad());
return clone;
}
......@@ -472,6 +483,7 @@ AbstractBasePtr AbstractTensor::Broaden() const {
auto shp = shape();
broaden->set_shape(shp->Clone());
broaden->set_value(kAnyValue);
broaden->set_sparse_grad(sparse_grad());
return broaden;
}
......@@ -482,6 +494,7 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const {
shp->Broaden();
broaden->set_shape(shp);
broaden->set_value(kAnyValue);
broaden->set_sparse_grad(sparse_grad());
return broaden;
}
......@@ -502,7 +515,8 @@ std::string AbstractTensor::ToString() const {
MS_EXCEPTION_IF_NULL(value_track);
buffer << type_name() << "("
<< "shape: " << shape_track->ToString() << ", element: " << element_->ToString()
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")";
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << " sparse_grad " << sparse_grad()
<< ")";
return buffer.str();
}
......
......@@ -44,7 +44,7 @@ class AbstractBase : public Base {
public:
explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType,
const BaseShapePtr &shape = kNoShape)
: value_(value), type_(type), shape_(shape) {}
: value_(value), type_(type), shape_(shape), sparse_grad_(false) {}
~AbstractBase() override = default;
MS_DECLARE_PARENT(AbstractBase, Base)
......@@ -53,11 +53,13 @@ class AbstractBase : public Base {
virtual bool operator==(const AbstractBase &other) const;
void set_value(const ValuePtr &value) { value_ = value; }
void set_sparse_grad(const bool &sparse_grad) { sparse_grad_ = sparse_grad; }
void set_type(const TypePtr &type) { type_ = type; }
void set_shape(const BaseShapePtr &shape) { shape_ = shape; }
void set_value_desc(const std::string &desc) { value_desc_ = desc; }
const std::string &value_desc() const { return value_desc_; }
ValuePtr GetValueTrack() const { return value_; }
bool sparse_grad() const { return sparse_grad_; }
TypePtr GetTypeTrack() const { return type_; }
BaseShapePtr GetShapeTrack() const { return shape_; }
......@@ -85,6 +87,7 @@ class AbstractBase : public Base {
TypePtr type_;
BaseShapePtr shape_;
std::string value_desc_; // store initial value description for error report
bool sparse_grad_;
};
class AbstractScalar : public AbstractBase {
......
......@@ -851,7 +851,11 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
}
auto refkey = key_value->cast<RefKeyPtr>();
if (refkey == nullptr) {
return std::make_shared<EvalResult>(std::make_shared<AbstractScalar>(type), std::make_shared<AttrValueMap>());
auto ret = std::make_shared<AbstractScalar>(type);
auto ref_value = ref_abs->ref();
MS_EXCEPTION_IF_NULL(ref_value);
ret->set_sparse_grad(ref_value->sparse_grad());
return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
}
std::string name = refkey->tag();
......@@ -865,6 +869,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
x = SensitivityTransform(x);
std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
abs_scalar->set_sparse_grad(x->sparse_grad());
return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
}
};
......
......@@ -50,12 +50,14 @@ class Parameter:
requires_grad (bool): True if the parameter requires gradient. Default: True.
layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode,
broadcast and gradients communication would not be applied on parameters. Default: False.
sparse_grad (bool): True if the parameter's gradient is sparse. Default: False.
"""
def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False):
def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False, sparse_grad=False):
self.set_parameter_data(default_input)
self.name = name
self.requires_grad = requires_grad
self.layerwise_parallel = layerwise_parallel
self.sparse_grad = sparse_grad
self._is_init = False
self._sliced = False
self.clone_info = _CloneInfo()
......@@ -168,6 +170,17 @@ class Parameter:
raise TypeError("`requires_grad` parameter must be bool type")
self._requires_grad = value
@property
def sparse_grad(self):
"""Return whether the parameter's gradient is sparse."""
return self._sparse_grad
@sparse_grad.setter
def sparse_grad(self, value=True):
if not isinstance(value, bool):
raise TypeError("`sparse_grad` parameter must be bool type")
self._sparse_grad = value
@property
def data(self):
return self.default_input
......
......@@ -30,6 +30,7 @@ unsorted_segment_sum = P.UnsortedSegmentSum()
transpose = P.Transpose()
shape_op = P.Shape()
reshape = P.Reshape()
size_op = P.Size()
invert_permutation = P.InvertPermutation()
logical_and = P.LogicalAnd()
......@@ -284,6 +285,37 @@ def get_bprop_gather_v2(self):
return bprop
@bprop_getters.register(P.SparseGatherV2)
def get_bprop_sparse_gather_v2(self):
"""Generate bprop for SparseGatherV2"""
def bprop(x, indices, axis, out, dout):
x_shp = shape_op(x)
if axis == 0:
indices_size = (size_op(indices),)
x_tail_shp = x_shp[1:]
values_shape = indices_size + x_tail_shp
values = reshape(dout, values_shape)
indices = reshape(indices, indices_size)
return (indices, values, x_shp), zeros_like(indices), zeros_like(axis)
if F.rank(dout) == 0:
dout = P.ExpandDims()(dout, -1)
if F.rank(indices) == 0:
indices = P.ExpandDims()(indices, -1)
out_shp = shape_op(dout)
ind_shp = shape_op(indices)
# Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
values_transpose = transpose(dout, perm_1)
params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
# Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
perm_2 = _generate_inverse_index(x_shp, axis)
params_grad = transpose(params_grad, perm_2)
return params_grad, zeros_like(indices), zeros_like(axis)
return bprop
@bprop_getters.register(P.Range)
def get_bprop_range(self):
"""Generate bprop for Range"""
......
......@@ -20,7 +20,7 @@ Pre-defined combination of operators.
"""
from .base import GradOperation, HyperMap, MultitypeFuncGraph, add_flags, \
from .base import GradOperation, HyperMap, Map, MultitypeFuncGraph, add_flags, \
grad, grad_all, grad_all_with_sens, grad_by_list, grad_by_list_with_sens, grad_with_sens, \
core, env_get, tail, zip_operation
from .clip_ops import clip_by_value
......
......@@ -19,7 +19,7 @@
from functools import partial
from mindspore import context
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeFuncGraph_, Tail_, TensorSlice_, \
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, TensorSlice_, \
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_
from ...common import dtype as mstype
from ...common.api import ms_function, _pynative_exec
......@@ -240,6 +240,69 @@ class HyperMap(HyperMap_):
return func(*args_list)
return tuple(map(hypermap, *args_list))
class Map(Map_):
"""
Map will apply the set operation on input sequences.
Which will apply the operations of every elements of the sequence.
Args:
ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`,
the operations should be putted in the first input of the instance.
Inputs:
- **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences,
and each row of the sequences. e.g. If args length is 2, and for `i` in length of each sequence
`(args[0][i], args[1][i])` will be the input of the operation.
If `ops` is not `None`, the first input is the operation, and the other is inputs.
Outputs:
sequence, the output will be same type and same length of sequence from input and the value of each element
is the result of operation apply each row of element. e.g. `operation(args[0][i], args[1][i])`.
"""
def __init__(self, ops=None):
self.ops = ops
if ops:
Map_.__init__(self, ops)
else:
Map_.__init__(self)
def __call__(self, *args):
func = args[0]
count = 0
count_max = 1
args_list = args[1:]
if self.ops is not None:
func = self.ops
args_list = args
for item in args_list:
if isinstance(item, (tuple, list)):
count_max = len(item)
break
def get_item(x):
nonlocal count
if isinstance(x, (tuple, list)):
return x[count]
return x
for i in range(count_max):
true_args = tuple(map(get_item, args_list))
func(*true_args)
count = i + 1
return True
def register(self, *type_names):
"""Register a function for the given type string."""
def deco(fn):
self.register_fn(type_names, fn)
return fn
return deco
class _ListAppend(ListAppend_):
"""
A metafuncgraph class that append one element to list.
......
......@@ -21,7 +21,7 @@ A collection of operators to build nerual networks or computing functions.
from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Diag, DiagPart, DType, ExpandDims, Eye,
Fill, GatherNd, GatherV2, InvertPermutation,
Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation,
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Range,
SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate,
......@@ -122,6 +122,7 @@ __all__ = [
'Transpose',
'OneHot',
'GatherV2',
'SparseGatherV2',
'Concat',
'Pack',
'Unpack',
......
......@@ -526,6 +526,29 @@ class GatherV2(PrimitiveWithInfer):
return out
class SparseGatherV2(GatherV2):
"""
Returns a slice of input tensor based on the specified indices and axis.
Inputs:
- **input_params** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
The original Tensor.
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. Must be in the range
`[0, input_param.shape()[axis])`.
- **axis** (int) - Specifies the dimension index to gather indices.
Outputs:
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
Examples:
>>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32)
>>> input_indices = Tensor(np.array([1, 2]), mindspore.int32)
>>> axis = 1
>>> out = P.GatherV2()(input_params, input_indices, axis)
"""
class Range(PrimitiveWithInfer):
r"""
Creates a sequence of numbers.
......
......@@ -332,6 +332,8 @@ class CheckBprop(PrimitiveWithInfer):
def infer_shape(self, xshapes, yshapes):
tips = f'Bprop of {self.prim_to_check}'
validator.check_value_type('grads', xshapes, (tuple,), tips)
validator.check_value_type('params', yshapes, (tuple,), tips)
if len(xshapes) < len(yshapes):
raise TypeError(f"{tips}, the size of output should be {len(yshapes)},"
f" but got {len(xshapes)}.")
......@@ -348,6 +350,8 @@ class CheckBprop(PrimitiveWithInfer):
def infer_dtype(self, xdtypes, ydtypes):
tips = f'Bprop of {self.prim_to_check}'
validator.check_value_type('grads', xdtypes, (tuple,), tips)
validator.check_value_type('params', ydtypes, (tuple,), tips)
if len(xdtypes) < len(ydtypes):
raise TypeError(f"{tips}, the size of output should be {len(ydtypes)},"
f" but got {len(xdtypes)}.")
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test adam """
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, Parameter, context
from mindspore.common.api import _executor
from mindspore.common import dtype as mstype
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Optimizer
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map")
@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Bool")
def _update_run_op_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
op_mul = P.Mul()
op_square = P.Square()
op_sqrt = P.Sqrt()
op_cast = P.Cast()
op_reshape = P.Reshape()
op_shape = P.Shape()
param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- beta2, op_square(gradient_fp32))
update = next_m / (op_sqrt(next_v) + eps)
if decay_flag:
update = update + op_mul(weight_decay_tensor, param_fp32)
update_with_lr = op_mul(lr, update)
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
next_v = F.depend(next_v, F.assign(param, next_param))
next_v = F.depend(next_v, F.assign(m, next_m))
next_v = F.depend(next_v, F.assign(v, next_v))
return next_v
@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tuple", "Bool")
def _update_run_op_sparse_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
return gradient[2][2]
def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
"""Check the type of inputs."""
validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name)
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
class AdamWeightDecaySparse(Optimizer):
"""
Implements Adam algorithm weight decay fix.
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
should be class mindspore.Parameter.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
take the i-th value as the learning rate.
When the learning_rate is float or learning_rate is a Tensor
but the dims of the Tensor is 0, use fixed learning rate.
Other cases are not supported. Default: 1e-3.
beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9.
Should be in range (0.0, 1.0).
beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999.
Should be in range (0.0, 1.0).
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
Should be greater than 0.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`,
and might be in sparse format.
Outputs:
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
Examples:
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.AdamWeightDecay(params=net.trainable_params())
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
"""
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(AdamWeightDecaySparse, self).__init__(learning_rate, params)
if self.is_group:
raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.")
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
self.eps = Tensor(np.array([eps]).astype(np.float32))
self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32))
self.params = self.parameters
self.moments1 = self.params.clone(prefix="adam_m", init='zeros')
self.moments2 = self.params.clone(prefix="adam_v", init='zeros')
self.decay_flag = tuple(decay_filter(x) for x in self.params)
self.map = C.Map()
def construct(self, gradients):
lr = self.get_lr()
updated_velocity = self.map(F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
return updated_velocity
def test_AdamWeightDecaySparse():
""" test_AdamWeightDecaySparse """
context.set_context(mode=context.GRAPH_MODE)
class Loss(nn.Cell):
def __init__(self):
super(Loss, self).__init__()
def construct(self, base, target):
return base
class NetWithSparseGatherV2(nn.Cell):
def __init__(self):
super(NetWithSparseGatherV2, self).__init__()
self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", sparse_grad=True)
self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2")
self.gatherv2 = P.SparseGatherV2()
self.axis = 0
def construct(self, indices):
return self.gatherv2(self.w1, indices, self.axis) * self.w2
inputs = Tensor(np.array([0, 1]).astype(np.int32))
label = Tensor(np.zeros([2, 1, 2]).astype(np.float32))
net = NetWithSparseGatherV2()
net.set_train()
loss = Loss()
optimizer = AdamWeightDecaySparse(net.trainable_params())
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册