From c4e54acd99973d9d470cca58c32ed97958f94248 Mon Sep 17 00:00:00 2001 From: weishengying <63448337+weishengying@users.noreply.github.com> Date: Wed, 6 Apr 2022 18:40:06 +0800 Subject: [PATCH] add rewrite pattern form paddle op tp trt op (#41323) --- paddle/infrt/dialect/tensorrt/convert.h | 185 +++++++++--------- .../infrt/dialect/tensorrt/pd_lower_to_trt.td | 10 +- .../dialect/tensorrt/trt_op_converter_pass.cc | 129 ++++++++++-- paddle/infrt/dialect/tensorrt/trt_ops.td | 15 +- paddle/infrt/kernel/tensorrt/trt_kernels.cc | 6 + paddle/infrt/kernel/tensorrt/trt_layers.h | 71 +++++++ 6 files changed, 296 insertions(+), 120 deletions(-) diff --git a/paddle/infrt/dialect/tensorrt/convert.h b/paddle/infrt/dialect/tensorrt/convert.h index c1f87ecde7..5b9e4a9074 100644 --- a/paddle/infrt/dialect/tensorrt/convert.h +++ b/paddle/infrt/dialect/tensorrt/convert.h @@ -58,58 +58,110 @@ template <> } static mlir::Value createTRTConv2dOp(mlir::PatternRewriter &rewriter, // NOLINT - mlir::Operation *op) { + mlir::Operation *op, + mlir::Value input, + mlir::Value filter) { auto conv_op = ::llvm::dyn_cast(op); ::mlir::SmallVector<::mlir::Value, 4> operands; - ::mlir::Operation::operand_range Input = conv_op.getODSOperands(0); - ::mlir::Operation::operand_range Filter = conv_op.getODSOperands(1); - operands.push_back((*Input.begin())); - operands.push_back((*Filter.begin())); + operands.push_back(input); + operands.push_back(filter); ::mlir::SmallVector<::mlir::Type, 4> resultTypes; for (auto v : conv_op.getODSResults(0)) { resultTypes.push_back(v.getType()); } + ::mlir::SmallVector<::mlir::NamedAttribute, 8> attributes; - { - // TODO(weishengying) : get out_channel_num for filter shape - auto tblgen_attr = rewriter.getSI32IntegerAttr(3); - attributes.emplace_back(rewriter.getStringAttr("out_channel_num"), - tblgen_attr); + + auto *filter_producer = filter.getDefiningOp(); + auto create_inited_tensor_op = + llvm::dyn_cast<::infrt::phi::CreateHostInitedDenseTensorOp>( + filter_producer); + + CHECK_NOTNULL(create_inited_tensor_op); + mlir::ArrayAttr dims = create_inited_tensor_op.dims(); + CHECK_EQ(dims.size(), 4U); + CHECK(dims[0].getType().isIntOrIndex()); + + const int32_t n_output = dims[0].cast().getInt(); + const int32_t filter_h = dims[2].cast().getInt(); + const int32_t filter_w = dims[3].cast().getInt(); + + auto padding_attr = conv_op->getAttrOfType<::mlir::ArrayAttr>("paddings"); + llvm::SmallVector paddings(padding_attr.size()); + for (size_t i = 0; i < padding_attr.size(); i++) { + paddings[i] = padding_attr[i].cast().getInt(); } - { - // TODO(weishengying) : get kernel_size for filter shape - auto tblgen_attr = rewriter.getI32ArrayAttr({3, 3}); - attributes.emplace_back(rewriter.getStringAttr("kernel_size"), tblgen_attr); + + auto dilations_attr = conv_op->getAttrOfType<::mlir::ArrayAttr>("dilations"); + llvm::SmallVector dilations(dilations_attr.size()); + for (size_t i = 0; i < dilations_attr.size(); i++) { + dilations[i] = dilations_attr[i].cast().getInt(); } - { - auto tblgen_attr = op->getAttrOfType<::mlir::ArrayAttr>("strides"); - attributes.emplace_back(rewriter.getStringAttr("strides"), tblgen_attr); + + llvm::SmallVector nv_paddings(2); + llvm::SmallVector nv_pre_paddings(2); + llvm::SmallVector nv_post_paddings(2); + llvm::SmallVector nv_dilations({dilations[0], dilations[1]}); + int32_t nv_padding_mode = 0; // nvinfer1::PaddingMode::kEXPLICIT_ROUND_DOWN + auto padding_algorithm_attr = + conv_op->getAttrOfType<::mlir::StringAttr>("padding_algorithm"); + if (padding_algorithm_attr.strref() == "VALID") { + for (size_t i = 0; i < paddings.size(); i++) { + paddings[i] = 0; + } } - { - auto tblgen_attr = op->getAttrOfType<::mlir::ArrayAttr>("paddings"); - attributes.emplace_back(rewriter.getStringAttr("paddings"), tblgen_attr); + if (padding_algorithm_attr.strref() == "SAME") { + nv_padding_mode = 2; // nvinfer1::PaddingMode::kSAME_UPPER + nv_dilations[0] = 1; + nv_dilations[1] = 1; } - { - auto tblgen_attr = - op->getAttrOfType<::mlir::StringAttr>("padding_algorithm"); - attributes.emplace_back(rewriter.getStringAttr("padding_mode"), - tblgen_attr); + + if (paddings.size() == 2) { + nv_paddings[0] = paddings[0]; + nv_paddings[1] = paddings[1]; + } else { + CHECK_EQ(paddings.size(), 4U); + nv_pre_paddings[0] = paddings[0]; + nv_pre_paddings[1] = paddings[2]; + nv_post_paddings[0] = paddings[1]; + nv_post_paddings[1] = paddings[3]; } + + attributes.emplace_back(rewriter.getStringAttr("out_channel_num"), + rewriter.getSI32IntegerAttr(n_output)); + + attributes.emplace_back(rewriter.getStringAttr("kernel_size"), + rewriter.getI32ArrayAttr({filter_h, filter_w})); + + attributes.emplace_back( + rewriter.getStringAttr("dilations"), + rewriter.getI32ArrayAttr({nv_dilations[0], nv_dilations[1]})); + + attributes.emplace_back(rewriter.getStringAttr("padding_mode"), + rewriter.getSI32IntegerAttr(nv_padding_mode)); + + attributes.emplace_back(rewriter.getStringAttr("paddings"), + rewriter.getI32ArrayAttr({paddings[0], paddings[1]})); + + attributes.emplace_back( + rewriter.getStringAttr("pre_paddings"), + rewriter.getI32ArrayAttr({nv_pre_paddings[0], nv_pre_paddings[1]})); + + attributes.emplace_back( + rewriter.getStringAttr("post_paddings"), + rewriter.getI32ArrayAttr({nv_post_paddings[0], nv_post_paddings[1]})); + { - auto tblgen_attr = op->getAttrOfType<::mlir::IntegerAttr>("groups"); + auto tblgen_attr = conv_op->getAttrOfType<::mlir::IntegerAttr>("groups"); attributes.emplace_back(rewriter.getStringAttr("groups"), tblgen_attr); } { - auto tblgen_attr = op->getAttrOfType<::mlir::ArrayAttr>("dilations"); - attributes.emplace_back(rewriter.getStringAttr("dilations"), tblgen_attr); - } - { - auto tblgen_attr = op->getAttrOfType<::mlir::StringAttr>("data_format"); - attributes.emplace_back(rewriter.getStringAttr("data_format"), tblgen_attr); + auto tblgen_attr = conv_op->getAttrOfType<::mlir::ArrayAttr>("strides"); + attributes.emplace_back(rewriter.getStringAttr("strides"), tblgen_attr); } return rewriter.create( - op->getLoc(), resultTypes, operands, attributes); + conv_op->getLoc(), resultTypes, operands, attributes); } static inline mlir::ArrayAttr TransposeWeight( @@ -193,51 +245,6 @@ inline ::llvm::SmallVector<::mlir::Value, 4> createTrtFcOp( return tblgen_repl_values; } -static mlir::Value createTRTShuffledOp( - mlir::PatternRewriter &rewriter, // NOLINT - mlir::Operation *op, - const mlir::Value &input, - const mlir::Attribute &start, - const mlir::Attribute &stop) { - auto flatten_op = ::llvm::dyn_cast(op); - ::mlir::SmallVector<::mlir::Value, 4> operands; - operands.push_back(input); - - ::mlir::SmallVector<::mlir::Type, 4> resultTypes; - for (auto v : flatten_op.getODSResults(0)) { - resultTypes.push_back(v.getType()); - } - - ::mlir::SmallVector<::mlir::NamedAttribute, 8> attributes; - mlir::IntegerAttr start_attr = start.dyn_cast(); - mlir::IntegerAttr stop_attr = stop.dyn_cast(); - - int start_axis = start_attr.getSInt(); - int stop_axis = stop_attr.getSInt(); - // TODO(weishengying) : get dim form DenseTonsor - int dims = 4; - // TODO(weishengying) : get input_dims form DenseTonsor - int input_dims[4] = {1, 2048, 1, 1}; - int dim_prod = 1; - - std::vector flatten_dim(dims - (stop_axis - start_axis)); - for (int i = 0, j = 0; i < dims; ++i) { - if (start_axis <= i + 1 && i + 1 <= stop_axis) { - int dim_i = input_dims[i]; - dim_prod *= dim_i; - if (i + 1 == stop_axis) { - flatten_dim[j++] = dim_prod; - } - } else { - flatten_dim[j++] = input_dims[i]; - } - } - auto reshape_arrt = rewriter.getI32ArrayAttr(flatten_dim); - attributes.emplace_back(rewriter.getStringAttr("reshape"), reshape_arrt); - return rewriter.create( - op->getLoc(), resultTypes, operands, attributes); -} - inline mlir::IntegerAttr CreatePoolingType( mlir::PatternRewriter &builder, // NOLINT mlir::StringAttr pool_type) { @@ -339,17 +346,17 @@ inline ::llvm::SmallVector<::mlir::Value, 4> CreatePaddleTrtPoolingOp( PoolingOp pool_op; { auto ods_loc = builder.getFusedLoc({input_producer->getLoc()}); - builder.create(ods_loc, - input.getType(), - input, - pool_type_attr, - ksize, - strides, - paddings_attr, - padding_mode_attr, - exclusive, - adaptive, - padding_algorithm); + pool_op = builder.create(ods_loc, + input.getType(), + input, + pool_type_attr, + ksize, + strides, + paddings_attr, + padding_mode_attr, + exclusive, + adaptive, + padding_algorithm); } for (auto v : diff --git a/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td b/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td index 227b473c3f..0cd100aa5b 100644 --- a/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td +++ b/paddle/infrt/dialect/tensorrt/pd_lower_to_trt.td @@ -25,11 +25,11 @@ def PD2TRT_Relu6_Lower : Pat< (PD_Relu6Op $X, $threshold), (TRT_ActivationOp $X, (TRT_createNvinferEnumAttr<"nvinfer1::ActivationType", "kCLIP">), (INFRT_createF32Attr<"0.0">), $threshold)>; -def createTRTConv2dOp : NativeCodeCall<"createTRTConv2dOp($_builder, $0.getDefiningOp())">; +def createTRTConv2dOp : NativeCodeCall<"createTRTConv2dOp($_builder, $0.getDefiningOp(), $1, $2)">; def PD2TRT_Conv2d_Lower : Pat< (PD_Conv2dOp:$old_value $Input, $Filter, $strides, $paddings, $padding_algorithm, $groups, $dilations, $data_format), - (createTRTConv2dOp $old_value)>; + (createTRTConv2dOp $old_value, $Input, $Filter)>; def createTrtPoolingOp : NativeCodeCall<"::infrt::trt::CreatePaddleTrtPoolingOp($_builder, $0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10)">; def PD2TRT_Pooling_Lower : Pat< @@ -50,9 +50,7 @@ def PD2TRT_Fc_Lower : Pat< (PD_Elementwise_addOp:$elt_out (PD_Matmul_v2Op $X, $Y, $trans_x, $trans_y), $elt_y, $axis), (createTrtFcOp $X, $Y, $elt_y, $elt_out)>; -def createTRTShuffledOp : NativeCodeCall<"createTRTShuffledOp($_builder, $0.getDefiningOp(), $1, $2, $3)">; - def PD2TRT_Flatten_contiguous_range_Lower : Pat< - (PD_Flatten_contiguous_rangeOp:$out $input, $start_axis, $end_axis), - (createTRTShuffledOp $out, $input, $start_axis, $end_axis)>; + (PD_Flatten_contiguous_rangeOp $input, $start_axis, $end_axis), + (TRT_ShuffleOp $input, $start_axis, $end_axis)>; #endif // PD_LOWER_TO_TRT diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc index 95dd31fcd5..5273bcaa6a 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc @@ -92,20 +92,122 @@ struct PD2TRT_Batch_Norm_Lower : public ::mlir::RewritePattern { ::mlir::Operation::operand_range Input = casted_op.getODSOperands(0); ::mlir::Operation::operand_range Scale = casted_op.getODSOperands(1); ::mlir::Operation::operand_range Bias = casted_op.getODSOperands(2); + ::mlir::Operation::operand_range Mean = casted_op.getODSOperands(3); + ::mlir::Operation::operand_range Variance = casted_op.getODSOperands(4); + operands.push_back(Input[0]); + operands.push_back(Bias[0]); + operands.push_back(Scale[0]); // TODO(weishengying) : recompute this via params - operands.push_back((*Input.begin())); - operands.push_back((*Scale.begin())); - operands.push_back((*Bias.begin())); - operands.push_back((*Bias.begin())); + auto *scale_producer = Scale[0].getDefiningOp(); + auto create_scale_tensor_op = + llvm::dyn_cast<::infrt::phi::CreateHostInitedDenseTensorOp>( + scale_producer); + CHECK_NOTNULL(create_scale_tensor_op); - trt::ScaleNdOp scaleNd_op; - // inputs - ::mlir::SmallVector<::mlir::Value, 4> trt_inputs; - for (auto v : operands) { - trt_inputs.push_back(v); + auto *bias_producer = Bias[0].getDefiningOp(); + auto create_bias_tensor_op = + llvm::dyn_cast<::infrt::phi::CreateHostInitedDenseTensorOp>( + bias_producer); + CHECK_NOTNULL(create_bias_tensor_op); + + auto *mean_producer = Mean[0].getDefiningOp(); + auto create_mean_tensor_op = + llvm::dyn_cast<::infrt::phi::CreateHostInitedDenseTensorOp>( + mean_producer); + CHECK_NOTNULL(create_mean_tensor_op); + + auto *variance_producer = Variance[0].getDefiningOp(); + auto create_variance_tensor_op = + llvm::dyn_cast<::infrt::phi::CreateHostInitedDenseTensorOp>( + variance_producer); + CHECK_NOTNULL(create_variance_tensor_op); + + llvm::SmallVector scale_data; + mlir::ArrayAttr scale_array_attr = create_scale_tensor_op.values(); + CHECK_GT(scale_array_attr.size(), 0U); + CHECK(scale_array_attr[0].getType().isF32()); + scale_data.resize(scale_array_attr.size()); + for (size_t i = 0; i < scale_array_attr.size(); i++) { + scale_data[i] = + scale_array_attr[i].cast().getValueAsDouble(); + } + + llvm::SmallVector bias_data; + mlir::ArrayAttr bias_array_attr = create_bias_tensor_op.values(); + CHECK_GT(bias_array_attr.size(), 0U); + CHECK(bias_array_attr[0].getType().isF32()); + bias_data.resize(bias_array_attr.size()); + for (size_t i = 0; i < bias_array_attr.size(); i++) { + bias_data[i] = + bias_array_attr[i].cast().getValueAsDouble(); } + llvm::SmallVector mean_data; + mlir::ArrayAttr mean_array_attr = create_mean_tensor_op.values(); + CHECK_GT(mean_array_attr.size(), 0U); + CHECK(mean_array_attr[0].getType().isF32()); + mean_data.resize(mean_array_attr.size()); + for (size_t i = 0; i < mean_array_attr.size(); i++) { + mean_data[i] = + mean_array_attr[i].cast().getValueAsDouble(); + } + + llvm::SmallVector variance_data; + mlir::ArrayAttr variance_array_attr = create_variance_tensor_op.values(); + CHECK_GT(variance_array_attr.size(), 0U); + CHECK(variance_array_attr[0].getType().isF32()); + variance_data.resize(variance_array_attr.size()); + for (size_t i = 0; i < variance_array_attr.size(); i++) { + variance_data[i] = + variance_array_attr[i].cast().getValueAsDouble(); + } + + double eps = casted_op.epsilonAttr().getValueAsDouble(); + + llvm::SmallVector combile_scale_data; + combile_scale_data.resize(scale_data.size()); + llvm::SmallVector combile_bias_data; + combile_bias_data.resize(bias_data.size()); + + size_t ele_num = combile_scale_data.size(); + for (size_t i = 0; i < ele_num; i++) { + float scale = scale_data[i]; + float bias = bias_data[i]; + float mean = mean_data[i]; + float variance = variance_data[i]; + combile_scale_data[i] = scale / sqrtf(variance + eps); + combile_bias_data[i] = bias - mean * combile_scale_data[i]; + } + + rewriter.setInsertionPoint(create_scale_tensor_op); + auto new_scale_op = + rewriter.create<::infrt::phi::CreateHostInitedDenseTensorOp>( + create_scale_tensor_op->getLoc(), + create_scale_tensor_op.output().getType(), + create_scale_tensor_op.context(), + create_bias_tensor_op.dims(), + ::infrt::LayoutAttr::get(rewriter.getContext(), + ::infrt::LayoutType::NCHW), + create_scale_tensor_op.lod(), + rewriter.getF32ArrayAttr(combile_scale_data)); + rewriter.replaceOp(create_scale_tensor_op, new_scale_op->getResults()); + + rewriter.setInsertionPoint(create_bias_tensor_op); + auto new_bias_op = + rewriter.create<::infrt::phi::CreateHostInitedDenseTensorOp>( + create_bias_tensor_op->getLoc(), + create_bias_tensor_op.output().getType(), + create_bias_tensor_op.context(), + create_bias_tensor_op.dims(), + ::infrt::LayoutAttr::get(rewriter.getContext(), + ::infrt::LayoutType::NCHW), + create_bias_tensor_op.lod(), + rewriter.getF32ArrayAttr(combile_bias_data)); + rewriter.replaceOp(create_bias_tensor_op, new_bias_op->getResults()); + + rewriter.setInsertionPoint(op); + trt::ScaleNdOp scaleNd_op; // resultTypes ::mlir::SmallVector<::mlir::Type, 4> resultTypes; for (auto v : casted_op.getODSResults(0)) { @@ -114,15 +216,6 @@ struct PD2TRT_Batch_Norm_Lower : public ::mlir::RewritePattern { // attributes ::mlir::SmallVector<::mlir::NamedAttribute, 8> attributes; - { - auto mode_attr = rewriter.getI32IntegerAttr(1); - attributes.emplace_back(rewriter.getStringAttr("mode"), mode_attr); - } - - { - auto axis_attr = rewriter.getI32IntegerAttr(-1); - attributes.emplace_back(rewriter.getStringAttr("axis"), axis_attr); - } auto result = rewriter .create( op->getLoc(), resultTypes, operands, attributes) diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.td b/paddle/infrt/dialect/tensorrt/trt_ops.td index 68a593e440..b112cc748e 100755 --- a/paddle/infrt/dialect/tensorrt/trt_ops.td +++ b/paddle/infrt/dialect/tensorrt/trt_ops.td @@ -81,7 +81,9 @@ def TRT_ConvolutionOp : TRT_Op<"Convolution", [NoSideEffect]> { I32ArrayAttr:$kernel_size, I32ArrayAttr:$strides, I32ArrayAttr:$paddings, - StrAttr:$padding_mode, + I32ArrayAttr:$pre_paddings, + I32ArrayAttr:$post_paddings, + DefaultValuedAttr:$padding_mode, //kEXPLICIT_ROUND_DOWN SI32Attr:$groups, I32ArrayAttr:$dilations ); @@ -97,11 +99,11 @@ def TRT_PoolingOp : TRT_Op<"Pooling", [NoSideEffect]> { }]; let arguments = (ins DenseTensor:$input_tensor, - I32Attr:$pool_type, + SI32Attr:$pool_type, I32ArrayAttr:$window_size, I32ArrayAttr:$strides, I32ArrayAttr:$paddings, - I32Attr:$padding_mode, + SI32Attr:$padding_mode, BoolAttr:$exclusive, BoolAttr:$adaptive, StrAttr:$padding_algorithm @@ -195,11 +197,9 @@ def TRT_ScaleNdOp : TRT_Op<"ScaleNd", [NoSideEffect]> { }]; let arguments = (ins DenseTensor:$input_tensor, - I32Attr:$mode, DenseTensor:$shift, DenseTensor:$scale, - DenseTensor:$power, - I32Attr:$axis + Optional:$power ); let results = (outs DenseTensor:$Out); @@ -214,7 +214,8 @@ def TRT_ShuffleOp : TRT_Op<"Shuffle", [NoSideEffect]> { }]; let arguments = (ins DenseTensor:$input_tensor, - I32ArrayAttr:$reshape + DefaultValuedAttr:$start_axis, + DefaultValuedAttr:$stop_axis ); let results = (outs DenseTensor:$Out); diff --git a/paddle/infrt/kernel/tensorrt/trt_kernels.cc b/paddle/infrt/kernel/tensorrt/trt_kernels.cc index 92e3a624bb..9b7fb20009 100644 --- a/paddle/infrt/kernel/tensorrt/trt_kernels.cc +++ b/paddle/infrt/kernel/tensorrt/trt_kernels.cc @@ -141,6 +141,12 @@ namespace tensorrt { ConvFunc(op, network.get(), value_to_trt_tensor_map, value_to_tensor_map); } else if (trt::PoolingOp op = llvm::dyn_cast(operation)) { PoolFunc(op, network.get(), value_to_trt_tensor_map, value_to_tensor_map); + } else if (trt::ShuffleOp op = llvm::dyn_cast(operation)) { + ShuffleFunc( + op, network.get(), value_to_trt_tensor_map, value_to_tensor_map); + } else if (trt::ScaleNdOp op = llvm::dyn_cast(operation)) { + ScaleNdFunc( + op, network.get(), value_to_trt_tensor_map, value_to_tensor_map); } else { CHECK(false) << "not supported operation."; } diff --git a/paddle/infrt/kernel/tensorrt/trt_layers.h b/paddle/infrt/kernel/tensorrt/trt_layers.h index 3a300ad0c1..8c7dd4d813 100644 --- a/paddle/infrt/kernel/tensorrt/trt_layers.h +++ b/paddle/infrt/kernel/tensorrt/trt_layers.h @@ -151,6 +151,77 @@ inline void FcFunc(trt::FullyConnectedOp& op, // NOLINT nvinfer1::ITensor* out_tensor = layer->getOutput(0); value_to_trt_tensor_map[out_repr] = out_tensor; } + +inline void ShuffleFunc(trt::ShuffleOp& op, // NOLINT + nvinfer1::INetworkDefinition* network, + ValueToITensorMap& value_to_trt_tensor_map, // NOLINT + ValueToTensorMap& value_to_tensor_map) { // NOLINT + mlir::Value input_tensor_repr = op.input_tensor(); + nvinfer1::ITensor* input = value_to_trt_tensor_map[input_tensor_repr]; + int dims = input->getDimensions().nbDims; + + int start_axis = op.start_axisAttr().getInt(); + int stop_axis = op.start_axisAttr().getInt(); + + nvinfer1::IShuffleLayer* layer = nullptr; + if (start_axis < 0) start_axis += dims + 1; + if (stop_axis < 0) stop_axis += dims + 1; + + int dim_prod = 1; + nvinfer1::Dims flatten_dim; + flatten_dim.nbDims = dims - (stop_axis - start_axis); + for (int i = 0, j = 0; i < dims; ++i) { + if (start_axis <= i + 1 && i + 1 <= stop_axis) { + int dim_i = input->getDimensions().d[i]; + CHECK_GT(dim_i, 0); + dim_prod *= dim_i; + if (i + 1 == stop_axis) { + flatten_dim.d[j++] = dim_prod; + } + } else { + flatten_dim.d[j++] = input->getDimensions().d[i]; + } + } + layer = network->addShuffle(*value_to_trt_tensor_map[input_tensor_repr]); + CHECK_NOTNULL(layer); + layer->setReshapeDimensions(flatten_dim); + + for (size_t i = 0; i < op->getNumResults(); ++i) { + nvinfer1::ITensor* out_tensor = layer->getOutput(i); + mlir::Value out_value = op->getResult(i); + value_to_trt_tensor_map[out_value] = out_tensor; + } +} + +inline void ScaleNdFunc(trt::ScaleNdOp& op, // NOLINT + nvinfer1::INetworkDefinition* network, + ValueToITensorMap& value_to_trt_tensor_map, // NOLINT + ValueToTensorMap& value_to_tensor_map) { // NOLINT + mlir::Value input_tensor_repr = op.input_tensor(); + nvinfer1::ITensor* input = value_to_trt_tensor_map[input_tensor_repr]; + + mlir::Value shift_tensor_repr = op.shift(); + nvinfer1::Weights shift = + TensorToWeights(value_to_tensor_map[shift_tensor_repr]); + + mlir::Value scale_tensor_repr = op.scale(); + + nvinfer1::Weights scale = + TensorToWeights(value_to_tensor_map[scale_tensor_repr]); + + nvinfer1::Weights power_weights{nvinfer1::DataType::kFLOAT, nullptr, 0}; + + nvinfer1::IScaleLayer* layer = nullptr; + layer = network->addScaleNd( + *input, nvinfer1::ScaleMode::kCHANNEL, shift, scale, power_weights, 0); + CHECK_NOTNULL(layer); + + for (size_t i = 0; i < op->getNumResults(); ++i) { + nvinfer1::ITensor* out_tensor = layer->getOutput(i); + mlir::Value out_value = op->getResult(i); + value_to_trt_tensor_map[out_value] = out_tensor; + } +} } // namespace tensorrt } // namespace kernel } // namespace infrt -- GitLab