未验证 提交 f9a4f007 编写于 作者: W wenbin 提交者: GitHub

squeeze2_op (#51146)

* squeeze2_op

* add ut

* fix ut

* fix static

* modity ut
上级 5dfbb229
......@@ -32,8 +32,22 @@ class Squeeze2OpConverter : public OpConverter {
auto output_name = op_desc.Output("Out")[0];
// Get Attrs
std::vector<int> axes =
PADDLE_GET_CONST(std::vector<int>, op_desc.GetAttr("axes"));
std::vector<int> axes;
if (op_desc.HasAttr("axes")) {
axes = PADDLE_GET_CONST(std::vector<int>, op_desc.GetAttr("axes"));
}
if (axes.size() == 0) {
for (int i = 0; i < input_dims.nbDims; i++) {
if (input_dims.d[i] == -1) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The necessary attributes of the squeeze2 operator axes is "
"missing."));
} else if (input_dims.d[i] == 1) {
axes.push_back(engine_->with_dynamic_shape() ? i : i + 1);
}
}
}
PADDLE_ENFORCE_GT(
axes.size(),
0,
......
......@@ -996,9 +996,28 @@ struct SimpleOpTypeSetTeller : public Teller {
axes = PADDLE_GET_CONST(std::vector<int>, desc.GetAttr("axes"));
}
if (axes.size() == 0) {
VLOG(3) << "The necessary attributes of the squeeze2 operator axes is "
"missing.";
return false;
auto* block = desc.Block();
if (block) {
auto input_var_name = desc.Input("X")[0];
auto* input_var_desc = block->FindVar(input_var_name);
const auto input_shape = input_var_desc->GetShape();
for (int s : input_shape) {
if (s == -1) {
VLOG(3) << "The necessary attributes of the squeeze2 operator "
"axes is "
"missing. ss ==== -1";
return false;
} else if (s == 1) {
axes.push_back(s);
}
}
}
if (axes.size() == 0) {
VLOG(3)
<< "The necessary attributes of the squeeze2 operator axes is "
"missing.";
return false;
}
}
if (!with_dynamic_shape) {
if (std::find(axes.begin(), axes.end(), 0) != axes.end()) {
......
......@@ -29,7 +29,7 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
if len(inputs['in_data'].shape) <= max(attrs[0]['axes']):
if len(inputs['in_data'].shape) <= max(self.axes):
return False
return True
......@@ -37,54 +37,59 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
for dims in [2, 3, 4]:
for batch in [3, 4]:
for axes in [[2], [2, 3], [-1]]:
self.batch = batch
self.dims = dims
self.axes = axes
dics = [{"axes": axes}]
ops_config = [
{
"op_type": "squeeze2",
"op_inputs": {"X": ["in_data"]},
"op_outputs": {
"Out": ["out_data"],
"XShape": ["XShape_data"],
for attr_axis in [True, False]:
self.batch = batch
self.dims = dims
self.axes = axes
dics = [{"axes": []}]
if attr_axis:
dics[0]["axes"] = axes
ops_config = [
{
"op_type": "squeeze2",
"op_inputs": {"X": ["in_data"]},
"op_outputs": {
"Out": ["out_data"],
"XShape": ["XShape_data"],
},
"op_attrs": dics[0],
}
]
# new_axes is the update of axes
new_axes = list(axes)
for i in range(len(new_axes)):
if new_axes[i] < 0:
new_axes[i] += dims
if max(new_axes) >= dims:
continue
# generate input data
self.input_shape = [1] * dims
for i in range(dims):
self.input_shape[i] = np.random.randint(1, 20)
def generate_input1(attrs: List[Dict[str, Any]], batch):
self.input_shape[0] = batch
for i in new_axes:
self.input_shape[i] = 1
return np.random.random(self.input_shape).astype(
np.float32
)
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"in_data": TensorConfig(
data_gen=partial(
generate_input1, dics, batch
)
)
},
"op_attrs": dics[0],
}
]
# new_axes is the update of axes
new_axes = list(axes)
for i in range(len(new_axes)):
if new_axes[i] < 0:
new_axes[i] += dims
if max(new_axes) >= dims:
continue
# generate input data
self.input_shape = [1] * dims
for i in range(dims):
self.input_shape[i] = np.random.randint(1, 20)
def generate_input1(attrs: List[Dict[str, Any]], batch):
self.input_shape[0] = batch
for i in new_axes:
self.input_shape[i] = 1
return np.random.random(self.input_shape).astype(
np.float32
outputs=["out_data"],
)
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"in_data": TensorConfig(
data_gen=partial(generate_input1, dics, batch)
)
},
outputs=["out_data"],
)
yield program_config
yield program_config
def sample_predictor_configs(
self, program_config
......@@ -93,8 +98,6 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
max_shape = list(self.input_shape)
min_shape = list(self.input_shape)
opt_shape = list(self.input_shape)
for i in range(len(self.input_shape)):
max_shape[i] = max_shape[i] + 1
self.dynamic_shape.min_input_shape = {"in_data": min_shape}
self.dynamic_shape.max_input_shape = {"in_data": max_shape}
self.dynamic_shape.opt_input_shape = {"in_data": opt_shape}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册