未验证 提交 f9a4f007 编写于 作者: W wenbin 提交者: GitHub

squeeze2_op (#51146)

* squeeze2_op

* add ut

* fix ut

* fix static

* modity ut
上级 5dfbb229
...@@ -32,8 +32,22 @@ class Squeeze2OpConverter : public OpConverter { ...@@ -32,8 +32,22 @@ class Squeeze2OpConverter : public OpConverter {
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
// Get Attrs // Get Attrs
std::vector<int> axes = std::vector<int> axes;
PADDLE_GET_CONST(std::vector<int>, op_desc.GetAttr("axes")); if (op_desc.HasAttr("axes")) {
axes = PADDLE_GET_CONST(std::vector<int>, op_desc.GetAttr("axes"));
}
if (axes.size() == 0) {
for (int i = 0; i < input_dims.nbDims; i++) {
if (input_dims.d[i] == -1) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The necessary attributes of the squeeze2 operator axes is "
"missing."));
} else if (input_dims.d[i] == 1) {
axes.push_back(engine_->with_dynamic_shape() ? i : i + 1);
}
}
}
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
axes.size(), axes.size(),
0, 0,
......
...@@ -996,10 +996,29 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -996,10 +996,29 @@ struct SimpleOpTypeSetTeller : public Teller {
axes = PADDLE_GET_CONST(std::vector<int>, desc.GetAttr("axes")); axes = PADDLE_GET_CONST(std::vector<int>, desc.GetAttr("axes"));
} }
if (axes.size() == 0) { if (axes.size() == 0) {
VLOG(3) << "The necessary attributes of the squeeze2 operator axes is " auto* block = desc.Block();
if (block) {
auto input_var_name = desc.Input("X")[0];
auto* input_var_desc = block->FindVar(input_var_name);
const auto input_shape = input_var_desc->GetShape();
for (int s : input_shape) {
if (s == -1) {
VLOG(3) << "The necessary attributes of the squeeze2 operator "
"axes is "
"missing. ss ==== -1";
return false;
} else if (s == 1) {
axes.push_back(s);
}
}
}
if (axes.size() == 0) {
VLOG(3)
<< "The necessary attributes of the squeeze2 operator axes is "
"missing."; "missing.";
return false; return false;
} }
}
if (!with_dynamic_shape) { if (!with_dynamic_shape) {
if (std::find(axes.begin(), axes.end(), 0) != axes.end()) { if (std::find(axes.begin(), axes.end(), 0) != axes.end()) {
VLOG(3) << "Invalid squeeze axes. Axes having batch axis is not " VLOG(3) << "Invalid squeeze axes. Axes having batch axis is not "
......
...@@ -29,7 +29,7 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest): ...@@ -29,7 +29,7 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
attrs = [ attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops)) program_config.ops[i].attrs for i in range(len(program_config.ops))
] ]
if len(inputs['in_data'].shape) <= max(attrs[0]['axes']): if len(inputs['in_data'].shape) <= max(self.axes):
return False return False
return True return True
...@@ -37,10 +37,13 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest): ...@@ -37,10 +37,13 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
for dims in [2, 3, 4]: for dims in [2, 3, 4]:
for batch in [3, 4]: for batch in [3, 4]:
for axes in [[2], [2, 3], [-1]]: for axes in [[2], [2, 3], [-1]]:
for attr_axis in [True, False]:
self.batch = batch self.batch = batch
self.dims = dims self.dims = dims
self.axes = axes self.axes = axes
dics = [{"axes": axes}] dics = [{"axes": []}]
if attr_axis:
dics[0]["axes"] = axes
ops_config = [ ops_config = [
{ {
"op_type": "squeeze2", "op_type": "squeeze2",
...@@ -78,7 +81,9 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest): ...@@ -78,7 +81,9 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
weights={}, weights={},
inputs={ inputs={
"in_data": TensorConfig( "in_data": TensorConfig(
data_gen=partial(generate_input1, dics, batch) data_gen=partial(
generate_input1, dics, batch
)
) )
}, },
outputs=["out_data"], outputs=["out_data"],
...@@ -93,8 +98,6 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest): ...@@ -93,8 +98,6 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
max_shape = list(self.input_shape) max_shape = list(self.input_shape)
min_shape = list(self.input_shape) min_shape = list(self.input_shape)
opt_shape = list(self.input_shape) opt_shape = list(self.input_shape)
for i in range(len(self.input_shape)):
max_shape[i] = max_shape[i] + 1
self.dynamic_shape.min_input_shape = {"in_data": min_shape} self.dynamic_shape.min_input_shape = {"in_data": min_shape}
self.dynamic_shape.max_input_shape = {"in_data": max_shape} self.dynamic_shape.max_input_shape = {"in_data": max_shape}
self.dynamic_shape.opt_input_shape = {"in_data": opt_shape} self.dynamic_shape.opt_input_shape = {"in_data": opt_shape}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册