diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc index fbcb06629db77abd2dcaa607ddecf7da50193a6c..262cb789cc5f926e844254f49894f9e33c7aff51 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.cc +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -333,28 +333,28 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL } FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { - FuncGraphPtr ptrGraph = std::make_shared(); - ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); - ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); - ptrGraph->debug_info()->set_name("hyper_map"); + FuncGraphPtr ptr_graph = std::make_shared(); + ptr_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); + ptr_graph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); + ptr_graph->debug_info()->set_name("hyper_map"); AnfNodePtr ptrFnArg = nullptr; std::size_t i = 0; ArgsPairList argmap; ArgsPairList argmap2; if (fn_leaf_ == nullptr) { - ptrFnArg = ptrGraph->add_parameter(); + ptrFnArg = ptr_graph->add_parameter(); i = 1; } std::size_t size = args_spec_list.size(); for (; i < size; ++i) { - argmap.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i])); + argmap.push_back(std::make_pair(ptr_graph->add_parameter(), args_spec_list[i])); } - argmap2 = Harmonize(ptrGraph, argmap); - ptrGraph->set_output(Make(ptrGraph, ptrFnArg, argmap2)); - return ptrGraph; + argmap2 = Harmonize(ptr_graph, argmap); + ptr_graph->set_output(Make(ptr_graph, ptrFnArg, argmap2)); + return ptr_graph; } abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { @@ -582,30 +582,30 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights, inputs.push_back(opsTupleItem); inputs.push_back(cnode); inputs.push_back(NewValueNode(1)); - AnfNodePtr ptrBprop = ret->NewCNode(inputs); + AnfNodePtr ptr_bprop = ret->NewCNode(inputs); - doGetGrad(ret, out, ptrBprop, weights_node, opsTupleItem); + doGetGrad(ret, out, ptr_bprop, weights_node, opsTupleItem); return ret; } -void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights, +void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptr_bprop, AnfNodePtr weights, ValueNodePtr opsTupleItem) { MS_EXCEPTION_IF_NULL(func_graph); - AnfNodePtr ptrBPropArg = nullptr; + AnfNodePtr ptr_bprop_arg = nullptr; if (sens_param_) { - ptrBPropArg = func_graph->add_parameter(); + ptr_bprop_arg = func_graph->add_parameter(); } else { auto ones_like = prim::GetPythonOps("ones_like"); - ptrBPropArg = func_graph->NewCNode({NewValueNode(ones_like), out}); + ptr_bprop_arg = func_graph->NewCNode({NewValueNode(ones_like), out}); } - AnfNodePtr ptrBApp = func_graph->NewCNode({ptrBprop, ptrBPropArg}); + AnfNodePtr ptr_bapp = func_graph->NewCNode({ptr_bprop, ptr_bprop_arg}); CNodePtr fv_bprop = nullptr; if (get_by_list_) { // python code: grads = hyper_map(F.partial(env_get, env), weights) - AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrBApp, NewValueNode(0)}); + AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptr_bapp, NewValueNode(0)}); AnfNodePtr partial_env_get = func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env}); MetaFuncGraphPtr hyper_map = std::make_shared(); @@ -614,7 +614,7 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An CNodePtr inputs_bprop = nullptr; if (get_all_) { - inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptrBApp}); + inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptr_bapp}); } // Gradients wrt inputs and parameters @@ -636,8 +636,8 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An } // Gradients wrt first input. - // ptrBApp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input - func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptrBApp, NewValueNode(1)})); + // ptr_bapp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input + func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptr_bapp, NewValueNode(1)})); } // Generate the graph. @@ -657,35 +657,35 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp auto real_fn = dyn_cast(fn); MS_EXCEPTION_IF_NULL(real_fn); - FuncGraphPtr ptrGraph = real_fn->func_graph(); - MS_EXCEPTION_IF_NULL(ptrGraph); - TraceManager::DebugTrace(std::make_shared(ptrGraph->debug_info())); - FuncGraphPtr dfBuilder = std::make_shared(); + FuncGraphPtr ptr_graph = real_fn->func_graph(); + MS_EXCEPTION_IF_NULL(ptr_graph); + TraceManager::DebugTrace(std::make_shared(ptr_graph->debug_info())); + FuncGraphPtr df_builder = std::make_shared(); TraceManager::EndTrace(); - auto nparam = ptrGraph->parameters().size(); + auto nparam = ptr_graph->parameters().size(); std::ostringstream ss; ss << "grad{" << nparam << "}"; - dfBuilder->set_flag(FUNC_GRAPH_FLAG_CORE, true); - dfBuilder->debug_info()->set_name(ss.str()); - ParameterPtr param_graph = dfBuilder->add_parameter(); + df_builder->set_flag(FUNC_GRAPH_FLAG_CORE, true); + df_builder->debug_info()->set_name(ss.str()); + ParameterPtr param_graph = df_builder->add_parameter(); AnfNodePtr weights = nullptr; if (get_by_list_) { - weights = dfBuilder->add_parameter(); + weights = df_builder->add_parameter(); } std::vector inputs; inputs.push_back(NewValueNode(prim::kPrimJ)); inputs.push_back(param_graph); - auto jf = dfBuilder->NewCNode(inputs); + auto jf = df_builder->NewCNode(inputs); // df is checked in GetGrad - TraceManager::DebugTrace(std::make_shared(ptrGraph->debug_info())); - auto df = GetGrad(jf, weights, ptrGraph->parameters()); + TraceManager::DebugTrace(std::make_shared(ptr_graph->debug_info())); + auto df = GetGrad(jf, weights, ptr_graph->parameters()); TraceManager::EndTrace(); - dfBuilder->set_output(NewValueNode(df)); + df_builder->set_output(NewValueNode(df)); - return dfBuilder; + return df_builder; } REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) { diff --git a/mindspore/ccsrc/frontend/operator/composite/do_signature.cc b/mindspore/ccsrc/frontend/operator/composite/do_signature.cc index 50be3c5b29a2914738abeae1d4a327a504511cf9..248e42cb5b8aecf7aed546ee4cf6b646e2e27cf9 100644 --- a/mindspore/ccsrc/frontend/operator/composite/do_signature.cc +++ b/mindspore/ccsrc/frontend/operator/composite/do_signature.cc @@ -72,10 +72,15 @@ void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_ bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id, TypeId *arg_type = nullptr) { if (arg_value->isa()) { - if (is_write) { - arg_value = arg_value->cast()->ref_origin(); - } else { - arg_value = arg_value->cast()->ref(); + auto ref = arg_value->cast(); + arg_value = ref->ref(); + if (!is_write && ref->need_cast()) { + auto tensor_type = ref->target_type(); + *arg_type_id = tensor_type->type_id(); + if (arg_type != nullptr) { + *arg_type = kObjectTypeTensorType; + } + return true; } } if (arg_value->isa()) { @@ -248,6 +253,8 @@ void DoAutoCast(const std::string &func_name, const std::vector &sign if (arg_value->isa() && arg_type_id == it->second) { continue; } + MS_LOG(DEBUG) << "do cast for inputs " << i << " " << (*op_inputs)[i + 1]->ToString() << " " << arg_type_id + << " to " << it->second; (*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph); } } @@ -289,16 +296,23 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func TypePtr type = args_spec_list[i]->GetTypeTrack(); if (type && type->type_id() == kObjectTypeRef) { + auto ref_abs = args_spec_list[i]->cast(); if (sig == SignatureEnumRW::kRWRead) { - param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param}); + param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph); + if (ref_abs && ref_abs->need_cast()) { + auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional"); + param = NewCNode({NewValueNode(cast), param, NewValueNode(ref_abs->target_type())}, func_graph); + } } else if (sig == SignatureEnumRW::kRWWrite) { - param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param}); + param = NewCNode({NewValueNode(prim::kPrimGetRefValue), param}, func_graph); write_indices.insert(i); } // If sig is SignatureEnumRW::kRWRef, not do anything. } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter."; } + MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type " + << args_spec_list[i]->ToString(); op_inputs.push_back(param); } // process default diff --git a/mindspore/ccsrc/frontend/operator/composite/unpack_call.cc b/mindspore/ccsrc/frontend/operator/composite/unpack_call.cc index 2c9e0b538f344c136694ddbc2fb107831991cdb8..1b6f358edd5f3c44ae5bb7e2e1e10035bc345a40 100644 --- a/mindspore/ccsrc/frontend/operator/composite/unpack_call.cc +++ b/mindspore/ccsrc/frontend/operator/composite/unpack_call.cc @@ -49,13 +49,14 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_ MS_LOG(EXCEPTION) << op_name << " requires at least two args, but got " << arg_length << "."; } - (void)abstract::CheckArg(op_name, args_spec_list, 0); + // No need to check, check will be done in infer. auto ret_graph = std::make_shared(); ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); + ret_graph->debug_info()->set_name("UnpackCall"); - AnfNodePtr fnNode = ret_graph->add_parameter(); + AnfNodePtr fn_node = ret_graph->add_parameter(); std::vector elems; - elems.push_back(fnNode); + elems.push_back(fn_node); for (size_t index = 1; index < arg_length; index++) { MS_EXCEPTION_IF_NULL(args_spec_list[index]); if (args_spec_list[index]->isa()) { diff --git a/mindspore/ccsrc/frontend/operator/prim_others.cc b/mindspore/ccsrc/frontend/operator/prim_others.cc index 4b2a5be482d1ff10ce4d54c3add507451b51e62d..7707dd5a8fb45de343ab2407013363f4c8803fe5 100644 --- a/mindspore/ccsrc/frontend/operator/prim_others.cc +++ b/mindspore/ccsrc/frontend/operator/prim_others.cc @@ -129,16 +129,22 @@ AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePt AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &, const AbstractBasePtrList &args_spec_list) { - // arguments: key, value, original value + // arguments: key, value, target type(None if no target type) if (args_spec_list.size() != 3) { MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size() << "."; } TypePtr type = args_spec_list[0]->GetTypeTrack(); + ValuePtr tensor_target_v = args_spec_list[2]->BuildValue(); if (type->type_id() != kObjectTypeRefKey) { MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString(); } - return std::make_shared(args_spec_list[0], args_spec_list[1], args_spec_list[2]); + auto need_cast = !tensor_target_v->isa(); + if (need_cast && !tensor_target_v->isa()) { + MS_LOG(EXCEPTION) << "Third input of make_ref should be a Type but a " << tensor_target_v->ToString(); + } + TypePtr cast_target = tensor_target_v->cast(); + return std::make_shared(args_spec_list[0], args_spec_list[1], need_cast, cast_target); } AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, @@ -163,25 +169,11 @@ AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitiveP } TypePtr type = args_spec_list[0]->GetTypeTrack(); if (type->type_id() != kObjectTypeRef) { - MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString(); + return args_spec_list[0]; } return args_spec_list[0]->cast()->ref(); } -AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - // arguments: value - if (args_spec_list.size() != 1) { - MS_LOG(EXCEPTION) << "get_ref_origin requires 1 parameters, while the input size is " << args_spec_list.size() - << "."; - } - TypePtr type = args_spec_list[0]->GetTypeTrack(); - if (type->type_id() != kObjectTypeRef) { - MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString(); - } - return args_spec_list[0]->cast()->ref_origin(); -} - AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // args: Two objects of a subclass of AbstractBase, key and value. diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index d1d29fcbae1bf41f27ed8b7c13fcc10f64e3905d..b41c3081b48f85d64525c9c568b40eeebf90275b 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -95,10 +95,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { // Ref eliminate make_ref_eliminate_ = MakeSubstitution(std::make_shared(), "make_ref_eliminate", prim::kPrimMakeRef); - get_ref_param_eliminate_ = MakeSubstitution(std::make_shared(), "get_ref_param_eliminate", - {prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); + get_ref_param_eliminate_ = + MakeSubstitution(std::make_shared(), "get_ref_param_eliminate", {prim::kPrimGetRefValue}); get_make_ref_eliminate_ = MakeSubstitution(std::make_shared(), "get_make_ref_eliminate", - {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); + {prim::kPrimGetRefKey, prim::kPrimGetRefValue}); replace_refkey_by_param_ = MakeSubstitution(std::make_shared(), "replace_refkey_by_param", IsValueNode, opt::FORCE_RENORM); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h index b7759daad415dddad2241627283485822d2aaefd..fc859b213e704208315af66bf9465c81063054c2 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h @@ -37,27 +37,23 @@ class MakeRefEliminater : public OptimizerCaller { }; // {prim::kPrimGetRefValue, Parameter} -> Parameter -// {prim::kPrimGetRefOrigin, Parameter} -> Parameter class GetRefParamEliminater : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { PatternNode x; MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, x), x); - MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, x), x); return nullptr; } }; // {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X // {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y -// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z class GetMakeRefEliminater : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { PatternNode x, y, z; MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x); MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y); - MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z); return nullptr; } diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index cadb0f61996a19be0017e0ef636aaf13ab3a5e39..be75d6ac2e6179e5aa8105ca52ec2ca87c00bb5b 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -60,6 +60,17 @@ FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mo return func_graph; } +ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m) { + TypePtr dst_type; + if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) { + return kFloat32; + } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) { + return kFloat16; + } else { + return kNone; + } +} + // if any mixed precision flag add a cast node after the parameter node. AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m) { TypePtr dst_type; diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.h b/mindspore/ccsrc/pipeline/jit/parse/parse.h index afb72ba5c94e8b183c873e1dd12d8d9939d3a5cf..47366b664eddcd82068b21d80703ae35fb3bc29c 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.h @@ -359,6 +359,7 @@ class ParseAst { bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph); AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); +ValuePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); } // namespace parse } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc index 48f3a24652c2d74d7f7c6bb33063d561bc8be43b..9f0a4b495c9a9e46a51e633ce7d571028f11e0f3 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -70,6 +70,7 @@ bool SymbolResolver::Resolve() { } namespace { +// if any mixed precision flag add a cast node after the parameter node. // argument obj should be python Parameter object // it will be converted to Parameter node here AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) { @@ -112,11 +113,12 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object } auto iter = func_graph->make_ref_params().find(para_node); if (iter == func_graph->make_ref_params().end()) { - AnfNodePtr value = GetMixedPrecisionCastHelp(func_graph, para_node); + ValuePtr target_type = GetMixedPrecisionTargetType(func_graph, para_node); AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); AnfNodePtr ref_key = NewValueNode(std::make_shared(param_name)); - AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, value, para_node}); + AnfNodePtr target_type_node = NewValueNode(target_type); + AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, para_node, target_type_node}); func_graph->make_ref_params()[para_node] = ref_node; func_graph->add_parameter_obj_node(ref_node); return ref_node; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 2dd0ba6b4996c0f4e7ec5e160e390d2ce736a0a6..1cd9ecdb3b36f5aaae73f5cac00e7fcfb7f867ea 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -125,7 +125,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimMakeRef, {InferImplMakeRef, true}}, {prim::kPrimGetRefKey, {InferImplGetRefKey, true}}, {prim::kPrimGetRefValue, {InferImplGetRefValue, true}}, - {prim::kPrimGetRefOrigin, {InferImplGetRefOrigin, true}}, {prim::kPrimStateSetItem, {InferImplStateSetItem, true}}, {prim::kPrimDepend, {InferImplDepend, true}}, {prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}}, diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 703f3dff7ec5045ac3719ae6d81ff7faaf830688..b918de0942445698ef41447dd60a61b6e29e0531 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -1117,11 +1117,12 @@ std::vector PynativeExecutor::GetWeightsArgs(const py::object &weigh free_param->debug_info()->set_name(param_name); para_node = free_param; } - AnfNodePtr value = parse::GetMixedPrecisionCastHelp(df_builder_, para_node); + ValuePtr target_type = parse::GetMixedPrecisionTargetType(df_builder_, para_node); AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); auto refkey = std::make_shared(para_node->cast()->name()); AnfNodePtr ref_key_node = NewValueNode(refkey); - AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, value, para_node}); + AnfNodePtr target_type_node = NewValueNode(target_type); + AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, para_node, target_type_node}); w_args.push_back(ref_node); } } else { diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc index efdf12452b0322202f64d628162ca08ab1116030..dab262bc8970d94329dd18aad8ca83f158891fd6 100644 --- a/mindspore/core/abstract/abstract_value.cc +++ b/mindspore/core/abstract/abstract_value.cc @@ -808,14 +808,40 @@ std::string AbstractJTagged::ToString() const { return buffer.str(); } +AbstractRef::AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast, + TypePtr cast_target) + : ref_key_(ref_key), ref_(ref_value), need_cast_(false), target_type_(nullptr), ref_key_value_(nullptr) { + set_type(std::make_shared()); + auto origin_type = ref_value->BuildType(); + if (need_cast && cast_target && origin_type && origin_type->isa()) { + auto tensor_dtype = origin_type->cast()->element(); + if (tensor_dtype && IsSubType(tensor_dtype, kFloat)) { + if (cast_target != tensor_dtype) { + need_cast_ = true; + target_type_ = cast_target; + } + } + } + if (ref_key && ref_key->isa()) { + ref_key_value_ = ref_key->cast()->ref_key_value(); + } +} + +BaseShapePtr AbstractRef::BuildShape() const { return ref_->BuildShape(); } + TypePtr AbstractRef::BuildType() const { TypePtr subtype = ref_->BuildType(); - TypePtr subtype_origin = ref_origin_->BuildType(); + TypePtr subtype_origin = subtype; + if (need_cast_) { + subtype_origin = std::make_shared(target_type_); + } return std::make_shared(subtype, subtype_origin); } bool AbstractRef::operator==(const AbstractRef &other) const { - return (*ref_ == *other.ref_) && (*ref_key_ == *other.ref_key_) && (*ref_origin_ == *other.ref_origin_); + return (*ref_ == *other.ref_) && (need_cast_ == other.need_cast_) && + (!need_cast_ || (*target_type_ == *other.target_type_)); + // not compare the key for reuse the graph (*ref_key_ == *other.ref_key_); } bool AbstractRef::operator==(const AbstractBase &other) const { @@ -826,27 +852,45 @@ bool AbstractRef::operator==(const AbstractBase &other) const { return false; } +AbstractBasePtr AbstractRefKey::Join(const AbstractBasePtr &other) { + MS_EXCEPTION_IF_NULL(other); + if (*this == *other) { + auto ret = shared_from_base(); + return ret; + } + auto value_self = GetValueTrack(); + MS_EXCEPTION_IF_NULL(value_self); + ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack()); + if (res_value == value_self) { + auto ret = shared_from_base(); + return ret; + } + auto ret = std::make_shared(); + ret->set_value(res_value); + return ret; +} + AbstractBasePtr AbstractRef::Join(const AbstractBasePtr &other) { auto other_ref = other->cast(); if (other_ref == nullptr) { auto new_ref = ref_->Join(other); - return std::make_shared(ref_key_, new_ref, ref_origin_); + return std::make_shared(ref_key_, new_ref); } - if (*this == *other) { + if ((*this == *other) && (*ref_key_ == *other_ref->ref_key_)) { return shared_from_base(); } auto ref_key = ref_key_->Join(other_ref->ref_key_); auto ref = ref_->Join(other_ref->ref()); - auto ref_origin = ref_origin_->Join(other_ref->ref_origin_); - - return std::make_shared(ref_key, ref, ref_origin); + return std::make_shared(ref_key, ref); } std::string AbstractRef::ToString() const { std::ostringstream buffer; buffer << type_name() << "(" - << "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString() - << " origin_value: " << ref_origin_->ToString(); + << "key: " << ref_key_->ToString() << " ref_value: " << ref_->ToString(); + if (need_cast_) { + buffer << " cast to: " << target_type_->ToString(); + } auto value = GetValueTrack(); if (value) { buffer << ", value: " << value->ToString(); @@ -873,6 +917,12 @@ std::string AbstractNone::ToString() const { ValuePtr AbstractNone::RealBuildValue() const { return kNone; } +AbstractBasePtr AbstractRefKey::Broaden() const { + auto refkey = std::make_shared(); + refkey->set_value(kAnyValue); + return refkey; +} + bool AbstractRefKey::operator==(const AbstractRefKey &other) const { ValuePtr value_self = GetValueTrack(); ValuePtr value_other = other.GetValueTrack(); diff --git a/mindspore/core/abstract/abstract_value.h b/mindspore/core/abstract/abstract_value.h index faf80c639b5c615914100ff6887fdda7020594eb..eee0fc670c8a6206d26dcedc5be5bf15ad3dcfca 100644 --- a/mindspore/core/abstract/abstract_value.h +++ b/mindspore/core/abstract/abstract_value.h @@ -535,50 +535,70 @@ using AbstractEllipsisPtr = std::shared_ptr; class AbstractRefKey : public AbstractBase { public: - AbstractRefKey() : AbstractBase() { set_type(std::make_shared()); } + AbstractRefKey() : AbstractBase(), ref_key_value_(nullptr) { set_type(std::make_shared()); } ~AbstractRefKey() override = default; MS_DECLARE_PARENT(AbstractRefKey, AbstractBase) TypePtr BuildType() const override { return std::make_shared(); } bool operator==(const AbstractRefKey &other) const; bool operator==(const AbstractBase &other) const override; - AbstractBasePtr Clone() const override { return std::make_shared(); } + AbstractBasePtr Clone() const override { + auto cloned = std::make_shared(); + cloned->set_value(GetValueTrack()); + return cloned; + } + inline void set_value(const ValuePtr &value) { + AbstractBase::set_value(value); + ref_key_value_ = value->cast(); + } + RefKeyPtr ref_key_value() const { return ref_key_value_; } + AbstractBasePtr Join(const AbstractBasePtr &other) override; + AbstractBasePtr Broaden() const override; std::string ToString() const override; + + private: + // cache for ref_key after build value, when value is null, return nullptr. + RefKeyPtr ref_key_value_{nullptr}; }; using AbstractRefKeyPtr = std::shared_ptr; class AbstractRef : public AbstractBase { public: - AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, const AbstractBasePtr &ref_origin) - : ref_key_(ref_key), ref_(ref_value), ref_origin_(ref_origin) { - set_type(std::make_shared()); - } + AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, bool need_cast = false, + TypePtr cast_target = nullptr); ~AbstractRef() override = default; MS_DECLARE_PARENT(AbstractRef, AbstractBase) TypePtr BuildType() const override; + BaseShapePtr BuildShape() const override; bool operator==(const AbstractRef &other) const; bool operator==(const AbstractBase &other) const override; AbstractBasePtr Clone() const override { - return std::make_shared(ref_key_->Clone(), ref_->Clone(), ref_origin_->Clone()); + return std::make_shared(ref_key_->Clone(), ref_->Clone(), need_cast_, target_type_); } std::string ToString() const override; - AbstractBasePtr ref() { return ref_; } - AbstractBasePtr ref_origin() { return ref_origin_; } - AbstractBasePtr ref_key() { return ref_key_; } + inline AbstractBasePtr ref() const { return ref_; } + inline AbstractBasePtr ref_key() const { return ref_key_; } + inline RefKeyPtr ref_key_value() const { return ref_key_value_; } + inline TypePtr target_type() const { return target_type_; } + inline bool need_cast() const { return need_cast_; } AbstractBasePtr Broaden() const override { - return std::make_shared(ref_key_->Broaden(), ref_->Broaden(), ref_origin_->Broaden()); + return std::make_shared(ref_key_->Broaden(), ref_->Broaden(), need_cast_, target_type_); } AbstractBasePtr Join(const AbstractBasePtr &other) override; std::size_t hash() const override { - return ref_key_->hash() ^ ref_->hash() ^ ref_origin_->hash() ^ (std::hash{}(this->tid()) << 1); + return ref_->hash() ^ (std::hash{}(this->tid()) << 1); // ref_key_->hash() ^ } private: AbstractBasePtr ref_key_; AbstractBasePtr ref_; - AbstractBasePtr ref_origin_; + // For mix presicion, only float type need to cast to float16 of float32 + bool need_cast_; + TypePtr target_type_; + // cache for ref_key after build value, when value is null, return nullptr. + RefKeyPtr ref_key_value_; }; using AbstractRefPtr = std::shared_ptr; diff --git a/mindspore/core/abstract/analysis_context.cc b/mindspore/core/abstract/analysis_context.cc index 228ddf0f54a23adb00a458f079385e287ebc3ced..2270f3c1b03d5da67d811a4941c1b09b60c4a253 100644 --- a/mindspore/core/abstract/analysis_context.cc +++ b/mindspore/core/abstract/analysis_context.cc @@ -171,9 +171,7 @@ AnalysisContextPtr AnalysisContext::SpecializeKey() const { } if (arg->isa()) { MS_LOG(DEBUG) << "refkey broaden"; - auto arg_spec = dyn_cast(arg); - auto ret_spec = arg_spec->Broaden(); - return ret_spec; + return arg->Broaden(); } return arg; }); diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index a04b983a2ddcb59c19368a104dc8fd4fd94a9578..76bdb4231ede879e233f31793ed409bd6f926de7 100755 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -121,7 +121,6 @@ inline const PrimitivePtr kPrimEnvAdd = std::make_shared("env_add"); inline const PrimitivePtr kPrimMakeRefKey = std::make_shared("MakeRefKey"); inline const PrimitivePtr kPrimGetRefKey = std::make_shared("get_ref_key"); inline const PrimitivePtr kPrimGetRefValue = std::make_shared("get_ref_value"); -inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared("get_ref_origin"); inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared("InsertGradientOf"); inline const PrimitivePtr kPrimHookBackward = std::make_shared("HookBackward"); inline const PrimitivePtr kPrimPrintShapeType = std::make_shared("PrintShapeType");