未验证 提交 2de0b58e 编写于 作者: F feng_shuai 提交者: GitHub

feat: Add TRT support for 3D(batch_norm_op and elementwise_add_op) (#36446)

上级 277c9a55
...@@ -147,9 +147,10 @@ class BatchNormOpConverter : public OpConverter { ...@@ -147,9 +147,10 @@ class BatchNormOpConverter : public OpConverter {
X = expand_layer->getOutput(0); X = expand_layer->getOutput(0);
} }
layer = TRT_ENGINE_ADD_LAYER( layer = TRT_ENGINE_ADD_LAYER(engine_, ScaleNd, *X,
engine_, Scale, *X, nvinfer1::ScaleMode::kCHANNEL, shift_weights.get(), nvinfer1::ScaleMode::kCHANNEL,
scale_weights.get(), power_weights.get()); shift_weights.get(), scale_weights.get(),
power_weights.get(), dynamic_shape_offset);
auto output_name = op_desc.Output("Y").front(); auto output_name = op_desc.Output("Y").front();
engine_->SetWeights(op_desc.Input("Bias").front(), engine_->SetWeights(op_desc.Input("Bias").front(),
......
...@@ -83,8 +83,8 @@ class ElementwiseWeightOpConverter : public OpConverter { ...@@ -83,8 +83,8 @@ class ElementwiseWeightOpConverter : public OpConverter {
} }
if (op_type_ == "add") { if (op_type_ == "add") {
nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER( nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *X, scale_mode, shift_weights.get(), engine_, ScaleNd, *X, scale_mode, shift_weights.get(),
scale_weights.get(), power_weights.get()); scale_weights.get(), power_weights.get(), dynamic_shape_offset);
layer = scale_layer; layer = scale_layer;
} else if (op_type_ == "mul") { } else if (op_type_ == "mul") {
nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER( nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册