提交 981bdfaa 编写于 作者: C channingss

fix the bug

上级 3e0a37d1
...@@ -423,7 +423,7 @@ def aten_avg_pool2d(mapper, graph, node): ...@@ -423,7 +423,7 @@ def aten_avg_pool2d(mapper, graph, node):
graph.add_layer( graph.add_layer(
"prim.assert", "prim.assert",
inputs={}, inputs={},
outputs=[inputs_name[6]], outputs=[inputs_name[6] + "_assert"],
type="eq", type="eq",
key=mapper.attrs[inputs_name[6]], key=mapper.attrs[inputs_name[6]],
value=None) value=None)
...@@ -1473,7 +1473,7 @@ def aten_flatten(mapper, graph, node): ...@@ -1473,7 +1473,7 @@ def aten_flatten(mapper, graph, node):
graph.add_layer( graph.add_layer(
"prim.assert", "prim.assert",
inputs={}, inputs={},
outputs=[inputs_name[1]], outputs=[inputs_name[1] + "_assert"],
type='eq', type='eq',
key=mapper.attrs[inputs_name[1]], key=mapper.attrs[inputs_name[1]],
value=1) value=1)
...@@ -1481,7 +1481,7 @@ def aten_flatten(mapper, graph, node): ...@@ -1481,7 +1481,7 @@ def aten_flatten(mapper, graph, node):
graph.add_layer( graph.add_layer(
"prim.assert", "prim.assert",
inputs={}, inputs={},
outputs=[inputs_name[2]], outputs=[inputs_name[2] + "_assert"],
type='eq', type='eq',
key=mapper.attrs[inputs_name[2]], key=mapper.attrs[inputs_name[2]],
value=-1) value=-1)
...@@ -2378,7 +2378,7 @@ def aten_max_pool2d(mapper, graph, node): ...@@ -2378,7 +2378,7 @@ def aten_max_pool2d(mapper, graph, node):
graph.add_layer( graph.add_layer(
"prim.assert", "prim.assert",
inputs={}, inputs={},
outputs=[inputs_name[4]], outputs=[inputs_name[4] + "_assert"],
type="eq", type="eq",
key=mapper.attrs[inputs_name[4]], key=mapper.attrs[inputs_name[4]],
value=[1, [1, 1]]) value=[1, [1, 1]])
...@@ -3912,7 +3912,7 @@ def aten_upsample_bilinear2d(mapper, graph, node): ...@@ -3912,7 +3912,7 @@ def aten_upsample_bilinear2d(mapper, graph, node):
type="eq") type="eq")
layer_inputs["scale_factor"] = inputs_name[3] layer_inputs["scale_factor"] = inputs_name[3]
layer_attrs["align_mode"] = 0 layer_attrs["align_mode"] = 0
layer_attrs["mode"] = "bilinear" layer_attrs["mode"] = string("bilinear")
graph.add_layer( graph.add_layer(
"paddle.nn.functional.interpolate", "paddle.nn.functional.interpolate",
inputs=layer_inputs, inputs=layer_inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册