提交 3dbac4f4 编写于 作者: M Megvii Engine Team

feat(mge): add atlas_subgraph module

GitOrigin-RevId: 11530383c0a31f4648ed89d3070b2dab178ea5b2
上级 00ef6772
...@@ -34,6 +34,17 @@ def cambricon_subgraph( ...@@ -34,6 +34,17 @@ def cambricon_subgraph(
) )
@wrap_io_tensor
def atlas_subgraph(inputs: List[Tensor], data: bytes) -> List[Tensor]:
"""Load a serialized Atlas subgraph (i.e. om model) and
execute the operations defined in the subgraph.
:param inputs: List of input tensors of the subgraph.
:param data: The serialized subgraph.
"""
return mgb.opr.atlas_runtime(tuple(map(lambda x: x._symvar, inputs)), data)
@wrap_io_tensor @wrap_io_tensor
def extern_opr_subgraph( def extern_opr_subgraph(
inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes, inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes,
......
...@@ -8,7 +8,11 @@ ...@@ -8,7 +8,11 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np import numpy as np
from ..functional.external import cambricon_subgraph, extern_opr_subgraph from ..functional.external import (
atlas_subgraph,
cambricon_subgraph,
extern_opr_subgraph,
)
from .module import Module from .module import Module
...@@ -41,6 +45,29 @@ class CambriconSubgraph(Module): ...@@ -41,6 +45,29 @@ class CambriconSubgraph(Module):
return outputs return outputs
class AtlasSubgraph(Module):
r"""Load a serialized Atlas subgraph.
See :func:`~.atlas_subgraph` for more details.
"""
def __init__(self, data):
super(AtlasSubgraph, self).__init__()
self._data = data
@property
def data(self):
return self._data.tobytes()
@data.setter
def data(self, val):
self._data = np.frombuffer(val, dtype=np.uint8)
def forward(self, inputs):
outputs = atlas_subgraph(inputs, self._data)
return outputs
class ExternOprSubgraph(Module): class ExternOprSubgraph(Module):
r"""Load a serialized extern opr subgraph. r"""Load a serialized extern opr subgraph.
""" """
......
...@@ -13,10 +13,10 @@ import numpy as np ...@@ -13,10 +13,10 @@ import numpy as np
import megengine as mge import megengine as mge
from megengine import tensor from megengine import tensor
from megengine.module import Module from megengine.module import Module
from megengine.module.external import CambriconSubgraph from megengine.module.external import AtlasSubgraph, CambriconSubgraph
class MyModule(Module): class CambriconModule(Module):
def __init__(self, data): def __init__(self, data):
super().__init__() super().__init__()
self.cambricon = CambriconSubgraph(data, "subnet0", True) self.cambricon = CambriconSubgraph(data, "subnet0", True)
...@@ -31,7 +31,7 @@ def test_cambricon_module(): ...@@ -31,7 +31,7 @@ def test_cambricon_module():
model = os.path.join(os.path.dirname(__file__), model) model = os.path.join(os.path.dirname(__file__), model)
with open(model, "rb") as f: with open(model, "rb") as f:
data = f.read() data = f.read()
m = MyModule(data) m = CambriconModule(data)
inputs = [] inputs = []
inputs.append(tensor(dtype=np.float16, device="cambricon0")) inputs.append(tensor(dtype=np.float16, device="cambricon0"))
inputs[0].set_value(np.random.normal(size=(1, 64, 32, 32)).astype(np.float16)) inputs[0].set_value(np.random.normal(size=(1, 64, 32, 32)).astype(np.float16))
...@@ -41,3 +41,30 @@ def test_cambricon_module(): ...@@ -41,3 +41,30 @@ def test_cambricon_module():
return pred return pred
pred = inference(inputs) pred = inference(inputs)
class AtlasModule(Module):
def __init__(self, data):
super().__init__()
self.atlas = AtlasSubgraph(data)
def forward(self, inputs):
out = self.atlas(inputs)
return out
def test_atlas_module():
model = "AtlasRuntimeOprTest.basic.om"
model = os.path.join(os.path.dirname(__file__), model)
with open(model, "rb") as f:
data = f.read()
m = AtlasModule(data)
inputs = []
inputs.append(tensor(dtype=np.float32, device="atlas0"))
inputs[0].set_value(np.random.normal(size=(4, 3, 16, 16)).astype(np.float32))
def inference(inps):
pred = m(inps)
return pred
pred = inference(inputs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册