未验证 提交 0942f77e 编写于 作者: B bukejiyu 提交者: GitHub

[inference][trt] update roll op 2 gather layer (#53984)

* update roll convert
上级 f276f5d5
...@@ -26,50 +26,67 @@ class RollOpConverter : public OpConverter { ...@@ -26,50 +26,67 @@ class RollOpConverter : public OpConverter {
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, const framework::Scope& scope,
bool test_mode) override { bool test_mode) override {
VLOG(4) << "convert roll op to tensorrt Slice layer"; VLOG(4) << "convert roll op to tensorrt Gather layer";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
auto* input = engine_->GetITensor(op_desc.Input("X")[0]); auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
nvinfer1::Dims input_dims = input->getDimensions();
std::vector<int64_t> axis = std::vector<int64_t> axis =
PADDLE_GET_CONST(std::vector<int64_t>, op_desc.GetAttr("axis")); PADDLE_GET_CONST(std::vector<int64_t>, op_desc.GetAttr("axis"));
std::vector<int64_t> shifts = std::vector<int64_t> shifts =
PADDLE_GET_CONST(std::vector<int64_t>, op_desc.GetAttr("shifts")); PADDLE_GET_CONST(std::vector<int64_t>, op_desc.GetAttr("shifts"));
nvinfer1::Dims start;
start.nbDims = input_dims.nbDims;
for (int i = 0; i < start.nbDims; i++) {
start.d[i] = 0;
}
int axis_size = axis.size(); int axis_size = axis.size();
nvinfer1::ITensor* input_shape_tensor = Shape(input);
nvinfer1::ILayer* layer = nullptr;
for (int i = 0; i < axis_size; i++) { for (int i = 0; i < axis_size; i++) {
start.d[axis[i]] = (-shifts[i]) % input_dims.d[axis[i]]; auto axi = static_cast<int32_t>(axis[i]);
auto shift = static_cast<int32_t>(shifts[i]);
nvinfer1::ITensor* input_axis =
GetEleTensorOfShape(input_shape_tensor, axi);
nvinfer1::ITensor* input_shift = Add1DConstantLayer(shift);
// 1.sub_value mod input_axis
auto input1 = Sub(input_axis, input_shift);
auto tmp_div_res = FloorDiv(input1, input_axis);
auto tmp_prod_res = Prod(tmp_div_res, input_axis);
auto start = Sub(input1, tmp_prod_res);
// 2.avoid start less than 0,start mod input_axis
start = Sum(start, input_axis);
auto tmp_div_res1 = FloorDiv(start, input_axis);
auto tmp_prod_res1 = Prod(tmp_div_res1, input_axis);
start = Sub(start, tmp_prod_res1);
auto zero_tensor = Add1DConstantLayer(0);
auto step = Add1DConstantLayer(1);
// 3.make index_tensor0
auto quotient_tensor = FloorDiv(Sub(input_axis, start), step);
auto* start1 = GetEleTensorOfShape(start, 0, true);
auto fill_layer0 = TRT_ENGINE_ADD_LAYER(
engine_, Fill, nvinfer1::Dims{}, nvinfer1::FillOperation::kLINSPACE);
fill_layer0->setInput(0, *quotient_tensor);
fill_layer0->setInput(1, *start1);
fill_layer0->setInput(2, *step);
auto* index_tensor0 = fill_layer0->getOutput(0);
// 4.make index_tensor1
quotient_tensor = FloorDiv(Sub(start, zero_tensor), step);
auto* start2 = Add1DConstantLayer(0, "", true);
auto fill_layer1 = TRT_ENGINE_ADD_LAYER(
engine_, Fill, nvinfer1::Dims{}, nvinfer1::FillOperation::kLINSPACE);
fill_layer1->setInput(0, *quotient_tensor);
fill_layer1->setInput(1, *start2);
fill_layer1->setInput(2, *step);
auto* index_tensor1 = fill_layer1->getOutput(0);
std::vector<nvinfer1::ITensor*> itensors;
itensors.push_back(index_tensor0);
itensors.push_back(index_tensor1);
nvinfer1::ITensor* concat_input_tensor = Concat(itensors);
if (layer == nullptr) {
layer = TRT_ENGINE_ADD_LAYER(
engine_, Gather, *input, *concat_input_tensor, axi);
} else {
layer = TRT_ENGINE_ADD_LAYER(
engine_, Gather, *layer->getOutput(0), *concat_input_tensor, axi);
}
} }
nvinfer1::Dims stride;
stride.nbDims = input_dims.nbDims;
for (int i = 0; i < stride.nbDims; i++) {
stride.d[i] = 1;
}
nvinfer1::Dims size;
size.nbDims = input_dims.nbDims;
for (int i = 0; i < size.nbDims; i++) {
size.d[i] = 1;
}
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
auto shape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shape, *input);
auto* layer =
TRT_ENGINE_ADD_LAYER(engine_, Slice, *input, start, size, stride);
layer->setInput(2, *shape_layer->getOutput(0));
#if IS_TRT_VERSION_GE(7000)
layer->setMode(nvinfer1::SliceMode::kWRAP);
#endif
RreplenishLayerAndOutput(layer, "roll", {output_name}, test_mode); RreplenishLayerAndOutput(layer, "roll", {output_name}, test_mode);
} }
}; };
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random
import unittest import unittest
from functools import partial from functools import partial
from typing import Any, Dict, List from typing import Any, Dict, List
...@@ -33,8 +34,10 @@ class TrtConvertRollTest(TrtLayerAutoScanTest): ...@@ -33,8 +34,10 @@ class TrtConvertRollTest(TrtLayerAutoScanTest):
return True return True
def sample_program_configs(self): def sample_program_configs(self):
self.trt_param.workspace_size = random.randint(1024, 1 << 30)
def generate_input1(attrs: List[Dict[str, Any]]): def generate_input1(attrs: List[Dict[str, Any]]):
return np.ones([1, 56, 56, 192]).astype(np.float32) return np.random.random([1, 56, 56, 192]).astype(np.float32)
for axis in [[1, 2]]: for axis in [[1, 2]]:
for shifts in [[-1, -1], [-3, -3]]: for shifts in [[-1, -1], [-3, -3]]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册