未验证 提交 56ddd7c2 编写于 作者: Z zhoutianzi666 提交者: GitHub

remove decrease_axis in op_teller.cc , support them in slice (#43963)

上级 73f957cf
......@@ -169,7 +169,7 @@ class BatchNormOpConverter : public OpConverter {
engine_->SetWeights(op_desc.Input("Scale").front(),
std::move(combile_scale_tensor));
if (x_dim.nbDims < 3 + dynamic_shape_offset) {
layer->getOutput(0)->setName("batch_norm_out");
layer->getOutput(0)->setName(("BN: ScaleNd: " + output_name).c_str());
layer->setName(("BN: ScaleNd: (Output: " + output_name + ")").c_str());
nvinfer1::Dims squeeze_shape;
squeeze_shape.nbDims = x_dim.nbDims;
......
......@@ -44,6 +44,22 @@ class ElementwiseTensorOpConverter : public OpConverter {
for (int i = 0; i < trt_dims_y.nbDims; i++) {
trt_dims_y.d[i] = dims_y[i];
}
// this is the special case when dims_y includes batch dimension!
// we need remove batch dimension!
if (!engine_->with_dynamic_shape() &&
trt_dims_y.nbDims == (X->getDimensions().nbDims + 1)) {
trt_dims_y.nbDims--;
PADDLE_ENFORCE_EQ(trt_dims_y.d[0],
1,
platform::errors::InvalidArgument(
"Elementwise type(%s) op's Y is a weight "
"including batch dimension. Please "
"check if the 0th dimension equals 1.",
op_type_));
for (int i = 0; i < trt_dims_y.nbDims; i++) {
trt_dims_y.d[i] = trt_dims_y.d[i + 1];
}
}
Y = TRT_ENGINE_ADD_LAYER(engine_, Constant, trt_dims_y, y_weight.get())
->getOutput(0);
} else {
......
......@@ -166,6 +166,29 @@ class SliceOpConverter : public OpConverter {
}
layer = TRT_ENGINE_ADD_LAYER(
engine_, Slice, *input, trt_start_dims, trt_size_dims, trt_step_dims);
nvinfer1::Dims real_trt_size_dims;
real_trt_size_dims.nbDims = 0;
if (decrease_axises.size() > 0) {
for (size_t i = 0; i < decrease_axises.size(); i++) {
decrease_axises[i]--;
}
for (int i = 0; i < trt_size_dims.nbDims; i++) {
if (decrease_axises.end() !=
std::find(decrease_axises.begin(), decrease_axises.end(), i))
continue;
real_trt_size_dims.d[real_trt_size_dims.nbDims] = trt_size_dims.d[i];
real_trt_size_dims.nbDims++;
}
if (real_trt_size_dims.nbDims == 0) {
real_trt_size_dims.nbDims = 1;
real_trt_size_dims.d[0] = 1;
}
auto reshape_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *layer->getOutput(0));
reshape_layer->setReshapeDimensions(real_trt_size_dims);
layer = static_cast<nvinfer1::ILayer*>(reshape_layer);
}
#else
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
......
......@@ -1217,14 +1217,9 @@ bool OpTeller::Tell(const framework::ir::Node* node,
if (desc.HasAttr("decrease_axis")) {
std::vector<int> decrease_axis =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("decrease_axis"));
if (with_dynamic_shape) {
if (decrease_axis.size() > 1) {
return false;
}
} else {
if (decrease_axis.size() > 0) {
VLOG(3) << "Invalid slice decrease_axis. decrease_axis.size() > 0"
"is not supported in TensorRT";
if (!with_dynamic_shape) {
if (decrease_axis.end() !=
std::find(decrease_axis.begin(), decrease_axis.end(), 0)) {
return false;
}
}
......
......@@ -21,6 +21,109 @@ from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
# This is the special test case with weight including batch dimension
# I don't want to mess up the code written by others, so I wrote a class specifically
class TrtConvertElementwiseTest_one_input_special_case0(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self):
def generate_input(shape):
return np.random.random(shape).astype(np.float32)
def generate_weight():
return np.random.randn(1, 32, 1, 1).astype(np.float32)
for batch in [1, 4]:
for shape in [[batch, 32, 16, 32]]:
for op_type in ["elementwise_add", "elementwise_mul"]:
for axis in [-1]:
self.dims = len(shape)
dics = [{"axis": axis}]
ops_config = [{
"op_type": op_type,
"op_inputs": {
"X": ["input_data"],
"Y": ["weight"]
},
"op_outputs": {
"Out": ["output_data"]
},
"op_attrs": dics[0]
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"weight":
TensorConfig(data_gen=partial(generate_weight))
},
inputs={
"input_data":
TensorConfig(
data_gen=partial(generate_input, shape)),
},
outputs=["output_data"])
yield program_config
def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
# The input.dims[1] must be equal to the weight's length.
if self.dims == 4:
self.dynamic_shape.min_input_shape = {
"input_data": [1, 32, 4, 4]
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 32, 32, 32]
}
self.dynamic_shape.opt_input_shape = {
"input_data": [4, 32, 16, 32]
}
def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 2
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), 1e-5
def add_skip_trt_case(self):
pass
def test(self):
self.add_skip_trt_case()
self.run_test()
class TrtConvertElementwiseTest_one_input(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
......
......@@ -111,13 +111,6 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
inputs = program_config.inputs
if dynamic_shape == True and len(attrs[0]["decrease_axis"]) == 0:
return 1, 2
if dynamic_shape == True and len(attrs[0]["decrease_axis"]) != 1:
return 0, 3
if dynamic_shape == False and len(attrs[0]["decrease_axis"]) != 0:
return 0, 3
if not dynamic_shape:
for x in attrs[0]["axes"]:
if x == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册