diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index 9318f1089781b30468cf4d3c7151d0dd26e50a9c..e51d6cfeb931d50a9a573df29c916ebd3da403d1 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -24,7 +24,7 @@ namespace paddle { -DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, true, +DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false, "Enable subgraph to TensorRT engine for acceleration"); DEFINE_string(inference_analysis_graphviz_log_root, "./", @@ -44,7 +44,8 @@ class DfgPassManagerImpl final : public DfgPassManager { if (FLAGS_inference_analysis_enable_tensorrt_subgraph_engine) { auto trt_teller = [&](const Node* node) { std::unordered_set teller_set( - {"elementwise_add", "mul", "conv2d", "pool2d", "relu", "softmax"}); + {"elementwise_add", "mul", "conv2d", "pool2d", "relu", "softmax", + "depthwise_conv2d", "batch_norm", "concat"}); if (!node->IsFunction()) return false; const auto* func = static_cast(node); diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc index 18c32fa09199003f17183207828cdfe4e627ae1a..ce0639a6162da6347ed130ecb1586c9a2d4071d5 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc @@ -23,9 +23,6 @@ namespace paddle { namespace inference { -DEFINE_int32(tensorrt_max_batchsize, 3, "TensorRT maximum batch size"); -DEFINE_int32(tensorrt_workspace_size, 2048, "TensorRT workspace size"); - namespace analysis { using framework::proto::ProgramDesc; @@ -52,7 +49,6 @@ bool DataFlowGraphToFluidPass::Initialize(Argument *argument) { bool DataFlowGraphToFluidPass::Finalize() { return true; } void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) { - FilterRedundantOutputOfSubGraph(graph); LOG(INFO) << "graph.inputs " << graph->inputs.size(); for (auto &node : GraphTraits(graph).nodes_in_TS()) { if (node.deleted()) continue; @@ -191,8 +187,6 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, // Set attrs SetAttr(desc.Proto(), "subgraph", block->SerializeAsString()); SetAttr(desc.Proto(), "engine_uniq_key", "trt-" + std::to_string(counter++)); - SetAttr(desc.Proto(), "max_batch", FLAGS_tensorrt_max_batchsize); - SetAttr(desc.Proto(), "max_workspace", FLAGS_tensorrt_workspace_size); SetAttr(desc.Proto(), "parameters", ExtractParameters(graph.nodes.nodes())); SetAttr(desc.Proto(), "output_name_mapping", output_mapping); node->SetPbMsg(desc.Proto()->SerializeAsString()); diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h index 59c47365aa6c8ad5886c4515850d264f69cc4670..0c9a8a0b7cae17bf2eaa714348ea1c9b5e43611b 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h @@ -27,9 +27,6 @@ namespace paddle { namespace inference { -DECLARE_int32(tensorrt_max_batchsize); -DECLARE_int32(tensorrt_workspace_size); - namespace analysis { class DataFlowGraphToFluidPass final : public DataFlowGraphPass { public: diff --git a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc index 511631d3e067f14bc1230d9e4b4d92dbe604e1d4..16d82b5aa1acaf87d1cd78ad5b79faa65143ad7d 100644 --- a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc +++ b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.cc @@ -92,6 +92,7 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { auto *in = graph->nodes.GetMutable(var2id.at(in_var.arguments(k))); in->outlinks.push_back(o); o->inlinks.push_back(in); + unique_written_vars.insert(in); } } for (int j = 0; j < op.outputs_size(); j++) { @@ -112,7 +113,6 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { } out->inlinks.push_back(o); o->outlinks.push_back(out); - unique_written_vars.insert(out); } } } diff --git a/paddle/fluid/inference/analysis/subgraph_splitter.cc b/paddle/fluid/inference/analysis/subgraph_splitter.cc index 80809d4c43ca08298bad25cf614dcb4117d3f99a..9146c0e45e77b5f120d3be622f74e3008bca2b6f 100644 --- a/paddle/fluid/inference/analysis/subgraph_splitter.cc +++ b/paddle/fluid/inference/analysis/subgraph_splitter.cc @@ -153,6 +153,7 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() { inlink_or_outlink_cleaner(o->inlinks); } } + FilterRedundantOutputOfSubGraph(graph_); } } // namespace analysis diff --git a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc index 45b5a7638b7dc6a54bbd905766fd5c284cb6aea1..5967402055c61dd480055497640a16ab7e94a746 100644 --- a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc +++ b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/api/api_impl.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/operators/tensorrt_engine_op.h" @@ -32,7 +33,9 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor { bool Init(const std::shared_ptr& parent_scope) { VLOG(3) << "Predictor::init()"; - + FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = true; + FLAGS_tensorrt_max_batch_size = config_.max_batch_size; + FLAGS_tensorrt_workspace_size = config_.workspace_size; if (config_.use_gpu) { place_ = paddle::platform::CUDAPlace(config_.device); } else { @@ -150,3 +153,13 @@ CreatePaddlePredictor( } } // namespace paddle + +USE_TRT_CONVERTER(elementwise_add_weight); +USE_TRT_CONVERTER(mul); +USE_TRT_CONVERTER(conv2d); +USE_TRT_CONVERTER(relu); +USE_TRT_CONVERTER(fc); +USE_TRT_CONVERTER(pool2d); +USE_TRT_CONVERTER(softmax); +USE_TRT_CONVERTER(batch_norm); +USE_TRT_CONVERTER(concat); diff --git a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine_tester.cc b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine_tester.cc index fcbf9b89d608e7961e3ef81ac1c70e083dae1cc0..b3892b5dd89fef055f87f9170a0e1f24cc93d0a0 100644 --- a/paddle/fluid/inference/api/api_tensorrt_subgraph_engine_tester.cc +++ b/paddle/fluid/inference/api/api_tensorrt_subgraph_engine_tester.cc @@ -37,6 +37,7 @@ void CompareTensorRTWithFluid(bool enable_tensorrt) { config1.use_gpu = true; config1.fraction_of_gpu_memory = 0.3; config1.device = 0; + config1.max_batch_size = 10; auto predictor0 = CreatePaddlePredictor(config0); diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index 794534467be066e91db2b4c204913ab2cf12dbfd..da6c2cfc21809aa08bba79880dc898bb056d16a0 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -137,6 +137,14 @@ struct AnakinConfig : public PaddlePredictor::Config { struct TensorRTConfig : public NativeConfig { // Determine whether a subgraph will be executed by TRT. int min_subgraph_size{1}; + // While TensorRT allows an engine optimized for a given max batch size + // to run at any smaller size, the performance for those smaller + // sizes may not be as well-optimized. Therefore, Max batch is best + // equivalent to the runtime batch size. + int max_batch_size{1}; + // For workspace_size, refer it from here: + // https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#troubleshooting + int workspace_size{1 << 30}; }; // A factory to help create different predictors. diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 6863b035d8cd9dfb21aed3947226a796778912a4..9d7be2d03cf7bb12afe7e52d9630f184d689dc25 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -1,7 +1,7 @@ # Add TRT tests nv_library(tensorrt_converter SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc -activation_op.cc softmax_op.cc +batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc DEPS tensorrt_engine operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS @@ -18,9 +18,12 @@ nv_test(test_trt_conv_op SRCS test_conv2d_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine conv_op SERIAL) nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine pool_op SERIAL) - nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine elementwise_add_op SERIAL) - nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine softmax_op SERIAL) +nv_test(test_trt_batch_norm_op SRCS test_batch_norm_op.cc batch_norm_op.cc + DEPS ${FLUID_CORE_MODULES} tensorrt_engine batch_norm_op SERIAL) + +nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc + DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL) diff --git a/paddle/fluid/inference/tensorrt/convert/batch_norm_op.cc b/paddle/fluid/inference/tensorrt/convert/batch_norm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..94f8b0ae5606d39a722ffe28501645c9b6fc5d2e --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/batch_norm_op.cc @@ -0,0 +1,136 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class BatchNormOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + LOG(INFO) << "convert a fluid batch norm op to tensorrt batch_norm"; + + framework::OpDesc op_desc(op, nullptr); + PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1); + PADDLE_ENFORCE_EQ(op_desc.Input("Bias").size(), 1); // Bias is a weight + PADDLE_ENFORCE_EQ(op_desc.Input("Mean").size(), 1); // Mean is a weight + PADDLE_ENFORCE_EQ(op_desc.Input("Scale").size(), 1); // Scale is a weight + PADDLE_ENFORCE_EQ(op_desc.Input("Variance").size(), + 1); // Variance is a weight + PADDLE_ENFORCE_EQ(op_desc.Output("Y").size(), 1); + + auto* X = engine_->GetITensor(op_desc.Input("X").front()); + // Declare weights + auto* Bias_v = scope.FindVar(op_desc.Input("Bias").front()); + auto* Mean_v = scope.FindVar(op_desc.Input("Mean").front()); + auto* Scale_v = scope.FindVar(op_desc.Input("Scale").front()); + auto* Variance_v = scope.FindVar(op_desc.Input("Variance").front()); + const float eps = boost::get(op_desc.GetAttr("epsilon")); + + PADDLE_ENFORCE_NOT_NULL(Bias_v); + PADDLE_ENFORCE_NOT_NULL(Mean_v); + PADDLE_ENFORCE_NOT_NULL(Scale_v); + PADDLE_ENFORCE_NOT_NULL(Variance_v); + + // get tensor + auto* Bias_t = Bias_v->GetMutable(); + auto* Mean_t = Mean_v->GetMutable(); + auto* Scale_t = Scale_v->GetMutable(); + auto* Variance_t = Variance_v->GetMutable(); + + // create temp tensor for weights + framework::LoDTensor bias_tensor; + framework::LoDTensor mean_tensor; + framework::LoDTensor scale_tensor; + framework::LoDTensor variance_tensor; + + bias_tensor.Resize(Bias_t->dims()); + mean_tensor.Resize(Mean_t->dims()); + scale_tensor.Resize(Scale_t->dims()); + variance_tensor.Resize(Variance_t->dims()); + + platform::CPUPlace cpu_place; + // copy data from gpu to cpu + TensorCopySync((*Bias_t), cpu_place, &bias_tensor); + TensorCopySync((*Mean_t), cpu_place, &mean_tensor); + TensorCopySync((*Scale_t), cpu_place, &scale_tensor); + TensorCopySync((*Variance_t), cpu_place, &variance_tensor); + + auto* bias_data = bias_tensor.mutable_data(platform::CPUPlace()); + auto* mean_data = mean_tensor.mutable_data(platform::CPUPlace()); + auto* scale_data = scale_tensor.mutable_data(platform::CPUPlace()); + auto* variance_data = + variance_tensor.mutable_data(platform::CPUPlace()); + + std::unique_ptr combile_scale_tensor( + new framework::LoDTensor()); + std::unique_ptr combile_bias_tensor( + new framework::LoDTensor()); + + combile_scale_tensor->Resize(scale_tensor.dims()); + combile_bias_tensor->Resize(bias_tensor.dims()); + + auto* combile_scale_data = + combile_scale_tensor->mutable_data(platform::CPUPlace()); + auto* combile_bias_data = + combile_bias_tensor->mutable_data(platform::CPUPlace()); + + size_t ele_num = combile_scale_tensor->memory_size() / sizeof(float); + + for (size_t i = 0; i < ele_num; i++) { + float scale = scale_data[i]; + float bias = bias_data[i]; + float mean = mean_data[i]; + float variance = variance_data[i]; + combile_scale_data[i] = scale / sqrtf(variance + eps); + combile_bias_data[i] = bias - mean * combile_scale_data[i]; + } + + TensorRTEngine::Weight scale_weights{ + nvinfer1::DataType::kFLOAT, static_cast(combile_scale_data), + combile_scale_tensor->memory_size() / sizeof(float)}; + TensorRTEngine::Weight shift_weights{ + nvinfer1::DataType::kFLOAT, static_cast(combile_bias_data), + combile_bias_tensor->memory_size() / sizeof(float)}; + TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr, + 0}; + + nvinfer1::IScaleLayer* layer = + TRT_ENGINE_ADD_LAYER(engine_, Scale, *const_cast(X), + nvinfer1::ScaleMode::kCHANNEL, shift_weights.get(), + scale_weights.get(), power_weights.get()); + + auto output_name = op_desc.Output("Y").front(); + engine_->weight_map[op_desc.Input("Bias").front()] = + std::move(combile_bias_tensor); + engine_->weight_map[op_desc.Input("Scale").front()] = + std::move(combile_scale_tensor); + + engine_->SetITensor(output_name, layer->getOutput(0)); + + if (test_mode) { + engine_->DeclareOutput(output_name); + } + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(batch_norm, BatchNormOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/concat_op.cc b/paddle/fluid/inference/tensorrt/convert/concat_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..bb9627bf957b63993b2c8d23e7ec8122eb004eaf --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/concat_op.cc @@ -0,0 +1,57 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights. + */ +class ConcatOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(4) << "convert a fluid mul op to tensorrt mul layer without bias"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + std::vector itensors; + for (auto& input_name : op_desc.Input("X")) { + itensors.push_back(engine_->GetITensor(input_name)); + } + int axis = boost::get(op_desc.GetAttr("axis")); + PADDLE_ENFORCE(axis > 0, + "The axis attr of Concat op should be large than 0 for trt"); + + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Concatenation, itensors.data(), + itensors.size()); + axis = axis - 1; // Remove batch dim + layer->setAxis(axis); + auto output_name = op_desc.Output("Out")[0]; + engine_->SetITensor(output_name, layer->getOutput(0)); + if (test_mode) { // the test framework can not determine which is the + // output, so place the declaration inside. + engine_->DeclareOutput(output_name); + } + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(concat, ConcatOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index dba1d50b2d1c487ced8e6ca51f2d257641ad5fc7..841a95db38ce7cf0cb5961ff04cb569ee2633e6f 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -35,12 +35,20 @@ class Conv2dOpConverter : public OpConverter { auto* Y_v = scope.FindVar(op_desc.Input("Filter").front()); PADDLE_ENFORCE_NOT_NULL(Y_v); auto* Y_t = Y_v->GetMutable(); - auto* weight_data = Y_t->mutable_data(platform::CPUPlace()); - PADDLE_ENFORCE_EQ(Y_t->dims().size(), 4UL); - const int n_output = Y_t->dims()[0]; - const int filter_h = Y_t->dims()[2]; - const int filter_w = Y_t->dims()[3]; + platform::CPUPlace cpu_place; + std::unique_ptr weight_tensor( + new framework::LoDTensor()); + weight_tensor->Resize(Y_t->dims()); + TensorCopySync((*Y_t), cpu_place, weight_tensor.get()); + + auto* weight_data = + weight_tensor->mutable_data(platform::CPUPlace()); + + PADDLE_ENFORCE_EQ(weight_tensor->dims().size(), 4UL); + const int n_output = weight_tensor->dims()[0]; + const int filter_h = weight_tensor->dims()[2]; + const int filter_w = weight_tensor->dims()[3]; const int groups = boost::get(op_desc.GetAttr("groups")); const std::vector dilations = @@ -57,7 +65,7 @@ class Conv2dOpConverter : public OpConverter { TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, static_cast(weight_data), - Y_t->memory_size() / sizeof(float)}; + weight_tensor->memory_size() / sizeof(float)}; TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; auto* layer = TRT_ENGINE_ADD_LAYER( @@ -70,6 +78,8 @@ class Conv2dOpConverter : public OpConverter { layer->setNbGroups(groups); auto output_name = op_desc.Output("Output").front(); + engine_->weight_map[op_desc.Input("Filter").front()] = + std::move(weight_tensor); engine_->SetITensor(output_name, layer->getOutput(0)); if (test_mode) { engine_->DeclareOutput(output_name); diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc index 3744550f60a1696aedd8a3ecd24f1b21d22325b9..60a72b4eb5c75b5cd12305f13763a9a1a567213f 100644 --- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" namespace paddle { @@ -40,10 +39,17 @@ class ElementwiseWeightOpConverter : public OpConverter { auto* Y_v = scope.FindVar(op_desc.Input("Y").front()); PADDLE_ENFORCE_NOT_NULL(Y_v); auto* Y_t = Y_v->GetMutable(); - auto* weight_data = Y_t->mutable_data(platform::CPUPlace()); + + platform::CPUPlace cpu_place; + std::unique_ptr weight_tensor( + new framework::LoDTensor()); + weight_tensor->Resize(Y_t->dims()); + TensorCopySync((*Y_t), cpu_place, weight_tensor.get()); + auto* weight_data = + weight_tensor->mutable_data(platform::CPUPlace()); auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; - std::vector dims_y = framework::vectorize2int(Y_t->dims()); + std::vector dims_y = framework::vectorize2int(weight_tensor->dims()); if (static_cast(dims_y.size()) == dims_x.nbDims + 1) { if (dims_y[0] == 1) dims_y.erase(dims_y.begin()); } @@ -70,9 +76,9 @@ class ElementwiseWeightOpConverter : public OpConverter { PADDLE_THROW("TensorRT unsupported weight Shape for Elementwise op!"); } - TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT, - static_cast(weight_data), - Y_t->memory_size() / sizeof(float)}; + TensorRTEngine::Weight shift_weights{ + nvinfer1::DataType::kFLOAT, static_cast(weight_data), + weight_tensor->memory_size() / sizeof(float)}; TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT, nullptr, 0}; TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr, @@ -82,6 +88,8 @@ class ElementwiseWeightOpConverter : public OpConverter { engine_, Scale, *const_cast(X), scale_mode, shift_weights.get(), scale_weights.get(), power_weights.get()); auto output_name = op_desc.Output("Out")[0]; + + engine_->weight_map[op_desc.Input("Y").front()] = std::move(weight_tensor); engine_->SetITensor(output_name, layer->getOutput(0)); if (test_mode) { // the test framework can not determine which is the // output, so place the declaration inside. diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index 39fe1f609d7b94638506877fc301f19ef33ec8ac..ad98d85aae9cf594922aca00c43718ccfbce2278 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -12,12 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" -#include "paddle/fluid/inference/tensorrt/engine.h" -#include "paddle/fluid/platform/place.h" namespace paddle { namespace inference { @@ -73,19 +68,26 @@ class FcOpConverter : public OpConverter { auto* Y_t = Y_v->GetMutable(); // This may trigger a GPU->CPU copy, because TRT's weight can only be // assigned from CPU memory, that can't be avoided. - auto* weight_data = Y_t->mutable_data(platform::CPUPlace()); - PADDLE_ENFORCE_EQ(Y_t->dims().size(), 2UL); // a matrix - size_t n_output = Y_t->dims()[1]; + platform::CPUPlace cpu_place; + framework::LoDTensor weight_tensor; + weight_tensor.Resize(Y_t->dims()); + TensorCopySync((*Y_t), cpu_place, &weight_tensor); - framework::LoDTensor tmp; - tmp.Resize(Y_t->dims()); - memcpy(tmp.mutable_data(platform::CPUPlace()), weight_data, + auto* weight_data = weight_tensor.mutable_data(platform::CPUPlace()); + + PADDLE_ENFORCE_EQ(weight_tensor.dims().size(), 2UL); // a matrix + size_t n_output = weight_tensor.dims()[1]; + + std::unique_ptr tmp(new framework::LoDTensor()); + tmp->Resize(weight_tensor.dims()); + + memcpy(tmp->mutable_data(platform::CPUPlace()), weight_data, Y_t->dims()[0] * Y_t->dims()[1] * sizeof(float)); TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, static_cast(weight_data), Y_t->memory_size() / sizeof(float)}; TensorRTEngine::Weight tmp_weight(nvinfer1::DataType::kFLOAT, - static_cast(tmp.data()), + static_cast(tmp->data()), Y_t->memory_size() / sizeof(float)); weight.dims.assign({Y_t->dims()[0], Y_t->dims()[1]}); tmp_weight.dims = weight.dims; @@ -106,6 +108,7 @@ class FcOpConverter : public OpConverter { auto output_name = op_desc.Output("Out").front(); engine_->SetITensor(output_name, layer->getOutput(0)); + engine_->weight_map[op_desc.Input("Y").front()] = std::move(tmp); if (test_mode) { engine_->DeclareOutput(output_name); } diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index 41faaf7212accaaec238062b1340e8da8fa6be33..d309d94c560f2b484fac6b6cd40cc2704d641069 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -79,6 +79,14 @@ class OpConverter { it = Registry::Lookup("elementwise_" + op_type + "_tensor"); } + PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", + op_desc.Type()); + } + + if (op_desc.Type() == "depthwise_conv2d") { + it = Registry::Lookup("conv2d"); + PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", + op_desc.Type()); } if (!it) { diff --git a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc index 11cad95361867476c6f775af778015da37f1cfb1..73f1b28ddf73403862e55d102a259d7b6cf67b1f 100644 --- a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc @@ -33,6 +33,7 @@ class Pool2dOpConverter : public OpConverter { PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1); auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); + bool global_pooling = boost::get(op_desc.GetAttr("global_pooling")); std::string pool_type = boost::get(op_desc.GetAttr("pooling_type")); std::vector ksize = @@ -42,7 +43,13 @@ class Pool2dOpConverter : public OpConverter { std::vector paddings = boost::get>(op_desc.GetAttr("paddings")); - const nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]); + nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]); + if (global_pooling == true) { + nvinfer1::Dims input_shape = input1->getDimensions(); + int nbDims = input_shape.nbDims; + nv_ksize.d[0] = input_shape.d[nbDims - 2]; + nv_ksize.d[1] = input_shape.d[nbDims - 1]; + } const nvinfer1::DimsHW nv_strides(strides[0], strides[1]); const nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]); diff --git a/paddle/fluid/inference/tensorrt/convert/test_batch_norm_op.cc b/paddle/fluid/inference/tensorrt/convert/test_batch_norm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..41412cb079540da72760558379b158b6538aa6a8 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_batch_norm_op.cc @@ -0,0 +1,71 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(batch_norm_op, test) { + std::unordered_set parameters( + {"batch_norm_scale", "batch_norm_bias", "batch_norm_mean", + "batch_norm_variance"}); + framework::Scope scope; + TRTConvertValidation validator(5, parameters, scope, 1 << 15); + std::vector param_shape{2}; + + validator.DeclInputVar("batch_norm_X", nvinfer1::DimsCHW(2, 5, 5)); + validator.DeclParamVar("batch_norm_scale", param_shape); + validator.DeclParamVar("batch_norm_bias", param_shape); + validator.DeclParamVar("batch_norm_mean", param_shape); + validator.DeclParamVar("batch_norm_variance", param_shape); + validator.DeclOutputVar("batch_norm_Y", nvinfer1::DimsCHW(2, 5, 5)); + validator.DeclOutputVar("batch_norm_save_mean", param_shape); + validator.DeclOutputVar("batch_norm_save_variance", param_shape); + + // Prepare Op description + framework::OpDesc desc; + + desc.SetType("batch_norm"); + desc.SetInput("X", {"batch_norm_X"}); + desc.SetInput("Scale", {"batch_norm_scale"}); + desc.SetInput("Bias", {"batch_norm_bias"}); + desc.SetInput("Mean", {"batch_norm_mean"}); + desc.SetInput("Variance", {"batch_norm_variance"}); + desc.SetOutput("Y", {"batch_norm_Y"}); + desc.SetOutput("MeanOut", {"batch_norm_mean"}); + desc.SetOutput("VarianceOut", {"batch_norm_variance"}); + desc.SetOutput("SavedMean", {"batch_norm_save_mean"}); + desc.SetOutput("SavedVariance", {"batch_norm_save_variance"}); + + float eps = 1e-5f; + bool is_test = true; + desc.SetAttr("epsilon", eps); + desc.SetAttr("is_test", is_test); + + validator.SetOp(*desc.Proto()); + + std::unordered_set neglected_output = { + "batch_norm_save_mean", "batch_norm_save_variance", "batch_norm_mean", + "batch_norm_variance"}; + validator.Execute(3, neglected_output); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle +USE_OP(batch_norm); diff --git a/paddle/fluid/inference/tensorrt/convert/test_concat_op.cc b/paddle/fluid/inference/tensorrt/convert/test_concat_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4f284a4db5758e072915d7fd0f16115b8a36ba8b --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_concat_op.cc @@ -0,0 +1,49 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(concat_op, test) { + std::unordered_set parameters({""}); + framework::Scope scope; + TRTConvertValidation validator(10, parameters, scope, 1000); + validator.DeclInputVar("concat_x1", nvinfer1::DimsCHW(10, 3, 1)); + validator.DeclInputVar("concat_x2", nvinfer1::DimsCHW(3, 3, 1)); + validator.DeclInputVar("concat_x3", nvinfer1::DimsCHW(7, 3, 1)); + validator.DeclOutputVar("concat_out", nvinfer1::DimsCHW(20, 3, 1)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("concat"); + desc.SetInput("X", {"concat_x1", "concat_x2", "concat_x3"}); + desc.SetOutput("Out", {"concat_out"}); + + int axis = 1; + desc.SetAttr("axis", axis); + + validator.SetOp(*desc.Proto()); + + validator.Execute(5); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle +USE_OP(concat); diff --git a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc index d6651a5b244ba31a01220e6299cb2016ae61fe64..01d7f700da9cc67d0ebbd3d9649e3823f58a8811 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc @@ -57,6 +57,7 @@ TEST(OpConverter, ConvertBlock) { auto* x = scope.Var("conv2d-Y"); auto* x_tensor = x->GetMutable(); x_tensor->Resize(framework::make_ddim(dim_vec)); + x_tensor->mutable_data(platform::CUDAPlace(0)); OpConverter converter; converter.ConvertBlock(*block->Proto(), {"conv2d-Y"}, scope, diff --git a/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc index c5dddbc8cd37b9fb1ba39382af2da5ad045f3af2..aedd6b62df040eeee4e48f628128511cd8bf4439 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc @@ -20,7 +20,7 @@ namespace paddle { namespace inference { namespace tensorrt { -TEST(Pool2dOpConverter, main) { +void test_pool2d(bool global_pooling) { framework::Scope scope; std::unordered_set parameters; TRTConvertValidation validator(5, parameters, scope, 1 << 15); @@ -28,7 +28,10 @@ TEST(Pool2dOpConverter, main) { // The ITensor's Dims should not contain the batch size. // So, the ITensor's Dims of input and output should be C * H * W. validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 4, 4)); - validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 2, 2)); + if (global_pooling) + validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 1, 1)); + else + validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 2, 2)); // Prepare Op description framework::OpDesc desc; @@ -45,6 +48,7 @@ TEST(Pool2dOpConverter, main) { desc.SetAttr("ksize", ksize); desc.SetAttr("strides", strides); desc.SetAttr("paddings", paddings); + desc.SetAttr("global_pooling", global_pooling); LOG(INFO) << "set OP"; validator.SetOp(*desc.Proto()); @@ -53,6 +57,10 @@ TEST(Pool2dOpConverter, main) { validator.Execute(3); } +TEST(Pool2dOpConverter, normal) { test_pool2d(false); } + +TEST(Pool2dOpConverter, test_global_pooling) { test_pool2d(true); } + } // namespace tensorrt } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/convert/ut_helper.h b/paddle/fluid/inference/tensorrt/convert/ut_helper.h index 4265f33f28fe36b1745baf4761c3c85e3a281d6b..0a6f171fc40a838fd81d6a51aca0430d5526f188 100644 --- a/paddle/fluid/inference/tensorrt/convert/ut_helper.h +++ b/paddle/fluid/inference/tensorrt/convert/ut_helper.h @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/engine.h" @@ -48,11 +49,17 @@ void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place, auto dims = tensor->dims(); size_t num_elements = analysis::AccuDims(dims, dims.size()); PADDLE_ENFORCE_GT(num_elements, 0); - auto* data = tensor->mutable_data(place); + + platform::CPUPlace cpu_place; + framework::LoDTensor temp_tensor; + temp_tensor.Resize(dims); + auto* temp_data = temp_tensor.mutable_data(cpu_place); for (size_t i = 0; i < num_elements; i++) { - *(data + i) = random(0., 1.); + *(temp_data + i) = random(0., 1.); } + + TensorCopySync(temp_tensor, place, tensor); } /* @@ -91,18 +98,26 @@ class TRTConvertValidation { engine_->DeclareInput(name, nvinfer1::DataType::kFLOAT, dims); } + void DeclParamVar(const std::string& name, const std::vector dim_vec) { + DeclVar(name, dim_vec); + } + // Declare a parameter varaible in the scope. void DeclParamVar(const std::string& name, const nvinfer1::Dims& dims) { DeclVar(name, dims, true); } + void DeclOutputVar(const std::string& name, const std::vector dim_vec) { + DeclVar(name, dim_vec); + } + void DeclOutputVar(const std::string& name, const nvinfer1::Dims& dims) { DeclVar(name, dims); } void DeclVar(const std::string& name, const std::vector dim_vec) { - platform::CPUPlace place; - platform::CPUDeviceContext ctx(place); + platform::CUDAPlace place; + platform::CUDADeviceContext ctx(place); auto* x = scope_.Var(name); auto* x_tensor = x->GetMutable(); @@ -141,18 +156,22 @@ class TRTConvertValidation { PADDLE_ENFORCE(var); auto tensor = var->GetMutable(); - engine_->SetInputFromCPU( + engine_->SetInputFromGPU( input, static_cast(tensor->data()), sizeof(float) * analysis::AccuDims(tensor->dims(), tensor->dims().size())); } } - void Execute(int batch_size) { + // We use the set 'neglected_output' here, because some Ops like batch norm, + // the outputs specified in the op des are only used during training, + // so we should neglect those output during inference. + void Execute(int batch_size, + std::unordered_set neglected_output = {}) { // Execute Fluid Op PADDLE_ENFORCE_LE(batch_size, max_batch_size_); - platform::CPUPlace place; - platform::CPUDeviceContext ctx(place); + platform::CUDAPlace place; + platform::CUDADeviceContext ctx(place); op_->Run(scope_, place); // Execute TRT. engine_->Execute(batch_size); @@ -161,6 +180,7 @@ class TRTConvertValidation { ASSERT_FALSE(op_desc_->OutputArgumentNames().empty()); const size_t output_space_size = 3000; for (const auto& output : op_desc_->OutputArgumentNames()) { + if (neglected_output.count(output)) continue; std::vector fluid_out; std::vector trt_out(output_space_size); engine_->GetOutputInCPU(output, &trt_out[0], output_space_size); diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index b821c3d0bf425c46fae634fbf53f7ee63100ca5c..14e9e14d33d637ee68e37593cc48721e5169499f 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -33,6 +33,7 @@ void TensorRTEngine::Build(const DescType &paddle_model) { } void TensorRTEngine::Execute(int batch_size) { + freshDeviceId(); batch_size_ = batch_size; std::vector buffers; for (auto &buf : buffers_) { @@ -60,6 +61,7 @@ TensorRTEngine::~TensorRTEngine() { } void TensorRTEngine::FreezeNetwork() { + freshDeviceId(); PADDLE_ENFORCE(infer_builder_ != nullptr, "Call InitNetwork first to initialize network."); PADDLE_ENFORCE(infer_network_ != nullptr, @@ -241,6 +243,13 @@ void TensorRTEngine::SetRuntimeBatch(size_t batch_size) { int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; } +void TensorRTEngine::freshDeviceId() { + int count; + cudaGetDeviceCount(&count); + PADDLE_ENFORCE_LT(device_, count); + cudaSetDevice(device_); +} + } // namespace tensorrt } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 694468c419c20089de1cdecff1a903ad0cc6e99f..bd3ba4cea6551a7f6651e311e2649de191a6faa1 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -19,6 +19,7 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/inference/engine.h" #include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/utils/singleton.h" @@ -52,13 +53,15 @@ class TensorRTEngine : public EngineBase { }; TensorRTEngine(int max_batch, int max_workspace, - cudaStream_t* stream = nullptr, + cudaStream_t* stream = nullptr, int device = 0, nvinfer1::ILogger& logger = NaiveLogger::Global()) : max_batch_(max_batch), max_workspace_(max_workspace), stream_(stream ? stream : &default_stream_), - logger_(logger) { - cudaStreamCreate(&default_stream_); + logger_(logger), + device_(device) { + freshDeviceId(); + cudaStreamCreate(stream_); } virtual ~TensorRTEngine(); @@ -119,6 +122,15 @@ class TensorRTEngine : public EngineBase { nvinfer1::INetworkDefinition* network() { return infer_network_.get(); } void SetRuntimeBatch(size_t batch_size); int GetRuntimeBatch(); + int GetDevice() { return device_; } + + // A pointer to CPU memory is needed of the TRT weight. + // Before TRT runs, fluid loads weight into GPU storage. + // so we need to copy the weights from GPU to CPU in our op converter. + // We use a map to store these weights for the weight memory is not released + // in advance, which affecting the construction of TRT Op. + std::unordered_map> + weight_map; private: // the max batch size @@ -140,6 +152,8 @@ class TensorRTEngine : public EngineBase { std::unordered_map buffer_sizes_; std::unordered_map itensor_map_; + // The specific GPU id that the TensorRTEngine bounded to. + int device_; // TensorRT related internal members template @@ -156,6 +170,10 @@ class TensorRTEngine : public EngineBase { infer_ptr infer_network_; infer_ptr infer_engine_; infer_ptr infer_context_; + // Each ICudaEngine object is bound to a specific GPU when it is instantiated, + // ensure that the thread is associated with the correct device by calling + // freshDeviceId(). + void freshDeviceId(); }; // class TensorRTEngine // Add an layer__ into engine__ with args ARGS. @@ -188,8 +206,8 @@ class TRT_EngineManager { // Create or get an engine called `name` TensorRTEngine* Create(int max_batch, int max_workspace, cudaStream_t* stream, - const std::string& name) { - auto* p = new TensorRTEngine(max_batch, max_workspace, stream); + const std::string& name, int gpu_device = 0) { + auto* p = new TensorRTEngine(max_batch, max_workspace, stream, gpu_device); engines_[name].reset(p); return p; } diff --git a/paddle/fluid/inference/tensorrt/test_engine.cc b/paddle/fluid/inference/tensorrt/test_engine.cc index dc03702990587bf5e65d28da662d10df4d882110..da1f6535cb3b2476cd475797861d6d2bb6d88856 100644 --- a/paddle/fluid/inference/tensorrt/test_engine.cc +++ b/paddle/fluid/inference/tensorrt/test_engine.cc @@ -27,7 +27,7 @@ namespace tensorrt { class TensorRTEngineTest : public ::testing::Test { protected: void SetUp() override { - ASSERT_EQ(0, cudaStreamCreate(&stream_)); + // ASSERT_EQ(0, cudaStreamCreate(&stream_)); engine_ = new TensorRTEngine(10, 1 << 10, &stream_); engine_->InitNetwork(); } diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index e8b5dec9d49f5613cec92441d19ab7dc1a1ad90c..e29fe2a42bd1aaee1ea8c01159e331cf47ca6b72 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -100,7 +100,8 @@ function(op_library TARGET) endif() # Define operators that don't need pybind here. - foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op") + foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" +"tensor_array_read_write_op" "tensorrt_engine_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() @@ -248,6 +249,7 @@ op_library(softmax_op DEPS softmax) op_library(sequence_softmax_op DEPS softmax) if (WITH_GPU AND TENSORRT_FOUND) op_library(tensorrt_engine_op DEPS tensorrt_engine tensorrt_converter) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(tensorrt_engine);\n") nv_test(test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc DEPS tensorrt_engine_op analysis) diff --git a/paddle/fluid/operators/tensorrt_engine_op.cc b/paddle/fluid/operators/tensorrt_engine_op.cc index ee3078876c15b06a887064f08dc0c05d450b5f77..1048d3017140c9e31426a1580b2862667116a024 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.cc +++ b/paddle/fluid/operators/tensorrt_engine_op.cc @@ -17,112 +17,16 @@ #include #include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" -#include "paddle/fluid/inference/tensorrt/engine.h" -#include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/operators/tensorrt_engine_op.h" namespace paddle { DEFINE_int32(tensorrt_engine_batch_size, 1, "the batch_size of TensorRT"); +DEFINE_int32(tensorrt_max_batch_size, 1, "TensorRT maximum batch size"); +DEFINE_int32(tensorrt_workspace_size, 16 << 20, "TensorRT workspace size"); namespace operators { -using inference::Singleton; -using inference::tensorrt::TRT_EngineManager; - -using FluidDT = framework::proto::VarType_Type; -using TRT_DT = nvinfer1::DataType; - -namespace { - -TRT_DT FluidDataType2TRT(FluidDT type) { - switch (type) { - case FluidDT::VarType_Type_FP32: - return TRT_DT::kFLOAT; - case FluidDT::VarType_Type_INT32: - return TRT_DT::kINT32; - default: - return TRT_DT::kINT32; - } - PADDLE_THROW("unkown type"); - return TRT_DT::kINT32; -} - -nvinfer1::Dims Vec2TRT_Dims(const std::vector &shape) { - PADDLE_ENFORCE_GT(shape.size(), 1UL, - "TensorRT' tensor input requires at least 2 dimensions"); - PADDLE_ENFORCE_LE(shape.size(), 4UL, - "TensorRT' tensor input requires at most 4 dimensions"); - PADDLE_ENFORCE_EQ(shape.size(), 4UL); - return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]); -} - -} // namespace - -template -void TensorRTEngineKernel::Prepare( - const framework::ExecutionContext &context) const { - VLOG(4) << "Prepare engine"; - // Get the ProgramDesc and pass to convert. - framework::proto::BlockDesc block_desc; - block_desc.ParseFromString(context.Attr("subgraph")); - int max_batch = context.Attr("max_batch"); - auto max_workspace = context.Attr("max_workspace"); - auto params = context.Attr>("parameters"); - std::unordered_set parameters; - for (const auto ¶m : params) { - parameters.insert(param); - } - - std::vector output_maps = - context.Attr>("output_name_mapping"); - - // TODO(Superjomn) replace this with a different stream - auto *engine = Singleton::Global().Create( - max_batch, max_workspace, nullptr /*engine hold its own stream*/, - context.Attr("engine_uniq_key")); - engine->InitNetwork(); - - framework::BlockDesc block(nullptr /*programdesc*/, &block_desc); - VLOG(4) << "parsed var size " << block.AllVars().size(); - // Add inputs - VLOG(4) << "declare inputs"; - for (auto &input : context.Inputs("Xs")) { - if (parameters.count(input)) continue; - VLOG(4) << "declare input " << input; - auto *var = block.FindVar(input); - // TensorRT engine need to create parameters. The parameter's description - // should be set in - PADDLE_ENFORCE(var, "no variable called %s", input); - PADDLE_ENFORCE_EQ(var->GetType(), FluidDT::VarType_Type_LOD_TENSOR, - "TensorRT engine only takes LoDTensor as input"); - auto shape = var->GetShape(); - // For the special batch_size placeholder -1, drop it and pass the real - // shape of data. - // TODO(Superjomn) fix this with batch broadcast, or it can't handle - // variational batch size. - if (shape[0] == -1) { - shape[0] = FLAGS_tensorrt_engine_batch_size; - } - engine->DeclareInput( - input, FluidDataType2TRT( - var->Proto()->type().lod_tensor().tensor().data_type()), - Vec2TRT_Dims(shape)); - } - - inference::Singleton::Global().ConvertBlock( - block_desc, parameters, context.scope(), engine); - - // Add outputs - for (auto &output : output_maps) { - engine->DeclareOutput(output); - } - - engine->FreezeNetwork(); -} - class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -130,8 +34,6 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Ys", "A list of outputs").AsDuplicable(); AddAttr("subgraph", "the subgraph."); AddAttr("engine_uniq_key", "unique key for the TRT engine."); - AddAttr("max_batch", "the maximum batch size."); - AddAttr("max_workspace", "the maximum batch size."); AddComment("TensorRT engine operator."); } }; @@ -150,11 +52,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(tensorrt_engine, ops::TensorRTEngineOp, ops::TensorRTEngineOpMaker, ops::TensorRTEngineOpMaker); -REGISTER_OP_CPU_KERNEL( - tensorrt_engine, - ops::TensorRTEngineKernel, - ops::TensorRTEngineKernel, - ops::TensorRTEngineKernel, - ops::TensorRTEngineKernel); - #endif // PADDLE_WITH_CUDA diff --git a/paddle/fluid/operators/tensorrt_engine_op.cu.cc b/paddle/fluid/operators/tensorrt_engine_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..e1ddfde6d51ef719ca0b89cf286b176195ee682a --- /dev/null +++ b/paddle/fluid/operators/tensorrt_engine_op.cu.cc @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/tensorrt_engine_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + tensorrt_engine, + ops::TensorRTEngineKernel, + ops::TensorRTEngineKernel, + ops::TensorRTEngineKernel, + ops::TensorRTEngineKernel); diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h index 2cbe1213a2f428a3ce56b06f97636baeb4b66c26..bc556ab3643cefa3e45d2a8a3835937753af723f 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -19,16 +19,51 @@ #include #include +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/inference/analysis/helper.h" +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/engine.h" namespace paddle { DECLARE_int32(tensorrt_engine_batch_size); +DECLARE_int32(tensorrt_max_batch_size); +DECLARE_int32(tensorrt_workspace_size); namespace operators { +using FluidDT = framework::proto::VarType_Type; +using TRT_DT = nvinfer1::DataType; + +namespace { + +TRT_DT FluidDataType2TRT(FluidDT type) { + switch (type) { + case FluidDT::VarType_Type_FP32: + return TRT_DT::kFLOAT; + case FluidDT::VarType_Type_INT32: + return TRT_DT::kINT32; + default: + return TRT_DT::kINT32; + } + PADDLE_THROW("unkown type"); + return TRT_DT::kINT32; +} + +nvinfer1::Dims Vec2TRT_Dims(const std::vector& shape) { + PADDLE_ENFORCE_GT(shape.size(), 1UL, + "TensorRT' tensor input requires at least 2 dimensions"); + PADDLE_ENFORCE_LE(shape.size(), 4UL, + "TensorRT' tensor input requires at most 4 dimensions"); + PADDLE_ENFORCE(shape.size() == 4UL || shape.size() == 2UL); + if (shape.size() == 4UL) + return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]); + return nvinfer1::DimsCHW(shape[1], 1, 1); +} + +} // namespace + using inference::Singleton; using inference::tensorrt::TRT_EngineManager; @@ -47,7 +82,7 @@ class TensorRTEngineOp : public framework::OperatorWithKernel { .FindVar(input0) ->GetMutable() ->type()), - platform::CPUPlace()); + ctx.GetPlace()); return kt; } }; @@ -64,7 +99,7 @@ class TensorRTEngineKernel : public framework::OpKernel { auto input_names = context.op().Inputs("Xs"); PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs"); PADDLE_ENFORCE_LE(FLAGS_tensorrt_engine_batch_size, - context.Attr("max_batch")); + FLAGS_tensorrt_max_batch_size); std::vector output_maps = context.Attr>("output_name_mapping"); @@ -94,12 +129,19 @@ class TensorRTEngineKernel : public framework::OpKernel { // Convert output tensor from engine to fluid int output_index = 0; + VLOG(4) << "TensorRT Engine Op Outputs:"; for (const auto& y : context.Outputs("Ys")) { + VLOG(4) << y; // convert output and copy to fluid. nvinfer1::ITensor* trt_t = engine->GetITensor(output_maps[output_index]); auto dims = trt_t->getDimensions(); // Use the output ITensor's dims to reshape the Fluid Tensor. - std::vector ddim(dims.d, dims.d + dims.nbDims); + // The ITensor doesn't contain the batch size dim. + std::vector ddim; + ddim.push_back(FLAGS_tensorrt_engine_batch_size); + for (int i = 0; i < dims.nbDims; i++) { + ddim.push_back(dims.d[i]); + } auto* fluid_v = context.scope().FindVar(y); PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y); @@ -113,9 +155,11 @@ class TensorRTEngineKernel : public framework::OpKernel { // TODO(Superjomn) change this float to dtype size. auto size = inference::analysis::AccuDims(dims.d, dims.nbDims) * FLAGS_tensorrt_engine_batch_size; - engine->GetOutputInCPU(output_maps[output_index], - fluid_t->mutable_data(platform::CPUPlace()), - size * sizeof(float)); + engine->GetOutputInGPU( + output_maps[output_index], + fluid_t->mutable_data(platform::CUDAPlace( + boost::get(context.GetPlace()).device)), + size * sizeof(float)); //} else { // engine->GetOutputInGPU( // y, fluid_t->mutable_data(platform::CUDAPlace()), @@ -128,8 +172,67 @@ class TensorRTEngineKernel : public framework::OpKernel { } protected: - // Build the engine. - void Prepare(const framework::ExecutionContext& context) const; + void Prepare(const framework::ExecutionContext& context) const { + VLOG(4) << "Prepare engine"; + // Get the ProgramDesc and pass to convert. + framework::proto::BlockDesc block_desc; + block_desc.ParseFromString(context.Attr("subgraph")); + int max_batch = FLAGS_tensorrt_max_batch_size; + auto max_workspace = FLAGS_tensorrt_workspace_size; + auto params = context.Attr>("parameters"); + std::unordered_set parameters; + for (const auto& param : params) { + parameters.insert(param); + } + + std::vector output_maps = + context.Attr>("output_name_mapping"); + + // TODO(Superjomn) replace this with a different stream + auto* engine = Singleton::Global().Create( + max_batch, max_workspace, nullptr /*engine hold its own stream*/, + context.Attr("engine_uniq_key"), + boost::get(context.GetPlace()).device); + + engine->InitNetwork(); + + framework::BlockDesc block(nullptr /*programdesc*/, &block_desc); + VLOG(4) << "parsed var size " << block.AllVars().size(); + // Add inputs + VLOG(4) << "declare inputs"; + for (auto& input : context.Inputs("Xs")) { + if (parameters.count(input)) continue; + VLOG(4) << "declare input " << input; + auto* var = block.FindVar(input); + // TensorRT engine need to create parameters. The parameter's description + // should be set in + PADDLE_ENFORCE(var, "no variable called %s", input); + PADDLE_ENFORCE_EQ(var->GetType(), FluidDT::VarType_Type_LOD_TENSOR, + "TensorRT engine only takes LoDTensor as input"); + auto shape = var->GetShape(); + // For the special batch_size placeholder -1, drop it and pass the real + // shape of data. + // TODO(Superjomn) fix this with batch broadcast, or it can't handle + // variational batch size. + if (shape[0] == -1) { + shape[0] = FLAGS_tensorrt_engine_batch_size; + } + engine->DeclareInput( + input, FluidDataType2TRT( + var->Proto()->type().lod_tensor().tensor().data_type()), + Vec2TRT_Dims(shape)); + } + + inference::Singleton::Global() + .ConvertBlock(block_desc, parameters, context.scope(), engine); + + // Add outputs + for (auto& output : output_maps) { + engine->DeclareOutput(output); + } + + engine->FreezeNetwork(); + } }; } // namespace operators diff --git a/paddle/fluid/operators/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt_engine_op_test.cc index 37657fa0b0498986fe67027415279af1775e58b9..27c1d29762b3de5e57f877b271aae52e71eb7cf9 100644 --- a/paddle/fluid/operators/tensorrt_engine_op_test.cc +++ b/paddle/fluid/operators/tensorrt_engine_op_test.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/operators/tensorrt_engine_op.h" #include #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/lod_tensor.h" @@ -23,20 +24,20 @@ limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" -USE_CPU_ONLY_OP(tensorrt_engine); +USE_CUDA_ONLY_OP(tensorrt_engine); namespace paddle { namespace operators { namespace { -void CreateCPUTensor(framework::Scope* scope, const std::string& name, - const std::vector& shape) { +void CreateCUDATensor(framework::Scope* scope, const std::string& name, + const std::vector& shape) { auto* var = scope->Var(name); auto* tensor = var->GetMutable(); auto dims = framework::make_ddim(shape); tensor->Resize(dims); - platform::CPUPlace place; - platform::CPUDeviceContext ctx(place); + platform::CUDAPlace place; + platform::CUDADeviceContext ctx(place); inference::tensorrt::RandomizeTensor(tensor, place, ctx); } @@ -57,6 +58,8 @@ void AddTensorToBlockDesc(framework::proto::BlockDesc* block, using inference::analysis::SetAttr; TEST(TensorRTEngineOp, manual) { + FLAGS_tensorrt_engine_batch_size = 2; + FLAGS_tensorrt_max_batch_size = 2; framework::ProgramDesc program; auto* block_ = program.Proto()->add_blocks(); block_->set_idx(0); @@ -98,8 +101,6 @@ TEST(TensorRTEngineOp, manual) { engine_op_desc.SetOutput("Ys", std::vector({"z0"})); SetAttr(engine_op_desc.Proto(), "subgraph", block_->SerializeAsString()); - SetAttr(engine_op_desc.Proto(), "max_batch", 100); - SetAttr(engine_op_desc.Proto(), "max_workspace", 1 << 10); SetAttr(engine_op_desc.Proto(), "engine_uniq_key", "a_engine"); SetAttr>(engine_op_desc.Proto(), "parameters", std::vector({})); @@ -112,15 +113,15 @@ TEST(TensorRTEngineOp, manual) { LOG(INFO) << "engine_op " << engine_op.get(); framework::Scope scope; - platform::CPUPlace place; - platform::CPUDeviceContext ctx(place); + platform::CUDAPlace place; + platform::CUDADeviceContext ctx(place); // Prepare variables. - CreateCPUTensor(&scope, "x", std::vector({2, 4})); - CreateCPUTensor(&scope, "y", std::vector({4, 6})); - CreateCPUTensor(&scope, "z", std::vector({2, 6})); + CreateCUDATensor(&scope, "x", std::vector({2, 4})); + CreateCUDATensor(&scope, "y", std::vector({4, 6})); + CreateCUDATensor(&scope, "z", std::vector({2, 6})); - CreateCPUTensor(&scope, "y0", std::vector({6, 8})); - CreateCPUTensor(&scope, "z0", std::vector({2, 8})); + CreateCUDATensor(&scope, "y0", std::vector({6, 8})); + CreateCUDATensor(&scope, "z0", std::vector({2, 8})); // Execute them. LOG(INFO) << "engine_op run"; @@ -128,10 +129,12 @@ TEST(TensorRTEngineOp, manual) { } void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { + FLAGS_tensorrt_engine_batch_size = batch_size; + FLAGS_tensorrt_max_batch_size = batch_size; framework::ProgramDesc program; framework::Scope scope; - platform::CPUPlace place; - platform::CPUDeviceContext ctx(place); + platform::CUDAPlace place; + platform::CUDADeviceContext ctx(place); auto* block_ = program.Proto()->add_blocks(); block_->set_idx(0); @@ -165,10 +168,10 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { // Prepare variables. if (!x_created) { - CreateCPUTensor(&scope, x_name, std::vector(x_shape)); + CreateCUDATensor(&scope, x_name, std::vector(x_shape)); } - CreateCPUTensor(&scope, y_name, std::vector(y_shape)); - CreateCPUTensor(&scope, z_name, std::vector(z_shape)); + CreateCUDATensor(&scope, y_name, std::vector(y_shape)); + CreateCUDATensor(&scope, z_name, std::vector(z_shape)); // It is wired, need to copy manually. *block_->add_ops() = *fc->Proto();