diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc old mode 100755 new mode 100644 index 3b05dbf3ec1fdba85bf3e8da14f91e0cdb5afa3a..e057b26f0262a7708fe75fcbeda2ec4c043753c6 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -1155,6 +1155,9 @@ void DfGraphConvertor::SetOpControlInput(const AnfNodePtr node) { } } +const std::vector trans_var_list = {prim::kPrimAssign->name(), string(kNameAssignAdd), + string(kNameAssignSub)}; + void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) { OperatorPtr src = Convert(node); auto &inputs = node->inputs(); @@ -1167,6 +1170,26 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node if (IsValueNode(pred)) { continue; } + // transform "Const" op to "Variable" op when the next node is "Assign" op. + std::string c_name = GetCNodeFuncName(node); + auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name); + if (!training_ && pos != trans_var_list.end() && pred->isa()) { + std::string name = std::static_pointer_cast(pred)->name(); + auto op_itor = op_cache_.find(pred.get()); + if (op_itor == op_cache_.end()) { + MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << "."; + } + if (op_itor->second != nullptr && + (op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") && + vars_.find(name) != vars_.end()) { + auto variable = std::make_shared(name); + auto desc = vars_[name]->GetOutputDesc("y"); + (void)variable->update_output_desc_y(desc); + MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << "."; + op_itor->second = variable; // replace parameter with variable + vars_[name] = variable; + } + } // find in out_hadnle_cache_ first auto it = out_handle_cache_.find(pred.get()); if (it != out_handle_cache_.end()) {