diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc index 18c32fa09199003f17183207828cdfe4e627ae1a..f40d471cbfcb765895d911c7b36680271e2f0df0 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc @@ -23,7 +23,7 @@ namespace paddle { namespace inference { -DEFINE_int32(tensorrt_max_batchsize, 3, "TensorRT maximum batch size"); +DEFINE_int32(tensorrt_max_batchsize, 1, "TensorRT maximum batch size"); DEFINE_int32(tensorrt_workspace_size, 2048, "TensorRT workspace size"); namespace analysis { @@ -52,7 +52,6 @@ bool DataFlowGraphToFluidPass::Initialize(Argument *argument) { bool DataFlowGraphToFluidPass::Finalize() { return true; } void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) { - FilterRedundantOutputOfSubGraph(graph); LOG(INFO) << "graph.inputs " << graph->inputs.size(); for (auto &node : GraphTraits(graph).nodes_in_TS()) { if (node.deleted()) continue; diff --git a/paddle/fluid/inference/analysis/subgraph_splitter.cc b/paddle/fluid/inference/analysis/subgraph_splitter.cc index 80809d4c43ca08298bad25cf614dcb4117d3f99a..9146c0e45e77b5f120d3be622f74e3008bca2b6f 100644 --- a/paddle/fluid/inference/analysis/subgraph_splitter.cc +++ b/paddle/fluid/inference/analysis/subgraph_splitter.cc @@ -153,6 +153,7 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() { inlink_or_outlink_cleaner(o->inlinks); } } + FilterRedundantOutputOfSubGraph(graph_); } } // namespace analysis diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index dba1d50b2d1c487ced8e6ca51f2d257641ad5fc7..3d390c8ad3b8897721b634b373b84e0993fc897b 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -35,12 +35,20 @@ class Conv2dOpConverter : public OpConverter { auto* Y_v = scope.FindVar(op_desc.Input("Filter").front()); PADDLE_ENFORCE_NOT_NULL(Y_v); auto* Y_t = Y_v->GetMutable(); - auto* weight_data = Y_t->mutable_data(platform::CPUPlace()); - PADDLE_ENFORCE_EQ(Y_t->dims().size(), 4UL); - const int n_output = Y_t->dims()[0]; - const int filter_h = Y_t->dims()[2]; - const int filter_w = Y_t->dims()[3]; + platform::CPUPlace cpu_place; + framework::LoDTensor* weight_tensor = new framework::LoDTensor(); + weight_tensor->Resize(Y_t->dims()); + TensorCopySync((*Y_t), cpu_place, weight_tensor); + engine_->weight_map[op_desc.Input("Filter").front()] = + std::move(std::unique_ptr(weight_tensor)); + auto* weight_data = + weight_tensor->mutable_data(platform::CPUPlace()); + + PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL); + const int n_output = weight_tensor->dims()[0]; + const int filter_h = weight_tensor->dims()[2]; + const int filter_w = weight_tensor->dims()[3]; const int groups = boost::get(op_desc.GetAttr("groups")); const std::vector dilations = @@ -57,7 +65,7 @@ class Conv2dOpConverter : public OpConverter { TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, static_cast(weight_data), - Y_t->memory_size() / sizeof(float)}; + weight_tensor->memory_size() / sizeof(float)}; TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; auto* layer = TRT_ENGINE_ADD_LAYER( diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc index 3744550f60a1696aedd8a3ecd24f1b21d22325b9..066e5de373d15d50e5583e6ed7c5a07ae8abe791 100644 --- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" namespace paddle { @@ -40,10 +39,19 @@ class ElementwiseWeightOpConverter : public OpConverter { auto* Y_v = scope.FindVar(op_desc.Input("Y").front()); PADDLE_ENFORCE_NOT_NULL(Y_v); auto* Y_t = Y_v->GetMutable(); - auto* weight_data = Y_t->mutable_data(platform::CPUPlace()); + + platform::CPUPlace cpu_place; + framework::LoDTensor* weight_tensor = new framework::LoDTensor(); + weight_tensor->Resize(Y_t->dims()); + TensorCopySync((*Y_t), cpu_place, weight_tensor); + engine_->weight_map[op_desc.Input("Y").front()] = + std::move(std::unique_ptr(weight_tensor)); + + auto* weight_data = + weight_tensor->mutable_data(platform::CPUPlace()); auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; - std::vector dims_y = framework::vectorize2int(Y_t->dims()); + std::vector dims_y = framework::vectorize2int(weight_tensor->dims()); if (static_cast(dims_y.size()) == dims_x.nbDims + 1) { if (dims_y[0] == 1) dims_y.erase(dims_y.begin()); } @@ -70,9 +78,9 @@ class ElementwiseWeightOpConverter : public OpConverter { PADDLE_THROW("TensorRT unsupported weight Shape for Elementwise op!"); } - TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT, - static_cast(weight_data), - Y_t->memory_size() / sizeof(float)}; + TensorRTEngine::Weight shift_weights{ + nvinfer1::DataType::kFLOAT, static_cast(weight_data), + weight_tensor->memory_size() / sizeof(float)}; TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT, nullptr, 0}; TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr, diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index 39fe1f609d7b94638506877fc301f19ef33ec8ac..653ddb0ccae178008869e069617d7db2764f3866 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -12,12 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" -#include "paddle/fluid/inference/tensorrt/engine.h" -#include "paddle/fluid/platform/place.h" namespace paddle { namespace inference { @@ -73,19 +68,28 @@ class FcOpConverter : public OpConverter { auto* Y_t = Y_v->GetMutable(); // This may trigger a GPU->CPU copy, because TRT's weight can only be // assigned from CPU memory, that can't be avoided. - auto* weight_data = Y_t->mutable_data(platform::CPUPlace()); - PADDLE_ENFORCE_EQ(Y_t->dims().size(), 2UL); // a matrix - size_t n_output = Y_t->dims()[1]; + platform::CPUPlace cpu_place; + framework::LoDTensor weight_tensor; + weight_tensor.Resize(Y_t->dims()); + TensorCopySync((*Y_t), cpu_place, &weight_tensor); - framework::LoDTensor tmp; - tmp.Resize(Y_t->dims()); - memcpy(tmp.mutable_data(platform::CPUPlace()), weight_data, + auto* weight_data = weight_tensor.mutable_data(platform::CPUPlace()); + + PADDLE_ENFORCE_EQ(weight_tensor.dims().size(), 2UL); // a matrix + size_t n_output = weight_tensor.dims()[1]; + + framework::LoDTensor* tmp = new framework::LoDTensor(); + tmp->Resize(weight_tensor.dims()); + engine_->weight_map[op_desc.Input("Y").front()] = + std::move(std::unique_ptr(tmp)); + + memcpy(tmp->mutable_data(platform::CPUPlace()), weight_data, Y_t->dims()[0] * Y_t->dims()[1] * sizeof(float)); TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, static_cast(weight_data), Y_t->memory_size() / sizeof(float)}; TensorRTEngine::Weight tmp_weight(nvinfer1::DataType::kFLOAT, - static_cast(tmp.data()), + static_cast(tmp->data()), Y_t->memory_size() / sizeof(float)); weight.dims.assign({Y_t->dims()[0], Y_t->dims()[1]}); tmp_weight.dims = weight.dims; diff --git a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc index d6651a5b244ba31a01220e6299cb2016ae61fe64..01d7f700da9cc67d0ebbd3d9649e3823f58a8811 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc @@ -57,6 +57,7 @@ TEST(OpConverter, ConvertBlock) { auto* x = scope.Var("conv2d-Y"); auto* x_tensor = x->GetMutable(); x_tensor->Resize(framework::make_ddim(dim_vec)); + x_tensor->mutable_data(platform::CUDAPlace(0)); OpConverter converter; converter.ConvertBlock(*block->Proto(), {"conv2d-Y"}, scope, diff --git a/paddle/fluid/inference/tensorrt/convert/ut_helper.h b/paddle/fluid/inference/tensorrt/convert/ut_helper.h index 4265f33f28fe36b1745baf4761c3c85e3a281d6b..35ecfd02f43ca0715793a9b6b3997dbc943b5769 100644 --- a/paddle/fluid/inference/tensorrt/convert/ut_helper.h +++ b/paddle/fluid/inference/tensorrt/convert/ut_helper.h @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/engine.h" @@ -48,11 +49,17 @@ void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place, auto dims = tensor->dims(); size_t num_elements = analysis::AccuDims(dims, dims.size()); PADDLE_ENFORCE_GT(num_elements, 0); - auto* data = tensor->mutable_data(place); + + platform::CPUPlace cpu_place; + framework::LoDTensor temp_tensor; + temp_tensor.Resize(dims); + auto* temp_data = temp_tensor.mutable_data(cpu_place); for (size_t i = 0; i < num_elements; i++) { - *(data + i) = random(0., 1.); + *(temp_data + i) = random(0., 1.); } + + TensorCopySync(temp_tensor, place, tensor); } /* @@ -101,8 +108,8 @@ class TRTConvertValidation { } void DeclVar(const std::string& name, const std::vector dim_vec) { - platform::CPUPlace place; - platform::CPUDeviceContext ctx(place); + platform::CUDAPlace place; + platform::CUDADeviceContext ctx(place); auto* x = scope_.Var(name); auto* x_tensor = x->GetMutable(); @@ -141,7 +148,7 @@ class TRTConvertValidation { PADDLE_ENFORCE(var); auto tensor = var->GetMutable(); - engine_->SetInputFromCPU( + engine_->SetInputFromGPU( input, static_cast(tensor->data()), sizeof(float) * analysis::AccuDims(tensor->dims(), tensor->dims().size())); @@ -151,8 +158,8 @@ class TRTConvertValidation { void Execute(int batch_size) { // Execute Fluid Op PADDLE_ENFORCE_LE(batch_size, max_batch_size_); - platform::CPUPlace place; - platform::CPUDeviceContext ctx(place); + platform::CUDAPlace place; + platform::CUDADeviceContext ctx(place); op_->Run(scope_, place); // Execute TRT. engine_->Execute(batch_size); diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index b821c3d0bf425c46fae634fbf53f7ee63100ca5c..14e9e14d33d637ee68e37593cc48721e5169499f 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -33,6 +33,7 @@ void TensorRTEngine::Build(const DescType &paddle_model) { } void TensorRTEngine::Execute(int batch_size) { + freshDeviceId(); batch_size_ = batch_size; std::vector buffers; for (auto &buf : buffers_) { @@ -60,6 +61,7 @@ TensorRTEngine::~TensorRTEngine() { } void TensorRTEngine::FreezeNetwork() { + freshDeviceId(); PADDLE_ENFORCE(infer_builder_ != nullptr, "Call InitNetwork first to initialize network."); PADDLE_ENFORCE(infer_network_ != nullptr, @@ -241,6 +243,13 @@ void TensorRTEngine::SetRuntimeBatch(size_t batch_size) { int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; } +void TensorRTEngine::freshDeviceId() { + int count; + cudaGetDeviceCount(&count); + PADDLE_ENFORCE_LT(device_, count); + cudaSetDevice(device_); +} + } // namespace tensorrt } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 694468c419c20089de1cdecff1a903ad0cc6e99f..bd3ba4cea6551a7f6651e311e2649de191a6faa1 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -19,6 +19,7 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/inference/engine.h" #include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/utils/singleton.h" @@ -52,13 +53,15 @@ class TensorRTEngine : public EngineBase { }; TensorRTEngine(int max_batch, int max_workspace, - cudaStream_t* stream = nullptr, + cudaStream_t* stream = nullptr, int device = 0, nvinfer1::ILogger& logger = NaiveLogger::Global()) : max_batch_(max_batch), max_workspace_(max_workspace), stream_(stream ? stream : &default_stream_), - logger_(logger) { - cudaStreamCreate(&default_stream_); + logger_(logger), + device_(device) { + freshDeviceId(); + cudaStreamCreate(stream_); } virtual ~TensorRTEngine(); @@ -119,6 +122,15 @@ class TensorRTEngine : public EngineBase { nvinfer1::INetworkDefinition* network() { return infer_network_.get(); } void SetRuntimeBatch(size_t batch_size); int GetRuntimeBatch(); + int GetDevice() { return device_; } + + // A pointer to CPU memory is needed of the TRT weight. + // Before TRT runs, fluid loads weight into GPU storage. + // so we need to copy the weights from GPU to CPU in our op converter. + // We use a map to store these weights for the weight memory is not released + // in advance, which affecting the construction of TRT Op. + std::unordered_map> + weight_map; private: // the max batch size @@ -140,6 +152,8 @@ class TensorRTEngine : public EngineBase { std::unordered_map buffer_sizes_; std::unordered_map itensor_map_; + // The specific GPU id that the TensorRTEngine bounded to. + int device_; // TensorRT related internal members template @@ -156,6 +170,10 @@ class TensorRTEngine : public EngineBase { infer_ptr infer_network_; infer_ptr infer_engine_; infer_ptr infer_context_; + // Each ICudaEngine object is bound to a specific GPU when it is instantiated, + // ensure that the thread is associated with the correct device by calling + // freshDeviceId(). + void freshDeviceId(); }; // class TensorRTEngine // Add an layer__ into engine__ with args ARGS. @@ -188,8 +206,8 @@ class TRT_EngineManager { // Create or get an engine called `name` TensorRTEngine* Create(int max_batch, int max_workspace, cudaStream_t* stream, - const std::string& name) { - auto* p = new TensorRTEngine(max_batch, max_workspace, stream); + const std::string& name, int gpu_device = 0) { + auto* p = new TensorRTEngine(max_batch, max_workspace, stream, gpu_device); engines_[name].reset(p); return p; } diff --git a/paddle/fluid/inference/tensorrt/test_engine.cc b/paddle/fluid/inference/tensorrt/test_engine.cc index dc03702990587bf5e65d28da662d10df4d882110..da1f6535cb3b2476cd475797861d6d2bb6d88856 100644 --- a/paddle/fluid/inference/tensorrt/test_engine.cc +++ b/paddle/fluid/inference/tensorrt/test_engine.cc @@ -27,7 +27,7 @@ namespace tensorrt { class TensorRTEngineTest : public ::testing::Test { protected: void SetUp() override { - ASSERT_EQ(0, cudaStreamCreate(&stream_)); + // ASSERT_EQ(0, cudaStreamCreate(&stream_)); engine_ = new TensorRTEngine(10, 1 << 10, &stream_); engine_->InitNetwork(); } diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index ff0e989464e76b0f7cb163bd95997b82303036d6..13dfc1d99cc2058ee5ba2d8121b7f1f7bc1343e1 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -100,7 +100,8 @@ function(op_library TARGET) endif() # Define operators that don't need pybind here. - foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op") + foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" +"tensor_array_read_write_op" "tensorrt_engine_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() @@ -245,6 +246,7 @@ op_library(softmax_op DEPS softmax) op_library(sequence_softmax_op DEPS softmax) if (WITH_GPU AND TENSORRT_FOUND) op_library(tensorrt_engine_op DEPS tensorrt_engine tensorrt_converter) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(tensorrt_engine);\n") nv_test(test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc DEPS tensorrt_engine_op analysis) diff --git a/paddle/fluid/operators/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt_engine_op.cc index ee3078876c15b06a887064f08dc0c05d450b5f77..4d930e9cec208dca2c0e8f0ff690a1b4730a156a 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.cc +++ b/paddle/fluid/operators/tensorrt_engine_op.cc @@ -17,10 +17,6 @@ #include #include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" -#include "paddle/fluid/inference/tensorrt/engine.h" -#include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/operators/tensorrt_engine_op.h" namespace paddle { @@ -29,100 +25,6 @@ DEFINE_int32(tensorrt_engine_batch_size, 1, "the batch_size of TensorRT"); namespace operators { -using inference::Singleton; -using inference::tensorrt::TRT_EngineManager; - -using FluidDT = framework::proto::VarType_Type; -using TRT_DT = nvinfer1::DataType; - -namespace { - -TRT_DT FluidDataType2TRT(FluidDT type) { - switch (type) { - case FluidDT::VarType_Type_FP32: - return TRT_DT::kFLOAT; - case FluidDT::VarType_Type_INT32: - return TRT_DT::kINT32; - default: - return TRT_DT::kINT32; - } - PADDLE_THROW("unkown type"); - return TRT_DT::kINT32; -} - -nvinfer1::Dims Vec2TRT_Dims(const std::vector &shape) { - PADDLE_ENFORCE_GT(shape.size(), 1UL, - "TensorRT' tensor input requires at least 2 dimensions"); - PADDLE_ENFORCE_LE(shape.size(), 4UL, - "TensorRT' tensor input requires at most 4 dimensions"); - PADDLE_ENFORCE_EQ(shape.size(), 4UL); - return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]); -} - -} // namespace - -template -void TensorRTEngineKernel::Prepare( - const framework::ExecutionContext &context) const { - VLOG(4) << "Prepare engine"; - // Get the ProgramDesc and pass to convert. - framework::proto::BlockDesc block_desc; - block_desc.ParseFromString(context.Attr("subgraph")); - int max_batch = context.Attr("max_batch"); - auto max_workspace = context.Attr("max_workspace"); - auto params = context.Attr>("parameters"); - std::unordered_set parameters; - for (const auto ¶m : params) { - parameters.insert(param); - } - - std::vector output_maps = - context.Attr>("output_name_mapping"); - - // TODO(Superjomn) replace this with a different stream - auto *engine = Singleton::Global().Create( - max_batch, max_workspace, nullptr /*engine hold its own stream*/, - context.Attr("engine_uniq_key")); - engine->InitNetwork(); - - framework::BlockDesc block(nullptr /*programdesc*/, &block_desc); - VLOG(4) << "parsed var size " << block.AllVars().size(); - // Add inputs - VLOG(4) << "declare inputs"; - for (auto &input : context.Inputs("Xs")) { - if (parameters.count(input)) continue; - VLOG(4) << "declare input " << input; - auto *var = block.FindVar(input); - // TensorRT engine need to create parameters. The parameter's description - // should be set in - PADDLE_ENFORCE(var, "no variable called %s", input); - PADDLE_ENFORCE_EQ(var->GetType(), FluidDT::VarType_Type_LOD_TENSOR, - "TensorRT engine only takes LoDTensor as input"); - auto shape = var->GetShape(); - // For the special batch_size placeholder -1, drop it and pass the real - // shape of data. - // TODO(Superjomn) fix this with batch broadcast, or it can't handle - // variational batch size. - if (shape[0] == -1) { - shape[0] = FLAGS_tensorrt_engine_batch_size; - } - engine->DeclareInput( - input, FluidDataType2TRT( - var->Proto()->type().lod_tensor().tensor().data_type()), - Vec2TRT_Dims(shape)); - } - - inference::Singleton::Global().ConvertBlock( - block_desc, parameters, context.scope(), engine); - - // Add outputs - for (auto &output : output_maps) { - engine->DeclareOutput(output); - } - - engine->FreezeNetwork(); -} - class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -150,11 +52,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(tensorrt_engine, ops::TensorRTEngineOp, ops::TensorRTEngineOpMaker, ops::TensorRTEngineOpMaker); -REGISTER_OP_CPU_KERNEL( - tensorrt_engine, - ops::TensorRTEngineKernel, - ops::TensorRTEngineKernel, - ops::TensorRTEngineKernel, - ops::TensorRTEngineKernel); - #endif // PADDLE_WITH_CUDA diff --git a/paddle/fluid/operators/tensorrt_engine_op.cu.cc b/paddle/fluid/operators/tensorrt_engine_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..e1ddfde6d51ef719ca0b89cf286b176195ee682a --- /dev/null +++ b/paddle/fluid/operators/tensorrt_engine_op.cu.cc @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/tensorrt_engine_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + tensorrt_engine, + ops::TensorRTEngineKernel, + ops::TensorRTEngineKernel, + ops::TensorRTEngineKernel, + ops::TensorRTEngineKernel); diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h index 2cbe1213a2f428a3ce56b06f97636baeb4b66c26..f2ec7f066aa514d1107af17f9155f2c630b0f3b3 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -19,8 +19,10 @@ #include #include +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/inference/analysis/helper.h" +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/engine.h" namespace paddle { @@ -29,6 +31,35 @@ DECLARE_int32(tensorrt_engine_batch_size); namespace operators { +using FluidDT = framework::proto::VarType_Type; +using TRT_DT = nvinfer1::DataType; + +namespace { + +TRT_DT FluidDataType2TRT(FluidDT type) { + switch (type) { + case FluidDT::VarType_Type_FP32: + return TRT_DT::kFLOAT; + case FluidDT::VarType_Type_INT32: + return TRT_DT::kINT32; + default: + return TRT_DT::kINT32; + } + PADDLE_THROW("unkown type"); + return TRT_DT::kINT32; +} + +nvinfer1::Dims Vec2TRT_Dims(const std::vector& shape) { + PADDLE_ENFORCE_GT(shape.size(), 1UL, + "TensorRT' tensor input requires at least 2 dimensions"); + PADDLE_ENFORCE_LE(shape.size(), 4UL, + "TensorRT' tensor input requires at most 4 dimensions"); + PADDLE_ENFORCE_EQ(shape.size(), 4UL); + return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]); +} + +} // namespace + using inference::Singleton; using inference::tensorrt::TRT_EngineManager; @@ -47,7 +78,7 @@ class TensorRTEngineOp : public framework::OperatorWithKernel { .FindVar(input0) ->GetMutable() ->type()), - platform::CPUPlace()); + ctx.GetPlace()); return kt; } }; @@ -94,7 +125,9 @@ class TensorRTEngineKernel : public framework::OpKernel { // Convert output tensor from engine to fluid int output_index = 0; + VLOG(4) << "TensorRT Engine Op Outputs:"; for (const auto& y : context.Outputs("Ys")) { + VLOG(4) << y; // convert output and copy to fluid. nvinfer1::ITensor* trt_t = engine->GetITensor(output_maps[output_index]); auto dims = trt_t->getDimensions(); @@ -113,9 +146,11 @@ class TensorRTEngineKernel : public framework::OpKernel { // TODO(Superjomn) change this float to dtype size. auto size = inference::analysis::AccuDims(dims.d, dims.nbDims) * FLAGS_tensorrt_engine_batch_size; - engine->GetOutputInCPU(output_maps[output_index], - fluid_t->mutable_data(platform::CPUPlace()), - size * sizeof(float)); + engine->GetOutputInGPU( + output_maps[output_index], + fluid_t->mutable_data(platform::CUDAPlace( + boost::get(context.GetPlace()).device)), + size * sizeof(float)); //} else { // engine->GetOutputInGPU( // y, fluid_t->mutable_data(platform::CUDAPlace()), @@ -128,8 +163,67 @@ class TensorRTEngineKernel : public framework::OpKernel { } protected: - // Build the engine. - void Prepare(const framework::ExecutionContext& context) const; + void Prepare(const framework::ExecutionContext& context) const { + VLOG(4) << "Prepare engine"; + // Get the ProgramDesc and pass to convert. + framework::proto::BlockDesc block_desc; + block_desc.ParseFromString(context.Attr("subgraph")); + int max_batch = context.Attr("max_batch"); + auto max_workspace = context.Attr("max_workspace"); + auto params = context.Attr>("parameters"); + std::unordered_set parameters; + for (const auto& param : params) { + parameters.insert(param); + } + + std::vector output_maps = + context.Attr>("output_name_mapping"); + + // TODO(Superjomn) replace this with a different stream + auto* engine = Singleton::Global().Create( + max_batch, max_workspace, nullptr /*engine hold its own stream*/, + context.Attr("engine_uniq_key"), + boost::get(context.GetPlace()).device); + + engine->InitNetwork(); + + framework::BlockDesc block(nullptr /*programdesc*/, &block_desc); + VLOG(4) << "parsed var size " << block.AllVars().size(); + // Add inputs + VLOG(4) << "declare inputs"; + for (auto& input : context.Inputs("Xs")) { + if (parameters.count(input)) continue; + VLOG(4) << "declare input " << input; + auto* var = block.FindVar(input); + // TensorRT engine need to create parameters. The parameter's description + // should be set in + PADDLE_ENFORCE(var, "no variable called %s", input); + PADDLE_ENFORCE_EQ(var->GetType(), FluidDT::VarType_Type_LOD_TENSOR, + "TensorRT engine only takes LoDTensor as input"); + auto shape = var->GetShape(); + // For the special batch_size placeholder -1, drop it and pass the real + // shape of data. + // TODO(Superjomn) fix this with batch broadcast, or it can't handle + // variational batch size. + if (shape[0] == -1) { + shape[0] = FLAGS_tensorrt_engine_batch_size; + } + engine->DeclareInput( + input, FluidDataType2TRT( + var->Proto()->type().lod_tensor().tensor().data_type()), + Vec2TRT_Dims(shape)); + } + + inference::Singleton::Global() + .ConvertBlock(block_desc, parameters, context.scope(), engine); + + // Add outputs + for (auto& output : output_maps) { + engine->DeclareOutput(output); + } + + engine->FreezeNetwork(); + } }; } // namespace operators diff --git a/paddle/fluid/operators/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt_engine_op_test.cc index 37657fa0b0498986fe67027415279af1775e58b9..97c375361f42a61e8a2a626bcec2a91263ee2014 100644 --- a/paddle/fluid/operators/tensorrt_engine_op_test.cc +++ b/paddle/fluid/operators/tensorrt_engine_op_test.cc @@ -23,20 +23,20 @@ limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" -USE_CPU_ONLY_OP(tensorrt_engine); +USE_CUDA_ONLY_OP(tensorrt_engine); namespace paddle { namespace operators { namespace { -void CreateCPUTensor(framework::Scope* scope, const std::string& name, - const std::vector& shape) { +void CreateCUDATensor(framework::Scope* scope, const std::string& name, + const std::vector& shape) { auto* var = scope->Var(name); auto* tensor = var->GetMutable(); auto dims = framework::make_ddim(shape); tensor->Resize(dims); - platform::CPUPlace place; - platform::CPUDeviceContext ctx(place); + platform::CUDAPlace place; + platform::CUDADeviceContext ctx(place); inference::tensorrt::RandomizeTensor(tensor, place, ctx); } @@ -112,15 +112,15 @@ TEST(TensorRTEngineOp, manual) { LOG(INFO) << "engine_op " << engine_op.get(); framework::Scope scope; - platform::CPUPlace place; - platform::CPUDeviceContext ctx(place); + platform::CUDAPlace place; + platform::CUDADeviceContext ctx(place); // Prepare variables. - CreateCPUTensor(&scope, "x", std::vector({2, 4})); - CreateCPUTensor(&scope, "y", std::vector({4, 6})); - CreateCPUTensor(&scope, "z", std::vector({2, 6})); + CreateCUDATensor(&scope, "x", std::vector({2, 4})); + CreateCUDATensor(&scope, "y", std::vector({4, 6})); + CreateCUDATensor(&scope, "z", std::vector({2, 6})); - CreateCPUTensor(&scope, "y0", std::vector({6, 8})); - CreateCPUTensor(&scope, "z0", std::vector({2, 8})); + CreateCUDATensor(&scope, "y0", std::vector({6, 8})); + CreateCUDATensor(&scope, "z0", std::vector({2, 8})); // Execute them. LOG(INFO) << "engine_op run"; @@ -130,8 +130,8 @@ TEST(TensorRTEngineOp, manual) { void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { framework::ProgramDesc program; framework::Scope scope; - platform::CPUPlace place; - platform::CPUDeviceContext ctx(place); + platform::CUDAPlace place; + platform::CUDADeviceContext ctx(place); auto* block_ = program.Proto()->add_blocks(); block_->set_idx(0); @@ -165,10 +165,10 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { // Prepare variables. if (!x_created) { - CreateCPUTensor(&scope, x_name, std::vector(x_shape)); + CreateCUDATensor(&scope, x_name, std::vector(x_shape)); } - CreateCPUTensor(&scope, y_name, std::vector(y_shape)); - CreateCPUTensor(&scope, z_name, std::vector(z_shape)); + CreateCUDATensor(&scope, y_name, std::vector(y_shape)); + CreateCUDATensor(&scope, z_name, std::vector(z_shape)); // It is wired, need to copy manually. *block_->add_ops() = *fc->Proto();