未验证 提交 d38dde68 编写于 作者: C chen 提交者: GitHub

add pad genetic plugin (#56037)

上级 e63297f5
......@@ -772,6 +772,30 @@ nvinfer1::DimsExprs Conv2dTransposeInferMeta(
return VecExprWrapper2DimsExprs(output_dims_wrap);
}
nvinfer1::DimsExprs PadInferMeta(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder, // NOLINT
const framework::OpDesc& op_desc) {
const auto x_dims = inputs[0];
auto paddings =
PADDLE_GET_CONST(std::vector<int>, op_desc.GetAttr("paddings"));
nvinfer1::DimsExprs output;
output.nbDims = x_dims.nbDims;
for (int i = 0; i < x_dims.nbDims; ++i) {
output.d[i] = expr_builder.operation(
nvinfer1::DimensionOperation::kSUM,
*expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
*x_dims.d[i],
*expr_builder.constant(paddings[2 * i])),
*expr_builder.constant(paddings[2 * i + 1]));
}
return output;
}
PD_REGISTER_DYNAMIC_INFER_META_FN(gather_nd, GatherNdInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(yolo_box, YoloBoxInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(instance_norm, InstanceNormInferMeta);
......@@ -785,7 +809,7 @@ PD_REGISTER_DYNAMIC_INFER_META_FN(conv2d_fusion, Conv2dFusionInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(conv2d, Conv2dFusionInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(conv2d_transpose, Conv2dTransposeInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(p_norm, PNormInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(pad, PadInferMeta);
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -32,6 +32,7 @@ USE_TRT_DYNAMIC_INFER_META_FN(conv2d_fusion);
USE_TRT_DYNAMIC_INFER_META_FN(conv2d);
USE_TRT_DYNAMIC_INFER_META_FN(conv2d_transpose);
USE_TRT_DYNAMIC_INFER_META_FN(p_norm);
USE_TRT_DYNAMIC_INFER_META_FN(pad);
} // namespace tensorrt
} // namespace inference
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册