未验证 提交 c4e54acd 编写于 作者: W weishengying 提交者: GitHub

add rewrite pattern form paddle op tp trt op (#41323)

上级 5c6e4bff
...@@ -58,58 +58,110 @@ template <> ...@@ -58,58 +58,110 @@ template <>
} }
static mlir::Value createTRTConv2dOp(mlir::PatternRewriter &rewriter, // NOLINT static mlir::Value createTRTConv2dOp(mlir::PatternRewriter &rewriter, // NOLINT
mlir::Operation *op) { mlir::Operation *op,
mlir::Value input,
mlir::Value filter) {
auto conv_op = ::llvm::dyn_cast<infrt::pd::Conv2dOp>(op); auto conv_op = ::llvm::dyn_cast<infrt::pd::Conv2dOp>(op);
::mlir::SmallVector<::mlir::Value, 4> operands; ::mlir::SmallVector<::mlir::Value, 4> operands;
::mlir::Operation::operand_range Input = conv_op.getODSOperands(0); operands.push_back(input);
::mlir::Operation::operand_range Filter = conv_op.getODSOperands(1); operands.push_back(filter);
operands.push_back((*Input.begin()));
operands.push_back((*Filter.begin()));
::mlir::SmallVector<::mlir::Type, 4> resultTypes; ::mlir::SmallVector<::mlir::Type, 4> resultTypes;
for (auto v : conv_op.getODSResults(0)) { for (auto v : conv_op.getODSResults(0)) {
resultTypes.push_back(v.getType()); resultTypes.push_back(v.getType());
} }
::mlir::SmallVector<::mlir::NamedAttribute, 8> attributes; ::mlir::SmallVector<::mlir::NamedAttribute, 8> attributes;
{
// TODO(weishengying) : get out_channel_num for filter shape auto *filter_producer = filter.getDefiningOp();
auto tblgen_attr = rewriter.getSI32IntegerAttr(3); auto create_inited_tensor_op =
attributes.emplace_back(rewriter.getStringAttr("out_channel_num"), llvm::dyn_cast<::infrt::phi::CreateHostInitedDenseTensorOp>(
tblgen_attr); filter_producer);
CHECK_NOTNULL(create_inited_tensor_op);
mlir::ArrayAttr dims = create_inited_tensor_op.dims();
CHECK_EQ(dims.size(), 4U);
CHECK(dims[0].getType().isIntOrIndex());
const int32_t n_output = dims[0].cast<mlir::IntegerAttr>().getInt();
const int32_t filter_h = dims[2].cast<mlir::IntegerAttr>().getInt();
const int32_t filter_w = dims[3].cast<mlir::IntegerAttr>().getInt();
auto padding_attr = conv_op->getAttrOfType<::mlir::ArrayAttr>("paddings");
llvm::SmallVector<int32_t, 4> paddings(padding_attr.size());
for (size_t i = 0; i < padding_attr.size(); i++) {
paddings[i] = padding_attr[i].cast<mlir::IntegerAttr>().getInt();
} }
{
// TODO(weishengying) : get kernel_size for filter shape auto dilations_attr = conv_op->getAttrOfType<::mlir::ArrayAttr>("dilations");
auto tblgen_attr = rewriter.getI32ArrayAttr({3, 3}); llvm::SmallVector<int32_t> dilations(dilations_attr.size());
attributes.emplace_back(rewriter.getStringAttr("kernel_size"), tblgen_attr); for (size_t i = 0; i < dilations_attr.size(); i++) {
dilations[i] = dilations_attr[i].cast<mlir::IntegerAttr>().getInt();
} }
{
auto tblgen_attr = op->getAttrOfType<::mlir::ArrayAttr>("strides"); llvm::SmallVector<int32_t, 2> nv_paddings(2);
attributes.emplace_back(rewriter.getStringAttr("strides"), tblgen_attr); llvm::SmallVector<int32_t, 4> nv_pre_paddings(2);
llvm::SmallVector<int32_t, 4> nv_post_paddings(2);
llvm::SmallVector<int32_t, 2> nv_dilations({dilations[0], dilations[1]});
int32_t nv_padding_mode = 0; // nvinfer1::PaddingMode::kEXPLICIT_ROUND_DOWN
auto padding_algorithm_attr =
conv_op->getAttrOfType<::mlir::StringAttr>("padding_algorithm");
if (padding_algorithm_attr.strref() == "VALID") {
for (size_t i = 0; i < paddings.size(); i++) {
paddings[i] = 0;
} }
{
auto tblgen_attr = op->getAttrOfType<::mlir::ArrayAttr>("paddings");
attributes.emplace_back(rewriter.getStringAttr("paddings"), tblgen_attr);
} }
{ if (padding_algorithm_attr.strref() == "SAME") {
auto tblgen_attr = nv_padding_mode = 2; // nvinfer1::PaddingMode::kSAME_UPPER
op->getAttrOfType<::mlir::StringAttr>("padding_algorithm"); nv_dilations[0] = 1;
attributes.emplace_back(rewriter.getStringAttr("padding_mode"), nv_dilations[1] = 1;
tblgen_attr);
} }
{
auto tblgen_attr = op->getAttrOfType<::mlir::IntegerAttr>("groups"); if (paddings.size() == 2) {
attributes.emplace_back(rewriter.getStringAttr("groups"), tblgen_attr); nv_paddings[0] = paddings[0];
nv_paddings[1] = paddings[1];
} else {
CHECK_EQ(paddings.size(), 4U);
nv_pre_paddings[0] = paddings[0];
nv_pre_paddings[1] = paddings[2];
nv_post_paddings[0] = paddings[1];
nv_post_paddings[1] = paddings[3];
} }
attributes.emplace_back(rewriter.getStringAttr("out_channel_num"),
rewriter.getSI32IntegerAttr(n_output));
attributes.emplace_back(rewriter.getStringAttr("kernel_size"),
rewriter.getI32ArrayAttr({filter_h, filter_w}));
attributes.emplace_back(
rewriter.getStringAttr("dilations"),
rewriter.getI32ArrayAttr({nv_dilations[0], nv_dilations[1]}));
attributes.emplace_back(rewriter.getStringAttr("padding_mode"),
rewriter.getSI32IntegerAttr(nv_padding_mode));
attributes.emplace_back(rewriter.getStringAttr("paddings"),
rewriter.getI32ArrayAttr({paddings[0], paddings[1]}));
attributes.emplace_back(
rewriter.getStringAttr("pre_paddings"),
rewriter.getI32ArrayAttr({nv_pre_paddings[0], nv_pre_paddings[1]}));
attributes.emplace_back(
rewriter.getStringAttr("post_paddings"),
rewriter.getI32ArrayAttr({nv_post_paddings[0], nv_post_paddings[1]}));
{ {
auto tblgen_attr = op->getAttrOfType<::mlir::ArrayAttr>("dilations"); auto tblgen_attr = conv_op->getAttrOfType<::mlir::IntegerAttr>("groups");
attributes.emplace_back(rewriter.getStringAttr("dilations"), tblgen_attr); attributes.emplace_back(rewriter.getStringAttr("groups"), tblgen_attr);
} }
{ {
auto tblgen_attr = op->getAttrOfType<::mlir::StringAttr>("data_format"); auto tblgen_attr = conv_op->getAttrOfType<::mlir::ArrayAttr>("strides");
attributes.emplace_back(rewriter.getStringAttr("data_format"), tblgen_attr); attributes.emplace_back(rewriter.getStringAttr("strides"), tblgen_attr);
} }
return rewriter.create<trt::ConvolutionOp>( return rewriter.create<trt::ConvolutionOp>(
op->getLoc(), resultTypes, operands, attributes); conv_op->getLoc(), resultTypes, operands, attributes);
} }
static inline mlir::ArrayAttr TransposeWeight( static inline mlir::ArrayAttr TransposeWeight(
...@@ -193,51 +245,6 @@ inline ::llvm::SmallVector<::mlir::Value, 4> createTrtFcOp( ...@@ -193,51 +245,6 @@ inline ::llvm::SmallVector<::mlir::Value, 4> createTrtFcOp(
return tblgen_repl_values; return tblgen_repl_values;
} }
static mlir::Value createTRTShuffledOp(
mlir::PatternRewriter &rewriter, // NOLINT
mlir::Operation *op,
const mlir::Value &input,
const mlir::Attribute &start,
const mlir::Attribute &stop) {
auto flatten_op = ::llvm::dyn_cast<infrt::pd::Flatten_contiguous_rangeOp>(op);
::mlir::SmallVector<::mlir::Value, 4> operands;
operands.push_back(input);
::mlir::SmallVector<::mlir::Type, 4> resultTypes;
for (auto v : flatten_op.getODSResults(0)) {
resultTypes.push_back(v.getType());
}
::mlir::SmallVector<::mlir::NamedAttribute, 8> attributes;
mlir::IntegerAttr start_attr = start.dyn_cast<mlir::IntegerAttr>();
mlir::IntegerAttr stop_attr = stop.dyn_cast<mlir::IntegerAttr>();
int start_axis = start_attr.getSInt();
int stop_axis = stop_attr.getSInt();
// TODO(weishengying) : get dim form DenseTonsor
int dims = 4;
// TODO(weishengying) : get input_dims form DenseTonsor
int input_dims[4] = {1, 2048, 1, 1};
int dim_prod = 1;
std::vector<int> flatten_dim(dims - (stop_axis - start_axis));
for (int i = 0, j = 0; i < dims; ++i) {
if (start_axis <= i + 1 && i + 1 <= stop_axis) {
int dim_i = input_dims[i];
dim_prod *= dim_i;
if (i + 1 == stop_axis) {
flatten_dim[j++] = dim_prod;
}
} else {
flatten_dim[j++] = input_dims[i];
}
}
auto reshape_arrt = rewriter.getI32ArrayAttr(flatten_dim);
attributes.emplace_back(rewriter.getStringAttr("reshape"), reshape_arrt);
return rewriter.create<trt::ShuffleOp>(
op->getLoc(), resultTypes, operands, attributes);
}
inline mlir::IntegerAttr CreatePoolingType( inline mlir::IntegerAttr CreatePoolingType(
mlir::PatternRewriter &builder, // NOLINT mlir::PatternRewriter &builder, // NOLINT
mlir::StringAttr pool_type) { mlir::StringAttr pool_type) {
...@@ -339,7 +346,7 @@ inline ::llvm::SmallVector<::mlir::Value, 4> CreatePaddleTrtPoolingOp( ...@@ -339,7 +346,7 @@ inline ::llvm::SmallVector<::mlir::Value, 4> CreatePaddleTrtPoolingOp(
PoolingOp pool_op; PoolingOp pool_op;
{ {
auto ods_loc = builder.getFusedLoc({input_producer->getLoc()}); auto ods_loc = builder.getFusedLoc({input_producer->getLoc()});
builder.create<PoolingOp>(ods_loc, pool_op = builder.create<PoolingOp>(ods_loc,
input.getType(), input.getType(),
input, input,
pool_type_attr, pool_type_attr,
......
...@@ -25,11 +25,11 @@ def PD2TRT_Relu6_Lower : Pat< ...@@ -25,11 +25,11 @@ def PD2TRT_Relu6_Lower : Pat<
(PD_Relu6Op $X, $threshold), (PD_Relu6Op $X, $threshold),
(TRT_ActivationOp $X, (TRT_createNvinferEnumAttr<"nvinfer1::ActivationType", "kCLIP">), (INFRT_createF32Attr<"0.0">), $threshold)>; (TRT_ActivationOp $X, (TRT_createNvinferEnumAttr<"nvinfer1::ActivationType", "kCLIP">), (INFRT_createF32Attr<"0.0">), $threshold)>;
def createTRTConv2dOp : NativeCodeCall<"createTRTConv2dOp($_builder, $0.getDefiningOp())">; def createTRTConv2dOp : NativeCodeCall<"createTRTConv2dOp($_builder, $0.getDefiningOp(), $1, $2)">;
def PD2TRT_Conv2d_Lower : Pat< def PD2TRT_Conv2d_Lower : Pat<
(PD_Conv2dOp:$old_value $Input, $Filter, $strides, $paddings, $padding_algorithm, $groups, $dilations, $data_format), (PD_Conv2dOp:$old_value $Input, $Filter, $strides, $paddings, $padding_algorithm, $groups, $dilations, $data_format),
(createTRTConv2dOp $old_value)>; (createTRTConv2dOp $old_value, $Input, $Filter)>;
def createTrtPoolingOp : NativeCodeCall<"::infrt::trt::CreatePaddleTrtPoolingOp($_builder, $0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10)">; def createTrtPoolingOp : NativeCodeCall<"::infrt::trt::CreatePaddleTrtPoolingOp($_builder, $0, $1, $2, $3, $4, $5, $6, $7, $8, $9, $10)">;
def PD2TRT_Pooling_Lower : Pat< def PD2TRT_Pooling_Lower : Pat<
...@@ -50,9 +50,7 @@ def PD2TRT_Fc_Lower : Pat< ...@@ -50,9 +50,7 @@ def PD2TRT_Fc_Lower : Pat<
(PD_Elementwise_addOp:$elt_out (PD_Matmul_v2Op $X, $Y, $trans_x, $trans_y), $elt_y, $axis), (PD_Elementwise_addOp:$elt_out (PD_Matmul_v2Op $X, $Y, $trans_x, $trans_y), $elt_y, $axis),
(createTrtFcOp $X, $Y, $elt_y, $elt_out)>; (createTrtFcOp $X, $Y, $elt_y, $elt_out)>;
def createTRTShuffledOp : NativeCodeCall<"createTRTShuffledOp($_builder, $0.getDefiningOp(), $1, $2, $3)">;
def PD2TRT_Flatten_contiguous_range_Lower : Pat< def PD2TRT_Flatten_contiguous_range_Lower : Pat<
(PD_Flatten_contiguous_rangeOp:$out $input, $start_axis, $end_axis), (PD_Flatten_contiguous_rangeOp $input, $start_axis, $end_axis),
(createTRTShuffledOp $out, $input, $start_axis, $end_axis)>; (TRT_ShuffleOp $input, $start_axis, $end_axis)>;
#endif // PD_LOWER_TO_TRT #endif // PD_LOWER_TO_TRT
...@@ -92,20 +92,122 @@ struct PD2TRT_Batch_Norm_Lower : public ::mlir::RewritePattern { ...@@ -92,20 +92,122 @@ struct PD2TRT_Batch_Norm_Lower : public ::mlir::RewritePattern {
::mlir::Operation::operand_range Input = casted_op.getODSOperands(0); ::mlir::Operation::operand_range Input = casted_op.getODSOperands(0);
::mlir::Operation::operand_range Scale = casted_op.getODSOperands(1); ::mlir::Operation::operand_range Scale = casted_op.getODSOperands(1);
::mlir::Operation::operand_range Bias = casted_op.getODSOperands(2); ::mlir::Operation::operand_range Bias = casted_op.getODSOperands(2);
::mlir::Operation::operand_range Mean = casted_op.getODSOperands(3);
::mlir::Operation::operand_range Variance = casted_op.getODSOperands(4);
operands.push_back(Input[0]);
operands.push_back(Bias[0]);
operands.push_back(Scale[0]);
// TODO(weishengying) : recompute this via params // TODO(weishengying) : recompute this via params
operands.push_back((*Input.begin())); auto *scale_producer = Scale[0].getDefiningOp();
operands.push_back((*Scale.begin())); auto create_scale_tensor_op =
operands.push_back((*Bias.begin())); llvm::dyn_cast<::infrt::phi::CreateHostInitedDenseTensorOp>(
operands.push_back((*Bias.begin())); scale_producer);
CHECK_NOTNULL(create_scale_tensor_op);
trt::ScaleNdOp scaleNd_op; auto *bias_producer = Bias[0].getDefiningOp();
// inputs auto create_bias_tensor_op =
::mlir::SmallVector<::mlir::Value, 4> trt_inputs; llvm::dyn_cast<::infrt::phi::CreateHostInitedDenseTensorOp>(
for (auto v : operands) { bias_producer);
trt_inputs.push_back(v); CHECK_NOTNULL(create_bias_tensor_op);
auto *mean_producer = Mean[0].getDefiningOp();
auto create_mean_tensor_op =
llvm::dyn_cast<::infrt::phi::CreateHostInitedDenseTensorOp>(
mean_producer);
CHECK_NOTNULL(create_mean_tensor_op);
auto *variance_producer = Variance[0].getDefiningOp();
auto create_variance_tensor_op =
llvm::dyn_cast<::infrt::phi::CreateHostInitedDenseTensorOp>(
variance_producer);
CHECK_NOTNULL(create_variance_tensor_op);
llvm::SmallVector<double> scale_data;
mlir::ArrayAttr scale_array_attr = create_scale_tensor_op.values();
CHECK_GT(scale_array_attr.size(), 0U);
CHECK(scale_array_attr[0].getType().isF32());
scale_data.resize(scale_array_attr.size());
for (size_t i = 0; i < scale_array_attr.size(); i++) {
scale_data[i] =
scale_array_attr[i].cast<mlir::FloatAttr>().getValueAsDouble();
}
llvm::SmallVector<double> bias_data;
mlir::ArrayAttr bias_array_attr = create_bias_tensor_op.values();
CHECK_GT(bias_array_attr.size(), 0U);
CHECK(bias_array_attr[0].getType().isF32());
bias_data.resize(bias_array_attr.size());
for (size_t i = 0; i < bias_array_attr.size(); i++) {
bias_data[i] =
bias_array_attr[i].cast<mlir::FloatAttr>().getValueAsDouble();
} }
llvm::SmallVector<double> mean_data;
mlir::ArrayAttr mean_array_attr = create_mean_tensor_op.values();
CHECK_GT(mean_array_attr.size(), 0U);
CHECK(mean_array_attr[0].getType().isF32());
mean_data.resize(mean_array_attr.size());
for (size_t i = 0; i < mean_array_attr.size(); i++) {
mean_data[i] =
mean_array_attr[i].cast<mlir::FloatAttr>().getValueAsDouble();
}
llvm::SmallVector<double> variance_data;
mlir::ArrayAttr variance_array_attr = create_variance_tensor_op.values();
CHECK_GT(variance_array_attr.size(), 0U);
CHECK(variance_array_attr[0].getType().isF32());
variance_data.resize(variance_array_attr.size());
for (size_t i = 0; i < variance_array_attr.size(); i++) {
variance_data[i] =
variance_array_attr[i].cast<mlir::FloatAttr>().getValueAsDouble();
}
double eps = casted_op.epsilonAttr().getValueAsDouble();
llvm::SmallVector<float> combile_scale_data;
combile_scale_data.resize(scale_data.size());
llvm::SmallVector<float> combile_bias_data;
combile_bias_data.resize(bias_data.size());
size_t ele_num = combile_scale_data.size();
for (size_t i = 0; i < ele_num; i++) {
float scale = scale_data[i];
float bias = bias_data[i];
float mean = mean_data[i];
float variance = variance_data[i];
combile_scale_data[i] = scale / sqrtf(variance + eps);
combile_bias_data[i] = bias - mean * combile_scale_data[i];
}
rewriter.setInsertionPoint(create_scale_tensor_op);
auto new_scale_op =
rewriter.create<::infrt::phi::CreateHostInitedDenseTensorOp>(
create_scale_tensor_op->getLoc(),
create_scale_tensor_op.output().getType(),
create_scale_tensor_op.context(),
create_bias_tensor_op.dims(),
::infrt::LayoutAttr::get(rewriter.getContext(),
::infrt::LayoutType::NCHW),
create_scale_tensor_op.lod(),
rewriter.getF32ArrayAttr(combile_scale_data));
rewriter.replaceOp(create_scale_tensor_op, new_scale_op->getResults());
rewriter.setInsertionPoint(create_bias_tensor_op);
auto new_bias_op =
rewriter.create<::infrt::phi::CreateHostInitedDenseTensorOp>(
create_bias_tensor_op->getLoc(),
create_bias_tensor_op.output().getType(),
create_bias_tensor_op.context(),
create_bias_tensor_op.dims(),
::infrt::LayoutAttr::get(rewriter.getContext(),
::infrt::LayoutType::NCHW),
create_bias_tensor_op.lod(),
rewriter.getF32ArrayAttr(combile_bias_data));
rewriter.replaceOp(create_bias_tensor_op, new_bias_op->getResults());
rewriter.setInsertionPoint(op);
trt::ScaleNdOp scaleNd_op;
// resultTypes // resultTypes
::mlir::SmallVector<::mlir::Type, 4> resultTypes; ::mlir::SmallVector<::mlir::Type, 4> resultTypes;
for (auto v : casted_op.getODSResults(0)) { for (auto v : casted_op.getODSResults(0)) {
...@@ -114,15 +216,6 @@ struct PD2TRT_Batch_Norm_Lower : public ::mlir::RewritePattern { ...@@ -114,15 +216,6 @@ struct PD2TRT_Batch_Norm_Lower : public ::mlir::RewritePattern {
// attributes // attributes
::mlir::SmallVector<::mlir::NamedAttribute, 8> attributes; ::mlir::SmallVector<::mlir::NamedAttribute, 8> attributes;
{
auto mode_attr = rewriter.getI32IntegerAttr(1);
attributes.emplace_back(rewriter.getStringAttr("mode"), mode_attr);
}
{
auto axis_attr = rewriter.getI32IntegerAttr(-1);
attributes.emplace_back(rewriter.getStringAttr("axis"), axis_attr);
}
auto result = rewriter auto result = rewriter
.create<trt::ScaleNdOp>( .create<trt::ScaleNdOp>(
op->getLoc(), resultTypes, operands, attributes) op->getLoc(), resultTypes, operands, attributes)
......
...@@ -81,7 +81,9 @@ def TRT_ConvolutionOp : TRT_Op<"Convolution", [NoSideEffect]> { ...@@ -81,7 +81,9 @@ def TRT_ConvolutionOp : TRT_Op<"Convolution", [NoSideEffect]> {
I32ArrayAttr:$kernel_size, I32ArrayAttr:$kernel_size,
I32ArrayAttr:$strides, I32ArrayAttr:$strides,
I32ArrayAttr:$paddings, I32ArrayAttr:$paddings,
StrAttr:$padding_mode, I32ArrayAttr:$pre_paddings,
I32ArrayAttr:$post_paddings,
DefaultValuedAttr<SI32Attr, "0">:$padding_mode, //kEXPLICIT_ROUND_DOWN
SI32Attr:$groups, SI32Attr:$groups,
I32ArrayAttr:$dilations I32ArrayAttr:$dilations
); );
...@@ -97,11 +99,11 @@ def TRT_PoolingOp : TRT_Op<"Pooling", [NoSideEffect]> { ...@@ -97,11 +99,11 @@ def TRT_PoolingOp : TRT_Op<"Pooling", [NoSideEffect]> {
}]; }];
let arguments = (ins let arguments = (ins
DenseTensor:$input_tensor, DenseTensor:$input_tensor,
I32Attr:$pool_type, SI32Attr:$pool_type,
I32ArrayAttr:$window_size, I32ArrayAttr:$window_size,
I32ArrayAttr:$strides, I32ArrayAttr:$strides,
I32ArrayAttr:$paddings, I32ArrayAttr:$paddings,
I32Attr:$padding_mode, SI32Attr:$padding_mode,
BoolAttr:$exclusive, BoolAttr:$exclusive,
BoolAttr:$adaptive, BoolAttr:$adaptive,
StrAttr:$padding_algorithm StrAttr:$padding_algorithm
...@@ -195,11 +197,9 @@ def TRT_ScaleNdOp : TRT_Op<"ScaleNd", [NoSideEffect]> { ...@@ -195,11 +197,9 @@ def TRT_ScaleNdOp : TRT_Op<"ScaleNd", [NoSideEffect]> {
}]; }];
let arguments = (ins let arguments = (ins
DenseTensor:$input_tensor, DenseTensor:$input_tensor,
I32Attr:$mode,
DenseTensor:$shift, DenseTensor:$shift,
DenseTensor:$scale, DenseTensor:$scale,
DenseTensor:$power, Optional<DenseTensor>:$power
I32Attr:$axis
); );
let results = (outs DenseTensor:$Out); let results = (outs DenseTensor:$Out);
...@@ -214,7 +214,8 @@ def TRT_ShuffleOp : TRT_Op<"Shuffle", [NoSideEffect]> { ...@@ -214,7 +214,8 @@ def TRT_ShuffleOp : TRT_Op<"Shuffle", [NoSideEffect]> {
}]; }];
let arguments = (ins let arguments = (ins
DenseTensor:$input_tensor, DenseTensor:$input_tensor,
I32ArrayAttr:$reshape DefaultValuedAttr<SI32Attr, "1">:$start_axis,
DefaultValuedAttr<SI32Attr, "1">:$stop_axis
); );
let results = (outs DenseTensor:$Out); let results = (outs DenseTensor:$Out);
......
...@@ -141,6 +141,12 @@ namespace tensorrt { ...@@ -141,6 +141,12 @@ namespace tensorrt {
ConvFunc(op, network.get(), value_to_trt_tensor_map, value_to_tensor_map); ConvFunc(op, network.get(), value_to_trt_tensor_map, value_to_tensor_map);
} else if (trt::PoolingOp op = llvm::dyn_cast<trt::PoolingOp>(operation)) { } else if (trt::PoolingOp op = llvm::dyn_cast<trt::PoolingOp>(operation)) {
PoolFunc(op, network.get(), value_to_trt_tensor_map, value_to_tensor_map); PoolFunc(op, network.get(), value_to_trt_tensor_map, value_to_tensor_map);
} else if (trt::ShuffleOp op = llvm::dyn_cast<trt::ShuffleOp>(operation)) {
ShuffleFunc(
op, network.get(), value_to_trt_tensor_map, value_to_tensor_map);
} else if (trt::ScaleNdOp op = llvm::dyn_cast<trt::ScaleNdOp>(operation)) {
ScaleNdFunc(
op, network.get(), value_to_trt_tensor_map, value_to_tensor_map);
} else { } else {
CHECK(false) << "not supported operation."; CHECK(false) << "not supported operation.";
} }
......
...@@ -151,6 +151,77 @@ inline void FcFunc(trt::FullyConnectedOp& op, // NOLINT ...@@ -151,6 +151,77 @@ inline void FcFunc(trt::FullyConnectedOp& op, // NOLINT
nvinfer1::ITensor* out_tensor = layer->getOutput(0); nvinfer1::ITensor* out_tensor = layer->getOutput(0);
value_to_trt_tensor_map[out_repr] = out_tensor; value_to_trt_tensor_map[out_repr] = out_tensor;
} }
inline void ShuffleFunc(trt::ShuffleOp& op, // NOLINT
nvinfer1::INetworkDefinition* network,
ValueToITensorMap& value_to_trt_tensor_map, // NOLINT
ValueToTensorMap& value_to_tensor_map) { // NOLINT
mlir::Value input_tensor_repr = op.input_tensor();
nvinfer1::ITensor* input = value_to_trt_tensor_map[input_tensor_repr];
int dims = input->getDimensions().nbDims;
int start_axis = op.start_axisAttr().getInt();
int stop_axis = op.start_axisAttr().getInt();
nvinfer1::IShuffleLayer* layer = nullptr;
if (start_axis < 0) start_axis += dims + 1;
if (stop_axis < 0) stop_axis += dims + 1;
int dim_prod = 1;
nvinfer1::Dims flatten_dim;
flatten_dim.nbDims = dims - (stop_axis - start_axis);
for (int i = 0, j = 0; i < dims; ++i) {
if (start_axis <= i + 1 && i + 1 <= stop_axis) {
int dim_i = input->getDimensions().d[i];
CHECK_GT(dim_i, 0);
dim_prod *= dim_i;
if (i + 1 == stop_axis) {
flatten_dim.d[j++] = dim_prod;
}
} else {
flatten_dim.d[j++] = input->getDimensions().d[i];
}
}
layer = network->addShuffle(*value_to_trt_tensor_map[input_tensor_repr]);
CHECK_NOTNULL(layer);
layer->setReshapeDimensions(flatten_dim);
for (size_t i = 0; i < op->getNumResults(); ++i) {
nvinfer1::ITensor* out_tensor = layer->getOutput(i);
mlir::Value out_value = op->getResult(i);
value_to_trt_tensor_map[out_value] = out_tensor;
}
}
inline void ScaleNdFunc(trt::ScaleNdOp& op, // NOLINT
nvinfer1::INetworkDefinition* network,
ValueToITensorMap& value_to_trt_tensor_map, // NOLINT
ValueToTensorMap& value_to_tensor_map) { // NOLINT
mlir::Value input_tensor_repr = op.input_tensor();
nvinfer1::ITensor* input = value_to_trt_tensor_map[input_tensor_repr];
mlir::Value shift_tensor_repr = op.shift();
nvinfer1::Weights shift =
TensorToWeights(value_to_tensor_map[shift_tensor_repr]);
mlir::Value scale_tensor_repr = op.scale();
nvinfer1::Weights scale =
TensorToWeights(value_to_tensor_map[scale_tensor_repr]);
nvinfer1::Weights power_weights{nvinfer1::DataType::kFLOAT, nullptr, 0};
nvinfer1::IScaleLayer* layer = nullptr;
layer = network->addScaleNd(
*input, nvinfer1::ScaleMode::kCHANNEL, shift, scale, power_weights, 0);
CHECK_NOTNULL(layer);
for (size_t i = 0; i < op->getNumResults(); ++i) {
nvinfer1::ITensor* out_tensor = layer->getOutput(i);
mlir::Value out_value = op->getResult(i);
value_to_trt_tensor_map[out_value] = out_tensor;
}
}
} // namespace tensorrt } // namespace tensorrt
} // namespace kernel } // namespace kernel
} // namespace infrt } // namespace infrt
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册