未验证 提交 6db1dde2 编写于 作者: W WJJ1995 提交者: GitHub

Support YOLOv8 for dynamic shape (#931)

* fixed Gemm bug

* re-lint

* fixed typo error

* support fmod=1

* fixed nonzero bug

* add nonzero test case

* Support yolov8 for onnx

* re-lint
上级 363cc2b2
...@@ -308,6 +308,13 @@ class ONNXGraph(Graph): ...@@ -308,6 +308,13 @@ class ONNXGraph(Graph):
break break
else: else:
first_i = node.inputs.index(nd.name) first_i = node.inputs.index(nd.name)
# deal with Multiple outputs correspond to one node
if self.node_map[nd.name].outputs.count(
layer_name) > 1:
new_child_name = "{}/{}".format(nd.name,
idx)
node.which_child[new_child_name] = idx
else:
node.which_child[nd.name] = idx node.which_child[nd.name] = idx
self.node_map[nd.name].index = 0 self.node_map[nd.name].index = 0
break break
......
...@@ -38,7 +38,8 @@ class OpSet13(OpSet12): ...@@ -38,7 +38,8 @@ class OpSet13(OpSet12):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = self.graph.get_input_node(node, idx=1, copy=True) axes = self.graph.get_input_node(node, idx=1, copy=True)
# deal with scalar(0D) tensor # deal with scalar(0D) tensor
if len(val_x.out_shapes[0]) == 0 and len(axes.out_shapes[0]) == 1: if len(val_x.out_shapes[0]) == 0 and len(axes.out_shapes[
0]) == 1 and len(node.out_shapes[0]) == 1:
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.reshape', 'paddle.reshape',
inputs={"x": val_x.name}, inputs={"x": val_x.name},
......
...@@ -2309,7 +2309,15 @@ class OpSet(): ...@@ -2309,7 +2309,15 @@ class OpSet():
rename_mapper=self.rename_mapper) rename_mapper=self.rename_mapper)
else: else:
layer_attrs["bias_attr"] = False layer_attrs["bias_attr"] = False
if reduce(lambda x, y: x * y, # deal with dynamic shape
if len(input_shape) == 0:
self.paddle_graph.add_layer(
"paddle.reshape",
inputs=layer_inputs,
outputs=[layer_inputs["x"]],
shape=[0, num_in_channels * num_groups, 0, -1])
if len(input_shape) != 0 and reduce(
lambda x, y: x * y,
input_shape) in [1, -1] and 1 not in input_shape: input_shape) in [1, -1] and 1 not in input_shape:
input_shape[1] = num_in_channels * num_groups input_shape[1] = num_in_channels * num_groups
input_shape[0] = 0 input_shape[0] = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册