未验证 提交 685c8ce9 编写于 作者: J Jason 提交者: GitHub

Merge pull request #952 from wjj19950828/support_yolov8cls

[Bug]Support YOLOv8-cls
...@@ -796,11 +796,17 @@ class CaffeOpMapper(): ...@@ -796,11 +796,17 @@ class CaffeOpMapper():
inputs=inputs_dict, inputs=inputs_dict,
outputs=[node.layer_name + "_mul"]) outputs=[node.layer_name + "_mul"])
else: else:
new_shape = [1] * len(node.in_shapes[0])
new_shape[axis] = node.in_shapes[0][1]
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.fluid.layers.elementwise_mul", "paddle.reshape",
inputs={"x": node.layer_name + "_cparam1"},
outputs=[node.layer_name + "_cparam1"],
shape=new_shape)
self.paddle_graph.add_layer(
"paddle.multiply",
inputs=inputs_dict, inputs=inputs_dict,
outputs=[node.layer_name + "_mul"], outputs=[node.layer_name + "_mul"])
axis=axis)
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"self.create_parameter", "self.create_parameter",
inputs={}, inputs={},
......
此差异已折叠。
...@@ -311,7 +311,9 @@ class PyTorchOpMapper(): ...@@ -311,7 +311,9 @@ class PyTorchOpMapper():
last_name_segments = scope_name_part[index].split(".") last_name_segments = scope_name_part[index].split(".")
name_segments = scope_name_part[index + 1].split(".") name_segments = scope_name_part[index + 1].split(".")
for j, name in enumerate(last_name_segments): for j, name in enumerate(last_name_segments):
name_segments[j] = name if j < len(name_segments) and name_segments[j] == name:
continue
name_segments.insert(j, name)
scope_name_part[index + 1] = ".".join(name_segments) scope_name_part[index + 1] = ".".join(name_segments)
last_name = scope_name_part[-1] last_name = scope_name_part[-1]
name_segments = last_name.split(".") name_segments = last_name.split(".")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册