未验证 提交 37ac0dda 编写于 作者: F feng_shuai 提交者: GitHub

feat: Add TRT support for 3D(batch_norm_op and elementwise_add_op) (#36446) (#36702)

上级 59615fff
......@@ -147,9 +147,10 @@ class BatchNormOpConverter : public OpConverter {
X = expand_layer->getOutput(0);
}
layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *X, nvinfer1::ScaleMode::kCHANNEL, shift_weights.get(),
scale_weights.get(), power_weights.get());
layer = TRT_ENGINE_ADD_LAYER(engine_, ScaleNd, *X,
nvinfer1::ScaleMode::kCHANNEL,
shift_weights.get(), scale_weights.get(),
power_weights.get(), dynamic_shape_offset);
auto output_name = op_desc.Output("Y").front();
engine_->SetWeights(op_desc.Input("Bias").front(),
......
......@@ -83,8 +83,8 @@ class ElementwiseWeightOpConverter : public OpConverter {
}
if (op_type_ == "add") {
nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *X, scale_mode, shift_weights.get(),
scale_weights.get(), power_weights.get());
engine_, ScaleNd, *X, scale_mode, shift_weights.get(),
scale_weights.get(), power_weights.get(), dynamic_shape_offset);
layer = scale_layer;
} else if (op_type_ == "mul") {
nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册