未验证 提交 0942f77e 编写于 作者: B bukejiyu 提交者: GitHub

[inference][trt] update roll op 2 gather layer (#53984)

* update roll convert
上级 f276f5d5
......@@ -26,50 +26,67 @@ class RollOpConverter : public OpConverter {
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(4) << "convert roll op to tensorrt Slice layer";
VLOG(4) << "convert roll op to tensorrt Gather layer";
framework::OpDesc op_desc(op, nullptr);
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
nvinfer1::Dims input_dims = input->getDimensions();
std::vector<int64_t> axis =
PADDLE_GET_CONST(std::vector<int64_t>, op_desc.GetAttr("axis"));
std::vector<int64_t> shifts =
PADDLE_GET_CONST(std::vector<int64_t>, op_desc.GetAttr("shifts"));
nvinfer1::Dims start;
start.nbDims = input_dims.nbDims;
for (int i = 0; i < start.nbDims; i++) {
start.d[i] = 0;
}
int axis_size = axis.size();
nvinfer1::ITensor* input_shape_tensor = Shape(input);
nvinfer1::ILayer* layer = nullptr;
for (int i = 0; i < axis_size; i++) {
start.d[axis[i]] = (-shifts[i]) % input_dims.d[axis[i]];
auto axi = static_cast<int32_t>(axis[i]);
auto shift = static_cast<int32_t>(shifts[i]);
nvinfer1::ITensor* input_axis =
GetEleTensorOfShape(input_shape_tensor, axi);
nvinfer1::ITensor* input_shift = Add1DConstantLayer(shift);
// 1.sub_value mod input_axis
auto input1 = Sub(input_axis, input_shift);
auto tmp_div_res = FloorDiv(input1, input_axis);
auto tmp_prod_res = Prod(tmp_div_res, input_axis);
auto start = Sub(input1, tmp_prod_res);
// 2.avoid start less than 0,start mod input_axis
start = Sum(start, input_axis);
auto tmp_div_res1 = FloorDiv(start, input_axis);
auto tmp_prod_res1 = Prod(tmp_div_res1, input_axis);
start = Sub(start, tmp_prod_res1);
auto zero_tensor = Add1DConstantLayer(0);
auto step = Add1DConstantLayer(1);
// 3.make index_tensor0
auto quotient_tensor = FloorDiv(Sub(input_axis, start), step);
auto* start1 = GetEleTensorOfShape(start, 0, true);
auto fill_layer0 = TRT_ENGINE_ADD_LAYER(
engine_, Fill, nvinfer1::Dims{}, nvinfer1::FillOperation::kLINSPACE);
fill_layer0->setInput(0, *quotient_tensor);
fill_layer0->setInput(1, *start1);
fill_layer0->setInput(2, *step);
auto* index_tensor0 = fill_layer0->getOutput(0);
// 4.make index_tensor1
quotient_tensor = FloorDiv(Sub(start, zero_tensor), step);
auto* start2 = Add1DConstantLayer(0, "", true);
auto fill_layer1 = TRT_ENGINE_ADD_LAYER(
engine_, Fill, nvinfer1::Dims{}, nvinfer1::FillOperation::kLINSPACE);
fill_layer1->setInput(0, *quotient_tensor);
fill_layer1->setInput(1, *start2);
fill_layer1->setInput(2, *step);
auto* index_tensor1 = fill_layer1->getOutput(0);
std::vector<nvinfer1::ITensor*> itensors;
itensors.push_back(index_tensor0);
itensors.push_back(index_tensor1);
nvinfer1::ITensor* concat_input_tensor = Concat(itensors);
if (layer == nullptr) {
layer = TRT_ENGINE_ADD_LAYER(
engine_, Gather, *input, *concat_input_tensor, axi);
} else {
layer = TRT_ENGINE_ADD_LAYER(
engine_, Gather, *layer->getOutput(0), *concat_input_tensor, axi);
}
}
nvinfer1::Dims stride;
stride.nbDims = input_dims.nbDims;
for (int i = 0; i < stride.nbDims; i++) {
stride.d[i] = 1;
}
nvinfer1::Dims size;
size.nbDims = input_dims.nbDims;
for (int i = 0; i < size.nbDims; i++) {
size.d[i] = 1;
}
auto output_name = op_desc.Output("Out")[0];
auto shape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shape, *input);
auto* layer =
TRT_ENGINE_ADD_LAYER(engine_, Slice, *input, start, size, stride);
layer->setInput(2, *shape_layer->getOutput(0));
#if IS_TRT_VERSION_GE(7000)
layer->setMode(nvinfer1::SliceMode::kWRAP);
#endif
RreplenishLayerAndOutput(layer, "roll", {output_name}, test_mode);
}
};
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
from functools import partial
from typing import Any, Dict, List
......@@ -33,8 +34,10 @@ class TrtConvertRollTest(TrtLayerAutoScanTest):
return True
def sample_program_configs(self):
self.trt_param.workspace_size = random.randint(1024, 1 << 30)
def generate_input1(attrs: List[Dict[str, Any]]):
return np.ones([1, 56, 56, 192]).astype(np.float32)
return np.random.random([1, 56, 56, 192]).astype(np.float32)
for axis in [[1, 2]]:
for shifts in [[-1, -1], [-3, -3]]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册