external.py 2.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=redefined-builtin
import numpy as np

12 13 14 15 16
from ..functional.external import (
    atlas_runtime_opr,
    cambricon_runtime_opr,
    tensorrt_runtime_opr,
)
17 18 19 20 21 22 23 24 25
from .module import Module


class TensorrtRuntimeSubgraph(Module):
    r"""Load a serialized TensorrtRuntime subgraph.

    See :func:`~.tensorrt_runtime_opr` for more details.
    """

26 27
    def __init__(self, data, **kwargs):
        super(TensorrtRuntimeSubgraph, self).__init__(**kwargs)
28 29 30 31 32 33 34 35 36 37 38 39
        self._data = data

    @property
    def data(self):
        return self._data

    @data.setter
    def data(self, val):
        self._data = np.frombuffer(val, dtype=np.uint8)

    def forward(self, *inputs):
        return tensorrt_runtime_opr(inputs, data=self._data)
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88


class CambriconRuntimeSubgraph(Module):
    r"""Load a serialized CambriconRuntime subgraph.

    See :func:`~.cambricon_runtime_opr` for more details.
    """

    def __init__(self, data, symbol, tensor_dim_mutable, **kwargs):
        super(CambriconRuntimeSubgraph, self).__init__(**kwargs)
        self._data = data
        self.symbol = symbol
        self.tensor_dim_mutable = tensor_dim_mutable

    @property
    def data(self):
        return self._data

    @data.setter
    def data(self, val):
        self._data = np.frombuffer(val, dtype=np.uint8)

    def forward(self, *inputs):
        outputs = cambricon_runtime_opr(
            inputs, self._data, self.symbol, self.tensor_dim_mutable
        )
        return outputs


class AtlasRuntimeSubgraph(Module):
    r"""Load a serialized AtlasRuntime subgraph.

    See :func:`~.atlas_runtime_opr` for more details.
    """

    def __init__(self, data, **kwargs):
        super(AtlasRuntimeSubgraph, self).__init__(**kwargs)
        self._data = data

    @property
    def data(self):
        return self._data

    @data.setter
    def data(self, val):
        self._data = np.frombuffer(val, dtype=np.uint8)

    def forward(self, *inputs):
        return atlas_runtime_opr(inputs, data=self._data)