未验证 提交 77c44e2f 编写于 作者: S Shang Zhizhou 提交者: GitHub

change prelu plugin to tensorRT layer (#30210)

上级 353dd0cd
......@@ -72,9 +72,34 @@ class PReluOpConverter : public OpConverter {
"your TRT version is no less than 6.0"));
#endif
} else {
#if IS_TRT_VERSION_GE(7000)
float* alpha_weight_data = engine_->GetWeightCPUData(
op_desc.Input("Alpha")[0], alpha_tensor, false);
TensorRTEngine::Weight alpha_weight{
nvinfer1::DataType::kFLOAT, static_cast<void*>(alpha_weight_data),
static_cast<size_t>(alpha_tensor->numel())};
nvinfer1::Dims dims;
dims.nbDims = 0;
// jump batch dim
for (int i = 1; i < alpha_tensor->dims().size(); i++) {
dims.d[dims.nbDims++] = alpha_tensor->dims()[i];
}
for (; dims.nbDims < input->getDimensions().nbDims; dims.nbDims++) {
dims.d[dims.nbDims] = 1;
}
auto alpha_layer =
TRT_ENGINE_ADD_LAYER(engine_, Constant, dims, alpha_weight.get());
auto alpha_layer_output = alpha_layer->getOutput(0);
layer = TRT_ENGINE_ADD_LAYER(engine_, ParametricReLU, *input,
*alpha_layer_output);
#else
plugin::PReluPlugin* plugin =
new plugin::PReluPlugin(alpha_data, alpha_tensor_temp->numel(), mode);
layer = engine_->AddPlugin(&input, input_num, plugin);
#endif
}
// keep alpha tensor to avoid release it's memory
engine_->SetWeights(op_desc.Input("Alpha")[0],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册