提交 f5801d7e 编写于 作者: P pranavm

Adds missing ONNX parser Python bindings

Signed-off-by: Npranavm <pranavm@nvidia.com>
上级 13f7441a
......@@ -14,63 +14,68 @@
* limitations under the License.
*/
// Docstrings for the pyCaffe parser bindings.
#pragma once
namespace tensorrt
{
namespace OnnxParserDoc
{
constexpr const char* descr = R"trtdoc(
This class is used for parsing Onnx models into a TensorRT network definition
namespace OnnxParserDoc
{
constexpr const char* descr = R"trtdoc(
This class is used for parsing ONNX models into a TensorRT network definition
:ivar num_errors: :class:`int` The number of errors that occurred during prior calls to :func:`parse`
)trtdoc";
constexpr const char* init = R"trtdoc(
constexpr const char* init = R"trtdoc(
:arg network: The network definition to which the parser will write.
:arg logger: The logger to use.
)trtdoc";
constexpr const char* parse = R"trtdoc(
Parse a serialized Onnx model into the TensorRT network. Note that a result of true does not guarantee that the operator will be supported in all cases (i.e., this function may return false-positives).
constexpr const char* parse = R"trtdoc(
Parse a serialized ONNX model into the TensorRT network.
:arg model: The serialized Onnx model object.
:arg model: The serialized ONNX model.
:arg path: The path to the model file. Only required if the model has externally stored weights.
:returns: true if the model was parsed successfully
)trtdoc";
constexpr const char* parseFromFile = R"trtdoc(
Parse an ONNX model from file into a TensorRT network. Note that a result of true does not guarantee that the operator will be supported in all cases (i.e., this function may return false-positives).
constexpr const char* parseFromFile = R"trtdoc(
Parse an ONNX model from file into a TensorRT network.
:arg model: The path to an ONNX model.
:returns: true if the model was parsed successfully
)trtdoc";
constexpr const char* supports_operator = R"trtdoc(
constexpr const char* supports_model = R"trtdoc(
Check whether TensorRT supports a particular ONNX model.
:arg model: The serialized ONNX model.
:arg path: The path to the model file. Only required if the model has externally stored weights.
:returns: Tuple[bool, List[Tuple[NodeIndices, bool]]]
The first element of the tuple indicates whether the model is supported.
The second indicates subgraphs (by node index) in the model and whether they are supported.
)trtdoc";
constexpr const char* supports_operator = R"trtdoc(
Returns whether the specified operator may be supported by the parser.
Note that a result of true does not guarantee that the operator will be supported in all cases (i.e., this function may return false-positives).
:arg op_name: The name of the Onnx operator to check for support
:arg op_name: The name of the ONNX operator to check for support
)trtdoc";
constexpr const char* get_error = R"trtdoc(
constexpr const char* get_error = R"trtdoc(
Get an error that occurred during prior calls to :func:`parse`
:arg index: Index of the error
)trtdoc";
constexpr const char* clear_errors = R"trtdoc(
constexpr const char* clear_errors = R"trtdoc(
Clear errors from prior calls to :func:`parse`
)trtdoc";
constexpr const char* get_refit_map = R"trtdoc(
Get description of all weights that could be refit.
:returns: The names of ONNX weights that can be refitted, along with their corresponding TensorRT layer and weight role.
)trtdoc";
} /* OnnxParserDoc */
namespace ErrorCodeDoc
......
......@@ -15,106 +15,104 @@
*/
// Implementation of PyBind11 Binding Code for OnnxParser
#include "NvOnnxParser.h"
#include "parsers/pyOnnxDoc.h"
#include "ForwardDeclarations.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "parsers/pyOnnxDoc.h"
#include <pybind11/stl_bind.h>
using namespace nvonnxparser;
namespace tensorrt
{
// Long lambda functions should go here rather than being inlined into the bindings (1 liners are OK).
namespace lambdas
{
static const auto error_code_str = [] (ErrorCode self) {
switch (self) {
case ErrorCode::kSUCCESS:
return "SUCCESS";
case ErrorCode::kINTERNAL_ERROR:
return "INTERNAL_ERROR";
case ErrorCode::kMEM_ALLOC_FAILED:
return "MEM_ALLOC_FAILED";
case ErrorCode::kMODEL_DESERIALIZE_FAILED:
return "MODEL_DESERIALIZE_FAILED";
case ErrorCode::kINVALID_VALUE:
return "INVALID_VALUE";
case ErrorCode::kINVALID_GRAPH:
return "INVALID_GRAPH";
case ErrorCode::kINVALID_NODE:
return "INVALID_NODE";
case ErrorCode::kUNSUPPORTED_GRAPH:
return "UNSUPPORTED_GRAPH";
case ErrorCode::kUNSUPPORTED_NODE:
return "UNSUPPORTED_NODE";
}
return "UNKNOWN";
};
namespace lambdas
{
static const auto error_code_str = [] (ErrorCode self) {
switch (self) {
case ErrorCode::kSUCCESS:
return "SUCCESS";
case ErrorCode::kINTERNAL_ERROR:
return "INTERNAL_ERROR";
case ErrorCode::kMEM_ALLOC_FAILED:
return "MEM_ALLOC_FAILED";
case ErrorCode::kMODEL_DESERIALIZE_FAILED:
return "MODEL_DESERIALIZE_FAILED";
case ErrorCode::kINVALID_VALUE:
return "INVALID_VALUE";
case ErrorCode::kINVALID_GRAPH:
return "INVALID_GRAPH";
case ErrorCode::kINVALID_NODE:
return "INVALID_NODE";
case ErrorCode::kUNSUPPORTED_GRAPH:
return "UNSUPPORTED_GRAPH";
case ErrorCode::kUNSUPPORTED_NODE:
return "UNSUPPORTED_NODE";
}
return "UNKNOWN";
};
static const auto parser_error_str = [](IParserError& self) {
return "In node " + std::to_string(self.node()) + " (" + self.func() + "): " + error_code_str(self.code()) + ": " + self.desc();
};
static const auto parser_error_str = [] (IParserError& self) {
return "In node " + std::to_string(self.node()) + " (" + self.func() + "): " + error_code_str(self.code()) + ": " + self.desc();
};
// For ONNX Parser
static const auto parse = [](IParser& self, const py::buffer& model, const char* path = nullptr) {
py::buffer_info info = model.request();
return self.parse(info.ptr, info.size * info.itemsize, path);
};
// For ONNX Parser
static const auto parse = [] (IParser& self, const std::string& model) {
return self.parse(model.data(), model.size());
};
static const auto parseFromFile = [] (IParser& self, const std::string& model) {
return self.parseFromFile(model.c_str(), 0);
};
static const auto getRefitMap = [] (IParser& self)
{
int size = self.getRefitMap(nullptr, nullptr, nullptr);
std::vector<const char*> weightNames(size);
std::vector<const char*> layerNames(size);
std::vector<nvinfer1::WeightsRole> roles(size);
self.getRefitMap(weightNames.data(), layerNames.data(), roles.data());
return std::tuple<std::vector<const char*>, std::vector<const char*>, std::vector<nvinfer1::WeightsRole>>{weightNames, layerNames, roles};
};
static const auto parseFromFile
= [](IParser& self, const std::string& model) { return self.parseFromFile(model.c_str(), 0); };
} /* lambdas */
static const auto supportsModel = [](IParser& self, const py::buffer& model, const char* path = nullptr) {
py::buffer_info info = model.request();
SubGraphCollection_t subgraphs;
const bool supported = self.supportsModel(info.ptr, info.size * info.itemsize, subgraphs, path);
return std::make_pair(supported, subgraphs);
};
} // namespace lambdas
void bindOnnx(py::module& m)
{
py::bind_vector<std::vector<size_t>>(m, "NodeIndices");
py::bind_vector<SubGraphCollection_t>(m, "SubGraphCollection");
py::class_<IParser, std::unique_ptr<IParser, py::nodelete>>(m, "OnnxParser", OnnxParserDoc::descr)
.def(py::init(&nvonnxparser::createParser), "network"_a, "logger"_a, OnnxParserDoc::init)
.def("parse", lambdas::parse, "model"_a, OnnxParserDoc::parse)
.def("parse_from_file", lambdas::parseFromFile, "model"_a, OnnxParserDoc::parseFromFile)
.def("parse", lambdas::parse, "model"_a, "path"_a = nullptr, OnnxParserDoc::parse,
py::call_guard<py::gil_scoped_release>{})
.def("parse_from_file", lambdas::parseFromFile, "model"_a, OnnxParserDoc::parseFromFile,
py::call_guard<py::gil_scoped_release>{})
.def("supports_operator", &IParser::supportsOperator, "op_name"_a, OnnxParserDoc::supports_operator)
.def("supports_model", lambdas::supportsModel, "model"_a, "path"_a = nullptr,
OnnxParserDoc::supports_model)
.def_property_readonly("num_errors", &IParser::getNbErrors)
.def("get_error", &IParser::getError, "index"_a, OnnxParserDoc::get_error)
.def("clear_errors", &IParser::clearErrors, OnnxParserDoc::clear_errors)
.def("get_refit_map", lambdas::getRefitMap, OnnxParserDoc::get_refit_map)
.def("__del__", &IParser::destroy)
;
.def("__del__", &IParser::destroy);
py::enum_<ErrorCode>(m, "ErrorCode", ErrorCodeDoc::descr)
.value("SUCCESS", ErrorCode::kSUCCESS)
.value("INTERNAL_ERROR", ErrorCode::kINTERNAL_ERROR)
.value("MEM_ALLOC_FAILED", ErrorCode::kMEM_ALLOC_FAILED)
.value("MODEL_DESERIALIZE_FAILED", ErrorCode::kMODEL_DESERIALIZE_FAILED)
.value("INVALID_VALUE", ErrorCode::kINVALID_VALUE)
.value("INVALID_GRAPH", ErrorCode::kINVALID_GRAPH )
.value("INVALID_NODE", ErrorCode::kINVALID_NODE)
.value("UNSUPPORTED_GRAPH", ErrorCode::kUNSUPPORTED_GRAPH)
.value("UNSUPPORTED_NODE", ErrorCode::kUNSUPPORTED_NODE)
.def("__str__", lambdas::error_code_str)
.def("__repr__", lambdas::error_code_str)
;
py::enum_<ErrorCode>(m, "ErrorCode", ErrorCodeDoc::descr)
.value("SUCCESS", ErrorCode::kSUCCESS)
.value("INTERNAL_ERROR", ErrorCode::kINTERNAL_ERROR)
.value("MEM_ALLOC_FAILED", ErrorCode::kMEM_ALLOC_FAILED)
.value("MODEL_DESERIALIZE_FAILED", ErrorCode::kMODEL_DESERIALIZE_FAILED)
.value("INVALID_VALUE", ErrorCode::kINVALID_VALUE)
.value("INVALID_GRAPH", ErrorCode::kINVALID_GRAPH)
.value("INVALID_NODE", ErrorCode::kINVALID_NODE)
.value("UNSUPPORTED_GRAPH", ErrorCode::kUNSUPPORTED_GRAPH)
.value("UNSUPPORTED_NODE", ErrorCode::kUNSUPPORTED_NODE)
.def("__str__", lambdas::error_code_str)
.def("__repr__", lambdas::error_code_str);
py::class_<IParserError, std::unique_ptr<IParserError, py::nodelete> >(m, "ParserError")
.def("code", &IParserError::code, ParserErrorDoc::code)
.def("desc", &IParserError::desc, ParserErrorDoc::desc)
.def("file", &IParserError::file, ParserErrorDoc::file)
.def("line", &IParserError::line, ParserErrorDoc::line)
.def("func", &IParserError::func, ParserErrorDoc::func)
.def("node", &IParserError::node, ParserErrorDoc::node)
.def("__str__", lambdas::parser_error_str)
.def("__repr__", lambdas::parser_error_str)
;
py::class_<IParserError, std::unique_ptr<IParserError, py::nodelete>>(m, "ParserError")
.def("code", &IParserError::code, ParserErrorDoc::code)
.def("desc", &IParserError::desc, ParserErrorDoc::desc)
.def("file", &IParserError::file, ParserErrorDoc::file)
.def("line", &IParserError::line, ParserErrorDoc::line)
.def("func", &IParserError::func, ParserErrorDoc::func)
.def("node", &IParserError::node, ParserErrorDoc::node)
.def("__str__", lambdas::parser_error_str)
.def("__repr__", lambdas::parser_error_str);
// Free functions.
m.def("get_nv_onnx_parser_version", &getNvOnnxParserVersion, get_nv_onnx_parser_version);
}
} /* tensorrt */
} // namespace tensorrt
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册