diff --git a/x2paddle/op_mapper/dygraph/pytorch2paddle/aten.py b/x2paddle/op_mapper/dygraph/pytorch2paddle/aten.py index 28c2e6ae45e4b777b92cafc05b35d3f5c7086e58..9e372892147fbd66fed75a42ec768b0b2d6be5a1 100644 --- a/x2paddle/op_mapper/dygraph/pytorch2paddle/aten.py +++ b/x2paddle/op_mapper/dygraph/pytorch2paddle/aten.py @@ -1310,7 +1310,6 @@ def aten_dim(mapper, graph, node): """ scope_name = mapper.normalize_scope_name(node) output_name = mapper._get_outputs_name(node)[0] - layer_outputs = [output_name] layer_inputs = {} inputs_name, inputs_node = mapper._get_inputs_name(node) # 获取当前节点输出的list @@ -1322,9 +1321,9 @@ def aten_dim(mapper, graph, node): current_inputs = list(layer_inputs.values()) graph.add_layer( - "prim.shape", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) + "prim.shape", inputs=layer_inputs, outputs=[output_name], scope_name=scope_name) graph.add_layer( - "prim.len", inputs={"input": output_name}, outputs=layer_outputs, scope_name=scope_name) + "prim.len", inputs={"input": output_name}, outputs=[output_name], scope_name=scope_name) return current_inputs, current_outputs @@ -4512,21 +4511,10 @@ def aten_upsample_bilinear2d(mapper, graph, node): current_outputs, scope_name) layer_inputs["align_corners"] = inputs_name[2] current_inputs.append(inputs_name[2]) -# # 处理输入3和4,构造assert -# list_layer_inputs = {} -# mapper._check_input(graph, inputs_node[3], inputs_name[3], current_outputs, scope_name) -# list_layer_inputs["key"] = inputs_name[3] -# current_inputs.append(inputs_name[3]) -# mapper._check_input(graph, inputs_node[4], inputs_name[4], current_outputs, scope_name) -# list_layer_inputs["value"] = inputs_name[4] -# current_inputs.append(inputs_name[4]) -# graph.add_layer( -# "prim.assert", -# inputs=list_layer_inputs, -# outputs=[output_name + "_assert"], -# scope_name=scope_name, -# type="eq") - layer_inputs["scale_factor"] = inputs_name[3] + if "size" in layer_attrs and layer_attrs["size"] is None: + mapper._check_input(graph, inputs_node[3], inputs_name[3], + current_outputs, scope_name) + layer_inputs["scale_factor"] = inputs_name[3] layer_attrs["align_mode"] = 0 layer_attrs["mode"] = string("bilinear") graph.add_layer( @@ -4592,7 +4580,10 @@ def aten_upsample_nearest2d(mapper, graph, node): block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") if_layer.add_block(block) if_layer.inputs["input-0"] = inputs_name[1] - layer_inputs["scale_factor"] = inputs_name[3] + if "size" in layer_attrs and layer_attrs["size"] is None: + mapper._check_input(graph, inputs_node[3], inputs_name[3], + current_outputs, scope_name) + layer_inputs["scale_factor"] = inputs_name[3] layer_attrs["align_mode"] = 0 layer_attrs["mode"] = string("nearest") graph.add_layer( diff --git a/x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py b/x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py index 9940d3e1568bb956e6accdffe96156549d69f3ac..0ca02fc87e156e2bb55ab10516fea27d1164abe0 100644 --- a/x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py +++ b/x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py @@ -182,7 +182,7 @@ def prim_equal(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di def prim_exception(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "raise RaiseException({})".format(get_value(layer, "input", different_attrs)) + line = "raise Exception({})".format(get_value(layer, "input", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) @@ -458,10 +458,12 @@ def prim_slice(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di forward_func.extend(gen_codes([line], indent=indent)) -def prim_startswith(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_startswith(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): line = "{} = {}.startswith({})".format(layer.outputs[0], get_value(layer, "input", different_attrs), get_value(layer, "start_str", different_attrs)) + if is_return_line: + return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) @@ -471,7 +473,7 @@ def prim_str(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diff def prim_sub(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - if int(get_value(layer, "alpha", different_attrs)) == 1: + if int(float(get_value(layer, "alpha", different_attrs))) == 1: line = "{} = {} - {}".format(layer.outputs[0], get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) diff --git a/x2paddle/optimizer/fusion/dygraph/if_fuser.py b/x2paddle/optimizer/fusion/dygraph/if_fuser.py index 70cffa7f0fe0e7b4184407b7aeaf3b224a6a1615..877dcc1219db1ccd2e1a418add496dd1de03f0ae 100644 --- a/x2paddle/optimizer/fusion/dygraph/if_fuser.py +++ b/x2paddle/optimizer/fusion/dygraph/if_fuser.py @@ -41,13 +41,21 @@ class DygraphIfFuser(FuseBase): return for id in graph.edges_in[layer_id]: input_layer = graph.layers[id] + input_layer_id = id if input_layer.outputs == [layer.inputs["input"]]: if input_layer.kernel == "prim.if": matches.pop(layer_id) return input_id = id break + if list(layer.inputs.values()).count(input_layer.outputs[0]) > 1 or \ + (input_layer_id in graph.edges_out and len(graph.edges_out[input_layer_id]) > 1): + matches.pop(layer_id) + return func_name = input_layer.kernel.replace(".", "_") + if func_name in ["prim_if", "prim_loop"]: + matches.pop(layer_id) + return from x2paddle.op_mapper.dygraph.pytorch2paddle import prim2code func = getattr(prim2code, func_name) line = func(input_layer, is_return_line=True) diff --git a/x2paddle/optimizer/fusion/dygraph/interpolate_bilinear_fuser.py b/x2paddle/optimizer/fusion/dygraph/interpolate_bilinear_fuser.py index 9e7cff3f8af1de23178d53aff3b5eaa24bc4f277..84ed97211fe16042dfa2bbce7193e19c2d6e2561 100644 --- a/x2paddle/optimizer/fusion/dygraph/interpolate_bilinear_fuser.py +++ b/x2paddle/optimizer/fusion/dygraph/interpolate_bilinear_fuser.py @@ -186,7 +186,6 @@ class DygraphInterpolateBilinearFuser(FuseBase): inputs={ "input": "interpolate-input-0", "size": "interpolate-input-3", - "scale_factor": gen_name(21) }, outputs=[gen_name(23)]) pattern_block_block.add_layer( @@ -295,6 +294,5 @@ class DygraphInterpolateBilinearFuser(FuseBase): layer = matches[layers_id[9]] new_layer.outputs[0] = layer.outputs[0] new_layer.layer_id = layers_id[7] - new_layer.inputs.pop("scale_factor") new_layer.inputs["size"] = size return new_layer