提交 fe2c2e83 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4160 [refine]remove ref origin

Merge pull request !4160 from vlne-v1/remove-ref-origin
...@@ -333,28 +333,28 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL ...@@ -333,28 +333,28 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL
} }
FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>(); FuncGraphPtr ptr_graph = std::make_shared<FuncGraph>();
ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); ptr_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); ptr_graph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
ptrGraph->debug_info()->set_name("hyper_map"); ptr_graph->debug_info()->set_name("hyper_map");
AnfNodePtr ptrFnArg = nullptr; AnfNodePtr ptrFnArg = nullptr;
std::size_t i = 0; std::size_t i = 0;
ArgsPairList argmap; ArgsPairList argmap;
ArgsPairList argmap2; ArgsPairList argmap2;
if (fn_leaf_ == nullptr) { if (fn_leaf_ == nullptr) {
ptrFnArg = ptrGraph->add_parameter(); ptrFnArg = ptr_graph->add_parameter();
i = 1; i = 1;
} }
std::size_t size = args_spec_list.size(); std::size_t size = args_spec_list.size();
for (; i < size; ++i) { for (; i < size; ++i) {
argmap.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i])); argmap.push_back(std::make_pair(ptr_graph->add_parameter(), args_spec_list[i]));
} }
argmap2 = Harmonize(ptrGraph, argmap); argmap2 = Harmonize(ptr_graph, argmap);
ptrGraph->set_output(Make(ptrGraph, ptrFnArg, argmap2)); ptr_graph->set_output(Make(ptr_graph, ptrFnArg, argmap2));
return ptrGraph; return ptr_graph;
} }
abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
...@@ -582,30 +582,30 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights, ...@@ -582,30 +582,30 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
inputs.push_back(opsTupleItem); inputs.push_back(opsTupleItem);
inputs.push_back(cnode); inputs.push_back(cnode);
inputs.push_back(NewValueNode(1)); inputs.push_back(NewValueNode(1));
AnfNodePtr ptrBprop = ret->NewCNode(inputs); AnfNodePtr ptr_bprop = ret->NewCNode(inputs);
doGetGrad(ret, out, ptrBprop, weights_node, opsTupleItem); doGetGrad(ret, out, ptr_bprop, weights_node, opsTupleItem);
return ret; return ret;
} }
void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights, void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptr_bprop, AnfNodePtr weights,
ValueNodePtr opsTupleItem) { ValueNodePtr opsTupleItem) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
AnfNodePtr ptrBPropArg = nullptr; AnfNodePtr ptr_bprop_arg = nullptr;
if (sens_param_) { if (sens_param_) {
ptrBPropArg = func_graph->add_parameter(); ptr_bprop_arg = func_graph->add_parameter();
} else { } else {
auto ones_like = prim::GetPythonOps("ones_like"); auto ones_like = prim::GetPythonOps("ones_like");
ptrBPropArg = func_graph->NewCNode({NewValueNode(ones_like), out}); ptr_bprop_arg = func_graph->NewCNode({NewValueNode(ones_like), out});
} }
AnfNodePtr ptrBApp = func_graph->NewCNode({ptrBprop, ptrBPropArg}); AnfNodePtr ptr_bapp = func_graph->NewCNode({ptr_bprop, ptr_bprop_arg});
CNodePtr fv_bprop = nullptr; CNodePtr fv_bprop = nullptr;
if (get_by_list_) { if (get_by_list_) {
// python code: grads = hyper_map(F.partial(env_get, env), weights) // python code: grads = hyper_map(F.partial(env_get, env), weights)
AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrBApp, NewValueNode(0)}); AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptr_bapp, NewValueNode(0)});
AnfNodePtr partial_env_get = AnfNodePtr partial_env_get =
func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env}); func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env});
MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>(); MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>();
...@@ -614,7 +614,7 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An ...@@ -614,7 +614,7 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An
CNodePtr inputs_bprop = nullptr; CNodePtr inputs_bprop = nullptr;
if (get_all_) { if (get_all_) {
inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptrBApp}); inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptr_bapp});
} }
// Gradients wrt inputs and parameters // Gradients wrt inputs and parameters
...@@ -636,8 +636,8 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An ...@@ -636,8 +636,8 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An
} }
// Gradients wrt first input. // Gradients wrt first input.
// ptrBApp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input // ptr_bapp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptrBApp, NewValueNode(1)})); func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptr_bapp, NewValueNode(1)}));
} }
// Generate the graph. // Generate the graph.
...@@ -657,35 +657,35 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp ...@@ -657,35 +657,35 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn); auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
MS_EXCEPTION_IF_NULL(real_fn); MS_EXCEPTION_IF_NULL(real_fn);
FuncGraphPtr ptrGraph = real_fn->func_graph(); FuncGraphPtr ptr_graph = real_fn->func_graph();
MS_EXCEPTION_IF_NULL(ptrGraph); MS_EXCEPTION_IF_NULL(ptr_graph);
TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptrGraph->debug_info())); TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
FuncGraphPtr dfBuilder = std::make_shared<FuncGraph>(); FuncGraphPtr df_builder = std::make_shared<FuncGraph>();
TraceManager::EndTrace(); TraceManager::EndTrace();
auto nparam = ptrGraph->parameters().size(); auto nparam = ptr_graph->parameters().size();
std::ostringstream ss; std::ostringstream ss;
ss << "grad{" << nparam << "}"; ss << "grad{" << nparam << "}";
dfBuilder->set_flag(FUNC_GRAPH_FLAG_CORE, true); df_builder->set_flag(FUNC_GRAPH_FLAG_CORE, true);
dfBuilder->debug_info()->set_name(ss.str()); df_builder->debug_info()->set_name(ss.str());
ParameterPtr param_graph = dfBuilder->add_parameter(); ParameterPtr param_graph = df_builder->add_parameter();
AnfNodePtr weights = nullptr; AnfNodePtr weights = nullptr;
if (get_by_list_) { if (get_by_list_) {
weights = dfBuilder->add_parameter(); weights = df_builder->add_parameter();
} }
std::vector<AnfNodePtr> inputs; std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim::kPrimJ)); inputs.push_back(NewValueNode(prim::kPrimJ));
inputs.push_back(param_graph); inputs.push_back(param_graph);
auto jf = dfBuilder->NewCNode(inputs); auto jf = df_builder->NewCNode(inputs);
// df is checked in GetGrad // df is checked in GetGrad
TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptrGraph->debug_info())); TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
auto df = GetGrad(jf, weights, ptrGraph->parameters()); auto df = GetGrad(jf, weights, ptr_graph->parameters());
TraceManager::EndTrace(); TraceManager::EndTrace();
dfBuilder->set_output(NewValueNode(df)); df_builder->set_output(NewValueNode(df));
return dfBuilder; return df_builder;
} }
REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) { REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) {
......
...@@ -72,10 +72,15 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_ ...@@ -72,10 +72,15 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_
bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id, bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id,
TypeId *arg_type = nullptr) { TypeId *arg_type = nullptr) {
if (arg_value->isa<abstract::AbstractRef>()) { if (arg_value->isa<abstract::AbstractRef>()) {
if (is_write) { auto ref = arg_value->cast<abstract::AbstractRefPtr>();
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref_origin(); arg_value = ref->ref();
} else { if (!is_write && ref->need_cast()) {
arg_value = arg_value->cast<abstract::AbstractRefPtr>()->ref(); auto tensor_type = ref->target_type();
*arg_type_id = tensor_type->type_id();
if (arg_type != nullptr) {
*arg_type = kObjectTypeTensorType;
}
return true;
} }
} }
if (arg_value->isa<abstract::AbstractTensor>()) { if (arg_value->isa<abstract::AbstractTensor>()) {
...@@ -248,6 +253,8 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign ...@@ -248,6 +253,8 @@ void DoAutoCast(const std::string &func_name, const std::vector<Signature> &sign
if (arg_value->isa<abstract::AbstractTensor>() && arg_type_id == it->second) { if (arg_value->isa<abstract::AbstractTensor>() && arg_type_id == it->second) {
continue; continue;
} }
MS_LOG(DEBUG) << "do cast for inputs " << i << " " << (*op_inputs)[i + 1]->ToString() << " " << arg_type_id
<< " to " << it->second;
(*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph); (*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph);
} }
} }
...@@ -289,16 +296,23 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func ...@@ -289,16 +296,23 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
TypePtr type = args_spec_list[i]->GetTypeTrack(); TypePtr type = args_spec_list[i]->GetTypeTrack();
if (type && type->type_id() == kObjectTypeRef) { if (type && type->type_id() == kObjectTypeRef) {
auto ref_abs = args_spec_list[i]->cast<abstract::AbstractRefPtr>();
if (sig == SignatureEnumRW::kRWRead) { if (sig == SignatureEnumRW::kRWRead) {
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param}); param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph);
if (ref_abs && ref_abs->need_cast()) {
auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
param = NewCNode({NewValueNode(cast), param, NewValueNode(ref_abs->target_type())}, func_graph);
}
} else if (sig == SignatureEnumRW::kRWWrite) { } else if (sig == SignatureEnumRW::kRWWrite) {
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param}); param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph);
write_indices.insert(i); write_indices.insert(i);
} }
// If sig is SignatureEnumRW::kRWRef, not do anything. // If sig is SignatureEnumRW::kRWRef, not do anything.
} else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) {
MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter."; MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter.";
} }
MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type "
<< args_spec_list[i]->ToString();
op_inputs.push_back(param); op_inputs.push_back(param);
} }
// process default // process default
......
...@@ -49,13 +49,14 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_ ...@@ -49,13 +49,14 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_
MS_LOG(EXCEPTION) << op_name << " requires at least two args, but got " << arg_length << "."; MS_LOG(EXCEPTION) << op_name << " requires at least two args, but got " << arg_length << ".";
} }
(void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0); // No need to check, check will be done in infer.
auto ret_graph = std::make_shared<FuncGraph>(); auto ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ret_graph->debug_info()->set_name("UnpackCall");
AnfNodePtr fnNode = ret_graph->add_parameter(); AnfNodePtr fn_node = ret_graph->add_parameter();
std::vector<AnfNodePtr> elems; std::vector<AnfNodePtr> elems;
elems.push_back(fnNode); elems.push_back(fn_node);
for (size_t index = 1; index < arg_length; index++) { for (size_t index = 1; index < arg_length; index++) {
MS_EXCEPTION_IF_NULL(args_spec_list[index]); MS_EXCEPTION_IF_NULL(args_spec_list[index]);
if (args_spec_list[index]->isa<AbstractTuple>()) { if (args_spec_list[index]->isa<AbstractTuple>()) {
......
...@@ -129,16 +129,22 @@ AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePt ...@@ -129,16 +129,22 @@ AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePt
AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &, AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// arguments: key, value, original value // arguments: key, value, target type(None if no target type)
if (args_spec_list.size() != 3) { if (args_spec_list.size() != 3) {
MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size() MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size()
<< "."; << ".";
} }
TypePtr type = args_spec_list[0]->GetTypeTrack(); TypePtr type = args_spec_list[0]->GetTypeTrack();
ValuePtr tensor_target_v = args_spec_list[2]->BuildValue();
if (type->type_id() != kObjectTypeRefKey) { if (type->type_id() != kObjectTypeRefKey) {
MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString(); MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString();
} }
return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]); auto need_cast = !tensor_target_v->isa<None>();
if (need_cast && !tensor_target_v->isa<Type>()) {
MS_LOG(EXCEPTION) << "Third input of make_ref should be a Type but a " << tensor_target_v->ToString();
}
TypePtr cast_target = tensor_target_v->cast<TypePtr>();
return std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], need_cast, cast_target);
} }
AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &,
...@@ -163,25 +169,11 @@ AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitiveP ...@@ -163,25 +169,11 @@ AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitiveP
} }
TypePtr type = args_spec_list[0]->GetTypeTrack(); TypePtr type = args_spec_list[0]->GetTypeTrack();
if (type->type_id() != kObjectTypeRef) { if (type->type_id() != kObjectTypeRef) {
MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString(); return args_spec_list[0];
} }
return args_spec_list[0]->cast<AbstractRefPtr>()->ref(); return args_spec_list[0]->cast<AbstractRefPtr>()->ref();
} }
AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list) {
// arguments: value
if (args_spec_list.size() != 1) {
MS_LOG(EXCEPTION) << "get_ref_origin requires 1 parameters, while the input size is " << args_spec_list.size()
<< ".";
}
TypePtr type = args_spec_list[0]->GetTypeTrack();
if (type->type_id() != kObjectTypeRef) {
MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString();
}
return args_spec_list[0]->cast<AbstractRefPtr>()->ref_origin();
}
AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// args: Two objects of a subclass of AbstractBase, key and value. // args: Two objects of a subclass of AbstractBase, key and value.
......
...@@ -95,10 +95,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { ...@@ -95,10 +95,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Ref eliminate // Ref eliminate
make_ref_eliminate_ = make_ref_eliminate_ =
MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef); MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef);
get_ref_param_eliminate_ = MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate", get_ref_param_eliminate_ =
{prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate", {prim::kPrimGetRefValue});
get_make_ref_eliminate_ = MakeSubstitution(std::make_shared<GetMakeRefEliminater>(), "get_make_ref_eliminate", get_make_ref_eliminate_ = MakeSubstitution(std::make_shared<GetMakeRefEliminater>(), "get_make_ref_eliminate",
{prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); {prim::kPrimGetRefKey, prim::kPrimGetRefValue});
replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param", replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param",
IsValueNode<RefKey>, opt::FORCE_RENORM); IsValueNode<RefKey>, opt::FORCE_RENORM);
......
...@@ -37,27 +37,23 @@ class MakeRefEliminater : public OptimizerCaller { ...@@ -37,27 +37,23 @@ class MakeRefEliminater : public OptimizerCaller {
}; };
// {prim::kPrimGetRefValue, Parameter} -> Parameter // {prim::kPrimGetRefValue, Parameter} -> Parameter
// {prim::kPrimGetRefOrigin, Parameter} -> Parameter
class GetRefParamEliminater : public OptimizerCaller { class GetRefParamEliminater : public OptimizerCaller {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> x; PatternNode<AnfNodePtr> x;
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, x), x); MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, x), x);
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, x), x);
return nullptr; return nullptr;
} }
}; };
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z
class GetMakeRefEliminater : public OptimizerCaller { class GetMakeRefEliminater : public OptimizerCaller {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> x, y, z; PatternNode<AnfNodePtr> x, y, z;
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x); MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x);
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y); MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y);
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z);
return nullptr; return nullptr;
} }
......
...@@ -60,6 +60,17 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo ...@@ -60,6 +60,17 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo
return func_graph; return func_graph;
} }
ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr &param) {
TypePtr dst_type;
if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
return kFloat32;
} else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
return kFloat16;
} else {
return kNone;
}
}
// if any mixed precision flag add a cast node after the parameter node. // if any mixed precision flag add a cast node after the parameter node.
AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param) { AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param) {
TypePtr dst_type; TypePtr dst_type;
......
...@@ -359,6 +359,7 @@ class ParseAst { ...@@ -359,6 +359,7 @@ class ParseAst {
bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph); bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph);
AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param); AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param);
ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr &param);
} // namespace parse } // namespace parse
} // namespace mindspore } // namespace mindspore
......
...@@ -70,6 +70,7 @@ bool SymbolResolver::Resolve() { ...@@ -70,6 +70,7 @@ bool SymbolResolver::Resolve() {
} }
namespace { namespace {
// if any mixed precision flag add a cast node after the parameter node.
// argument obj should be python Parameter object // argument obj should be python Parameter object
// it will be converted to Parameter node here // it will be converted to Parameter node here
AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) { AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) {
...@@ -112,11 +113,12 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object ...@@ -112,11 +113,12 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
} }
auto iter = func_graph->make_ref_params().find(para_node); auto iter = func_graph->make_ref_params().find(para_node);
if (iter == func_graph->make_ref_params().end()) { if (iter == func_graph->make_ref_params().end()) {
AnfNodePtr value = GetMixedPrecisionCastHelp(func_graph, para_node); ValuePtr target_type = GetMixedPrecisionTargetType(func_graph, para_node);
AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef);
AnfNodePtr ref_key = NewValueNode(std::make_shared<RefKey>(param_name)); AnfNodePtr ref_key = NewValueNode(std::make_shared<RefKey>(param_name));
AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, value, para_node}); AnfNodePtr target_type_node = NewValueNode(target_type);
AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, para_node, target_type_node});
func_graph->make_ref_params()[para_node] = ref_node; func_graph->make_ref_params()[para_node] = ref_node;
func_graph->add_parameter_obj_node(ref_node); func_graph->add_parameter_obj_node(ref_node);
return ref_node; return ref_node;
......
...@@ -125,7 +125,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { ...@@ -125,7 +125,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimMakeRef, {InferImplMakeRef, true}}, {prim::kPrimMakeRef, {InferImplMakeRef, true}},
{prim::kPrimGetRefKey, {InferImplGetRefKey, true}}, {prim::kPrimGetRefKey, {InferImplGetRefKey, true}},
{prim::kPrimGetRefValue, {InferImplGetRefValue, true}}, {prim::kPrimGetRefValue, {InferImplGetRefValue, true}},
{prim::kPrimGetRefOrigin, {InferImplGetRefOrigin, true}},
{prim::kPrimStateSetItem, {InferImplStateSetItem, true}}, {prim::kPrimStateSetItem, {InferImplStateSetItem, true}},
{prim::kPrimDepend, {InferImplDepend, true}}, {prim::kPrimDepend, {InferImplDepend, true}},
{prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}}, {prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}},
......
...@@ -1117,11 +1117,12 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh ...@@ -1117,11 +1117,12 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh
free_param->debug_info()->set_name(param_name); free_param->debug_info()->set_name(param_name);
para_node = free_param; para_node = free_param;
} }
AnfNodePtr value = parse::GetMixedPrecisionCastHelp(df_builder_, para_node); ValuePtr target_type = parse::GetMixedPrecisionTargetType(df_builder_, para_node);
AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef);
auto refkey = std::make_shared<RefKey>(para_node->cast<ParameterPtr>()->name()); auto refkey = std::make_shared<RefKey>(para_node->cast<ParameterPtr>()->name());
AnfNodePtr ref_key_node = NewValueNode(refkey); AnfNodePtr ref_key_node = NewValueNode(refkey);
AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, value, para_node}); AnfNodePtr target_type_node = NewValueNode(target_type);
AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, para_node, target_type_node});
w_args.push_back(ref_node); w_args.push_back(ref_node);
} }
} else { } else {
......
...@@ -808,14 +808,40 @@ std::string AbstractJTagged::ToString() const { ...@@ -808,14 +808,40 @@ std::string AbstractJTagged::ToString() const {
return buffer.str(); return buffer.str();
} }
AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast,
TypePtr cast_target)
: ref_key_(ref_key), ref_(ref_value), need_cast_(false), target_type_(nullptr), ref_key_value_(nullptr) {
set_type(std::make_shared<RefType>());
auto origin_type = ref_value->BuildType();
if (need_cast && cast_target && origin_type && origin_type->isa<TensorType>()) {
auto tensor_dtype = origin_type->cast<TensorTypePtr>()->element();
if (tensor_dtype && IsSubType(tensor_dtype, kFloat)) {
if (cast_target != tensor_dtype) {
need_cast_ = true;
target_type_ = cast_target;
}
}
}
if (ref_key && ref_key->isa<AbstractRefKey>()) {
ref_key_value_ = ref_key->cast<AbstractRefKeyPtr>()->ref_key_value();
}
}
BaseShapePtr AbstractRef::BuildShape() const { return ref_->BuildShape(); }
TypePtr AbstractRef::BuildType() const { TypePtr AbstractRef::BuildType() const {
TypePtr subtype = ref_->BuildType(); TypePtr subtype = ref_->BuildType();
TypePtr subtype_origin = ref_origin_->BuildType(); TypePtr subtype_origin = subtype;
if (need_cast_) {
subtype_origin = std::make_shared<TensorType>(target_type_);
}
return std::make_shared<RefType>(subtype, subtype_origin); return std::make_shared<RefType>(subtype, subtype_origin);
} }
bool AbstractRef::operator==(const AbstractRef &other) const { bool AbstractRef::operator==(const AbstractRef &other) const {
return (*ref_ == *other.ref_) && (*ref_key_ == *other.ref_key_) && (*ref_origin_ == *other.ref_origin_); return (*ref_ == *other.ref_) && (need_cast_ == other.need_cast_) &&
(!need_cast_ || (*target_type_ == *other.target_type_));
// not compare the key for reuse the graph (*ref_key_ == *other.ref_key_);
} }
bool AbstractRef::operator==(const AbstractBase &other) const { bool AbstractRef::operator==(const AbstractBase &other) const {
...@@ -826,27 +852,45 @@ bool AbstractRef::operator==(const AbstractBase &other) const { ...@@ -826,27 +852,45 @@ bool AbstractRef::operator==(const AbstractBase &other) const {
return false; return false;
} }
AbstractBasePtr AbstractRefKey::Join(const AbstractBasePtr &other) {
MS_EXCEPTION_IF_NULL(other);
if (*this == *other) {
auto ret = shared_from_base<AbstractBase>();
return ret;
}
auto value_self = GetValueTrack();
MS_EXCEPTION_IF_NULL(value_self);
ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack());
if (res_value == value_self) {
auto ret = shared_from_base<AbstractBase>();
return ret;
}
auto ret = std::make_shared<AbstractRefKey>();
ret->set_value(res_value);
return ret;
}
AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) { AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) {
auto other_ref = other->cast<AbstractRefPtr>(); auto other_ref = other->cast<AbstractRefPtr>();
if (other_ref == nullptr) { if (other_ref == nullptr) {
auto new_ref = ref_->Join(other); auto new_ref = ref_->Join(other);
return std::make_shared<AbstractRef>(ref_key_, new_ref, ref_origin_); return std::make_shared<AbstractRef>(ref_key_, new_ref);
} }
if (*this == *other) { if ((*this == *other) && (*ref_key_ == *other_ref->ref_key_)) {
return shared_from_base<AbstractBase>(); return shared_from_base<AbstractBase>();
} }
auto ref_key = ref_key_->Join(other_ref->ref_key_); auto ref_key = ref_key_->Join(other_ref->ref_key_);
auto ref = ref_->Join(other_ref->ref()); auto ref = ref_->Join(other_ref->ref());
auto ref_origin = ref_origin_->Join(other_ref->ref_origin_); return std::make_shared<AbstractRef>(ref_key, ref);
return std::make_shared<AbstractRef>(ref_key, ref, ref_origin);
} }
std::string AbstractRef::ToString() const { std::string AbstractRef::ToString() const {
std::ostringstream buffer; std::ostringstream buffer;
buffer << type_name() << "(" buffer << type_name() << "("
<< "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString() << "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString();
<< " origin_value: " << ref_origin_->ToString(); if (need_cast_) {
buffer << " cast to: " << target_type_->ToString();
}
auto value = GetValueTrack(); auto value = GetValueTrack();
if (value) { if (value) {
buffer << ", value: " << value->ToString(); buffer << ", value: " << value->ToString();
...@@ -873,6 +917,12 @@ std::string AbstractNone::ToString() const { ...@@ -873,6 +917,12 @@ std::string AbstractNone::ToString() const {
ValuePtr AbstractNone::RealBuildValue() const { return kNone; } ValuePtr AbstractNone::RealBuildValue() const { return kNone; }
AbstractBasePtr AbstractRefKey::Broaden() const {
auto refkey = std::make_shared<AbstractRefKey>();
refkey->set_value(kAnyValue);
return refkey;
}
bool AbstractRefKey::operator==(const AbstractRefKey &other) const { bool AbstractRefKey::operator==(const AbstractRefKey &other) const {
ValuePtr value_self = GetValueTrack(); ValuePtr value_self = GetValueTrack();
ValuePtr value_other = other.GetValueTrack(); ValuePtr value_other = other.GetValueTrack();
......
...@@ -535,50 +535,70 @@ using AbstractEllipsisPtr = std::shared_ptr<AbstractEllipsis>; ...@@ -535,50 +535,70 @@ using AbstractEllipsisPtr = std::shared_ptr<AbstractEllipsis>;
class AbstractRefKey : public AbstractBase { class AbstractRefKey : public AbstractBase {
public: public:
AbstractRefKey() : AbstractBase() { set_type(std::make_shared<RefKeyType>()); } AbstractRefKey() : AbstractBase(), ref_key_value_(nullptr) { set_type(std::make_shared<RefKeyType>()); }
~AbstractRefKey() override = default; ~AbstractRefKey() override = default;
MS_DECLARE_PARENT(AbstractRefKey, AbstractBase) MS_DECLARE_PARENT(AbstractRefKey, AbstractBase)
TypePtr BuildType() const override { return std::make_shared<RefKeyType>(); } TypePtr BuildType() const override { return std::make_shared<RefKeyType>(); }
bool operator==(const AbstractRefKey &other) const; bool operator==(const AbstractRefKey &other) const;
bool operator==(const AbstractBase &other) const override; bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override { return std::make_shared<AbstractRefKey>(); } AbstractBasePtr Clone() const override {
auto cloned = std::make_shared<AbstractRefKey>();
cloned->set_value(GetValueTrack());
return cloned;
}
inline void set_value(const ValuePtr &value) {
AbstractBase::set_value(value);
ref_key_value_ = value->cast<RefKeyPtr>();
}
RefKeyPtr ref_key_value() const { return ref_key_value_; }
AbstractBasePtr Join(const AbstractBasePtr &other) override;
AbstractBasePtr Broaden() const override;
std::string ToString() const override; std::string ToString() const override;
private:
// cache for ref_key after build value, when value is null, return nullptr.
RefKeyPtr ref_key_value_{nullptr};
}; };
using AbstractRefKeyPtr = std::shared_ptr<AbstractRefKey>; using AbstractRefKeyPtr = std::shared_ptr<AbstractRefKey>;
class AbstractRef : public AbstractBase { class AbstractRef : public AbstractBase {
public: public:
AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, const AbstractBasePtr &ref_origin) AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast = false,
: ref_key_(ref_key), ref_(ref_value), ref_origin_(ref_origin) { TypePtr cast_target = nullptr);
set_type(std::make_shared<RefType>());
}
~AbstractRef() override = default; ~AbstractRef() override = default;
MS_DECLARE_PARENT(AbstractRef, AbstractBase) MS_DECLARE_PARENT(AbstractRef, AbstractBase)
TypePtr BuildType() const override; TypePtr BuildType() const override;
BaseShapePtr BuildShape() const override;
bool operator==(const AbstractRef &other) const; bool operator==(const AbstractRef &other) const;
bool operator==(const AbstractBase &other) const override; bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override { AbstractBasePtr Clone() const override {
return std::make_shared<AbstractRef>(ref_key_->Clone(), ref_->Clone(), ref_origin_->Clone()); return std::make_shared<AbstractRef>(ref_key_->Clone(), ref_->Clone(), need_cast_, target_type_);
} }
std::string ToString() const override; std::string ToString() const override;
AbstractBasePtr ref() { return ref_; } inline AbstractBasePtr ref() const { return ref_; }
AbstractBasePtr ref_origin() { return ref_origin_; } inline AbstractBasePtr ref_key() const { return ref_key_; }
AbstractBasePtr ref_key() { return ref_key_; } inline RefKeyPtr ref_key_value() const { return ref_key_value_; }
inline TypePtr target_type() const { return target_type_; }
inline bool need_cast() const { return need_cast_; }
AbstractBasePtr Broaden() const override { AbstractBasePtr Broaden() const override {
return std::make_shared<AbstractRef>(ref_key_->Broaden(), ref_->Broaden(), ref_origin_->Broaden()); return std::make_shared<AbstractRef>(ref_key_->Broaden(), ref_->Broaden(), need_cast_, target_type_);
} }
AbstractBasePtr Join(const AbstractBasePtr &other) override; AbstractBasePtr Join(const AbstractBasePtr &other) override;
std::size_t hash() const override { std::size_t hash() const override {
return ref_key_->hash() ^ ref_->hash() ^ ref_origin_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1); return ref_->hash() ^ (std::hash<uint32_t>{}(this->tid()) << 1); // ref_key_->hash() ^
} }
private: private:
AbstractBasePtr ref_key_; AbstractBasePtr ref_key_;
AbstractBasePtr ref_; AbstractBasePtr ref_;
AbstractBasePtr ref_origin_; // For mix presicion, only float type need to cast to float16 of float32
bool need_cast_;
TypePtr target_type_;
// cache for ref_key after build value, when value is null, return nullptr.
RefKeyPtr ref_key_value_;
}; };
using AbstractRefPtr = std::shared_ptr<AbstractRef>; using AbstractRefPtr = std::shared_ptr<AbstractRef>;
......
...@@ -171,9 +171,7 @@ AnalysisContextPtr AnalysisContext::SpecializeKey() const { ...@@ -171,9 +171,7 @@ AnalysisContextPtr AnalysisContext::SpecializeKey() const {
} }
if (arg->isa<AbstractRef>()) { if (arg->isa<AbstractRef>()) {
MS_LOG(DEBUG) << "refkey broaden"; MS_LOG(DEBUG) << "refkey broaden";
auto arg_spec = dyn_cast<AbstractRef>(arg); return arg->Broaden();
auto ret_spec = arg_spec->Broaden();
return ret_spec;
} }
return arg; return arg;
}); });
......
...@@ -121,7 +121,6 @@ inline const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add"); ...@@ -121,7 +121,6 @@ inline const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add");
inline const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey"); inline const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey");
inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key"); inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key");
inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value"); inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin");
inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf"); inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
inline const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookBackward"); inline const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookBackward");
inline const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType"); inline const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册