From 034ba1c291a5339ccf5c5dd109c0d59e0bb4511a Mon Sep 17 00:00:00 2001 From: nhzlx Date: Thu, 14 Feb 2019 07:11:47 +0000 Subject: [PATCH] add static model load for trt 1. bind trt input and output to fluid tensors --- .../ir_passes/tensorrt_subgraph_pass.cc | 175 +++++++++++------- paddle/fluid/inference/engine.h | 5 - .../inference/tensorrt/convert/conv2d_op.cc | 19 +- .../inference/tensorrt/convert/ut_helper.h | 69 ++++--- paddle/fluid/inference/tensorrt/engine.cc | 117 +----------- paddle/fluid/inference/tensorrt/engine.h | 41 +--- .../fluid/inference/tensorrt/test_engine.cc | 132 ++++++++----- .../operators/tensorrt/tensorrt_engine_op.h | 99 ++++++---- 8 files changed, 313 insertions(+), 344 deletions(-) diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 69a9caec030..d91f62a12f9 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -33,6 +33,14 @@ using framework::ir::Node; std::vector ExtractParameters( const std::unordered_set &nodes); +void RenameAndGetOutputs( + const std::vector &subgraph_nodes, + framework::BlockDesc *block_desc, + const std::set &input_names_with_id, + std::set *output_names_with_id, + std::set *output_names, + std::unordered_map *output_name_map); + std::unique_ptr analysis::TensorRtSubgraphPass::ApplyImpl( std::unique_ptr graph) const { @@ -120,9 +128,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, input_names.insert(x->Name()); input_names_with_id.insert(x->Name() + std::to_string(x->id())); } - op_desc->SetInput( - "Xs", std::vector(input_names.begin(), input_names.end())); - std::set output_names; std::set output_names_with_id; for (auto *x : node->outputs) { @@ -130,11 +135,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, output_names_with_id.insert(x->Name() + std::to_string(x->id())); } - op_desc->SetOutput( - "Ys", std::vector(output_names.begin(), output_names.end())); - op_desc->SetType("tensorrt_engine"); - std::unordered_map output_name_map; + auto &subgraph_nodes = *Agent(node).subgraph(); // The following procedure is used to rename all the intermediate // variables and the output variables of the subgraph. @@ -148,61 +150,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, // input of a OP, but also the output of a Op, there will be problems. // So we have to rename the variable in the subgraph to make sure // it is either an OP's input or an OP's output. - - auto &subgraph_nodes = *Agent(node).subgraph(); - for (size_t index = 0; index < block_desc.OpSize(); ++index) { - framework::proto::OpDesc *op = block_desc.Op(index)->Proto(); - auto correspond_node = subgraph_nodes[index]; - PADDLE_ENFORCE_EQ(correspond_node->Name(), op->type()); - - std::unordered_map var2id; - for (auto *in_var : correspond_node->inputs) { - var2id[in_var->Name()] = in_var->id(); - } - // rename for the input variables of op inside subgraph - for (int i = 0; i < op->inputs_size(); i++) { - // one input - auto *in_var = op->mutable_inputs(i); - std::vector replaced_names; - for (int k = 0; k < in_var->arguments_size(); k++) { // all the arguments - std::string arg_value = in_var->arguments(k); - std::string arg_value_with_id = - arg_value + std::to_string(var2id[arg_value]); - if (input_names_with_id.count(arg_value_with_id)) { - replaced_names.push_back(arg_value); - } else { - replaced_names.push_back(arg_value_with_id); - } - } - in_var->clear_arguments(); - for (size_t k = 0; k < replaced_names.size(); k++) { - in_var->add_arguments(replaced_names[k]); - } - } - var2id.clear(); - for (auto out_var : correspond_node->outputs) { - var2id[out_var->Name()] = out_var->id(); - } - - // rename for the output variables of op inside subgraph - for (int i = 0; i < op->outputs_size(); i++) { - framework::proto::OpDesc_Var *out_var = op->mutable_outputs(i); - std::vector replaced_names; - for (int k = 0; k < out_var->arguments_size(); k++) { - std::string arg_value = out_var->arguments(k); - std::string arg_value_with_id = - arg_value + std::to_string(var2id[arg_value]); - if (output_names_with_id.count(arg_value_with_id)) { - output_name_map[arg_value] = arg_value_with_id; - } - replaced_names.push_back(arg_value_with_id); - } - out_var->clear_arguments(); - for (size_t k = 0; k < replaced_names.size(); k++) { - out_var->add_arguments(replaced_names[k]); - } - } - } + RenameAndGetOutputs(subgraph_nodes, &block_desc, input_names_with_id, + &output_names_with_id, &output_names, &output_name_map); // When tensorrt engine runs at the end of the operation, // output_mapping help us copy the data from the renamed ITensor @@ -222,6 +171,14 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, PADDLE_ENFORCE(!block_desc.Proto()->vars().empty(), "the block has no var-desc"); + + op_desc->SetInput( + "Xs", std::vector(input_names.begin(), input_names.end())); + + op_desc->SetOutput( + "Ys", std::vector(output_names.begin(), output_names.end())); + op_desc->SetType("tensorrt_engine"); + PADDLE_ENFORCE(!output_mapping.empty()); op_desc->SetBlockAttr("sub_block", new_block); SetAttr(op_desc->Proto(), "subgraph", @@ -236,6 +193,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, auto engine_key = GenerateEngineKey(input_names_with_id, output_names_with_id); + // Get "" when there is no cached calibration table data. std::string calibration_data = GetTrtCalibTableData( Get("model_opt_cache_dir"), engine_key, enable_int8); SetAttr(op_desc->Proto(), "calibration_data", calibration_data); @@ -272,6 +230,99 @@ std::vector ExtractParameters( return parameters; } +void RenameAndGetOutputs( + const std::vector &subgraph_nodes, + framework::BlockDesc *block_desc, + const std::set &input_names_with_id, + std::set *output_names_with_id, + std::set *output_names, + std::unordered_map *output_name_map) { + //// In the normal case, the paddle-trt exists bug when runing the googlenet. + // When there are more than two convolutions of 1 * 1 with the same input, the + // paddle-tensorrt will do the merging optimization, which fuse those conv + // into one conv, and then trigger bug. So, We should use strategy to avoid + // this optimization for the time being. This bug will be fixed in the future. + std::unordered_map + same_hierarchy_conv2d_num_map; + + for (size_t index = 0; index < block_desc->OpSize(); ++index) { + framework::proto::OpDesc *op = block_desc->Op(index)->Proto(); + framework::OpDesc op_desc(*op, nullptr); + auto correspond_node = subgraph_nodes[index]; + PADDLE_ENFORCE_EQ(correspond_node->Name(), op->type()); + + std::unordered_map var2id; + std::unordered_map in_vars; + for (auto *in_var : correspond_node->inputs) { + var2id[in_var->Name()] = in_var->id(); + in_vars[in_var->Name()] = in_var; + } + // rename for the input variables of op inside subgraph + for (int i = 0; i < op->inputs_size(); i++) { + // one input + auto *in_var = op->mutable_inputs(i); + std::vector replaced_names; + for (int k = 0; k < in_var->arguments_size(); k++) { // all the arguments + std::string arg_value = in_var->arguments(k); + std::string arg_value_with_id = + arg_value + std::to_string(var2id[arg_value]); + if (input_names_with_id.count(arg_value_with_id)) { + replaced_names.push_back(arg_value); + } else { + replaced_names.push_back(arg_value_with_id); + } + } + in_var->clear_arguments(); + for (size_t k = 0; k < replaced_names.size(); k++) { + in_var->add_arguments(replaced_names[k]); + } + } + var2id.clear(); + for (auto out_var : correspond_node->outputs) { + var2id[out_var->Name()] = out_var->id(); + } + + if (op_desc.Type() == "conv2d") { + auto input_var_name = op_desc.Input("Input").front(); + auto filter_var_name = op_desc.Input("Filter").front(); + auto out_var_name = op_desc.Output("Output").front(); + auto filter_shape = in_vars[filter_var_name]->Var()->GetShape(); + const std::vector strides = + boost::get>(op_desc.GetAttr("strides")); + const std::vector paddings = + boost::get>(op_desc.GetAttr("paddings")); + if (same_hierarchy_conv2d_num_map[input_var_name] > 0) { + (*output_names_with_id) + .insert(out_var_name + std::to_string(var2id[out_var_name])); + (*output_names).insert(out_var_name); + } else if (filter_shape[2] == 1 && filter_shape[3] == 1 && + strides[0] == 1 && strides[1] == 1 && paddings[0] == 0 && + paddings[1] == 0) { + same_hierarchy_conv2d_num_map[input_var_name] += 1; + } + } + + // rename for the output variables of op inside subgraph + for (int i = 0; i < op->outputs_size(); i++) { + framework::proto::OpDesc_Var *out_var = op->mutable_outputs(i); + std::vector replaced_names; + for (int k = 0; k < out_var->arguments_size(); k++) { + std::string arg_value = out_var->arguments(k); + std::string arg_value_with_id = + arg_value + std::to_string(var2id[arg_value]); + if (output_names_with_id->count(arg_value_with_id)) { + (*output_name_map)[arg_value] = arg_value_with_id; + } + replaced_names.push_back(arg_value_with_id); + } + out_var->clear_arguments(); + for (size_t k = 0; k < replaced_names.size(); k++) { + out_var->add_arguments(replaced_names[k]); + } + } + } +} + } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/engine.h b/paddle/fluid/inference/engine.h index ce2b8161715..1a13ba51038 100644 --- a/paddle/fluid/inference/engine.h +++ b/paddle/fluid/inference/engine.h @@ -49,11 +49,6 @@ class EngineBase { // Execute the engine, that will run the inference network. virtual void Execute(int batch_size) = 0; - // Return the IO buffer that allocated in engine. One can read/write directly - // on the buffer. If the buffer's buffer is nullptr, one can also allocate - // memory and maintain it outside the engine. - virtual Buffer& buffer(const std::string& name) = 0; - virtual ~EngineBase() {} }; // class EngineBase diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index 7900f56c9ce..ae1849f4353 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -18,21 +18,6 @@ namespace paddle { namespace inference { namespace tensorrt { -bool to_skip_merging_optimize(TensorRTEngine* engine, - const std::vector& filters, - const std::vector& strides, - const std::vector& paddings, - std::string input_name) { - if (engine->itensor_quote_num[input_name] > 0) { - return true; - } - if (filters[0] == 1 && filters[1] == 1 && strides[0] == 1 && - strides[1] == 1 && paddings[0] == 0 && paddings[1] == 0) - engine->itensor_quote_num[input_name] += 1; - - return false; -} - template void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode, @@ -100,9 +85,7 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, layer->getOutput(0)->setName(output_name.c_str()); engine->SetITensor(output_name, layer->getOutput(0)); - if (test_mode || - to_skip_merging_optimize(engine, {filter_h, filter_w}, strides, paddings, - op_desc.Input("Input").front())) { + if (test_mode) { engine->DeclareOutput(output_name); } } diff --git a/paddle/fluid/inference/tensorrt/convert/ut_helper.h b/paddle/fluid/inference/tensorrt/convert/ut_helper.h index e83961f3d7b..3298a103a28 100644 --- a/paddle/fluid/inference/tensorrt/convert/ut_helper.h +++ b/paddle/fluid/inference/tensorrt/convert/ut_helper.h @@ -146,19 +146,6 @@ class TRTConvertValidation { // Declare outputs. op_desc_.reset(new framework::OpDesc(desc, nullptr)); - - // Set Inputs. - for (const auto& input : op_desc_->InputArgumentNames()) { - if (parameters_.count(input)) continue; - auto* var = scope_.FindVar(input); - PADDLE_ENFORCE(var); - auto tensor = var->GetMutable(); - - engine_->SetInputFromGPU( - input, static_cast(tensor->data()), - sizeof(float) * - analysis::AccuDims(tensor->dims(), tensor->dims().size())); - } } // We use the set 'neglected_output' here, because some Ops like batch norm, @@ -171,34 +158,64 @@ class TRTConvertValidation { platform::CUDAPlace place; platform::CUDADeviceContext ctx(place); op_->Run(scope_, place); + + std::vector input_output_names; + + // Note: we need filter the parameter + for (const auto& input : op_desc_->InputArgumentNames()) { + if (parameters_.count(input)) continue; + input_output_names.push_back(input); + } + + // Collect the fluid outputs. + std::vector> fluid_outs; + for (const auto& output : op_desc_->OutputArgumentNames()) { + if (neglected_output.count(output)) continue; + input_output_names.push_back(output); + std::vector fluid_out; + auto* var = scope_.FindVar(output); + auto* tensor = var->GetMutable(); + framework::TensorToVector(*tensor, ctx, &fluid_out); + fluid_outs.push_back(fluid_out); + } + + // Bind input and output for TRT. + const int num_bindings = input_output_names.size(); + std::vector buffers(num_bindings); + + for (const std::string& name : input_output_names) { + auto* var = scope_.FindVar(name); + auto* tensor = var->GetMutable(); + const int bind_index = engine_->engine()->getBindingIndex(name.c_str()); + buffers[bind_index] = + static_cast(tensor->mutable_data(place)); + } + // Execute TRT. - engine_->Execute(batch_size); + engine_->Execute(batch_size, buffers); + cudaStreamSynchronize(engine_->stream()); ASSERT_FALSE(op_desc_->OutputArgumentNames().empty()); - const size_t output_space_size = 3000; + int index = 0; for (const auto& output : op_desc_->OutputArgumentNames()) { if (neglected_output.count(output)) continue; - std::vector fluid_out; - std::vector trt_out(output_space_size); - engine_->GetOutputInCPU(output, &trt_out[0], output_space_size); - cudaStreamSynchronize(engine_->stream()); - + std::vector trt_out; auto* var = scope_.FindVar(output); - auto tensor = var->GetMutable(); - framework::TensorToVector(*tensor, ctx, &fluid_out); + auto* tensor = var->GetMutable(); + framework::TensorToVector(*tensor, ctx, &trt_out); - size_t fluid_out_size = fluid_out.size(); + size_t fluid_out_size = fluid_outs[index].size(); if (if_add_batch_ == true) { fluid_out_size = batch_size * (framework::product(tensor->dims()) / max_batch_size_); } - // Compare two output - ASSERT_FALSE(fluid_out.empty()); + for (size_t i = 0; i < fluid_out_size; i++) { // Loose the threshold for CI in different machine model. - EXPECT_LT(std::abs(fluid_out[i] - trt_out[i]), 2e-5); + EXPECT_LT(std::abs(fluid_outs[index][i] - trt_out[i]), 2e-5); } + index += 1; } } diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 10f48462cfa..1d07b373dad 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -32,8 +32,14 @@ void TensorRTEngine::Build(const DescType &paddle_model) { PADDLE_ENFORCE(false, "not implemented"); } +void TensorRTEngine::Execute(int batch_size, std::vector &buffers) { + batch_size_ = batch_size; + infer_context_->enqueue(batch_size, buffers.data(), stream_, nullptr); + cudaStreamSynchronize(stream_); + SetRuntimeBatch(batch_size); +} + void TensorRTEngine::Execute(int batch_size) { - freshDeviceId(); batch_size_ = batch_size; std::vector buffers; for (auto &buf : buffers_) { @@ -61,7 +67,6 @@ TensorRTEngine::~TensorRTEngine() { void TensorRTEngine::FreezeNetwork() { VLOG(3) << "TRT to freeze network"; - freshDeviceId(); PADDLE_ENFORCE(infer_builder_ != nullptr, "Call InitNetwork first to initialize network."); PADDLE_ENFORCE(infer_network_ != nullptr, @@ -81,30 +86,6 @@ void TensorRTEngine::FreezeNetwork() { PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!"); infer_context_.reset(infer_engine_->createExecutionContext()); - - // allocate GPU buffers. - buffers_.resize(buffer_sizes_.size()); - for (auto &item : buffer_sizes_) { - // The output buffers are not set in the network building phrase, need to - // infer from the TesorRT network. - if (item.second == 0) { - auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str()); - auto dims = infer_engine_->getBindingDimensions(slot_offset); - item.second = kDataTypeSize[static_cast( - infer_engine_->getBindingDataType(slot_offset))] * - analysis::AccuDims(dims.d, dims.nbDims) * max_batch_; - PADDLE_ENFORCE_GT(item.second, 0); - } - - auto &buf = buffer(item.first); - buf.max_size = item.second * max_batch_; - CHECK(buf.buffer == nullptr); // buffer should be allocated only once. - - PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, item.second * max_batch_)); - buf.size = 0; - PADDLE_ENFORCE_LE(buf.max_size, 1 << 30); // 10G - buf.device = DeviceType::GPU; - } } nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name, @@ -158,83 +139,6 @@ void TensorRTEngine::DeclareOutput(const std::string &name) { buffer_sizes_[name] = 0; } -void *TensorRTEngine::GetOutputInGPU(const std::string &name) { - return buffer(name).buffer; -} - -void TensorRTEngine::GetOutputInGPU(const std::string &name, void *dst, - size_t max_size) { - // determine data size - auto *output = TensorRTEngine::GetITensor(name); - nvinfer1::Dims dims = output->getDimensions(); - auto dim_size = analysis::AccuDims(dims.d, dims.nbDims); - size_t dst_size = dim_size * runtime_batch_ * - kDataTypeSize[static_cast(output->getType())]; - - auto it = buffer_sizes_.find(name); - PADDLE_ENFORCE(it != buffer_sizes_.end()); - PADDLE_ENFORCE_GT(it->second, 0); - PADDLE_ENFORCE_LE(dst_size, it->second); - PADDLE_ENFORCE_GE(max_size, dst_size); - auto &buf = buffer(name); - PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before"); - PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, dst_size, - cudaMemcpyDeviceToDevice, stream_), - 0); -} - -void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst, - size_t max_size) { - // determine data size - - auto *output = TensorRTEngine::GetITensor(name); - nvinfer1::Dims dims = output->getDimensions(); - auto dim_size = analysis::AccuDims(dims.d, dims.nbDims); - size_t dst_size = dim_size * runtime_batch_ * - kDataTypeSize[static_cast(output->getType())]; - auto it = buffer_sizes_.find(name); - PADDLE_ENFORCE(it != buffer_sizes_.end()); - PADDLE_ENFORCE_GT(it->second, 0); - PADDLE_ENFORCE_LE(dst_size, it->second); - PADDLE_ENFORCE_GE(max_size, dst_size); - auto &buf = buffer(name); - PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before"); - PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buf.buffer, dst_size, - cudaMemcpyDeviceToHost, stream_)); -} - -Buffer &TensorRTEngine::buffer(const std::string &name) { - PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first."); - auto it = buffer_sizes_.find(name); - PADDLE_ENFORCE(it != buffer_sizes_.end(), "tried to access buffer named %s", - name); - auto slot_offset = infer_engine_->getBindingIndex(name.c_str()); - return buffers_[slot_offset]; -} - -void TensorRTEngine::SetInputFromCPU(const std::string &name, const void *data, - size_t size) { - auto &buf = buffer(name); - PADDLE_ENFORCE_NOT_NULL(buf.buffer); - PADDLE_ENFORCE_NOT_NULL(data); - PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small"); - PADDLE_ENFORCE(buf.device == DeviceType::GPU); - buf.size = size; - PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size, - cudaMemcpyHostToDevice, stream_)); -} - -void TensorRTEngine::SetInputFromGPU(const std::string &name, const void *data, - size_t size) { - auto &buf = buffer(name); - buf.size = size; - PADDLE_ENFORCE_NOT_NULL(buf.buffer); - PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small"); - PADDLE_ENFORCE(buf.device == DeviceType::GPU); - PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size, - cudaMemcpyDeviceToDevice, stream_)); -} - void TensorRTEngine::SetITensor(const std::string &name, nvinfer1::ITensor *tensor) { PADDLE_ENFORCE(tensor != nullptr); @@ -254,13 +158,6 @@ void TensorRTEngine::SetRuntimeBatch(size_t batch_size) { int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; } -void TensorRTEngine::freshDeviceId() { - int count; - cudaGetDeviceCount(&count); - PADDLE_ENFORCE_LT(device_, count); - cudaSetDevice(device_); -} - nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin( nvinfer1::ITensor *const *inputs, int num_inputs, plugin::PluginTensorRT *plugin) { diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index cdfe09b5a7f..39559836581 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -57,13 +57,12 @@ class TensorRTEngine : public EngineBase { }; TensorRTEngine(int max_batch, int max_workspace, cudaStream_t stream, - int device = 0, bool enable_int8 = false, + bool enable_int8 = false, TRTInt8Calibrator* calibrator = nullptr, nvinfer1::ILogger& logger = NaiveLogger::Global()) : max_batch_(max_batch), max_workspace_(max_workspace), stream_(stream), - device_(device), enable_int8_(enable_int8), calibrator_(calibrator), logger_(logger) {} @@ -74,6 +73,7 @@ class TensorRTEngine : public EngineBase { void Build(const DescType& paddle_model) override; void Execute(int batch_size) override; + void Execute(int batch_size, std::vector& buffers); // Initialize the inference network, so that TensorRT layers can add to this // network. @@ -98,28 +98,8 @@ class TensorRTEngine : public EngineBase { // Check if the ITensor has been declared bool HasDeclared(const std::string& name); - // GPU memory address for an ITensor with specific name. One can operate on - // these memory directly for acceleration, for example, output the converted - // data directly to the buffer to save data copy overhead. - // NOTE this should be used after calling `FreezeNetwork`. - Buffer& buffer(const std::string& name) override; - cudaStream_t stream() { return stream_; } - // Fill an input from CPU memory with name and size. - void SetInputFromCPU(const std::string& name, const void* data, size_t size); - // TODO(Superjomn) is this method necessary given that buffer(xxx) can be - // accessed directly. Fill an input from GPU memory with name and size. - void SetInputFromGPU(const std::string& name, const void* data, size_t size); - // Get an output called name, the output of tensorrt is in GPU, so this method - // Return the output's GPU memory address without copy. - void* GetOutputInGPU(const std::string& name); - // Copy data into dst inside the GPU device. - void GetOutputInGPU(const std::string& name, void* dst, size_t max_size); - // LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU - // to CPU. - void GetOutputInCPU(const std::string& name, void* dst, size_t max_size); - // Fill an ITensor into map itensor_map_. void SetITensor(const std::string& name, nvinfer1::ITensor* tensor); // Get an ITensor called name. nvinfer1::ITensor* GetITensor(const std::string& name); @@ -128,7 +108,6 @@ class TensorRTEngine : public EngineBase { nvinfer1::INetworkDefinition* network() { return infer_network_.get(); } void SetRuntimeBatch(size_t batch_size); int GetRuntimeBatch(); - int GetDevice() { return device_; } nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs, int num_inputs, plugin::PluginTensorRT*); @@ -140,16 +119,6 @@ class TensorRTEngine : public EngineBase { std::unordered_map> weight_map; - // TODO(NHZLX) - // In the normal case, the paddle-trt exists bug when runing the googlenet. - // When there are more than two convolutions of 1 * 1 with the same input, the - // paddle-tensorrt will do the merging optimization, which fuse those conv - // into one conv, and then trigger bug. So, We should use strategy to avoid - // this - // optimization for the time being. This bug will be fixed in the future. - std::unordered_map - itensor_quote_num; - private: // the max batch size int max_batch_; @@ -159,8 +128,6 @@ class TensorRTEngine : public EngineBase { int max_workspace_; cudaStream_t stream_; - // The specific GPU id that the TensorRTEngine bounded to. - int device_; bool enable_int8_; TRTInt8Calibrator* calibrator_; @@ -192,10 +159,6 @@ class TensorRTEngine : public EngineBase { infer_ptr infer_network_; infer_ptr infer_engine_; infer_ptr infer_context_; - // Each ICudaEngine object is bound to a specific GPU when it is instantiated, - // ensure that the thread is associated with the correct device by calling - // freshDeviceId(). - void freshDeviceId(); }; // class TensorRTEngine // Add an layer__ into engine__ with args ARGS. diff --git a/paddle/fluid/inference/tensorrt/test_engine.cc b/paddle/fluid/inference/tensorrt/test_engine.cc index 9eed0f6ee9c..961b24960bd 100644 --- a/paddle/fluid/inference/tensorrt/test_engine.cc +++ b/paddle/fluid/inference/tensorrt/test_engine.cc @@ -17,6 +17,8 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/platform/enforce.h" @@ -27,19 +29,29 @@ namespace tensorrt { class TensorRTEngineTest : public ::testing::Test { protected: void SetUp() override { - ASSERT_EQ(0, cudaStreamCreate(&stream_)); - engine_ = new TensorRTEngine(10, 1 << 10, stream_); + ctx_ = new platform::CUDADeviceContext(platform::CUDAPlace(0)); + + engine_ = new TensorRTEngine(10, 1 << 10, ctx_->stream()); engine_->InitNetwork(); } - void TearDown() override { - delete engine_; - cudaStreamDestroy(stream_); + void TearDown() override { delete engine_; } + + void PrepareInputOutput(const std::vector &input, + std::vector output_shape) { + TensorFromVector(input, *ctx_, &input_); + output_.Resize(framework::make_ddim(output_shape)); + } + + void GetOutput(std::vector *output) { + TensorToVector(output_, *ctx_, output); } protected: - TensorRTEngine* engine_; - cudaStream_t stream_; + framework::Tensor input_; + framework::Tensor output_; + TensorRTEngine *engine_; + platform::CUDADeviceContext *ctx_; }; TEST_F(TensorRTEngineTest, add_layer) { @@ -48,12 +60,14 @@ TEST_F(TensorRTEngineTest, add_layer) { float raw_weight[size] = {2.}; // Weight in CPU memory. float raw_bias[size] = {3.}; + std::vector buffers(2); // TRT binded inputs + LOG(INFO) << "create weights"; TensorRTEngine::Weight weight(nvinfer1::DataType::kFLOAT, raw_weight, size); TensorRTEngine::Weight bias(nvinfer1::DataType::kFLOAT, raw_bias, size); - auto* x = engine_->DeclareInput("x", nvinfer1::DataType::kFLOAT, + auto *x = engine_->DeclareInput("x", nvinfer1::DataType::kFLOAT, nvinfer1::DimsCHW{1, 1, 1}); - auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *x, size, + auto *fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *x, size, weight.get(), bias.get()); PADDLE_ENFORCE(fc_layer != nullptr); @@ -63,18 +77,24 @@ TEST_F(TensorRTEngineTest, add_layer) { ASSERT_EQ(engine_->engine()->getNbBindings(), 2); // fill in real data - float x_v = 1234; - engine_->SetInputFromCPU("x", reinterpret_cast(&x_v), - 1 * sizeof(float)); + std::vector x_v = {1234}; + std::vector y_cpu; + PrepareInputOutput(x_v, {1}); + + auto *x_v_gpu_data = input_.mutable_data(ctx_->GetPlace()); + auto *y_gpu_data = output_.mutable_data(ctx_->GetPlace()); + + buffers[0] = reinterpret_cast(x_v_gpu_data); + buffers[1] = reinterpret_cast(y_gpu_data); + LOG(INFO) << "to execute"; - engine_->Execute(1); + engine_->Execute(1, buffers); LOG(INFO) << "to get output"; - float y_cpu; - engine_->GetOutputInCPU("y", &y_cpu, 1 * sizeof(float)); + GetOutput(&y_cpu); LOG(INFO) << "to checkout output"; - ASSERT_EQ(y_cpu, x_v * 2 + 3); + ASSERT_EQ(y_cpu[0], x_v[0] * 2 + 3); } TEST_F(TensorRTEngineTest, add_layer_multi_dim) { @@ -83,12 +103,13 @@ TEST_F(TensorRTEngineTest, add_layer_multi_dim) { // instead of row-major, which is [[1.0, 1.1], [3.3, 4.4]] float raw_weight[4] = {1.0, 1.1, 3.3, 4.4}; float raw_bias[2] = {1.3, 2.4}; + std::vector buffers(2); // TRT binded inputs TensorRTEngine::Weight weight(nvinfer1::DataType::kFLOAT, raw_weight, 4); TensorRTEngine::Weight bias(nvinfer1::DataType::kFLOAT, raw_bias, 2); - auto* x = engine_->DeclareInput("x", nvinfer1::DataType::kFLOAT, + auto *x = engine_->DeclareInput("x", nvinfer1::DataType::kFLOAT, nvinfer1::DimsCHW{1, 2, 1}); - auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *x, 2, + auto *fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *x, 2, weight.get(), bias.get()); PADDLE_ENFORCE(fc_layer != nullptr); @@ -96,19 +117,27 @@ TEST_F(TensorRTEngineTest, add_layer_multi_dim) { engine_->FreezeNetwork(); ASSERT_EQ(engine_->engine()->getNbBindings(), 2); - float x_v[2] = {1.0, 2.0}; - engine_->SetInputFromCPU("x", reinterpret_cast(&x_v), - 2 * sizeof(float)); - engine_->Execute(1); + // fill in real data + std::vector x_v = {1.0, 2.0}; + std::vector y_cpu; + PrepareInputOutput(x_v, {2}); + + auto *x_v_gpu_data = input_.mutable_data(ctx_->GetPlace()); + auto *y_gpu_data = output_.mutable_data(ctx_->GetPlace()); + + buffers[0] = reinterpret_cast(x_v_gpu_data); + buffers[1] = reinterpret_cast(y_gpu_data); + + engine_->Execute(1, buffers); LOG(INFO) << "to get output"; - float y_cpu[2] = {-1., -1.}; + GetOutput(&y_cpu); auto dims = engine_->GetITensor("y")->getDimensions(); ASSERT_EQ(dims.nbDims, 3); ASSERT_EQ(dims.d[0], 2); ASSERT_EQ(dims.d[1], 1); - engine_->GetOutputInCPU("y", &y_cpu[0], 2 * sizeof(float)); + ASSERT_EQ(y_cpu[0], 4.5); ASSERT_EQ(y_cpu[1], 14.5); } @@ -117,12 +146,13 @@ TEST_F(TensorRTEngineTest, test_conv2d) { // Weight in CPU memory. float raw_weight[9] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; float raw_bias[1] = {0}; + std::vector buffers(2); // TRT binded inputs TensorRTEngine::Weight weight(nvinfer1::DataType::kFLOAT, raw_weight, 9); TensorRTEngine::Weight bias(nvinfer1::DataType::kFLOAT, raw_bias, 1); - auto* x = engine_->DeclareInput("x", nvinfer1::DataType::kFLOAT, + auto *x = engine_->DeclareInput("x", nvinfer1::DataType::kFLOAT, nvinfer1::Dims3{1, 3, 3}); - auto* conv_layer = + auto *conv_layer = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *x, 1, nvinfer1::DimsHW{3, 3}, weight.get(), bias.get()); PADDLE_ENFORCE(conv_layer != nullptr); @@ -133,28 +163,37 @@ TEST_F(TensorRTEngineTest, test_conv2d) { engine_->FreezeNetwork(); ASSERT_EQ(engine_->engine()->getNbBindings(), 2); - float x_v[18] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; - engine_->SetInputFromCPU("x", reinterpret_cast(&x_v), - 18 * sizeof(float)); - engine_->Execute(2); + // fill in real data + std::vector x_v = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; + std::vector y_cpu; + PrepareInputOutput(x_v, {18}); + + auto *x_v_gpu_data = input_.mutable_data(ctx_->GetPlace()); + auto *y_gpu_data = output_.mutable_data(ctx_->GetPlace()); + + buffers[0] = reinterpret_cast(x_v_gpu_data); + buffers[1] = reinterpret_cast(y_gpu_data); + + engine_->Execute(2, buffers); LOG(INFO) << "to get output"; - float* y_cpu = new float[18]; - engine_->GetOutputInCPU("y", &y_cpu[0], 18 * sizeof(float)); + GetOutput(&y_cpu); + ASSERT_EQ(y_cpu[0], 4.0); ASSERT_EQ(y_cpu[1], 6.0); } TEST_F(TensorRTEngineTest, test_pool2d) { // Weight in CPU memory. - auto* x = engine_->DeclareInput("x", nvinfer1::DataType::kFLOAT, + auto *x = engine_->DeclareInput("x", nvinfer1::DataType::kFLOAT, nvinfer1::Dims3{1, 2, 2}); + std::vector buffers(2); // TRT binded inputs nvinfer1::PoolingType pool_t = nvinfer1::PoolingType::kAVERAGE; - auto* pool_layer = - TRT_ENGINE_ADD_LAYER(engine_, Pooling, *const_cast(x), - pool_t, nvinfer1::DimsHW{2, 2}); + auto *pool_layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling, + *const_cast(x), + pool_t, nvinfer1::DimsHW{2, 2}); PADDLE_ENFORCE(pool_layer != nullptr); pool_layer->setStride(nvinfer1::DimsHW{1, 1}); @@ -164,14 +203,21 @@ TEST_F(TensorRTEngineTest, test_pool2d) { engine_->FreezeNetwork(); ASSERT_EQ(engine_->engine()->getNbBindings(), 2); - float x_v[8] = {1.0, 2.0, 5.0, 0.0, 2.0, 3.0, 5.0, 10.0}; - engine_->SetInputFromCPU("x", reinterpret_cast(&x_v), - 8 * sizeof(float)); - engine_->Execute(2); + // fill in real data + std::vector x_v = {1.0, 2.0, 5.0, 0.0, 2.0, 3.0, 5.0, 10.0}; + std::vector y_cpu; + PrepareInputOutput(x_v, {2}); + + auto *x_v_gpu_data = input_.mutable_data(ctx_->GetPlace()); + auto *y_gpu_data = output_.mutable_data(ctx_->GetPlace()); + + buffers[0] = reinterpret_cast(x_v_gpu_data); + buffers[1] = reinterpret_cast(y_gpu_data); + + engine_->Execute(2, buffers); LOG(INFO) << "to get output"; - float* y_cpu = new float[2]; - engine_->GetOutputInCPU("y", &y_cpu[0], 2 * sizeof(float)); + GetOutput(&y_cpu); ASSERT_EQ(y_cpu[0], 2.0); ASSERT_EQ(y_cpu[1], 5.0); diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 2ff35c7c6ac..d3efea28120 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -106,6 +106,11 @@ class TensorRTEngineOp : public framework::OperatorBase { if (enable_int8_ && calibration_data_.size()) { calibrator_.reset(new TRTInt8Calibrator(calibration_data_)); } + + // we will create an engine here. + if (!calibration_mode_) { + // trt_engine_.reset(); + } } protected: @@ -125,7 +130,8 @@ class TensorRTEngineOp : public framework::OperatorBase { RunCalibration(scope, dev_place); return; } - RunTrt(scope, dev_place); + auto trt_engine = GetEngine(scope, dev_place); + RunTrt(scope, dev_place, trt_engine); } void RunCalibration(const framework::Scope &scope, @@ -155,10 +161,9 @@ class TensorRTEngineOp : public framework::OperatorBase { calib_res->calib_.reset(new TRTInt8Calibrator( calib_buffers, runtime_batch, engine_key_, dev_place)); calib_res->thr_.reset(new std::thread([&]() { - calib_res->engine_.reset(new TensorRTEngine( - max_batch_size_, workspace_size_, stream, - boost::get(dev_place).device, enable_int8_, - calib_res->calib_.get())); + calib_res->engine_.reset( + new TensorRTEngine(max_batch_size_, workspace_size_, stream, + enable_int8_, calib_res->calib_.get())); VLOG(3) << "start the calib trt engine thread"; Prepare(scope, dev_place, calib_res->engine_.get()); })); @@ -180,28 +185,30 @@ class TensorRTEngineOp : public framework::OperatorBase { RunNativeImpl(scope, dev_place); } - void RunTrt(const framework::Scope &scope, - const platform::Place &dev_place) const { + void RunTrt(const framework::Scope &scope, const platform::Place &dev_place, + TensorRTEngine *engine) const { int runtime_batch = 1; platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(dev_place); auto stream = reinterpret_cast(dev_ctx).stream(); - if (trt_engine_.get() == nullptr) { - trt_engine_.reset( - new TensorRTEngine(max_batch_size_, workspace_size_, stream, - boost::get(dev_place).device, - enable_int8_, calibrator_.get())); - Prepare(scope, dev_place, trt_engine_.get()); - } - auto *engine = trt_engine_.get(); + // auto *engine = trt_engine_.get(); PADDLE_ENFORCE(!input_names_.empty(), "should pass more than one inputs"); std::vector output_maps = Attr>("output_name_mapping"); - // Convert input tensor from fluid to engine. + int num_inputs = 0; + + for (const auto &x : Inputs("Xs")) { + if (param_names_.count(x)) continue; + num_inputs += 1; + } + const int num_bindings = num_inputs + Outputs("Ys").size(); + std::vector buffers(num_bindings); + + // Bind input tensor to TRT. for (const auto &x : Inputs("Xs")) { if (param_names_.count(x)) continue; // convert input and copy to TRT engine's buffer @@ -209,26 +216,17 @@ class TensorRTEngineOp : public framework::OperatorBase { inference::analysis::GetFromScope(scope, x); auto t_shape = framework::vectorize(t.dims()); runtime_batch = t_shape[0]; - if (platform::is_cpu_place(t.place())) { - engine->SetInputFromCPU(x, static_cast(t.data()), - t.memory_size()); - } else { - engine->SetInputFromGPU(x, static_cast(t.data()), - t.memory_size()); - } - } - cudaStreamSynchronize(stream); - PADDLE_ENFORCE_LE(runtime_batch, max_batch_size_); - // Execute the engine. - engine->Execute(runtime_batch); + const int bind_index = engine->engine()->getBindingIndex(x.c_str()); + PADDLE_ENFORCE(bind_index < num_bindings, + "The bind index should be less than num_bindings"); + buffers[bind_index] = static_cast(t.data()); + } - // Convert output tensor from engine to fluid + // Bind output tensor to TRT. int output_index = 0; VLOG(4) << "TensorRT Engine Op Outputs:"; for (const auto &y : Outputs("Ys")) { - VLOG(4) << y; - // convert output and copy to fluid. nvinfer1::ITensor *trt_t = engine->GetITensor(output_maps[output_index]); auto dims = trt_t->getDimensions(); // Use the output ITensor's dims to reshape the Fluid Tensor. @@ -238,27 +236,46 @@ class TensorRTEngineOp : public framework::OperatorBase { for (int i = 0; i < dims.nbDims; i++) { ddim.push_back(dims.d[i]); } - auto *fluid_v = scope.FindVar(y); PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y); auto *fluid_t = fluid_v->GetMutable(); - fluid_t->Resize(framework::make_ddim(ddim)); - // TODO(Superjomn) change this float to dtype size. - auto size = - inference::analysis::AccuDims(dims.d, dims.nbDims) * runtime_batch; - engine->GetOutputInGPU( - output_maps[output_index], - fluid_t->mutable_data(platform::CUDAPlace( - boost::get(dev_place).device)), - size * sizeof(float)); + const int bind_index = + engine->engine()->getBindingIndex(output_maps[output_index].c_str()); + PADDLE_ENFORCE(bind_index < num_bindings, + "The bind index should be less than num_bindings"); + buffers[bind_index] = static_cast(fluid_t->mutable_data( + boost::get(dev_place))); + output_index += 1; } + PADDLE_ENFORCE_LE(runtime_batch, max_batch_size_); + // Execute the engine. + engine->Execute(runtime_batch, buffers); cudaStreamSynchronize(stream); } + TensorRTEngine *GetEngine(const framework::Scope &scope, + const platform::Place &dev_place) const { + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); + auto stream = + reinterpret_cast(dev_ctx).stream(); + if (trt_engine_.get() == nullptr) { + trt_engine_.reset(new TensorRTEngine(max_batch_size_, workspace_size_, + stream, enable_int8_, + calibrator_.get())); + if (true) { + Prepare(scope, dev_place, trt_engine_.get()); + } else { + // create static engine + } + } + return trt_engine_.get(); + } + void Prepare(const framework::Scope &scope, const platform::Place &dev_place, TensorRTEngine *engine) const { LOG(INFO) << "Prepare TRT engine (Optimize model structure, Select OP " -- GitLab