提交 abe3c165 编写于 作者: M Megvii Engine Team

feat(mge): tensorrt runtime opr

GitOrigin-RevId: 2fdd00adcbbe272e79d53ba0ac879d2eedf471d0
上级 1b1ad56a
......@@ -5,4 +5,4 @@ dnn/src/cuda/conv_bias/int8_imma/kimpl/* binary
dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary
dnn/src/cuda/sass/prebuilt/map_defs.cpp binary
tools/mlir/mlir-tblgen filter=lfs diff=lfs merge=lfs -text
sdk/c-opr-loaders/mc40/example/sinopec_nv12_extra.neu filter=lfs diff=lfs merge=lfs -text
*.caffemodel filter=lfs diff=lfs merge=lfs -text
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=redefined-builtin
from typing import Sequence
from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin
def tensorrt_runtime_opr(inputs, *, data: bytes = None):
# empty model will give None result
if data is None:
return None
op = builtin.TensorRTRuntime(data, len(data))
# return sequence of outputs
return apply(op, *inputs)
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=redefined-builtin
import numpy as np
from ..functional.external import tensorrt_runtime_opr
from .module import Module
class TensorrtRuntimeSubgraph(Module):
r"""Load a serialized TensorrtRuntime subgraph.
See :func:`~.tensorrt_runtime_opr` for more details.
"""
def __init__(
self, data,
):
super(TensorrtRuntimeSubgraph, self).__init__()
self._data = data
@property
def data(self):
return self._data
@data.setter
def data(self, val):
self._data = np.frombuffer(val, dtype=np.uint8)
def forward(self, *inputs):
return tensorrt_runtime_opr(inputs, data=self._data)
......@@ -6,14 +6,20 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import io
import os
import platform
import numpy as np
import pytest
import megengine as mge
import megengine.utils.comp_graph_tools as cgtools
from megengine import Tensor
from megengine.distributed.helper import get_device_count_by_fork
from megengine.jit import trace
from megengine.module import Module
from megengine.module.external import TensorrtRuntimeSubgraph
class MyModule(Module):
......@@ -44,3 +50,5 @@ def test_cambricon_module():
return pred
pred = inference([inp])
/**
* \file imperative/src/impl/ops/tensorrt_runtime.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "../op_trait.h"
#include "megbrain/imperative/ops/autogen.h"
#if MGB_ENABLE_TENSOR_RT
#include "megbrain/tensorrt/tensorrt_runtime_opr.h"
namespace mgb::imperative {
namespace { namespace tensorrt_runtime {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const TensorRTRuntime&>(def);
SymbolVarArray sinputs(inputs.begin(), inputs.end());
return opr::TensorRTRuntimeOpr::make(op.buf.c_str(), op.buf_size, sinputs);
}
OP_TRAIT_REG(TensorRTRuntime, TensorRTRuntime)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // tensorrt_runtime
} // namespace mgb::imperative
#endif // MGB_ENABLE_TENSOR_RT
......@@ -241,4 +241,11 @@ def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypePara
def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>;
def TensorRTRuntime: MgbHashableOp<"TensorRTRuntime"> {
let extraArguments = (ins
MgbStringAttr:$buf,
MgbSizeTAddr:$buf_size
);
}
#endif // MGB_OPS
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册