From 2de0b58e383b9e9fddef23041ac8470e3191abd6 Mon Sep 17 00:00:00 2001 From: feng_shuai Date: Fri, 15 Oct 2021 14:23:54 +0800 Subject: [PATCH] feat: Add TRT support for 3D(batch_norm_op and elementwise_add_op) (#36446) --- paddle/fluid/inference/tensorrt/convert/batch_norm_op.cc | 7 ++++--- paddle/fluid/inference/tensorrt/convert/elementwise_op.cc | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/batch_norm_op.cc b/paddle/fluid/inference/tensorrt/convert/batch_norm_op.cc index 7ea41839cb9..71a2fa68f17 100644 --- a/paddle/fluid/inference/tensorrt/convert/batch_norm_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/batch_norm_op.cc @@ -147,9 +147,10 @@ class BatchNormOpConverter : public OpConverter { X = expand_layer->getOutput(0); } - layer = TRT_ENGINE_ADD_LAYER( - engine_, Scale, *X, nvinfer1::ScaleMode::kCHANNEL, shift_weights.get(), - scale_weights.get(), power_weights.get()); + layer = TRT_ENGINE_ADD_LAYER(engine_, ScaleNd, *X, + nvinfer1::ScaleMode::kCHANNEL, + shift_weights.get(), scale_weights.get(), + power_weights.get(), dynamic_shape_offset); auto output_name = op_desc.Output("Y").front(); engine_->SetWeights(op_desc.Input("Bias").front(), diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc index 2f802ea8d18..8569dd63478 100644 --- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc @@ -83,8 +83,8 @@ class ElementwiseWeightOpConverter : public OpConverter { } if (op_type_ == "add") { nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER( - engine_, Scale, *X, scale_mode, shift_weights.get(), - scale_weights.get(), power_weights.get()); + engine_, ScaleNd, *X, scale_mode, shift_weights.get(), + scale_weights.get(), power_weights.get(), dynamic_shape_offset); layer = scale_layer; } else if (op_type_ == "mul") { nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER( -- GitLab