diff --git a/paddle/fluid/inference/tensorrt/convert/prelu_op.cc b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc index 0de93624f175815500997ea019343a69da41115f..5e881ecbbc4e2cc8e81b9334dc827513bfad02eb 100644 --- a/paddle/fluid/inference/tensorrt/convert/prelu_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc @@ -72,9 +72,34 @@ class PReluOpConverter : public OpConverter { "your TRT version is no less than 6.0")); #endif } else { +#if IS_TRT_VERSION_GE(7000) + float* alpha_weight_data = engine_->GetWeightCPUData( + op_desc.Input("Alpha")[0], alpha_tensor, false); + TensorRTEngine::Weight alpha_weight{ + nvinfer1::DataType::kFLOAT, static_cast(alpha_weight_data), + static_cast(alpha_tensor->numel())}; + + nvinfer1::Dims dims; + dims.nbDims = 0; + // jump batch dim + for (int i = 1; i < alpha_tensor->dims().size(); i++) { + dims.d[dims.nbDims++] = alpha_tensor->dims()[i]; + } + for (; dims.nbDims < input->getDimensions().nbDims; dims.nbDims++) { + dims.d[dims.nbDims] = 1; + } + + auto alpha_layer = + TRT_ENGINE_ADD_LAYER(engine_, Constant, dims, alpha_weight.get()); + auto alpha_layer_output = alpha_layer->getOutput(0); + + layer = TRT_ENGINE_ADD_LAYER(engine_, ParametricReLU, *input, + *alpha_layer_output); +#else plugin::PReluPlugin* plugin = new plugin::PReluPlugin(alpha_data, alpha_tensor_temp->numel(), mode); layer = engine_->AddPlugin(&input, input_num, plugin); +#endif } // keep alpha tensor to avoid release it's memory engine_->SetWeights(op_desc.Input("Alpha")[0],