提交 dedecf69 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(imperative/utils): fix logical error of replace var

GitOrigin-RevId: 614302552cbeaa66cbc977ee81e5492b6023c1c4
上级 ea70d99b
...@@ -519,8 +519,7 @@ def _unwrap(x): ...@@ -519,8 +519,7 @@ def _unwrap(x):
return type(x)(map(_unwrap, x)) return type(x)(map(_unwrap, x))
if isinstance(x, VarNode): if isinstance(x, VarNode):
return x._node return x._node
else: return x
return x
def apply_normal_varnode(op: OpDef, *args: VarNode): def apply_normal_varnode(op: OpDef, *args: VarNode):
......
...@@ -12,14 +12,16 @@ import itertools ...@@ -12,14 +12,16 @@ import itertools
import pickle import pickle
import re import re
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, List, Sequence from typing import Any, Dict, List, Optional, Sequence
from ..core import _imperative_rt
from ..core._imperative_rt import ComputingGraph, SerializationMetadata from ..core._imperative_rt import ComputingGraph, SerializationMetadata
from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
from ..logger import get_logger from ..logger import get_logger
from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq
from .network_node import ( from .network_node import (
ConstOpBase,
Host2DeviceCopy, Host2DeviceCopy,
ImmutableTensor, ImmutableTensor,
NetworkNode, NetworkNode,
...@@ -37,8 +39,10 @@ class Network: ...@@ -37,8 +39,10 @@ class Network:
self._orig_inputs = [] self._orig_inputs = []
self.output_vars = [] # output var of graph self.output_vars = [] # output var of graph
self._orig_outputs = [] self._orig_outputs = []
self.all_oprs_map = OrderedDict() self.all_oprs_map = OrderedDict() # _imperative_rt.graph.VarNode.id: VarNode
self.all_vars_map = OrderedDict() self.all_vars_map = (
OrderedDict()
) # _imperative_rt.graph.OperatorNode.id: OpNode
self.graph = ComputingGraph() self.graph = ComputingGraph()
self._metadata = None self._metadata = None
...@@ -101,7 +105,7 @@ class Network: ...@@ -101,7 +105,7 @@ class Network:
self.all_oprs_map = {} self.all_oprs_map = {}
self.all_vars_map = {} self.all_vars_map = {}
for opr in self.all_oprs: for opr in self.all_oprs:
if isinstance(opr, (ImmutableTensor, Host2DeviceCopy)): if isinstance(opr, (ConstOpBase, Host2DeviceCopy)):
opr.compile(self.graph) opr.compile(self.graph)
else: else:
opr.compile() opr.compile()
...@@ -295,6 +299,9 @@ class Network: ...@@ -295,6 +299,9 @@ class Network:
def add_dep_oprs(self, *vars): def add_dep_oprs(self, *vars):
if len(vars) == 0: if len(vars) == 0:
vars = self.output_vars vars = self.output_vars
assert all(isinstance(var, VarNode) for var in vars), "Only support add VarNode"
q = list(vars) q = list(vars)
while len(q) > 0: while len(q) > 0:
cur = q.pop(0) cur = q.pop(0)
...@@ -368,11 +375,14 @@ class Network: ...@@ -368,11 +375,14 @@ class Network:
for var in self.all_vars: for var in self.all_vars:
if var in repl_dict: if var in repl_dict:
repl_var = repl_dict[var] repl_var = repl_dict[var]
owner = repl_var.owner if repl_var is var:
idx = owner.outputs.index(repl_var) continue
owner.outputs[idx] = var for opnode in var.users:
var.__dict__.update(repl_var.__dict__) assert var in opnode.inputs
var.var = repl_var.var opnode.inputs = [repl_var if var is i else i for i in opnode.inputs]
if opnode not in repl_var.users:
repl_var.users.append(opnode)
var.users.clear()
self._compile() self._compile()
def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]):
...@@ -473,14 +483,20 @@ class Network: ...@@ -473,14 +483,20 @@ class Network:
def all_oprs_dict(self): def all_oprs_dict(self):
return self.opr_filter.as_dict() return self.opr_filter.as_dict()
# used for loading and building graph def _add_opr(self, opr) -> Optional[OpNode]:
def _add_opr(self, opr): """
Used for loading and building graph.
"""
assert isinstance(opr, _imperative_rt.graph.OperatorNode)
# TODO: use megbrain C++ RTTI to replace type string # TODO: use megbrain C++ RTTI to replace type string
if opr.id not in self.all_oprs_map: if opr.id not in self.all_oprs_map:
opnode = str_to_mge_class(get_opr_type(opr)).load(opr) opnode = str_to_mge_class(get_opr_type(opr)).load(opr)
self.all_oprs_map[opr.id] = opnode self.all_oprs_map[opr.id] = opnode
for var in opr.inputs: for var in opr.inputs:
opnode.add_inp_var(self._get_var(var)) varnode = self._get_var(var)
opnode.add_inp_var(varnode)
varnode.users.append(opnode)
for var in opr.outputs: for var in opr.outputs:
opnode.add_out_var(self._get_var(var)) opnode.add_out_var(self._get_var(var))
return opnode return opnode
...@@ -503,7 +519,10 @@ class Network: ...@@ -503,7 +519,10 @@ class Network:
return None return None
def _get_var(self, x): def _get_var(self, x):
# auto convert to VarNode of Network """
Convert :class:`~._imperative_rt.graph.VarNode` to :class:`~.VarNode`.
"""
assert isinstance(x, _imperative_rt.graph.VarNode)
if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x: if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x:
self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner))
return self.all_vars_map[x.id] return self.all_vars_map[x.id]
......
...@@ -37,6 +37,7 @@ class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)): ...@@ -37,6 +37,7 @@ class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)):
class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
def __init__(self, var=None, *, owner_opr=None, name=None): def __init__(self, var=None, *, owner_opr=None, name=None):
SymbolVar.__init__(self, var) SymbolVar.__init__(self, var)
self.users = [] # List[OpNode]
self.owner = owner_opr self.owner = owner_opr
self.name = name self.name = name
self.id = id(self) self.id = id(self)
...@@ -214,6 +215,7 @@ class Host2DeviceCopy(OpNode): ...@@ -214,6 +215,7 @@ class Host2DeviceCopy(OpNode):
def compile(self, graph): def compile(self, graph):
if ( if (
self._opr is None self._opr is None
or self._opr.graph != graph
or self._opr.outputs[0].comp_node != self.device or self._opr.outputs[0].comp_node != self.device
or self._opr.outputs[0].shape != self.shape or self._opr.outputs[0].shape != self.shape
or self._opr.outputs[0].dtype != self.dtype or self._opr.outputs[0].dtype != self.dtype
...@@ -226,10 +228,11 @@ class Host2DeviceCopy(OpNode): ...@@ -226,10 +228,11 @@ class Host2DeviceCopy(OpNode):
assert self.outputs[0].owner is self assert self.outputs[0].owner is self
class ImmutableTensor(OpNode): class ConstOpBase(OpNode):
type = "ImmutableTensor" type = "ConstOpBase"
def __init__(self, data=None, name=None, device=None, graph=None): def __init__(self, data=None, name=None, device=None, graph=None):
assert type(self) is not ConstOpBase, "ConstOpBase cannot be instantiated"
super().__init__() super().__init__()
self.name = name self.name = name
self.outputs = [] self.outputs = []
...@@ -254,7 +257,7 @@ class ImmutableTensor(OpNode): ...@@ -254,7 +257,7 @@ class ImmutableTensor(OpNode):
return self._opr.outputs[0].dtype if self._opr else None return self._opr.outputs[0].dtype if self._opr else None
def numpy(self): def numpy(self):
return self._opr.outputs[0].value if self._opr else None return self.outputs[0].numpy()
def set_value(self, data, device=None): def set_value(self, data, device=None):
assert self.graph is not None assert self.graph is not None
...@@ -266,7 +269,7 @@ class ImmutableTensor(OpNode): ...@@ -266,7 +269,7 @@ class ImmutableTensor(OpNode):
data = data.astype(np.float32) data = data.astype(np.float32)
elif data.dtype == np.int64: elif data.dtype == np.int64:
data = data.astype(np.int32) data = data.astype(np.int32)
varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name) varnode = type(self).rt_fun(self.graph, data, cn, data.dtype, self.name)
if len(self.outputs) == 0: if len(self.outputs) == 0:
self.outputs.append(VarNode(owner_opr=self, name=self.name)) self.outputs.append(VarNode(owner_opr=self, name=self.name))
self.outputs[0].var = varnode self.outputs[0].var = varnode
...@@ -291,6 +294,16 @@ class ImmutableTensor(OpNode): ...@@ -291,6 +294,16 @@ class ImmutableTensor(OpNode):
self.outputs[0].var.name = self.name self.outputs[0].var.name = self.name
class ImmutableTensor(ConstOpBase):
type = "ImmutableTensor"
rt_fun = rt.make_const
class SharedDeviceTensor(ConstOpBase):
type = "SharedDeviceTensor"
rt_fun = rt.make_shared
class ReadOnlyOpNode(OpNode): class ReadOnlyOpNode(OpNode):
@classmethod @classmethod
def load(cls, opr): def load(cls, opr):
......
...@@ -130,6 +130,52 @@ def test_replace_opr(): ...@@ -130,6 +130,52 @@ def test_replace_opr():
np.testing.assert_equal(out["o"], [0, 0]) np.testing.assert_equal(out["o"], [0, 0])
def test_splice_network():
x = F.ones((2,))
y = F.ones((2,))
@trace(symbolic=True, capture_as_const=True)
def fun1(a, b):
return (a + b) * 2
@trace(symbolic=True, capture_as_const=True)
def fun2(a):
return a * 2 - 1
model = io.BytesIO()
fun1(x, y)
fun2(x)
fun1.dump(
model,
arg_names=["net1_i0", "net1_i1"],
output_names=["net1_o0"],
optimize_for_inference=False,
)
model.seek(0)
net1 = Net.load(model)
model.seek(0)
fun2.dump(
model,
arg_names=["net2_i0"],
output_names=["net2_o0"],
optimize_for_inference=False,
)
model.seek(0)
net2 = Net.load(model)
net1.add_output(*net2.output_vars)
var = net1.var_filter.name("net1_i0").as_unique()
repl_var = net2.var_filter.name("net2_o0").as_unique()
net1.replace_vars({var: repl_var})
assert "net1_i0" not in [var.name for var in net1.all_vars]
assert "net2_i0" in [var.name for var in net1.all_vars]
model.seek(0)
net1.dump(model, keep_var_name=2, optimize_for_inference=False)
model.seek(0)
net = Net.load(model)
assert "net1_i0" not in [var.name for var in net.all_vars]
assert "net2_i0" in [var.name for var in net.all_vars]
def test_modify_params(): def test_modify_params():
a = Tensor([1, 2]) a = Tensor([1, 2])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册