diff --git a/paddle/fluid/inference/tensorrt/convert/slice_op.cc b/paddle/fluid/inference/tensorrt/convert/slice_op.cc index 0081a7d8069f28d907342010d23094775c2f3728..07cd71e6d3679a2e361d4ac59d02fb854aa6a49d 100644 --- a/paddle/fluid/inference/tensorrt/convert/slice_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/slice_op.cc @@ -68,6 +68,12 @@ class SliceOpConverter : public OpConverter { starts_tensor[axes[i]] = GetEleTensorOfShape( engine_->GetITensor(op_desc.Input("StartsTensor")[0]), i); } + } else if (slice_inputs.find("StartsTensorList") != slice_inputs.end() && + op_desc.Input("StartsTensorList").size()) { + for (size_t i = 0; i < axes.size(); ++i) { + starts_tensor[axes[i]] = + engine_->GetITensor(op_desc.Input("StartsTensorList")[i]); + } } else { PADDLE_ENFORCE_EQ(starts.size(), axes.size(), @@ -97,6 +103,12 @@ class SliceOpConverter : public OpConverter { ends_tensor[axes[i]] = GetEleTensorOfShape( engine_->GetITensor(op_desc.Input("EndsTensor")[0]), i); } + } else if (slice_inputs.find("EndsTensorList") != slice_inputs.end() && + op_desc.Input("EndsTensorList").size()) { + for (size_t i = 0; i < axes.size(); ++i) { + ends_tensor[axes[i]] = + engine_->GetITensor(op_desc.Input("EndsTensorList")[i]); + } } else { PADDLE_ENFORCE_EQ(ends.size(), axes.size(), diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 4417a95cca0ef5b1be91c384990cb01689be091b..b0de399382c0dc09a44d917ca601652230a60eb9 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1381,14 +1381,10 @@ struct SimpleOpTypeSetTeller : public Teller { } } if (slice_inputs.find("StartsTensorList") != slice_inputs.end()) { - if (desc.Input("StartsTensorList").size()) { - return false; - } + VLOG(3) << "The Slice has StartsTensorList input."; } if (slice_inputs.find("EndsTensorList") != slice_inputs.end()) { - if (desc.Input("EndsTensorList").size()) { - return false; - } + VLOG(3) << "The Slice has EndsTensorList input."; } }