未验证 提交 a9f3719b 编写于 作者: Z zhoutianzi666 提交者: GitHub

remove stack plugin (#44756)

上级 8e9eea7f
...@@ -41,11 +41,10 @@ class StackOpConverter : public OpConverter { ...@@ -41,11 +41,10 @@ class StackOpConverter : public OpConverter {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
auto input = op_desc.Input("X"); auto input = op_desc.Input("X");
int input_num = input.size(); int input_num = input.size();
nvinfer1::ITensor** inputs = std::vector<nvinfer1::ITensor*> inputs;
(nvinfer1::ITensor**)malloc(input_num * sizeof(nvinfer1::ITensor*));
for (int i = 0; i < input_num; ++i) { for (int i = 0; i < input_num; ++i) {
inputs[i] = engine_->GetITensor(input[i]); inputs.push_back(engine_->GetITensor(input[i]));
if (op_desc.HasAttr("out_threshold")) { if (op_desc.HasAttr("out_threshold")) {
float out_scale = float out_scale =
PADDLE_GET_CONST(float, op_desc.GetAttr("out_threshold")); PADDLE_GET_CONST(float, op_desc.GetAttr("out_threshold"));
...@@ -54,28 +53,37 @@ class StackOpConverter : public OpConverter { ...@@ -54,28 +53,37 @@ class StackOpConverter : public OpConverter {
} }
int axis = PADDLE_GET_CONST(int, op_desc.GetAttr("axis")); int axis = PADDLE_GET_CONST(int, op_desc.GetAttr("axis"));
int output_rank = inputs[0]->getDimensions().nbDims + 1;
if (axis < 0) { if (axis < 0) {
axis = axis + inputs[0]->getDimensions().nbDims + 1; axis = axis + output_rank;
} }
// Now, axis is relative to output_rank.
auto* shape_tensor = Shape(inputs[0]);
std::vector<nvinfer1::ITensor*> shape_tensor_vec;
for (int i = 0; i < output_rank; i++) {
if (i < axis) {
shape_tensor_vec.push_back(GetEleTensorOfShape(shape_tensor, i));
} else if (i > axis) {
shape_tensor_vec.push_back(GetEleTensorOfShape(shape_tensor, i - 1));
} else {
shape_tensor_vec.push_back(Add1DConstantLayer(1));
}
}
auto* after_shape_tensor = Concat(shape_tensor_vec);
for (int i = 0; i < input_num; ++i) {
auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *inputs[i]);
reshape_layer->setInput(1, *after_shape_tensor);
inputs[i] = reshape_layer->getOutput(0);
}
auto* layer = TRT_ENGINE_ADD_LAYER(
engine_, Concatenation, inputs.data(), inputs.size());
layer->setAxis(axis);
nvinfer1::ILayer* layer = nullptr;
#if IS_TRT_VERSION_GE(6000)
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
plugin::StackPluginDynamic* plugin =
new plugin::StackPluginDynamic(axis, input_num, with_fp16);
layer = engine_->AddDynamicPlugin(inputs, input_num, plugin);
PADDLE_ENFORCE_NOT_NULL(
layer,
platform::errors::InvalidArgument(
"trt stack layer in converter could not be created."));
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
auto output_name = op_desc.Output("Y").front(); auto output_name = op_desc.Output("Y").front();
RreplenishLayerAndOutput(layer, "stack", {output_name}, test_mode); RreplenishLayerAndOutput(layer, "stack", {output_name}, test_mode);
free(inputs);
} }
}; };
......
...@@ -41,33 +41,33 @@ class TrtConvertStackTest(TrtLayerAutoScanTest): ...@@ -41,33 +41,33 @@ class TrtConvertStackTest(TrtLayerAutoScanTest):
def generate_input1(attrs: List[Dict[str, Any]], batch): def generate_input1(attrs: List[Dict[str, Any]], batch):
if self.dims == 4: if self.dims == 4:
return np.ones([batch, 3, 24, 24]).astype(np.float32) return np.random.random([batch, 3, 24, 24]).astype(np.float32)
elif self.dims == 3: elif self.dims == 3:
return np.ones([batch, 3, 24]).astype(np.float32) return np.random.random([batch, 3, 24]).astype(np.float32)
elif self.dims == 2: elif self.dims == 2:
return np.ones([batch, 24]).astype(np.float32) return np.random.random([batch, 24]).astype(np.float32)
elif self.dims == 1: elif self.dims == 1:
return np.ones([24]).astype(np.float32) return np.random.random([24]).astype(np.float32)
def generate_input2(attrs: List[Dict[str, Any]], batch): def generate_input2(attrs: List[Dict[str, Any]], batch):
if self.dims == 4: if self.dims == 4:
return np.ones([batch, 3, 24, 24]).astype(np.float32) return np.random.random([batch, 3, 24, 24]).astype(np.float32)
elif self.dims == 3: elif self.dims == 3:
return np.ones([batch, 3, 24]).astype(np.float32) return np.random.random([batch, 3, 24]).astype(np.float32)
elif self.dims == 2: elif self.dims == 2:
return np.ones([batch, 24]).astype(np.float32) return np.random.random([batch, 24]).astype(np.float32)
elif self.dims == 1: elif self.dims == 1:
return np.ones([24]).astype(np.float32) return np.random.random([24]).astype(np.float32)
def generate_input3(attrs: List[Dict[str, Any]], batch): def generate_input3(attrs: List[Dict[str, Any]], batch):
if self.dims == 4: if self.dims == 4:
return np.ones([batch, 3, 24, 24]).astype(np.float32) return np.random.random([batch, 3, 24, 24]).astype(np.float32)
elif self.dims == 3: elif self.dims == 3:
return np.ones([batch, 3, 24]).astype(np.float32) return np.random.random([batch, 3, 24]).astype(np.float32)
elif self.dims == 2: elif self.dims == 2:
return np.ones([batch, 24]).astype(np.float32) return np.random.random([batch, 24]).astype(np.float32)
elif self.dims == 1: elif self.dims == 1:
return np.ones([24]).astype(np.float32) return np.random.random([24]).astype(np.float32)
for dims in [1, 2, 3, 4]: for dims in [1, 2, 3, 4]:
for batch in [1, 4]: for batch in [1, 4]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册