未验证 提交 f9a4f007 编写于 作者: W wenbin 提交者: GitHub

squeeze2_op (#51146)

* squeeze2_op

* add ut

* fix ut

* fix static

* modity ut
上级 5dfbb229
......@@ -32,8 +32,22 @@ class Squeeze2OpConverter : public OpConverter {
auto output_name = op_desc.Output("Out")[0];
// Get Attrs
std::vector<int> axes =
PADDLE_GET_CONST(std::vector<int>, op_desc.GetAttr("axes"));
std::vector<int> axes;
if (op_desc.HasAttr("axes")) {
axes = PADDLE_GET_CONST(std::vector<int>, op_desc.GetAttr("axes"));
}
if (axes.size() == 0) {
for (int i = 0; i < input_dims.nbDims; i++) {
if (input_dims.d[i] == -1) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The necessary attributes of the squeeze2 operator axes is "
"missing."));
} else if (input_dims.d[i] == 1) {
axes.push_back(engine_->with_dynamic_shape() ? i : i + 1);
}
}
}
PADDLE_ENFORCE_GT(
axes.size(),
0,
......
......@@ -996,10 +996,29 @@ struct SimpleOpTypeSetTeller : public Teller {
axes = PADDLE_GET_CONST(std::vector<int>, desc.GetAttr("axes"));
}
if (axes.size() == 0) {
VLOG(3) << "The necessary attributes of the squeeze2 operator axes is "
auto* block = desc.Block();
if (block) {
auto input_var_name = desc.Input("X")[0];
auto* input_var_desc = block->FindVar(input_var_name);
const auto input_shape = input_var_desc->GetShape();
for (int s : input_shape) {
if (s == -1) {
VLOG(3) << "The necessary attributes of the squeeze2 operator "
"axes is "
"missing. ss ==== -1";
return false;
} else if (s == 1) {
axes.push_back(s);
}
}
}
if (axes.size() == 0) {
VLOG(3)
<< "The necessary attributes of the squeeze2 operator axes is "
"missing.";
return false;
}
}
if (!with_dynamic_shape) {
if (std::find(axes.begin(), axes.end(), 0) != axes.end()) {
VLOG(3) << "Invalid squeeze axes. Axes having batch axis is not "
......
......@@ -29,7 +29,7 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
if len(inputs['in_data'].shape) <= max(attrs[0]['axes']):
if len(inputs['in_data'].shape) <= max(self.axes):
return False
return True
......@@ -37,10 +37,13 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
for dims in [2, 3, 4]:
for batch in [3, 4]:
for axes in [[2], [2, 3], [-1]]:
for attr_axis in [True, False]:
self.batch = batch
self.dims = dims
self.axes = axes
dics = [{"axes": axes}]
dics = [{"axes": []}]
if attr_axis:
dics[0]["axes"] = axes
ops_config = [
{
"op_type": "squeeze2",
......@@ -78,7 +81,9 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
weights={},
inputs={
"in_data": TensorConfig(
data_gen=partial(generate_input1, dics, batch)
data_gen=partial(
generate_input1, dics, batch
)
)
},
outputs=["out_data"],
......@@ -93,8 +98,6 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
max_shape = list(self.input_shape)
min_shape = list(self.input_shape)
opt_shape = list(self.input_shape)
for i in range(len(self.input_shape)):
max_shape[i] = max_shape[i] + 1
self.dynamic_shape.min_input_shape = {"in_data": min_shape}
self.dynamic_shape.max_input_shape = {"in_data": max_shape}
self.dynamic_shape.opt_input_shape = {"in_data": opt_shape}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册