未验证 提交 55f0d840 编写于 作者: A Abhinav Arora 提交者: GitHub

Fix Cpplint Issues in fluid/inference/tensorrt/ (#10318)

* Fix CPPLint issues in fluid/inference/tensorrt/

* Fix compile errors
上级 0bc44c18
...@@ -65,7 +65,7 @@ class TensorRTEngine : public EngineBase { ...@@ -65,7 +65,7 @@ class TensorRTEngine : public EngineBase {
// Initialize the inference network, so that TensorRT layers can add to this // Initialize the inference network, so that TensorRT layers can add to this
// network. // network.
void InitNetwork() { void InitNetwork() {
infer_builder_.reset(createInferBuilder(logger_)); infer_builder_.reset(createInferBuilder(&logger_));
infer_network_.reset(infer_builder_->createNetwork()); infer_network_.reset(infer_builder_->createNetwork());
} }
// After finishing adding ops, freeze this network and creates the executation // After finishing adding ops, freeze this network and creates the executation
......
...@@ -46,13 +46,13 @@ const int kDataTypeSize[] = { ...@@ -46,13 +46,13 @@ const int kDataTypeSize[] = {
// The following two API are implemented in TensorRT's header file, cannot load // The following two API are implemented in TensorRT's header file, cannot load
// from the dynamic library. So create our own implementation and directly // from the dynamic library. So create our own implementation and directly
// trigger the method from the dynamic library. // trigger the method from the dynamic library.
static nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger& logger) { static nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger* logger) {
return static_cast<nvinfer1::IBuilder*>( return static_cast<nvinfer1::IBuilder*>(
dy::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION)); dy::createInferBuilder_INTERNAL(logger, NV_TENSORRT_VERSION));
} }
static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger& logger) { static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) {
return static_cast<nvinfer1::IRuntime*>( return static_cast<nvinfer1::IRuntime*>(
dy::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION)); dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION));
} }
// A logger for create TensorRT infer builder. // A logger for create TensorRT infer builder.
...@@ -80,7 +80,7 @@ class NaiveLogger : public nvinfer1::ILogger { ...@@ -80,7 +80,7 @@ class NaiveLogger : public nvinfer1::ILogger {
return *x; return *x;
} }
virtual ~NaiveLogger() override {} ~NaiveLogger() override {}
}; };
} // namespace tensorrt } // namespace tensorrt
......
...@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "NvInfer.h" #include "NvInfer.h"
#include "cuda.h"
#include "cuda_runtime_api.h"
#include "paddle/fluid/platform/dynload/tensorrt.h" #include "paddle/fluid/platform/dynload/tensorrt.h"
namespace dy = paddle::platform::dynload; namespace dy = paddle::platform::dynload;
...@@ -43,7 +43,7 @@ class Logger : public nvinfer1::ILogger { ...@@ -43,7 +43,7 @@ class Logger : public nvinfer1::ILogger {
class ScopedWeights { class ScopedWeights {
public: public:
ScopedWeights(float value) : value_(value) { explicit ScopedWeights(float value) : value_(value) {
w.type = nvinfer1::DataType::kFLOAT; w.type = nvinfer1::DataType::kFLOAT;
w.values = &value_; w.values = &value_;
w.count = 1; w.count = 1;
...@@ -58,13 +58,13 @@ class ScopedWeights { ...@@ -58,13 +58,13 @@ class ScopedWeights {
// The following two API are implemented in TensorRT's header file, cannot load // The following two API are implemented in TensorRT's header file, cannot load
// from the dynamic library. So create our own implementation and directly // from the dynamic library. So create our own implementation and directly
// trigger the method from the dynamic library. // trigger the method from the dynamic library.
nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger& logger) { nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger* logger) {
return static_cast<nvinfer1::IBuilder*>( return static_cast<nvinfer1::IBuilder*>(
dy::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION)); dy::createInferBuilder_INTERNAL(logger, NV_TENSORRT_VERSION));
} }
nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger& logger) { nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) {
return static_cast<nvinfer1::IRuntime*>( return static_cast<nvinfer1::IRuntime*>(
dy::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION)); dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION));
} }
const char* kInputTensor = "input"; const char* kInputTensor = "input";
...@@ -74,7 +74,7 @@ const char* kOutputTensor = "output"; ...@@ -74,7 +74,7 @@ const char* kOutputTensor = "output";
nvinfer1::IHostMemory* CreateNetwork() { nvinfer1::IHostMemory* CreateNetwork() {
Logger logger; Logger logger;
// Create the engine. // Create the engine.
nvinfer1::IBuilder* builder = createInferBuilder(logger); nvinfer1::IBuilder* builder = createInferBuilder(&logger);
ScopedWeights weights(2.); ScopedWeights weights(2.);
ScopedWeights bias(3.); ScopedWeights bias(3.);
...@@ -103,9 +103,9 @@ nvinfer1::IHostMemory* CreateNetwork() { ...@@ -103,9 +103,9 @@ nvinfer1::IHostMemory* CreateNetwork() {
return model; return model;
} }
void Execute(nvinfer1::IExecutionContext& context, const float* input, void Execute(nvinfer1::IExecutionContext* context, const float* input,
float* output) { float* output) {
const nvinfer1::ICudaEngine& engine = context.getEngine(); const nvinfer1::ICudaEngine& engine = context->getEngine();
// Two binds, input and output // Two binds, input and output
ASSERT_EQ(engine.getNbBindings(), 2); ASSERT_EQ(engine.getNbBindings(), 2);
const int input_index = engine.getBindingIndex(kInputTensor); const int input_index = engine.getBindingIndex(kInputTensor);
...@@ -119,7 +119,7 @@ void Execute(nvinfer1::IExecutionContext& context, const float* input, ...@@ -119,7 +119,7 @@ void Execute(nvinfer1::IExecutionContext& context, const float* input,
// Copy the input to the GPU, execute the network, and copy the output back. // Copy the input to the GPU, execute the network, and copy the output back.
ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float), ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float),
cudaMemcpyHostToDevice, stream)); cudaMemcpyHostToDevice, stream));
context.enqueue(1, buffers, stream, nullptr); context->enqueue(1, buffers, stream, nullptr);
ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float), ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float),
cudaMemcpyDeviceToHost, stream)); cudaMemcpyDeviceToHost, stream));
cudaStreamSynchronize(stream); cudaStreamSynchronize(stream);
...@@ -136,7 +136,7 @@ TEST(TensorrtTest, BasicFunction) { ...@@ -136,7 +136,7 @@ TEST(TensorrtTest, BasicFunction) {
// Use the model to create an engine and an execution context. // Use the model to create an engine and an execution context.
Logger logger; Logger logger;
nvinfer1::IRuntime* runtime = createInferRuntime(logger); nvinfer1::IRuntime* runtime = createInferRuntime(&logger);
nvinfer1::ICudaEngine* engine = nvinfer1::ICudaEngine* engine =
runtime->deserializeCudaEngine(model->data(), model->size(), nullptr); runtime->deserializeCudaEngine(model->data(), model->size(), nullptr);
model->destroy(); model->destroy();
...@@ -145,7 +145,7 @@ TEST(TensorrtTest, BasicFunction) { ...@@ -145,7 +145,7 @@ TEST(TensorrtTest, BasicFunction) {
// Execute the network. // Execute the network.
float input = 1234; float input = 1234;
float output; float output;
Execute(*context, &input, &output); Execute(context, &input, &output);
EXPECT_EQ(output, input * 2 + 3); EXPECT_EQ(output, input * 2 + 3);
// Destroy the engine. // Destroy the engine.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册