提交 3dbac4f4 编写于 作者: M Megvii Engine Team

feat(mge): add atlas_subgraph module

GitOrigin-RevId: 11530383c0a31f4648ed89d3070b2dab178ea5b2
上级 00ef6772
......@@ -34,6 +34,17 @@ def cambricon_subgraph(
)
@wrap_io_tensor
def atlas_subgraph(inputs: List[Tensor], data: bytes) -> List[Tensor]:
"""Load a serialized Atlas subgraph (i.e. om model) and
execute the operations defined in the subgraph.
:param inputs: List of input tensors of the subgraph.
:param data: The serialized subgraph.
"""
return mgb.opr.atlas_runtime(tuple(map(lambda x: x._symvar, inputs)), data)
@wrap_io_tensor
def extern_opr_subgraph(
inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes,
......
......@@ -8,7 +8,11 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
from ..functional.external import cambricon_subgraph, extern_opr_subgraph
from ..functional.external import (
atlas_subgraph,
cambricon_subgraph,
extern_opr_subgraph,
)
from .module import Module
......@@ -41,6 +45,29 @@ class CambriconSubgraph(Module):
return outputs
class AtlasSubgraph(Module):
r"""Load a serialized Atlas subgraph.
See :func:`~.atlas_subgraph` for more details.
"""
def __init__(self, data):
super(AtlasSubgraph, self).__init__()
self._data = data
@property
def data(self):
return self._data.tobytes()
@data.setter
def data(self, val):
self._data = np.frombuffer(val, dtype=np.uint8)
def forward(self, inputs):
outputs = atlas_subgraph(inputs, self._data)
return outputs
class ExternOprSubgraph(Module):
r"""Load a serialized extern opr subgraph.
"""
......
......@@ -13,10 +13,10 @@ import numpy as np
import megengine as mge
from megengine import tensor
from megengine.module import Module
from megengine.module.external import CambriconSubgraph
from megengine.module.external import AtlasSubgraph, CambriconSubgraph
class MyModule(Module):
class CambriconModule(Module):
def __init__(self, data):
super().__init__()
self.cambricon = CambriconSubgraph(data, "subnet0", True)
......@@ -31,7 +31,7 @@ def test_cambricon_module():
model = os.path.join(os.path.dirname(__file__), model)
with open(model, "rb") as f:
data = f.read()
m = MyModule(data)
m = CambriconModule(data)
inputs = []
inputs.append(tensor(dtype=np.float16, device="cambricon0"))
inputs[0].set_value(np.random.normal(size=(1, 64, 32, 32)).astype(np.float16))
......@@ -41,3 +41,30 @@ def test_cambricon_module():
return pred
pred = inference(inputs)
class AtlasModule(Module):
def __init__(self, data):
super().__init__()
self.atlas = AtlasSubgraph(data)
def forward(self, inputs):
out = self.atlas(inputs)
return out
def test_atlas_module():
model = "AtlasRuntimeOprTest.basic.om"
model = os.path.join(os.path.dirname(__file__), model)
with open(model, "rb") as f:
data = f.read()
m = AtlasModule(data)
inputs = []
inputs.append(tensor(dtype=np.float32, device="atlas0"))
inputs[0].set_value(np.random.normal(size=(4, 3, 16, 16)).astype(np.float32))
def inference(inps):
pred = m(inps)
return pred
pred = inference(inputs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册