未验证 提交 55f0d840 编写于 作者: A Abhinav Arora 提交者: GitHub

Fix Cpplint Issues in fluid/inference/tensorrt/ (#10318)

* Fix CPPLint issues in fluid/inference/tensorrt/

* Fix compile errors
上级 0bc44c18
......@@ -65,7 +65,7 @@ class TensorRTEngine : public EngineBase {
// Initialize the inference network, so that TensorRT layers can add to this
// network.
void InitNetwork() {
infer_builder_.reset(createInferBuilder(logger_));
infer_builder_.reset(createInferBuilder(&logger_));
infer_network_.reset(infer_builder_->createNetwork());
}
// After finishing adding ops, freeze this network and creates the executation
......
......@@ -46,13 +46,13 @@ const int kDataTypeSize[] = {
// The following two API are implemented in TensorRT's header file, cannot load
// from the dynamic library. So create our own implementation and directly
// trigger the method from the dynamic library.
static nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger& logger) {
static nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger* logger) {
return static_cast<nvinfer1::IBuilder*>(
dy::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION));
dy::createInferBuilder_INTERNAL(logger, NV_TENSORRT_VERSION));
}
static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger& logger) {
static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) {
return static_cast<nvinfer1::IRuntime*>(
dy::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION));
dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION));
}
// A logger for create TensorRT infer builder.
......@@ -80,7 +80,7 @@ class NaiveLogger : public nvinfer1::ILogger {
return *x;
}
virtual ~NaiveLogger() override {}
~NaiveLogger() override {}
};
} // namespace tensorrt
......
......@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "NvInfer.h"
#include "cuda.h"
#include "cuda_runtime_api.h"
#include "paddle/fluid/platform/dynload/tensorrt.h"
namespace dy = paddle::platform::dynload;
......@@ -43,7 +43,7 @@ class Logger : public nvinfer1::ILogger {
class ScopedWeights {
public:
ScopedWeights(float value) : value_(value) {
explicit ScopedWeights(float value) : value_(value) {
w.type = nvinfer1::DataType::kFLOAT;
w.values = &value_;
w.count = 1;
......@@ -58,13 +58,13 @@ class ScopedWeights {
// The following two API are implemented in TensorRT's header file, cannot load
// from the dynamic library. So create our own implementation and directly
// trigger the method from the dynamic library.
nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger& logger) {
nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger* logger) {
return static_cast<nvinfer1::IBuilder*>(
dy::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION));
dy::createInferBuilder_INTERNAL(logger, NV_TENSORRT_VERSION));
}
nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger& logger) {
nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) {
return static_cast<nvinfer1::IRuntime*>(
dy::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION));
dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION));
}
const char* kInputTensor = "input";
......@@ -74,7 +74,7 @@ const char* kOutputTensor = "output";
nvinfer1::IHostMemory* CreateNetwork() {
Logger logger;
// Create the engine.
nvinfer1::IBuilder* builder = createInferBuilder(logger);
nvinfer1::IBuilder* builder = createInferBuilder(&logger);
ScopedWeights weights(2.);
ScopedWeights bias(3.);
......@@ -103,9 +103,9 @@ nvinfer1::IHostMemory* CreateNetwork() {
return model;
}
void Execute(nvinfer1::IExecutionContext& context, const float* input,
void Execute(nvinfer1::IExecutionContext* context, const float* input,
float* output) {
const nvinfer1::ICudaEngine& engine = context.getEngine();
const nvinfer1::ICudaEngine& engine = context->getEngine();
// Two binds, input and output
ASSERT_EQ(engine.getNbBindings(), 2);
const int input_index = engine.getBindingIndex(kInputTensor);
......@@ -119,7 +119,7 @@ void Execute(nvinfer1::IExecutionContext& context, const float* input,
// Copy the input to the GPU, execute the network, and copy the output back.
ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float),
cudaMemcpyHostToDevice, stream));
context.enqueue(1, buffers, stream, nullptr);
context->enqueue(1, buffers, stream, nullptr);
ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float),
cudaMemcpyDeviceToHost, stream));
cudaStreamSynchronize(stream);
......@@ -136,7 +136,7 @@ TEST(TensorrtTest, BasicFunction) {
// Use the model to create an engine and an execution context.
Logger logger;
nvinfer1::IRuntime* runtime = createInferRuntime(logger);
nvinfer1::IRuntime* runtime = createInferRuntime(&logger);
nvinfer1::ICudaEngine* engine =
runtime->deserializeCudaEngine(model->data(), model->size(), nullptr);
model->destroy();
......@@ -145,7 +145,7 @@ TEST(TensorrtTest, BasicFunction) {
// Execute the network.
float input = 1234;
float output;
Execute(*context, &input, &output);
Execute(context, &input, &output);
EXPECT_EQ(output, input * 2 + 3);
// Destroy the engine.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册