提交 5f2bfb56 编写于 作者: G guohongzilong

trans const to variable in assign case

上级 d1b452cf
......@@ -1154,6 +1154,9 @@ void DfGraphConvertor::SetOpControlInput(const AnfNodePtr node) {
}
}
const std::vector<std::string> trans_var_list = {prim::kPrimAssign->name(), string(kNameAssignAdd),
string(kNameAssignSub)};
void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) {
OperatorPtr src = Convert(node);
auto &inputs = node->inputs();
......@@ -1166,6 +1169,26 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node
if (IsValueNode<None>(pred)) {
continue;
}
// transform "Const" op to "Variable" op when the next node is "Assign" op.
std::string c_name = GetCNodeFuncName(node);
auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name);
if (!training_ && pos != trans_var_list.end() && pred->isa<Parameter>()) {
std::string name = std::static_pointer_cast<Parameter>(pred)->name();
auto op_itor = op_cache_.find(pred.get());
if (op_itor == op_cache_.end()) {
MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << ".";
}
if (op_itor->second != nullptr &&
(op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") &&
vars_.find(name) != vars_.end()) {
auto variable = std::make_shared<Variable>(name);
auto desc = vars_[name]->GetOutputDesc("y");
(void)variable->update_output_desc_y(desc);
MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << ".";
op_itor->second = variable; // replace parameter with variable
vars_[name] = variable;
}
}
// find in out_hadnle_cache_ first
auto it = out_handle_cache_.find(pred.get());
if (it != out_handle_cache_.end()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册