diff --git a/Dockerfile b/Dockerfile index b4f8c9dcebb7040f12c0122a38e59d519f443a04..42a103240e882b2732f14619308cc00f010d20af 100644 --- a/Dockerfile +++ b/Dockerfile @@ -137,8 +137,8 @@ RUN curl -s -q https://glide.sh/get | sh RUN wget -q https://paddlepaddledeps.bj.bcebos.com/TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz --no-check-certificate && \ tar -zxf TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz -C /usr/local && \ - cp -rf /usr/local/TensorRT/include /usr && \ - cp -rf /usr/local/TensorRT/lib /usr + cp -rf /usr/local/TensorRT/include/* /usr/include/ && \ + cp -rf /usr/local/TensorRT/lib/* /usr/lib/ # git credential to skip password typing RUN git config --global credential.helper store diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 1cf72f7001b2a56eb613340f8d3d71c1bbec03a6..7b8b8655e7775a49e7568260646d3f820cdc40da 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -59,6 +59,7 @@ struct Argument { using unique_ptr_t = std::unique_ptr>; using fusion_statis_t = std::unordered_map; + using input_shape_t = std::map>; bool Has(const std::string& key) const { return valid_fields_.count(key); } // If we set the model using config.SetModelBuffer, @@ -174,6 +175,12 @@ struct Argument { DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool); DECL_ARGUMENT_FIELD(use_fc_padding, UseFcPadding, bool); DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int); + + // usually use for trt dynamic shape. + DECL_ARGUMENT_FIELD(min_input_shape, MinInputShape, input_shape_t); + DECL_ARGUMENT_FIELD(max_input_shape, MaxInputShape, input_shape_t); + DECL_ARGUMENT_FIELD(optim_input_shape, OptimInputShape, input_shape_t); + DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool); DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int); DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index a4a2fdb2b687ff6c354a6fdae2221e5947766e6d..ad12736fac86dfcb238f0180d8a7f127f344e467 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -123,6 +123,13 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set("gpu_device_id", new int(argument->gpu_device_id())); pass->Set("use_static_engine", new bool(use_static_engine)); pass->Set("model_from_memory", new bool(argument->model_from_memory())); + pass->Set("max_input_shape", new std::map>( + argument->max_input_shape())); + pass->Set("min_input_shape", new std::map>( + argument->min_input_shape())); + pass->Set("optim_input_shape", + new std::map>( + argument->optim_input_shape())); } if (pass_name == "ngraph_subgraph_pass") { pass->Set("program", diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 2b6418bbf8ab43e0c9b429e2a40bb8f53157e683..66231641a6cb452a85b90c8eeb302f1c3bd033de 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -166,6 +166,12 @@ void TensorRtSubgraphPass::CreateTensorRTOp( auto enable_int8 = Get("enable_int8"); auto use_calib_mode = Get("use_calib_mode"); auto &subgraph_nodes = *framework::ir::Agent(node).subgraph(); + auto min_input_shape = + Get>>("min_input_shape"); + auto max_input_shape = + Get>>("max_input_shape"); + auto opt_input_shape = + Get>>("optim_input_shape"); // The following procedure is used to rename all the intermediate // variables and the output variables of the subgraph. @@ -263,11 +269,33 @@ void TensorRtSubgraphPass::CreateTensorRTOp( std::copy(params_not_shared.begin(), params_not_shared.end(), std::back_inserter(*repetitive_params)); + // Check trt version for dynamic shape input. + + if (min_input_shape.size() > 0 && TRT_VERSION < 6000) { + std::cout << "hello"; + LOG_FIRST_N(WARNING, 1) << "You are using the dynamic size input mode of " + "Paddle-TRT, but we found that the version of " + "the TensorRT is less than 6.0, so we use the " + "static shape mode instead."; + min_input_shape = {}; + max_input_shape = {}; + opt_input_shape = {}; + } + + if (min_input_shape.size() > 0 && TRT_VERSION > 6000) { + LOG_FIRST_N(WARNING, 1) + << "The Paddle lib links the " << TRT_VERSION / 1000. + << " version TensorRT, " + << "make sure the runtime TensorRT you are using is no less than this " + "version, otherwise, there might be Segfault!"; + } + tensorrt::TensorRTEngine *trt_engine = inference::Singleton::Global() .Create(engine_key + std::to_string(predictor_id), Get("max_batch_size"), Get("workspace_size"), - precision_mode, calibrator.get(), Get("gpu_device_id")); + precision_mode, calibrator.get(), Get("gpu_device_id"), + min_input_shape, max_input_shape, opt_input_shape); bool need_serialize = (use_static_engine && !load_from_memory); if (need_serialize) { diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 75a05fa309fcb29cbf0d89294100366471a724c7..b6da6310b4c25ff3bc0c5551a5c5984c9ee92604 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -125,6 +125,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { // Quantization related. CP_MEMBER(use_mkldnn_quantizer_); CP_MEMBER(mkldnn_quantizer_config_); + CP_MEMBER(min_input_shape_); + CP_MEMBER(max_input_shape_); + CP_MEMBER(optim_input_shape_); CP_MEMBER(use_lite_); CP_MEMBER(lite_precision_mode_); @@ -223,7 +226,10 @@ MkldnnQuantizerConfig *AnalysisConfig::mkldnn_quantizer_config() const { void AnalysisConfig::EnableTensorRtEngine( int workspace_size, int max_batch_size, int min_subgraph_size, AnalysisConfig::Precision precision_mode, bool use_static, - bool use_calib_mode) { + bool use_calib_mode, + std::map> min_input_shape, + std::map> max_input_shape, + std::map> optim_input_shape) { #ifdef PADDLE_WITH_CUDA if (!use_gpu()) { LOG(ERROR) << "To use TensorRT engine, please call EnableGpu() first"; @@ -237,6 +243,9 @@ void AnalysisConfig::EnableTensorRtEngine( tensorrt_precision_mode_ = precision_mode; trt_use_static_engine_ = use_static; trt_use_calib_mode_ = use_calib_mode; + min_input_shape_ = min_input_shape; + max_input_shape_ = max_input_shape; + optim_input_shape_ = optim_input_shape; Update(); #else diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 5aa3d7a0527bc2736c47abf9c2bd47d52b26ce9d..78b2467ff09b137528dfb8b18d7b2b8c6657ba87 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -425,6 +425,9 @@ void AnalysisPredictor::PrepareArgument() { argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_); argument_.SetTensorRtUseStaticEngine(config_.trt_use_static_engine_); argument_.SetTensorRtUseCalibMode(config_.trt_use_calib_mode_); + argument_.SetMinInputShape(config_.min_input_shape_); + argument_.SetMaxInputShape(config_.max_input_shape_); + argument_.SetOptimInputShape(config_.optim_input_shape_); } if (config_.lite_engine_enabled()) { diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 260ec6562aa3521a751782c065c6c15ada774fa7..7a5ff0318b6676c27fda78ea3c32d0baf9f17d3b 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -160,11 +160,13 @@ struct AnalysisConfig { * @param min_subgrpah_size the minimum TensorRT subgraph size needed, if a * subgraph is less than this, it will not transfer to TensorRT engine. */ - void EnableTensorRtEngine(int workspace_size = 1 << 20, - int max_batch_size = 1, int min_subgraph_size = 3, - Precision precision = Precision::kFloat32, - bool use_static = false, - bool use_calib_mode = true); + void EnableTensorRtEngine( + int workspace_size = 1 << 20, int max_batch_size = 1, + int min_subgraph_size = 3, Precision precision = Precision::kFloat32, + bool use_static = false, bool use_calib_mode = true, + std::map> min_input_shape = {}, + std::map> max_input_shape = {}, + std::map> optim_input_shape = {}); /** A boolean state telling whether the TensorRT engine is used. */ bool tensorrt_engine_enabled() const { return use_tensorrt_; } @@ -348,6 +350,9 @@ struct AnalysisConfig { std::string serialized_info_cache_; mutable std::unique_ptr pass_builder_; + std::map> min_input_shape_; + std::map> max_input_shape_; + std::map> optim_input_shape_; bool use_lite_{false}; std::vector lite_passes_filter_; diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc index 840369976d7b731824fc390abdad915572022a51..4ae2f91d1a673d41b256113475982c1060518ad8 100644 --- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc @@ -49,8 +49,12 @@ class ElementwiseWeightOpConverter : public OpConverter { auto* X = engine_->GetITensor(op_desc.Input("X").front()); nvinfer1::Dims dims_x = X->getDimensions(); - PADDLE_ENFORCE(dims_x.nbDims >= 3, "x dims experts 3, but %d is given.", - dims_x.nbDims); + std::vector no_batch_dims; + int start_index = 0; + + if (engine_->with_dynamic_shape()) start_index = 1; + for (; start_index < dims_x.nbDims; start_index++) + no_batch_dims.push_back(dims_x.d[start_index]); auto* Y_v = scope.FindVar(op_desc.Input("Y").front()); PADDLE_ENFORCE_NOT_NULL(Y_v); @@ -62,23 +66,23 @@ class ElementwiseWeightOpConverter : public OpConverter { auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; std::vector dims_y = framework::vectorize(Y_t->dims()); - if (static_cast(dims_y.size()) == dims_x.nbDims + 1) { + if (dims_y.size() == no_batch_dims.size() + 1) { if (dims_y[0] == 1) dims_y.erase(dims_y.begin()); } - if (static_cast(dims_y.size()) == 1 && dims_y[0] == dims_x.d[0]) { + if (dims_y.size() == 1 && dims_y[0] == no_batch_dims[0]) { scale_mode = nvinfer1::ScaleMode::kCHANNEL; - } else if (static_cast(dims_y.size()) == dims_x.nbDims && - dims_y[0] == dims_x.d[0]) { + } else if (dims_y.size() == no_batch_dims.size() && + dims_y[0] == no_batch_dims[0]) { scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; - for (int i = 1; i < dims_x.nbDims; i++) { - if (dims_y[i] != dims_x.d[i]) { + for (size_t i = 1; i < no_batch_dims.size(); i++) { + if (dims_y[i] != no_batch_dims[i]) { scale_mode = nvinfer1::ScaleMode::kCHANNEL; break; } } if (scale_mode == nvinfer1::ScaleMode::kCHANNEL) { - for (int i = 1; i < dims_x.nbDims; i++) { + for (size_t i = 1; i < no_batch_dims.size(); i++) { if (dims_y[i] != 1) PADDLE_THROW( "TensorRT unsupported weight shape for Elementwise op!"); diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index ca5e1b8a7405769300c31d0f048b27758d474cf7..a299d845662c10a4ee29f119b806b367f4a6cd83 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -23,52 +23,13 @@ limitations under the License. */ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/utils/singleton.h" namespace paddle { namespace inference { namespace tensorrt { -using FluidDT = framework::proto::VarType_Type; -using TRT_DT = nvinfer1::DataType; - -namespace { // NOLINT - -TRT_DT FluidDataType2TRT(FluidDT type) { - switch (type) { - case FluidDT::VarType_Type_FP32: - return TRT_DT::kFLOAT; - case FluidDT::VarType_Type_INT32: - return TRT_DT::kINT32; - default: - return TRT_DT::kINT32; - } - PADDLE_THROW(platform::errors::InvalidArgument( - "unknown fluid datatype in TRT op converter")); - return TRT_DT::kINT32; -} - -nvinfer1::Dims Vec2TRT_Dims(const std::vector& shape, - std::string input) { - PADDLE_ENFORCE_GT(shape.size(), 1UL, - platform::errors::InvalidArgument( - "TensorRT's tensor input requires at least 2 " - "dimensions, but input %s has %d dims.", - input, shape.size())); - PADDLE_ENFORCE_LE(shape.size(), 4UL, - platform::errors::InvalidArgument( - "TensorRT's tensor input requires at most 4 " - "dimensions, but input %s has %d dims.", - input, shape.size())); - if (shape.size() == 4UL) - return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]); - else if (shape.size() == 3UL) - return nvinfer1::Dims2(shape[1], shape[2]); - return nvinfer1::DimsCHW(shape[1], 1, 1); -} - -} // namespace // NOLINT - /* * Convert Op from Fluid to TensorRT Engine. */ @@ -167,11 +128,37 @@ class OpConverter { PADDLE_ENFORCE_EQ(var->GetType(), FluidDT::VarType_Type_LOD_TENSOR, "TensorRT engine only takes LoDTensor as input"); auto var_shape = var->GetShape(); - - engine->DeclareInput( - input, FluidDataType2TRT( - var->Proto()->type().lod_tensor().tensor().data_type()), - Vec2TRT_Dims(var_shape, input)); + if (engine->with_dynamic_shape()) { +#if IS_TRT_VERSION_GE(6000) + auto min_input_shape = engine->min_input_shape()[input]; + auto max_input_shape = engine->max_input_shape()[input]; + auto optim_input_shape = engine->optim_input_shape()[input]; + size_t ranks = min_input_shape.size(); + std::vector input_shape; + input_shape.push_back(-1); + for (size_t i = 1; i < ranks; i++) { + if (min_input_shape[i] != max_input_shape[i]) { + input_shape.push_back(-1); + } else { + input_shape.push_back(min_input_shape[i]); + // the i dimension should be same. + PADDLE_ENFORCE_EQ(min_input_shape[i], optim_input_shape[i], + platform::errors::InvalidArgument( + "The dim (%d) of the min_input_shape and " + "optim_input_shape should be same.")); + } + } + engine->DeclareInput( + input, FluidDataType2TRT( + var->Proto()->type().lod_tensor().tensor().data_type()), + Vec2TRT_Dims(input_shape, input, true)); +#endif + } else { + engine->DeclareInput( + input, FluidDataType2TRT( + var->Proto()->type().lod_tensor().tensor().data_type()), + Vec2TRT_Dims(var_shape, input)); + } } framework::proto::BlockDesc* block_proto = block_desc->Proto(); ConvertBlock(*block_proto, parameters, scope, engine); diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index e7f7a842cf5725d4c83f1c4b8205ba32515a79fd..af9a580c75787e59e13bab3eb2af63bd99f3a339 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -28,23 +28,35 @@ namespace tensorrt { int TensorRTEngine::runtime_batch_ = 1; -void TensorRTEngine::Build(const DescType &paddle_model) { - PADDLE_ENFORCE(false, "not implemented"); +void TensorRTEngine::InitNetwork() { + freshDeviceId(); + infer_builder_.reset(createInferBuilder(&logger_)); + + if (with_dynamic_shape_) { +#if IS_TRT_VERSION_GE(6000) + infer_networkv2_.reset(infer_builder_->createNetworkV2( + 1U << static_cast( + nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH))); + infer_builder_config_.reset(infer_builder_->createBuilderConfig()); + infer_ptr infer_builder_config_; + optim_profile_.reset(infer_builder_->createOptimizationProfile()); +#endif + } else { + infer_network_.reset(infer_builder_->createNetwork()); + } } void TensorRTEngine::Execute(int batch_size, std::vector *buffers, cudaStream_t stream) { freshDeviceId(); - const std::thread::id tid = std::this_thread::get_id(); - batch_size_ = batch_size; - if (infer_context_.find(tid) == infer_context_.end()) { - std::unique_lock lock(mutex_); - PADDLE_ENFORCE_NOT_NULL( - infer_engine_, - "You should build engine first and then set the context."); - infer_context_[tid].reset(infer_engine_->createExecutionContext()); + auto infer_context = context(); + if (!with_dynamic_shape()) { + infer_context->enqueue(batch_size, buffers->data(), stream, nullptr); + } else { +#if IS_TRT_VERSION_GE(6000) + infer_context->enqueueV2(buffers->data(), stream, nullptr); +#endif } - infer_context_[tid]->enqueue(batch_size, buffers->data(), stream, nullptr); SetRuntimeBatch(batch_size); } @@ -53,8 +65,9 @@ void TensorRTEngine::FreezeNetwork() { VLOG(3) << "TRT to freeze network"; PADDLE_ENFORCE(infer_builder_ != nullptr, "Call InitNetwork first to initialize network."); - PADDLE_ENFORCE(infer_network_ != nullptr, - "Call InitNetwork first to initialize network."); + PADDLE_ENFORCE_EQ(network() != nullptr, true, + platform::errors::InvalidArgument( + "Call InitNetwork first to initialize network.")); // build engine. infer_builder_->setMaxBatchSize(max_batch_); infer_builder_->setMaxWorkspaceSize(max_workspace_); @@ -66,6 +79,8 @@ void TensorRTEngine::FreezeNetwork() { if (!support_fp16) { LOG(INFO) << "You specify FP16 mode, but the hardware do not support " "FP16 speed up, use FP32 instead."; + } else { + LOG(INFO) << "Run Paddle-TRT FP16 mode"; } } #else @@ -92,14 +107,14 @@ void TensorRTEngine::FreezeNetwork() { } std::unordered_set all_t; - for (int i = 0; i < infer_network_->getNbLayers(); i++) { - auto layer = infer_network_->getLayer(i); + for (int i = 0; i < network()->getNbLayers(); i++) { + auto layer = network()->getLayer(i); for (int j = 0; j < layer->getNbOutputs(); j++) { all_t.insert(layer->getOutput(j)); } } - for (int i = 0; i < infer_network_->getNbInputs(); i++) { - all_t.insert(infer_network_->getInput(i)); + for (int i = 0; i < network()->getNbInputs(); i++) { + all_t.insert(network()->getInput(i)); } for (auto &t : all_t) { @@ -110,14 +125,14 @@ void TensorRTEngine::FreezeNetwork() { } } std::unordered_set all_out_t_name; - for (int i = 0; i < infer_network_->getNbOutputs(); i++) { - auto *temp = infer_network_->getOutput(i); + for (int i = 0; i < network()->getNbOutputs(); i++) { + auto *temp = network()->getOutput(i); temp->setDynamicRange(-1, 1); all_out_t_name.insert(temp->getName()); } - for (int i = 0; i < infer_network_->getNbLayers(); i++) { - auto layer = infer_network_->getLayer(i); + for (int i = 0; i < network()->getNbLayers(); i++) { + auto layer = network()->getLayer(i); for (int j = 0; j < layer->getNbOutputs(); j++) { auto *temp_out = layer->getOutput(j); if (std::find(all_out_t_name.begin(), all_out_t_name.end(), @@ -127,26 +142,41 @@ void TensorRTEngine::FreezeNetwork() { } } } - #endif } } - infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_)); + if (with_dynamic_shape_) { +#if IS_TRT_VERSION_GE(6000) + for (auto &input : min_input_shape_) { + optim_profile_->setDimensions( + input.first.c_str(), nvinfer1::OptProfileSelector::kMIN, + Vec2TRT_Dims(input.second, input.first, true)); + optim_profile_->setDimensions( + input.first.c_str(), nvinfer1::OptProfileSelector::kMAX, + Vec2TRT_Dims(max_input_shape_[input.first], input.first, true)); + optim_profile_->setDimensions( + input.first.c_str(), nvinfer1::OptProfileSelector::kOPT, + Vec2TRT_Dims(optim_input_shape_[input.first], input.first, true)); + } + infer_builder_config_->addOptimizationProfile(optim_profile_.get()); + infer_engine_.reset(infer_builder_->buildEngineWithConfig( + *network(), *infer_builder_config_)); +#endif + } else { + infer_engine_.reset(infer_builder_->buildCudaEngine(*network())); + } PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!"); } nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name, nvinfer1::DataType dtype, const nvinfer1::Dims &dims) { - PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s", - name); - - PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first"); - auto *input = infer_network_->addInput(name.c_str(), dtype, dims); + PADDLE_ENFORCE_EQ(network() != nullptr, true, + platform::errors::InvalidArgument( + "The TRT network should be initialized first.")); + auto *input = network()->addInput(name.c_str(), dtype, dims); PADDLE_ENFORCE(input, "infer network add input %s failed", name); - buffer_sizes_[name] = kDataTypeSize[static_cast(dtype)] * - analysis::AccuDims(dims.d, dims.nbDims) * max_batch_; PADDLE_ENFORCE(input->isNetworkInput()); TensorRTEngine::SetITensor(name, input); return input; @@ -154,37 +184,21 @@ nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name, void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer *layer, int offset, const std::string &name) { - PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s", - name); - auto *output = layer->getOutput(offset); SetITensor(name, output); PADDLE_ENFORCE(output != nullptr); output->setName(name.c_str()); PADDLE_ENFORCE(!output->isNetworkInput()); - infer_network_->markOutput(*output); + network()->markOutput(*output); PADDLE_ENFORCE(output->isNetworkOutput()); - // output buffers' size can only be decided later, set zero here to mark this - // and will reset later. - buffer_sizes_[name] = 0; -} - -bool TensorRTEngine::HasDeclared(const std::string &name) { - return buffer_sizes_.count(name) > 0; } void TensorRTEngine::DeclareOutput(const std::string &name) { - PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s", - name); - auto *output = TensorRTEngine::GetITensor(name); PADDLE_ENFORCE(output != nullptr); output->setName(name.c_str()); PADDLE_ENFORCE(!output->isNetworkInput()); - infer_network_->markOutput(*output); - // output buffers' size can only be decided later, set zero here to mark this - // and will reset later. - buffer_sizes_[name] = 0; + network()->markOutput(*output); } void TensorRTEngine::SetITensor(const std::string &name, @@ -253,7 +267,7 @@ nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin( nvinfer1::ITensor *const *inputs, int num_inputs, plugin::PluginTensorRT *plugin) { owned_plugin_.emplace_back(plugin); - return infer_network_.get()->addPluginExt(inputs, num_inputs, *plugin); + return network()->addPluginExt(inputs, num_inputs, *plugin); } void TensorRTEngine::freshDeviceId() { diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index d847ce4b5df3a047fb9366d5232aefbf7814f2fd..c209bbd04c9c4ec4f96fbf9d919f17f4d73567e8 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include // NOLINT #include #include #include @@ -36,6 +37,57 @@ namespace paddle { namespace inference { namespace tensorrt { +using FluidDT = framework::proto::VarType_Type; +using TRT_DT = nvinfer1::DataType; + +namespace { // NOLINT + +TRT_DT FluidDataType2TRT(FluidDT type) { + switch (type) { + case FluidDT::VarType_Type_FP32: + return TRT_DT::kFLOAT; + case FluidDT::VarType_Type_INT32: + return TRT_DT::kINT32; + default: + return TRT_DT::kINT32; + } + PADDLE_THROW(platform::errors::InvalidArgument( + "unknown fluid datatype in TRT op converter")); + return TRT_DT::kINT32; +} + +// The T can be int32 or int64 type. +template +nvinfer1::Dims Vec2TRT_Dims(const std::vector& shape, std::string input, + bool with_dynamic_shape = false) { + PADDLE_ENFORCE_GT(shape.size(), 1UL, + platform::errors::InvalidArgument( + "TensorRT's tensor input requires at least 2 " + "dimensions, but input %s has %d dims.", + input, shape.size())); + PADDLE_ENFORCE_LE(shape.size(), 4UL, + platform::errors::InvalidArgument( + "TensorRT's tensor input requires at most 4 " + "dimensions, but input %s has %d dims.", + input, shape.size())); + if (!with_dynamic_shape) { + if (shape.size() == 4UL) { + return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]); + } else if (shape.size() == 3UL) { + return nvinfer1::Dims2(shape[1], shape[2]); + } + return nvinfer1::DimsCHW(shape[1], 1, 1); + } else { + if (shape.size() == 4UL) { + return nvinfer1::DimsNCHW(shape[0], shape[1], shape[2], shape[3]); + } else if (shape.size() == 3UL) { + return nvinfer1::Dims3(shape[0], shape[1], shape[2]); + } + return nvinfer1::Dims4(shape[0], shape[1], 1, 1); + } +} +} // NOLINT + class TRTInt8Calibrator; /* * TensorRT Engine. @@ -45,6 +97,7 @@ class TRTInt8Calibrator; */ class TensorRTEngine { using DescType = ::paddle::framework::proto::BlockDesc; + using ShapeMapType = std::map>; public: // Weight is model parameter. @@ -68,33 +121,44 @@ class TensorRTEngine { int max_batch, int max_workspace, AnalysisConfig::Precision precision = AnalysisConfig::Precision::kFloat32, TRTInt8Calibrator* calibrator = nullptr, int device_id = 0, + const ShapeMapType min_input_shape = {}, + const ShapeMapType max_input_shape = {}, + const ShapeMapType optim_input_shape = {}, nvinfer1::ILogger& logger = NaiveLogger::Global()) : max_batch_(max_batch), max_workspace_(max_workspace), precision_(precision), calibrator_(calibrator), device_id_(device_id), - logger_(logger) {} + min_input_shape_(min_input_shape), + max_input_shape_(max_input_shape), + optim_input_shape_(optim_input_shape), + logger_(logger) { + if (min_input_shape_.size() != 0 && max_input_shape_.size() != 0 && + optim_input_shape_.size() != 0) { + PADDLE_ENFORCE_EQ( + min_input_shape_.size(), max_input_shape_.size(), + platform::errors::InvalidArgument( + "The min_input_shape_'s size(%d) should be equal to the " + "size(%d) of max_input_shape_", + min_input_shape_.size(), max_input_shape_.size())); + PADDLE_ENFORCE_EQ( + min_input_shape_.size(), optim_input_shape_.size(), + platform::errors::InvalidArgument( + "The min_input_shape_'s size(%d) should be equal to the " + "size(%d) of optim_input_shape_", + min_input_shape_.size(), optim_input_shape_.size())); +#if IS_TRT_VERSION_GE(6000) + with_dynamic_shape_ = true; +#else + LOG(WARNING) << "Using dynamic shape of TRT need ensure that the TRT " + "version should be at least 6."; +#endif + } + } ~TensorRTEngine() {} - // TODO(Superjomn) implement it later when graph segmentation is supported. - void Build(const DescType& paddle_model); - - void Execute(int batch_size, std::vector* buffers, - cudaStream_t stream = nullptr); - - // Initialize the inference network, so that TensorRT layers can add to this - // network. - void InitNetwork() { - freshDeviceId(); - infer_builder_.reset(createInferBuilder(&logger_)); - infer_network_.reset(infer_builder_->createNetwork()); - } - // After finishing adding ops, freeze this network and creates the execution - // environment. - void FreezeNetwork(); - // Add an input and set its name, data type and dimension. nvinfer1::ITensor* DeclareInput(const std::string& name, nvinfer1::DataType dtype, @@ -105,15 +169,24 @@ class TensorRTEngine { const std::string& name); // Set the itensor_map_[name] as the network's output, and set its name. void DeclareOutput(const std::string& name); - // Check if the ITensor has been declared - bool HasDeclared(const std::string& name); void SetITensor(const std::string& name, nvinfer1::ITensor* tensor); // Get an ITensor called name. nvinfer1::ITensor* GetITensor(const std::string& name); nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); } - nvinfer1::INetworkDefinition* network() { return infer_network_.get(); } + nvinfer1::IExecutionContext* context() { + std::unique_lock lock(mutex_); + const std::thread::id tid = std::this_thread::get_id(); + if (infer_context_.find(tid) == infer_context_.end()) { + PADDLE_ENFORCE_NOT_NULL( + infer_engine_, + platform::errors::InvalidArgument( + "You should build engine first and then set the context.")); + infer_context_[tid].reset(infer_engine_->createExecutionContext()); + } + return infer_context_[tid].get(); + } nvinfer1::IHostMemory* Serialize() { PADDLE_ENFORCE(infer_engine_ != nullptr, @@ -170,6 +243,30 @@ class TensorRTEngine { } } + // NOTE: The func bellow was modified to adapt the dynamic shape. + // Initialize the inference network, so that TensorRT layers can add to this + // network. + void InitNetwork(); + // After finishing adding ops, freeze this network and creates the execution + // environment. + void FreezeNetwork(); + void Execute(int batch_size, std::vector* buffers, + cudaStream_t stream = nullptr); + + nvinfer1::INetworkDefinition* network() { + if (with_dynamic_shape_) { + return infer_networkv2_.get(); + } else { + return infer_network_.get(); + } + } + + ShapeMapType min_input_shape() { return min_input_shape_; } + ShapeMapType max_input_shape() { return max_input_shape_; } + ShapeMapType optim_input_shape() { return optim_input_shape_; } + + bool with_dynamic_shape() { return with_dynamic_shape_; } + private: // Each ICudaEngine object is bound to a specific GPU when it is instantiated, // ensure that the thread is associated with the correct device by calling @@ -189,10 +286,12 @@ class TensorRTEngine { int batch_size_{-1}; int device_id_; + ShapeMapType min_input_shape_; + ShapeMapType max_input_shape_; + ShapeMapType optim_input_shape_; nvinfer1::ILogger& logger_; // max data size for the buffers. - std::unordered_map buffer_sizes_; std::unordered_map itensor_map_; @@ -216,13 +315,17 @@ class TensorRTEngine { infer_context_; infer_ptr ihost_memory_; std::unordered_map quant_dynamic_range_; + + // For dynamic shape + bool with_dynamic_shape_{false}; + infer_ptr infer_networkv2_; +#if IS_TRT_VERSION_GE(6000) + infer_ptr infer_builder_config_; + std::unique_ptr optim_profile_; +#endif std::mutex mutex_; }; // class TensorRTEngine -#define IS_TRT_VERSION_GE(version) \ - ((NV_TENSORRT_MAJOR * 1000 + NV_TENSORRT_MINOR * 100 + \ - NV_TENSORRT_PATCH * 10 + NV_TENSORRT_BUILD) >= version) - // Add a layer__ into engine__ with args ARGS. // For example: // @@ -252,9 +355,13 @@ class TRTEngineManager { std::string name, int max_batch, int max_workspace, AnalysisConfig::Precision precision = AnalysisConfig::Precision::kFloat32, TRTInt8Calibrator* calibrator = nullptr, int device_id = 0, + const std::map> min_input_shape = {}, + const std::map> max_input_shape = {}, + const std::map> optim_input_shape = {}, nvinfer1::ILogger& logger = NaiveLogger::Global()) { auto* p = new TensorRTEngine(max_batch, max_workspace, precision, - calibrator, device_id, logger); + calibrator, device_id, min_input_shape, + max_input_shape, optim_input_shape, logger); engines_[name].reset(p); return p; } diff --git a/paddle/fluid/inference/tensorrt/helper.h b/paddle/fluid/inference/tensorrt/helper.h index 010942a0678fe9a592d1a95ba9cdc6adc42cc2ec..037dabf5d5888aecdaf781de474c35098be144c1 100644 --- a/paddle/fluid/inference/tensorrt/helper.h +++ b/paddle/fluid/inference/tensorrt/helper.h @@ -27,6 +27,14 @@ namespace paddle { namespace inference { namespace tensorrt { +#define IS_TRT_VERSION_GE(version) \ + ((NV_TENSORRT_MAJOR * 1000 + NV_TENSORRT_MINOR * 100 + \ + NV_TENSORRT_PATCH * 10 + NV_TENSORRT_BUILD) >= version) + +#define TRT_VERSION \ + NV_TENSORRT_MAJOR * 1000 + NV_TENSORRT_MINOR * 100 + \ + NV_TENSORRT_PATCH * 10 + NV_TENSORRT_BUILD + namespace dy = paddle::platform::dynload; // TensorRT data type to size @@ -103,6 +111,14 @@ class NaiveProfiler : public nvinfer1::IProfiler { } }; +inline size_t ProductDim(const nvinfer1::Dims& dims) { + size_t v = 1; + for (int i = 0; i < dims.nbDims; i++) { + v *= dims.d[i]; + } + return v; +} + } // namespace tensorrt } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 6e2b8a50bf19c6573762610ad870f02e66d07c00..c281cd8bccd86a39993e04cde49fe8f453c8c636 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -374,6 +374,14 @@ if(WITH_GPU AND TENSORRT_FOUND) inference_analysis_test(trt_quant_int8_test SRCS trt_quant_int8_test.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS --infer_model=${TRT_MODEL_QUANT_RESNET_DIR}) + + set(TEST_TRT_DYNAMIC_MODEL "${TRT_MODEL_INSTALL_DIR}/test_trt_dy_conv") + if (NOT EXISTS ${TEST_TRT_DYNAMIC_MODEL}) + inference_download_and_uncompress(${TEST_TRT_DYNAMIC_MODEL} ${INFERENCE_URL}/tensorrt_test "test_trt_dy_conv.tar.gz") + endif() + inference_analysis_test(trt_dynamic_shape_test SRCS trt_dynamic_shape_test.cc + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} + ARGS --infer_model=${TEST_TRT_DYNAMIC_MODEL}) endif() set(LITE_MODEL_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/lite") diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9c186c5973520f90c9dab9d9dea97901fee151fa --- /dev/null +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_test.cc @@ -0,0 +1,66 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/inference/tests/api/trt_test_helper.h" + +namespace paddle { +namespace inference { + +TEST(AnalysisPredictor, use_gpu) { + std::string model_dir = FLAGS_infer_model + "/test_trt_dy_conv"; + AnalysisConfig config; + config.EnableUseGpu(100, 0); + config.SetModel(model_dir); + config.SwitchUseFeedFetchOps(false); + // Set the input's min, max, opt shape + std::map> min_input_shape = { + {"image", {1, 1, 3, 3}}}; + std::map> max_input_shape = { + {"image", {1, 1, 10, 10}}}; + std::map> opt_input_shape = { + {"image", {1, 1, 3, 3}}}; + config.EnableTensorRtEngine( + 1 << 30, 1, 1, AnalysisConfig::Precision::kFloat32, false, true, + min_input_shape, max_input_shape, opt_input_shape); + auto predictor = CreatePaddlePredictor(config); + auto input_names = predictor->GetInputNames(); + int channels = 1; + int height = 3; + int width = 3; + int input_num = channels * height * width * 1; + + float *input = new float[input_num]; + memset(input, 0, input_num * sizeof(float)); + auto input_t = predictor->GetInputTensor(input_names[0]); + input_t->Reshape({1, channels, height, width}); + input_t->copy_from_cpu(input); + + ASSERT_TRUE(predictor->ZeroCopyRun()); + + std::vector out_data; + auto output_names = predictor->GetOutputNames(); + auto output_t = predictor->GetOutputTensor(output_names[0]); + std::vector output_shape = output_t->shape(); + int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, + std::multiplies()); + out_data.resize(out_num); + output_t->copy_to_cpu(out_data.data()); +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 3a33c8be101ecc4323e34d7aa2ed514beac90e92..dc8aa57e6d7d148e70b279e3c92c3c188316ecf5 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -29,6 +29,7 @@ #include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/helper.h" namespace paddle { @@ -40,6 +41,29 @@ using inference::tensorrt::TRTInt8Calibrator; using inference::tensorrt::TRTCalibratorEngine; using inference::tensorrt::TRTCalibratorEngineManager; +static void RuntimeStaticShapeCheck(std::vector runtime_input_shape, + std::vector model_input_shape) { + auto comma_fold = [](std::string a, int b) { + return std::move(a) + ", " + std::to_string(b); + }; + std::string model_input_shape_str = std::accumulate( + std::next(model_input_shape.begin()), model_input_shape.end(), + std::to_string(model_input_shape[0]), comma_fold); + std::string runtime_input_shape_str = std::accumulate( + std::next(runtime_input_shape.begin()), runtime_input_shape.end(), + std::to_string(runtime_input_shape[0]), comma_fold); + PADDLE_ENFORCE_EQ( + model_input_shape == runtime_input_shape, true, + platform::errors::InvalidArgument( + "Input shapes are inconsistent with the model. Expect [%s] in " + "model description, but got [%s] in runtime. TRT 5 " + "or lower version " + "does not support dynamic input shapes. Please check and " + "modify " + "your input shapes.", + model_input_shape_str, runtime_input_shape_str)); +} + class TensorRTEngineOp : public framework::OperatorBase { private: std::vector input_names_; @@ -206,39 +230,28 @@ class TensorRTEngineOp : public framework::OperatorBase { auto &t = inference::analysis::GetFromScope(scope, x); auto t_shape = framework::vectorize(t.dims()); - // check if the input shapes are consistent with model. - if (HasAttr(x + "_shape")) { - std::vector i_shape = Attr>(x + "_shape"); - std::vector model_input_shape(i_shape.begin() + 1, - i_shape.end()); - std::vector runtime_input_shape(t_shape.begin() + 1, - t_shape.end()); - auto comma_fold = [](std::string a, int b) { - return std::move(a) + ", " + std::to_string(b); - }; - std::string model_input_shape_str = std::accumulate( - std::next(model_input_shape.begin()), model_input_shape.end(), - std::to_string(model_input_shape[0]), comma_fold); - std::string runtime_input_shape_str = std::accumulate( - std::next(runtime_input_shape.begin()), runtime_input_shape.end(), - std::to_string(runtime_input_shape[0]), comma_fold); - PADDLE_ENFORCE_EQ( - model_input_shape == runtime_input_shape, true, - platform::errors::InvalidArgument( - "Input shapes are inconsistent with the model. Expect [%s] in " - "model description, but got [%s] in runtime. TRT 5 " - "or lower version " - "does not support dynamic input shapes. Please check and " - "modify " - "your input shapes.", - model_input_shape_str, runtime_input_shape_str)); - } - runtime_batch = t_shape[0]; - const int bind_index = engine->engine()->getBindingIndex(x.c_str()); PADDLE_ENFORCE(bind_index < num_bindings, "The bind index should be less than num_bindings"); + if (!engine->with_dynamic_shape()) { + // check if the input shapes are consistent with model. + if (HasAttr(x + "_shape")) { + std::vector i_shape = + Attr>(x + "_shape"); + std::vector model_input_shape(i_shape.begin() + 1, + i_shape.end()); + std::vector runtime_input_shape(t_shape.begin() + 1, + t_shape.end()); + RuntimeStaticShapeCheck(runtime_input_shape, model_input_shape); + } + } else { +#if IS_TRT_VERSION_GE(6000) + auto *trt_context = engine->context(); + trt_context->setBindingDimensions( + bind_index, inference::tensorrt::Vec2TRT_Dims(t_shape, x, true)); +#endif + } buffers[bind_index] = static_cast(t.data()); } @@ -248,13 +261,20 @@ class TensorRTEngineOp : public framework::OperatorBase { for (const auto &y : Outputs("Ys")) { const int bind_index = engine->engine()->getBindingIndex(output_maps[output_index].c_str()); - auto dims = engine->engine()->getBindingDimensions(bind_index); - // Use the output ITensor's dims to reshape the Fluid Tensor. - // The ITensor doesn't contain the batch size dim. std::vector ddim; - ddim.push_back(runtime_batch); - for (int i = 0; i < dims.nbDims; i++) { - ddim.push_back(dims.d[i]); + + if (!engine->with_dynamic_shape()) { + auto dims = engine->engine()->getBindingDimensions(bind_index); + ddim.push_back(runtime_batch); + for (int i = 0; i < dims.nbDims; i++) { + ddim.push_back(dims.d[i]); + } + } else { +#if IS_TRT_VERSION_GE(6000) + auto *trt_context = engine->context(); + auto dims = trt_context->getBindingDimensions(bind_index); + for (int i = 0; i < dims.nbDims; i++) ddim.push_back(dims.d[i]); +#endif } auto *fluid_v = scope.FindVar(y); PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y); @@ -289,7 +309,6 @@ class TensorRTEngineOp : public framework::OperatorBase { runtime_batch, max_batch_size_)); // Execute the engine. engine->Execute(runtime_batch, &buffers, stream); - cudaStreamSynchronize(stream); } TensorRTEngine *GetEngine(const framework::Scope &scope, diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 2d5aae960acd10005031fb10ede49927e1f2268a..3d4a8963a5af1fa910e70443945731bb56addc1d 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -412,7 +412,13 @@ void BindAnalysisConfig(py::module *m) { py::arg("workspace_size") = 1 << 20, py::arg("max_batch_size") = 1, py::arg("min_subgraph_size") = 3, py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32, - py::arg("use_static") = false, py::arg("use_calib_mode") = true) + py::arg("use_static") = false, py::arg("use_calib_mode") = true, + py::arg("min_input_shape") = + std::map>({}), + py::arg("max_input_shape") = + std::map>({}), + py::arg("optim_input_shape") = + std::map>({})) .def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled) .def("switch_ir_debug", &AnalysisConfig::SwitchIrDebug, py::arg("x") = true) diff --git a/paddle/scripts/build_docker_images.sh b/paddle/scripts/build_docker_images.sh index c60f42da7aa6351985acaede756d139329ef520c..a90f0885294a9cfb9f65c3cc993cd77025a9dc4a 100644 --- a/paddle/scripts/build_docker_images.sh +++ b/paddle/scripts/build_docker_images.sh @@ -6,8 +6,8 @@ REPO="${REPO:-paddlepaddle}" cp -f ../../python/requirements.txt . sed 's#FROM nvidia/cuda:8.0-cudnn7-devel-ubuntu16.04#FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04#g' ../../Dockerfile | -sed 's#TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz#TensorRT_5.1_ga_cuda9_cudnnv7.5.tar.gz#g' | -sed 's#/usr/local/TensorRT#/usr/local/TensorRT_5.1_ga_cuda9_cudnnv7.5#g' | +sed 's#TensorRT-4.0.1.6-ubuntu14.04.x86_64-gnu.cuda.8.0.cudnn7.0.tar.gz#TensorRT-6.0.1.5.Ubuntu-16.04.x86_64-gnu.cuda-9.0.cudnn7.6.tar.gz#g' | +sed 's#/usr/local/TensorRT#/usr/local/TensorRT-6.0.1.5#g' | sed 's#libnccl2=2.1.2-1+cuda8.0 libnccl-dev=2.1.2-1+cuda8.0#libnccl2=2.4.7-1+cuda9.0 libnccl-dev=2.4.7-1+cuda9.0#g' | sed 's#COPY ./paddle/scripts/docker/root/#COPY ./docker/root/#g' | sed 's#COPY ./python/requirements.txt#COPY ./requirements.txt#' > Dockerfile.cuda9.0-cudnn7