未验证 提交 f5fdb41b 编写于 作者: J Jason 提交者: GitHub

Merge pull request #488 from SunAhong1993/pytorch

fix the pytorch interpolate
...@@ -1310,7 +1310,6 @@ def aten_dim(mapper, graph, node): ...@@ -1310,7 +1310,6 @@ def aten_dim(mapper, graph, node):
""" """
scope_name = mapper.normalize_scope_name(node) scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node) inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list # 获取当前节点输出的list
...@@ -1322,9 +1321,9 @@ def aten_dim(mapper, graph, node): ...@@ -1322,9 +1321,9 @@ def aten_dim(mapper, graph, node):
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.shape", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "prim.shape", inputs=layer_inputs, outputs=[output_name], scope_name=scope_name)
graph.add_layer( graph.add_layer(
"prim.len", inputs={"input": output_name}, outputs=layer_outputs, scope_name=scope_name) "prim.len", inputs={"input": output_name}, outputs=[output_name], scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -4512,21 +4511,10 @@ def aten_upsample_bilinear2d(mapper, graph, node): ...@@ -4512,21 +4511,10 @@ def aten_upsample_bilinear2d(mapper, graph, node):
current_outputs, scope_name) current_outputs, scope_name)
layer_inputs["align_corners"] = inputs_name[2] layer_inputs["align_corners"] = inputs_name[2]
current_inputs.append(inputs_name[2]) current_inputs.append(inputs_name[2])
# # 处理输入3和4,构造assert if "size" in layer_attrs and layer_attrs["size"] is None:
# list_layer_inputs = {} mapper._check_input(graph, inputs_node[3], inputs_name[3],
# mapper._check_input(graph, inputs_node[3], inputs_name[3], current_outputs, scope_name) current_outputs, scope_name)
# list_layer_inputs["key"] = inputs_name[3] layer_inputs["scale_factor"] = inputs_name[3]
# current_inputs.append(inputs_name[3])
# mapper._check_input(graph, inputs_node[4], inputs_name[4], current_outputs, scope_name)
# list_layer_inputs["value"] = inputs_name[4]
# current_inputs.append(inputs_name[4])
# graph.add_layer(
# "prim.assert",
# inputs=list_layer_inputs,
# outputs=[output_name + "_assert"],
# scope_name=scope_name,
# type="eq")
layer_inputs["scale_factor"] = inputs_name[3]
layer_attrs["align_mode"] = 0 layer_attrs["align_mode"] = 0
layer_attrs["mode"] = string("bilinear") layer_attrs["mode"] = string("bilinear")
graph.add_layer( graph.add_layer(
...@@ -4592,7 +4580,10 @@ def aten_upsample_nearest2d(mapper, graph, node): ...@@ -4592,7 +4580,10 @@ def aten_upsample_nearest2d(mapper, graph, node):
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
if_layer.add_block(block) if_layer.add_block(block)
if_layer.inputs["input-0"] = inputs_name[1] if_layer.inputs["input-0"] = inputs_name[1]
layer_inputs["scale_factor"] = inputs_name[3] if "size" in layer_attrs and layer_attrs["size"] is None:
mapper._check_input(graph, inputs_node[3], inputs_name[3],
current_outputs, scope_name)
layer_inputs["scale_factor"] = inputs_name[3]
layer_attrs["align_mode"] = 0 layer_attrs["align_mode"] = 0
layer_attrs["mode"] = string("nearest") layer_attrs["mode"] = string("nearest")
graph.add_layer( graph.add_layer(
......
...@@ -182,7 +182,7 @@ def prim_equal(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di ...@@ -182,7 +182,7 @@ def prim_equal(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di
def prim_exception(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_exception(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
line = "raise RaiseException({})".format(get_value(layer, "input", different_attrs)) line = "raise Exception({})".format(get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
...@@ -458,10 +458,12 @@ def prim_slice(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di ...@@ -458,10 +458,12 @@ def prim_slice(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_startswith(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_startswith(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False):
line = "{} = {}.startswith({})".format(layer.outputs[0], line = "{} = {}.startswith({})".format(layer.outputs[0],
get_value(layer, "input", different_attrs), get_value(layer, "input", different_attrs),
get_value(layer, "start_str", different_attrs)) get_value(layer, "start_str", different_attrs))
if is_return_line:
return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
...@@ -471,7 +473,7 @@ def prim_str(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diff ...@@ -471,7 +473,7 @@ def prim_str(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diff
def prim_sub(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_sub(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
if int(get_value(layer, "alpha", different_attrs)) == 1: if int(float(get_value(layer, "alpha", different_attrs))) == 1:
line = "{} = {} - {}".format(layer.outputs[0], line = "{} = {} - {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
......
...@@ -41,13 +41,21 @@ class DygraphIfFuser(FuseBase): ...@@ -41,13 +41,21 @@ class DygraphIfFuser(FuseBase):
return return
for id in graph.edges_in[layer_id]: for id in graph.edges_in[layer_id]:
input_layer = graph.layers[id] input_layer = graph.layers[id]
input_layer_id = id
if input_layer.outputs == [layer.inputs["input"]]: if input_layer.outputs == [layer.inputs["input"]]:
if input_layer.kernel == "prim.if": if input_layer.kernel == "prim.if":
matches.pop(layer_id) matches.pop(layer_id)
return return
input_id = id input_id = id
break break
if list(layer.inputs.values()).count(input_layer.outputs[0]) > 1 or \
(input_layer_id in graph.edges_out and len(graph.edges_out[input_layer_id]) > 1):
matches.pop(layer_id)
return
func_name = input_layer.kernel.replace(".", "_") func_name = input_layer.kernel.replace(".", "_")
if func_name in ["prim_if", "prim_loop"]:
matches.pop(layer_id)
return
from x2paddle.op_mapper.dygraph.pytorch2paddle import prim2code from x2paddle.op_mapper.dygraph.pytorch2paddle import prim2code
func = getattr(prim2code, func_name) func = getattr(prim2code, func_name)
line = func(input_layer, is_return_line=True) line = func(input_layer, is_return_line=True)
......
...@@ -186,7 +186,6 @@ class DygraphInterpolateBilinearFuser(FuseBase): ...@@ -186,7 +186,6 @@ class DygraphInterpolateBilinearFuser(FuseBase):
inputs={ inputs={
"input": "interpolate-input-0", "input": "interpolate-input-0",
"size": "interpolate-input-3", "size": "interpolate-input-3",
"scale_factor": gen_name(21)
}, },
outputs=[gen_name(23)]) outputs=[gen_name(23)])
pattern_block_block.add_layer( pattern_block_block.add_layer(
...@@ -295,6 +294,5 @@ class DygraphInterpolateBilinearFuser(FuseBase): ...@@ -295,6 +294,5 @@ class DygraphInterpolateBilinearFuser(FuseBase):
layer = matches[layers_id[9]] layer = matches[layers_id[9]]
new_layer.outputs[0] = layer.outputs[0] new_layer.outputs[0] = layer.outputs[0]
new_layer.layer_id = layers_id[7] new_layer.layer_id = layers_id[7]
new_layer.inputs.pop("scale_factor")
new_layer.inputs["size"] = size new_layer.inputs["size"] = size
return new_layer return new_layer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册