diff --git a/python_module/megengine/functional/external.py b/python_module/megengine/functional/external.py index badede8fb0a06324e752ad06d72a496dbff703fe..6c93d217ff247f3f5607d7a2a55cc63333d2381d 100644 --- a/python_module/megengine/functional/external.py +++ b/python_module/megengine/functional/external.py @@ -34,6 +34,17 @@ def cambricon_subgraph( ) +@wrap_io_tensor +def atlas_subgraph(inputs: List[Tensor], data: bytes) -> List[Tensor]: + """Load a serialized Atlas subgraph (i.e. om model) and + execute the operations defined in the subgraph. + + :param inputs: List of input tensors of the subgraph. + :param data: The serialized subgraph. + """ + return mgb.opr.atlas_runtime(tuple(map(lambda x: x._symvar, inputs)), data) + + @wrap_io_tensor def extern_opr_subgraph( inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes, diff --git a/python_module/megengine/module/external.py b/python_module/megengine/module/external.py index a5da2a140c7e281d612c633eee395d234c223d9e..962754e8d482bc9a1ac8340fc2ff7cfd2c48bf3b 100644 --- a/python_module/megengine/module/external.py +++ b/python_module/megengine/module/external.py @@ -8,7 +8,11 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import numpy as np -from ..functional.external import cambricon_subgraph, extern_opr_subgraph +from ..functional.external import ( + atlas_subgraph, + cambricon_subgraph, + extern_opr_subgraph, +) from .module import Module @@ -41,6 +45,29 @@ class CambriconSubgraph(Module): return outputs +class AtlasSubgraph(Module): + r"""Load a serialized Atlas subgraph. + + See :func:`~.atlas_subgraph` for more details. + """ + + def __init__(self, data): + super(AtlasSubgraph, self).__init__() + self._data = data + + @property + def data(self): + return self._data.tobytes() + + @data.setter + def data(self, val): + self._data = np.frombuffer(val, dtype=np.uint8) + + def forward(self, inputs): + outputs = atlas_subgraph(inputs, self._data) + return outputs + + class ExternOprSubgraph(Module): r"""Load a serialized extern opr subgraph. """ diff --git a/python_module/test/unit/module/AtlasRuntimeOprTest.basic.om b/python_module/test/unit/module/AtlasRuntimeOprTest.basic.om new file mode 100644 index 0000000000000000000000000000000000000000..942fe2edff071b681433616ea0de43b30aa9cb29 Binary files /dev/null and b/python_module/test/unit/module/AtlasRuntimeOprTest.basic.om differ diff --git a/python_module/test/unit/module/test_external.py b/python_module/test/unit/module/test_external.py index 3a4e6d7f644b726f39b3aa0fdff72271f933f8b8..44f5cf21fbaa1cd544b1e06ed6c43ea24c7adf85 100644 --- a/python_module/test/unit/module/test_external.py +++ b/python_module/test/unit/module/test_external.py @@ -13,10 +13,10 @@ import numpy as np import megengine as mge from megengine import tensor from megengine.module import Module -from megengine.module.external import CambriconSubgraph +from megengine.module.external import AtlasSubgraph, CambriconSubgraph -class MyModule(Module): +class CambriconModule(Module): def __init__(self, data): super().__init__() self.cambricon = CambriconSubgraph(data, "subnet0", True) @@ -31,7 +31,7 @@ def test_cambricon_module(): model = os.path.join(os.path.dirname(__file__), model) with open(model, "rb") as f: data = f.read() - m = MyModule(data) + m = CambriconModule(data) inputs = [] inputs.append(tensor(dtype=np.float16, device="cambricon0")) inputs[0].set_value(np.random.normal(size=(1, 64, 32, 32)).astype(np.float16)) @@ -41,3 +41,30 @@ def test_cambricon_module(): return pred pred = inference(inputs) + + +class AtlasModule(Module): + def __init__(self, data): + super().__init__() + self.atlas = AtlasSubgraph(data) + + def forward(self, inputs): + out = self.atlas(inputs) + return out + + +def test_atlas_module(): + model = "AtlasRuntimeOprTest.basic.om" + model = os.path.join(os.path.dirname(__file__), model) + with open(model, "rb") as f: + data = f.read() + m = AtlasModule(data) + inputs = [] + inputs.append(tensor(dtype=np.float32, device="atlas0")) + inputs[0].set_value(np.random.normal(size=(4, 3, 16, 16)).astype(np.float32)) + + def inference(inps): + pred = m(inps) + return pred + + pred = inference(inputs)