diff --git a/paddle/fluid/inference/tensorrt/convert/squeeze2_op.cc b/paddle/fluid/inference/tensorrt/convert/squeeze2_op.cc index 7ab370128bd62d65eb679522b0b7fe21d9a05436..8dd7bfd203798cc2a75d0786caf7be6631d61804 100644 --- a/paddle/fluid/inference/tensorrt/convert/squeeze2_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/squeeze2_op.cc @@ -32,8 +32,22 @@ class Squeeze2OpConverter : public OpConverter { auto output_name = op_desc.Output("Out")[0]; // Get Attrs - std::vector axes = - PADDLE_GET_CONST(std::vector, op_desc.GetAttr("axes")); + std::vector axes; + if (op_desc.HasAttr("axes")) { + axes = PADDLE_GET_CONST(std::vector, op_desc.GetAttr("axes")); + } + if (axes.size() == 0) { + for (int i = 0; i < input_dims.nbDims; i++) { + if (input_dims.d[i] == -1) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The necessary attributes of the squeeze2 operator axes is " + "missing.")); + } else if (input_dims.d[i] == 1) { + axes.push_back(engine_->with_dynamic_shape() ? i : i + 1); + } + } + } + PADDLE_ENFORCE_GT( axes.size(), 0, diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 725e4fd75e9eaa57e6af18601f1325ec1cb67ca6..744324423b66cc21e6b6001e0057a1f0c971b814 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -996,9 +996,28 @@ struct SimpleOpTypeSetTeller : public Teller { axes = PADDLE_GET_CONST(std::vector, desc.GetAttr("axes")); } if (axes.size() == 0) { - VLOG(3) << "The necessary attributes of the squeeze2 operator axes is " - "missing."; - return false; + auto* block = desc.Block(); + if (block) { + auto input_var_name = desc.Input("X")[0]; + auto* input_var_desc = block->FindVar(input_var_name); + const auto input_shape = input_var_desc->GetShape(); + for (int s : input_shape) { + if (s == -1) { + VLOG(3) << "The necessary attributes of the squeeze2 operator " + "axes is " + "missing. ss ==== -1"; + return false; + } else if (s == 1) { + axes.push_back(s); + } + } + } + if (axes.size() == 0) { + VLOG(3) + << "The necessary attributes of the squeeze2 operator axes is " + "missing."; + return false; + } } if (!with_dynamic_shape) { if (std::find(axes.begin(), axes.end(), 0) != axes.end()) { diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_squeeze2.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_squeeze2.py index d5722c742724821e8746f1163ed448e059e6d976..a24465428f632e8d2f8a95699870c32fe2902938 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_squeeze2.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_squeeze2.py @@ -29,7 +29,7 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest): attrs = [ program_config.ops[i].attrs for i in range(len(program_config.ops)) ] - if len(inputs['in_data'].shape) <= max(attrs[0]['axes']): + if len(inputs['in_data'].shape) <= max(self.axes): return False return True @@ -37,54 +37,59 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest): for dims in [2, 3, 4]: for batch in [3, 4]: for axes in [[2], [2, 3], [-1]]: - self.batch = batch - self.dims = dims - self.axes = axes - dics = [{"axes": axes}] - ops_config = [ - { - "op_type": "squeeze2", - "op_inputs": {"X": ["in_data"]}, - "op_outputs": { - "Out": ["out_data"], - "XShape": ["XShape_data"], + for attr_axis in [True, False]: + self.batch = batch + self.dims = dims + self.axes = axes + dics = [{"axes": []}] + if attr_axis: + dics[0]["axes"] = axes + ops_config = [ + { + "op_type": "squeeze2", + "op_inputs": {"X": ["in_data"]}, + "op_outputs": { + "Out": ["out_data"], + "XShape": ["XShape_data"], + }, + "op_attrs": dics[0], + } + ] + # new_axes is the update of axes + new_axes = list(axes) + for i in range(len(new_axes)): + if new_axes[i] < 0: + new_axes[i] += dims + if max(new_axes) >= dims: + continue + # generate input data + self.input_shape = [1] * dims + for i in range(dims): + self.input_shape[i] = np.random.randint(1, 20) + + def generate_input1(attrs: List[Dict[str, Any]], batch): + self.input_shape[0] = batch + for i in new_axes: + self.input_shape[i] = 1 + return np.random.random(self.input_shape).astype( + np.float32 + ) + + ops = self.generate_op_config(ops_config) + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "in_data": TensorConfig( + data_gen=partial( + generate_input1, dics, batch + ) + ) }, - "op_attrs": dics[0], - } - ] - # new_axes is the update of axes - new_axes = list(axes) - for i in range(len(new_axes)): - if new_axes[i] < 0: - new_axes[i] += dims - if max(new_axes) >= dims: - continue - # generate input data - self.input_shape = [1] * dims - for i in range(dims): - self.input_shape[i] = np.random.randint(1, 20) - - def generate_input1(attrs: List[Dict[str, Any]], batch): - self.input_shape[0] = batch - for i in new_axes: - self.input_shape[i] = 1 - return np.random.random(self.input_shape).astype( - np.float32 + outputs=["out_data"], ) - ops = self.generate_op_config(ops_config) - program_config = ProgramConfig( - ops=ops, - weights={}, - inputs={ - "in_data": TensorConfig( - data_gen=partial(generate_input1, dics, batch) - ) - }, - outputs=["out_data"], - ) - - yield program_config + yield program_config def sample_predictor_configs( self, program_config @@ -93,8 +98,6 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest): max_shape = list(self.input_shape) min_shape = list(self.input_shape) opt_shape = list(self.input_shape) - for i in range(len(self.input_shape)): - max_shape[i] = max_shape[i] + 1 self.dynamic_shape.min_input_shape = {"in_data": min_shape} self.dynamic_shape.max_input_shape = {"in_data": max_shape} self.dynamic_shape.opt_input_shape = {"in_data": opt_shape}