diff --git a/x2paddle/core/program.py b/x2paddle/core/program.py index 6dc2fc96ffa1b0d4bc0ab3a7467df45ba24bb9bf..27a26eb5ec82a78c261f3f77329ed16412822f68 100644 --- a/x2paddle/core/program.py +++ b/x2paddle/core/program.py @@ -279,8 +279,14 @@ class PaddleGraph(object): hierarchical_tree.save_source_files(save_dir) self.dump_dygraph_parameter(save_dir) else: - self.gen_dygraph_code(save_dir) - self.dump_dygraph_parameter(save_dir) + if self.source_type == "pytorch": + from x2paddle.optimizer.code_optimizer import ModuleGraph + module_graph = ModuleGraph(self) + module_graph.save_source_files(save_dir) + self.dump_dygraph_parameter(save_dir) + else: + self.gen_dygraph_code(save_dir) + self.dump_dygraph_parameter(save_dir) # 动转静 code_path = osp.join(osp.abspath(save_dir), "x2paddle_code.py") print("Exporting inference model from python code ('{}')... \n".format(code_path)) @@ -621,7 +627,7 @@ class PaddleGraph(object): layer.outputs[0])], indent=indent)) else: self.forward_func.extend(gen_codes([line], indent=indent)) - if indent == 2: + if indent == 2 and code_dir is not None: gen_main_code(code_dir) write_code(code_dir) else: diff --git a/x2paddle/op_mapper/dygraph/pytorch2paddle/aten.py b/x2paddle/op_mapper/dygraph/pytorch2paddle/aten.py index 297a0b0b6fb2e1fa48372a9ce0b1549fb05e898a..8f443f166f84687b883c96b565534fda2bdd4d92 100644 --- a/x2paddle/op_mapper/dygraph/pytorch2paddle/aten.py +++ b/x2paddle/op_mapper/dygraph/pytorch2paddle/aten.py @@ -663,6 +663,37 @@ def aten_batch_norm(mapper, graph, node): return current_inputs, current_outputs +def aten_bmm(mapper, graph, node): + """ 构造矩阵相乘的PaddleLayer。 + + TorchScript示例: + %x.222 : Tensor = aten::bmm(%32, %7) + 参数含义: + %x.222 (Tensor): 输出,矩阵相乘后的结果。 + %i.12 (list): 输入1。 + %7 (int): 输入2。 + """ + scope_name = mapper.normalize_scope_name(node) + output_name = mapper._get_outputs_name(node)[0] + layer_outputs = [output_name] + layer_inputs = {} + inputs_name, inputs_node = mapper._get_inputs_name(node) + # 获取当前节点输出的list + current_outputs = [output_name] + # 处理输入0,即%i.12 + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) + layer_inputs["x"] = inputs_name[0] + # 处理输入1,即%288 + mapper._check_input( + graph, inputs_node[1], inputs_name[1], current_outputs, scope_name, add_dim=True) + layer_inputs["y"] = inputs_name[1] + # 获取当前节点输入的list + current_inputs = list(layer_inputs.values()) + + graph.add_layer("paddle.bmm", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) + return current_inputs, current_outputs + + def aten_cat(mapper, graph, node): """ 构造连接Tensor的PaddleLayer。 @@ -885,7 +916,7 @@ def aten_constant_pad_nd(mapper, graph, node): outputs=[inputs_name[0] + "_if", output_name], scope_name=scope_name) if_layer = graph.layers[list(graph.layers.keys())[-1]] - block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph") + block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block.add_layer( "prim.sub", inputs={"y": inputs_name[0] + "_len"}, @@ -916,7 +947,7 @@ def aten_constant_pad_nd(mapper, graph, node): outputs=[output_name], scope_name=scope_name) if_layer.add_block(block) - block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph") + block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") layer_inputs["input"] = inputs_name[0] block.add_layer( kernel, inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name, **layer_attrs) @@ -1525,6 +1556,32 @@ def aten_eq(mapper, graph, node): return current_inputs, current_outputs +def aten_erf(mapper, graph, node): + """ 构造逐元素计算 Erf 激活函数的PaddleLayer。 + + TorchScript示例: + %94 : Tensor = aten::erf(%sinusoid_inp.1) + 参数含义: + %94 (Tensor): 输出,erf之后的结果。 + %sinusoid_inp.1 (Tensor): 需要进行erf的Tensor。 + """ + scope_name = mapper.normalize_scope_name(node) + output_name = mapper._get_outputs_name(node)[0] + layer_outputs = [output_name] + layer_inputs = {} + inputs_name, inputs_node = mapper._get_inputs_name(node) + # 获取当前节点输出的list + current_outputs = [output_name] + # 处理输入0,即%sinusoid_inp.1 + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) + layer_inputs["x"] = inputs_name[0] + # 获取当前节点输入、输出的list + current_inputs = list(layer_inputs.values()) + + graph.add_layer("paddle.erf", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) + return current_inputs, current_outputs + + def aten_exp(mapper, graph, node): """ 构造以自然数e为底指数运算的PaddleLayer。 @@ -1639,7 +1696,7 @@ def aten_expand_as(mapper, graph, node): outputs=[inputs_name[0] + "_if1"], scope_name=scope_name) if_layer = graph.layers[list(graph.layers.keys())[-1]] - block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph") + block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block.add_layer( "prim.type", inputs={"input": inputs_name[1]}, @@ -1652,7 +1709,7 @@ def aten_expand_as(mapper, graph, node): scope_name=scope_name, dtype=inputs_name[1] + "_type") if_layer.add_block(block) - block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph") + block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") if_layer.add_block(block) if_layer.inputs["input-0"] = inputs_name[0] if_layer.inputs["input-1"] = inputs_name[1] @@ -1663,7 +1720,7 @@ def aten_expand_as(mapper, graph, node): outputs=[inputs_name[0] + "_if2"], scope_name=scope_name) if_layer = graph.layers[list(graph.layers.keys())[-1]] - block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph") + block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block.add_layer( "fluid.layers.cast", inputs={"x": layer_outputs[0]}, @@ -1671,7 +1728,7 @@ def aten_expand_as(mapper, graph, node): scope_name=scope_name, dtype=string("bool")) if_layer.add_block(block) - block = PaddleGraph(if_layer, graph_type="dygraph") + block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") if_layer.add_block(block) if_layer.inputs["input-0"] = layer_outputs[0] # TODO(syf): check expand_as @@ -1868,10 +1925,10 @@ def aten_floor(mapper, graph, node): outputs=[inputs_name[0] + "_if"], scope_name=scope_name) if_layer = graph.layers[list(graph.layers.keys())[-1]] - block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph") + block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block.add_layer("paddle.floor", inputs=copy.deepcopy(layer_inputs), outputs=copy.deepcopy(layer_outputs), scope_name=scope_name) if_layer.add_block(block) - block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph") + block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block.add_layer("prim.floor", inputs=copy.deepcopy(layer_inputs), outputs=copy.deepcopy(layer_outputs), scope_name=scope_name) if_layer.add_block(block) if_layer.inputs["input-0"] = inputs_name[0] @@ -2569,14 +2626,14 @@ def aten_masked_fill_(mapper, graph, node): outputs=[inputs_name[2] + "_if"], scope_name=scope_name) if_layer = graph.layers[list(graph.layers.keys())[-1]] - block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph") + block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block.add_layer( "prim.equal", inputs={"input": inputs_name[1] + "_mask"}, outputs=[inputs_name[2] + "_1"], scope_name=scope_name) if_layer.add_block(block) - block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph") + block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block.add_layer( "prim.mul", inputs={"x": inputs_name[1] + "_mask", @@ -2677,14 +2734,14 @@ def aten_masked_fill(mapper, graph, node): outputs=[inputs_name[2] + "_if"], scope_name=scope_name) if_layer = graph.layers[list(graph.layers.keys())[-1]] - block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph") + block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block.add_layer( "prim.equal", inputs={"input": inputs_name[1] + "_mask"}, outputs=[inputs_name[2] + "_1"], scope_name=scope_name) if_layer.add_block(block) - block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph") + block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block.add_layer( "prim.mul", inputs={"x": inputs_name[1] + "_mask", @@ -3986,16 +4043,18 @@ def aten_sub(mapper, graph, node): """ 构造数值相减的PaddleLayer。 TorchScript示例: - %840 : int = aten::sub(%839, %836) + %840 : int = aten::sub(%839, %836, %3) 参数含义: %840 (-): 相减结果。 %839 (-): 输入数值 x。 %836 (-): 输入数值 y。 + %3 (-): alpha。 """ scope_name = mapper.normalize_scope_name(node) output_name = mapper._get_outputs_name(node)[0] layer_outputs = [output_name] layer_inputs = {} + layer_attrs = {} inputs_name, inputs_node = mapper._get_inputs_name(node) # 获取当前节点输出的list current_outputs = [output_name] @@ -4006,13 +4065,37 @@ def aten_sub(mapper, graph, node): mapper._check_input( graph, inputs_node[1], inputs_name[1], current_outputs, scope_name, add_dim=True) layer_inputs["y"] = inputs_name[1] + # 处理输入2,即%3 + if len(inputs_node) > 2: + if inputs_name[2] in mapper.attrs: + layer_attrs["alpha"] = mapper.attrs[inputs_name[2]] + else: + mapper._check_input(graph, inputs_node[2], inputs_name[2], + current_outputs, scope_name) + layer_inputs["alpha"] = inputs_name[2] + current_inputs.append(inputs_name[2]) + else: + layer_attrs["alpha"] = 1.0 # 获取当前节点输入的list current_inputs = list(layer_inputs.values()) - graph.add_layer("prim.sub", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) + graph.add_layer("prim.sub", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name, **layer_attrs) return current_inputs, current_outputs +def aten_sub_(mapper, graph, node): + """ 构造数值相减的PaddleLayer。 + + TorchScript示例: + %840 : int = aten::sub_(%839, %836, %3) + 参数含义: + %840 (-): 相减结果。 + %839 (-): 输入数值 x。 + %836 (-): 输入数值 y。 + %3 (-): alpha。 + """ + return aten_sub(mapper, graph, node) + def aten_t(mapper, graph, node): """ 构造矩阵转置的PaddleLayer。 @@ -4366,14 +4449,14 @@ def aten_upsample_bilinear2d(mapper, graph, node): outputs=[inputs_name[0] + "_if1"], scope_name=scope_name) if_layer = graph.layers[list(graph.layers.keys())[-1]] - block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph") + block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block.add_layer( "prim.var2list", inputs={"input": inputs_name[1]}, outputs=[inputs_name[1]], scope_name=scope_name) if_layer.add_block(block) - block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph") + block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") if_layer.add_block(block) if_layer.inputs["input-0"] = inputs_name[1] # 处理输入2,即%5421 diff --git a/x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py b/x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py index 78033619dfbd5f7507eb17a2e565a94dc13b9250..9940d3e1568bb956e6accdffe96156549d69f3ac 100644 --- a/x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py +++ b/x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py @@ -67,9 +67,11 @@ def prim_add_(layer, indent=1, init_func=[], forward_func=[], layer_id=None, dif forward_func.extend(gen_codes([line], indent=indent)) -def prim_and(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_and(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): line = "{} = {} and {}".format(layer.outputs[0], get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) + if is_return_line: + return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) @@ -91,16 +93,23 @@ def prim_assert(layer, indent=1, init_func=[], forward_func=[], layer_id=None, d s += "{} == {} or ".format(get_value(layer, "key"), v) if len(s) > 0: s = s[:-4] + lc=locals() + exec("assert_result = {}".format(s)) + assert_result = lc['assert_result'] line = "assert {}, \'The {} must be {}!\'".format( s, get_value(layer, "key"), get_value(layer, "value")) else: - line = "assert {} == {}, \'The {} must be {}!\'".format( - get_value(layer, "key"), - get_value(layer, "value"), - get_value(layer, "key"), get_value(layer, "value")) + s = "{} == {}".format(get_value(layer, "key"), + get_value(layer, "value")) + lc=locals() + exec("assert_result = {}".format(s)) + assert_result = lc['assert_result'] + line = "assert {}, \'The {} must be {}!\'".format( + s, get_value(layer, "key"), get_value(layer, "value")) else: raise Exception("Not implement yet!") - forward_func.extend(gen_codes([line], indent=indent)) + if not assert_result: + forward_func.extend(gen_codes([line], indent=indent)) def prim_check_dim(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): @@ -119,10 +128,12 @@ def prim_constant(layer, indent=1, init_func=[], forward_func=[], layer_id=None, forward_func.extend(gen_codes([line], indent=indent)) -def prim_contain(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_contain(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): line = "{} = {} in {}".format(layer.outputs[0], get_value(layer, "element", different_attrs), get_value(layer, "input", different_attrs)) + if is_return_line: + return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) @@ -156,10 +167,12 @@ def prim_div(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diff forward_func.extend(gen_codes([line], indent=indent)) -def prim_eq(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_eq(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None,is_return_line=False): line = "{} = {} == {}".format(layer.outputs[0], get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) + if is_return_line: + return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) @@ -198,33 +211,56 @@ def prim_getitem(layer, indent=1, init_func=[], forward_func=[], layer_id=None, forward_func.extend(gen_codes([line], indent=indent)) -def prim_gt(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_gt(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): line = "{} = {} > {}".format(layer.outputs[0], get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) + if is_return_line: + return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) def prim_if(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "if {} :".format(get_value(layer, "input", different_attrs)) - forward_func.extend(gen_codes([line], indent=indent)) - block = layer.blocks[0] - if len(block.layers) == 0: - line = "pass" - forward_func.extend(gen_codes([line], indent=indent + 1)) - else: - b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1) - init_func.extend(b_init_lines) - forward_func.extend(b_forward_lines) - block = layer.blocks[1] - if len(block.layers) > 0: - b_init_lines, b_forward_lines = block.gen_dygraph_code( - indent=indent + 1) - if len(b_forward_lines) != 0: - line = "else:" - forward_func.extend(gen_codes([line], indent=indent)) - init_func.extend(b_init_lines) - forward_func.extend(b_forward_lines) + try: + exec_s = None + for line in forward_func: + s = line.replace(" ", "") + if s.startswith("{} = ".format(get_value(layer, "input", different_attrs))): + exec_s = s.split(" = ")[1] + lc=locals() + if exec_s is not None: + exec("if_result = {}".format(exec_s)) + else: + exec("if_result = {}".format(get_value(layer, "input", different_attrs))) + if_result = lc['if_result'] + if if_result: + block = layer.blocks[0] + else: + block = layer.blocks[1] + if len(block.layers) > 0: + b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent) + init_func.extend(b_init_lines) + forward_func.extend(b_forward_lines) + except: + line = "if {} :".format(get_value(layer, "input", different_attrs)) + forward_func.extend(gen_codes([line], indent=indent)) + block = layer.blocks[0] + if len(block.layers) == 0: + line = "pass" + forward_func.extend(gen_codes([line], indent=indent + 1)) + else: + b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1) + init_func.extend(b_init_lines) + forward_func.extend(b_forward_lines) + block = layer.blocks[1] + if len(block.layers) > 0: + b_init_lines, b_forward_lines = block.gen_dygraph_code( + indent=indent + 1) + if len(b_forward_lines) != 0: + line = "else:" + forward_func.extend(gen_codes([line], indent=indent)) + init_func.extend(b_init_lines) + forward_func.extend(b_forward_lines) def prim_int(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): @@ -232,31 +268,39 @@ def prim_int(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diff forward_func.extend(gen_codes([line], indent=indent)) -def prim_is(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_is(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): line = "{} = {} is {}".format(layer.outputs[0], get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) + if is_return_line: + return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) -def prim_isinstance(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_isinstance(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): line = "{} = isinstance({}, {})".format(layer.outputs[0], get_value(layer, "input", different_attrs), layer.attrs["cls"]) + if is_return_line: + return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) -def prim_isnot(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_isnot(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): line = "{} = {} is not {}".format(layer.outputs[0], get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) + if is_return_line: + return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) -def prim_le(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_le(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): line = "{} = {} <= {}".format(layer.outputs[0], get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) + if is_return_line: + return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) @@ -273,10 +317,12 @@ def prim_len2list(layer, indent=1, init_func=[], forward_func=[], layer_id=None, forward_func.extend(gen_codes(lines, indent=indent)) -def prim_lt(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_lt(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): line = "{} = {} < {}".format(layer.outputs[0], get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) + if is_return_line: + return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) @@ -317,10 +363,12 @@ def prim_mul(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diff forward_func.extend(gen_codes([line], indent=indent)) -def prim_ne(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_ne(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): line = "{} = {} != {}".format(layer.outputs[0], get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) + if is_return_line: + return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) @@ -329,15 +377,19 @@ def prim_neg(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diff forward_func.extend(gen_codes([line], indent=indent)) -def prim_not(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_not(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): line = "{} = not {}".format(layer.outputs[0], get_value(layer, "input", different_attrs)) + if is_return_line: + return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) -def prim_or(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_or(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): line = "{} = {} or {}".format(layer.outputs[0], get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) + if is_return_line: + return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) @@ -419,9 +471,15 @@ def prim_str(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diff def prim_sub(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "{} = {} - {}".format(layer.outputs[0], - get_value(layer, "x", different_attrs), - get_value(layer, "y", different_attrs)) + if int(get_value(layer, "alpha", different_attrs)) == 1: + line = "{} = {} - {}".format(layer.outputs[0], + get_value(layer, "x", different_attrs), + get_value(layer, "y", different_attrs)) + else: + line = "{} = {} - {} * {}".format(layer.outputs[0], + get_value(layer, "x", different_attrs), + get_value(layer, "alpha", different_attrs), + get_value(layer, "y", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) diff --git a/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py b/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py index 118790772565bc565115adeffd65af22cafb6adf..0df7cb7c05c74831f405cc4de323cd5162ec5d6c 100644 --- a/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py +++ b/x2paddle/op_mapper/dygraph/pytorch2paddle/pytorch_op_mapper.py @@ -84,7 +84,7 @@ class PyTorchOpMapper(OpMapper): current_node_outputs.extend(outputs) # 初始化 - graph = PaddleGraph(parent_layer=parent_layer, graph_type="dygraph") + graph = PaddleGraph(source_type="pytorch", parent_layer=parent_layer, graph_type="dygraph") if "TopLevelTracedModule" in str(type(self.script)): graph.set_script(self.script) current_node_outputs = [] diff --git a/x2paddle/optimizer/code_optimizer/__init__.py b/x2paddle/optimizer/code_optimizer/__init__.py index 7655c8fccee445137c74543278ffc089a1751df8..6aba8a196de57797d27af44e916c349a38002b4a 100644 --- a/x2paddle/optimizer/code_optimizer/__init__.py +++ b/x2paddle/optimizer/code_optimizer/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. -from x2paddle.optimizer.code_optimizer.hierachical_tree import HierarchicalTree \ No newline at end of file +from x2paddle.optimizer.code_optimizer.hierachical_tree import HierarchicalTree +from x2paddle.optimizer.code_optimizer.module_graph import ModuleGraph \ No newline at end of file diff --git a/x2paddle/optimizer/code_optimizer/layer_code_generator.py b/x2paddle/optimizer/code_optimizer/layer_code_generator.py index 3861eb453d6c20c4cc1775716f856f9ccf5a4935..3c5589a2e90177ca3223149dceba311c0e0d0d72 100644 --- a/x2paddle/optimizer/code_optimizer/layer_code_generator.py +++ b/x2paddle/optimizer/code_optimizer/layer_code_generator.py @@ -38,7 +38,7 @@ NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn", "paddle.nn.LeakyReLU": "leakly_relu"} NN_KERNEL_WITH_PARAMS = list(NN_KERNEL_NAME.keys())[:6] -def rename_layers(layers, param_tree=None): +def rename_layers(layers, param_tree=None, is_rename_module=False): """ 对子模块的输入输出等进行重命名。 """ layers_cp = copy.deepcopy(layers) @@ -84,17 +84,23 @@ def rename_layers(layers, param_tree=None): layer.outputs[0] = new_name nn_count_dict[layer.kernel] += 1 elif i == 0 and layer.kernel == "module": - old_name = layer.outputs[0].split("/")[0] - if old_name not in nn_count_dict: - nn_count_dict[old_name] = 0 + if is_rename_module: + if param_tree is not None: + param_node = param_tree.get_node(layer.outputs[0]) + nn_param_nodes.append(param_node) + param_node.new_name = layer.outputs[0] else: - nn_count_dict[old_name] += 1 - new_name = old_name + str(nn_count_dict[old_name]) - if param_tree is not None: - param_node = param_tree.get_node(layer.outputs[0]) - nn_param_nodes.append(param_node) - param_node.new_name = new_name - layer.outputs[0] = new_name + old_name = layer.outputs[0].split("/")[0] + if old_name not in nn_count_dict: + nn_count_dict[old_name] = 0 + else: + nn_count_dict[old_name] += 1 + new_name = old_name + str(nn_count_dict[old_name]) + if param_tree is not None: + param_node = param_tree.get_node(layer.outputs[0]) + nn_param_nodes.append(param_node) + param_node.new_name = new_name + layer.outputs[0] = new_name else: old_name = layer.outputs[i] new_name = "x{}".format(count) @@ -196,6 +202,15 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()): outputs.append(output_name) else: outputs.append(output_name) + if layer.kernel == "prim.dict": + is_set_item = True + for out_layer_id in graph.edges_out[layer_id]: + out_layer = sub_layers[out_layer_id] + if out_layer.kernel != "prim.set_item": + is_set_item = False + break + if is_set_item: + outputs.append(layer.outputs[0]) no_output_count = 0 for i, (layer_id, layer) in enumerate(sub_layers.items()): if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel) or \ diff --git a/x2paddle/optimizer/code_optimizer/module_graph.py b/x2paddle/optimizer/code_optimizer/module_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..9045ba67a9fff7f8db51f43436fc99e26266dc4a --- /dev/null +++ b/x2paddle/optimizer/code_optimizer/module_graph.py @@ -0,0 +1,373 @@ +# -*- coding:UTF-8 -*- +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +import os.path as osp +from x2paddle.core.program import PaddleLayer +from x2paddle.optimizer.code_optimizer.subgraphs_union import construct_attrs_table, get_inputs_outputs +from x2paddle.optimizer.code_optimizer.layer_code_generator import gen_layer_code, rename_layers +from x2paddle.optimizer.code_optimizer.parameter_tree import PamareterNode, PamareterTree + + +NoModuleStart = ["paddle.nn.ReLU"] + +class Apriori(object): + """ 使用Apriori算法挖掘频繁子图 + 1. 构建频繁1项集 + 2. 挖掘频繁k项集 + 3. 最终k项集和节点数满足最少节点数的子图组成集合GS + + Args: + min_support (int): 子图出现次数的最小值。 + """ + def __init__(self, min_support): + self.min_support = min_support + + def is_match(self, item, sublayers): + for i in range(len(item)): + if len(sublayers) <= i or item[i] != sublayers[i].kernel: + return False + return True + + def create_C1(self): + # 构建候选1-项集 + C1 = list() + for layer_id, layer in self.layers.items(): + if layer.kernel == "paddle.to_tensor" or \ + layer.kernel == "prim.if" or \ + layer.kernel == "prim.loop": #or \ +# layer.kernel == "prim.list" or \ +# layer.kernel == "prim.tuple" or \ +# layer.kernel == "prim.dict_construct": + continue + if self.pd_graph.edges_in.get(layer_id, 0) == 0 and \ + self.pd_graph.edges_out.get(layer_id, 0) == 0: + continue + if [layer.kernel] not in C1: + C1.append([layer.kernel]) + return C1 + + def create_Ck(self, Lk_last, C1): + # 构建候选k-项集 + Ck = list() + for item in Lk_last: + for item_one in C1: + new_item = copy.deepcopy(item) + new_item.append(item_one[0]) + if new_item[0] in NoModuleStart: + continue + Ck.append(new_item) + return Ck + + def generate_Lk_by_Ck(self, Ck): + # 生成频繁k-项集 + Lk = list() + for item in Ck: + count = 0 + for i in range(len(self.layers)): + sublayers = list(self.layers.values())[i:] + if self.is_match(item, sublayers): + count += 1 + if count >= self.min_support: + Lk.append(item) + return Lk + + def run(self, graph): + self.pd_graph = graph + self.layers = graph.layers + itemset = list() + C1 = self.create_C1() + L1 = self.generate_Lk_by_Ck(C1) + Lk = L1 + while len(Lk) > 0: + Ck = self.create_Ck(Lk, C1) + Lk = self.generate_Lk_by_Ck(Ck) + itemset.extend(Lk) + return itemset + + +class DP(object): + """ 使用动动态规划找到使代码最短的组合方式。 + """ + def __init__(self, combination_itemset): + self.combination_itemset = combination_itemset + + def get_combination_id(self, combination, layers): + combination_id = list() + for layer_obj in combination: + if len(layer_obj) > 1: + kernel_itemset = list() + for layer_id in layer_obj: + kernel_itemset.append(layers[layer_id].kernel) + id = self.combination_itemset.index(kernel_itemset) + combination_id.append(id) + else: + combination_id.append(-1) + return combination_id + + def run(self, graph): + layers = graph.layers + layer_combination_list = list() + for i, (layer_id, layer) in enumerate(layers.items()): + if i == 0: + layer_combination_list.append([[layer_id]]) + continue + current_itemset = [layer_id] + kernel_itemset = [layer.kernel] + candidate_itemset = list() + min_count = len(layers) + prefix_ids = list(range(i)) + prefix_ids.reverse() + for j in prefix_ids: + current_layer_id = list(layers.keys())[j] + current_layer = list(layers.values())[j] + current_itemset.insert(0, current_layer_id) + kernel_itemset.insert(0, current_layer.kernel) + if kernel_itemset in self.combination_itemset: + current_count = len(layer_combination_list[j - 1]) + all_count = current_count + 1 + if all_count < min_count: + min_count = all_count + candidate_itemset = copy.deepcopy(current_itemset) + if j - 1 < 0: + last_itemset = list() + else: + last_itemset = copy.deepcopy(layer_combination_list[j - 1]) + else: + if j == prefix_ids[0]: + min_count = len(layer_combination_list[j]) + 1 + current_itemset.pop(0) + candidate_itemset = copy.deepcopy(current_itemset) + last_itemset = copy.deepcopy(layer_combination_list[j]) + break + last_itemset.append(candidate_itemset) + layer_combination_list.append(last_itemset) + final_combination = layer_combination_list[-1] + combination_id = self.get_combination_id(final_combination, layers) + return final_combination, combination_id + + +class ModuleGraph(object): + """ 更新PaddleGraph,生成代码。 + """ + def __init__(self, graph): + self.pd_graph = graph + self.global_layers = graph.get_global_layers() + self.codes = list() + self.param_tree = PamareterTree() + + def get_updation_information(self): + aprior = Apriori(3) + combination_itemset = aprior.run(self.pd_graph) + dp = DP(combination_itemset) + combination, combination_id = dp.run(self.pd_graph) + return combination, combination_id + + def analyze_attrs_table(self, attrs_table): + """ 分析属性表格,哪些属性取值不一致。 + """ + diff_attrs_column = list() + for column in list(attrs_table.columns): + elements = list(attrs_table.get(column)) + base = elements[0] + for element in elements[1:]: + if isinstance(base, str) and "'" not in base: + break + if element != base: + diff_attrs_column.append(column) + break + return diff_attrs_column + + def analyze_graph(self, sub_layers_list): + def is_same(sub_layers1, sub_layers2, id1, id2): + inputs1, outputs1 = ipt_opt_list[id1] + inputs2, outputs2 = ipt_opt_list[id2] + if len(inputs1) != len(inputs2) or len(outputs1) != len(outputs2): + return False + layer_id_list1 = list(sub_layers1.keys()) + layer_id_list2 = list(sub_layers2.keys()) + for i, layer_id1 in enumerate(layer_id_list1): + layer_id2 = layer_id_list2[i] + if len(self.pd_graph.edges_in[layer_id1]) != len(self.pd_graph.edges_in[layer_id2]): + return False + for j, ipt_layer_id1 in enumerate(self.pd_graph.edges_in[layer_id1]): + ipt_layer_id2 = self.pd_graph.edges_in[layer_id2][j] + if (ipt_layer_id1 in layer_id_list1) ^ (ipt_layer_id2 in layer_id_list2): + return False + if (layer_id1 in self.pd_graph.edges_out) ^ (layer_id2 in self.pd_graph.edges_out): + return False + if (layer_id1 in self.pd_graph.edges_out) and (layer_id2 in self.pd_graph.edges_out): + if (len(self.pd_graph.edges_out[layer_id1]) > 1 and len(self.pd_graph.edges_out[layer_id2]) == 1) or \ + (len(self.pd_graph.edges_out[layer_id1]) == 1 and len(self.pd_graph.edges_out[layer_id2]) > 1): + return False + for j, opt_layer_id1 in enumerate(self.pd_graph.edges_out[layer_id1]): + if len(self.pd_graph.edges_out[layer_id1]) == 1 and len(self.pd_graph.edges_out[layer_id2]) == 1: + opt_layer_id2 = self.pd_graph.edges_out[layer_id2][j] + if (opt_layer_id1 in layer_id_list1) ^ (opt_layer_id2 in layer_id_list2): + return False + return True + sub_layers_list_list = list() + id_list = list() + ipt_opt_list = list() + sub_layers_list_list.append([sub_layers_list[0]]) + id_list.append(0) + for i, sub_layer in enumerate(sub_layers_list): + ipt_opt_list.append(get_inputs_outputs(self.pd_graph, sub_layer)) + if i == 0: + continue + id_list_cp = copy.deepcopy(id_list) + for j, index in enumerate(id_list_cp): + if is_same(sub_layers_list[index], sub_layer, index, i): + sub_layers_list_list[j].append(sub_layer) + break + if j == len(id_list_cp) - 1: + sub_layers_list_list.append(list()) + sub_layers_list_list[j + 1].append(sub_layer) + id_list.append(i) + return sub_layers_list_list + + + def merge_node(self, sub_layers_list, attrs_table, module_name): + sub_layers = sub_layers_list[0] + diff_attrs_column = self.analyze_attrs_table(attrs_table) + sub_layers, _, _ = rename_layers(sub_layers) + code_str = gen_layer_code(self.pd_graph, + sub_layers, + module_name, + different_attrs=diff_attrs_column) + self.codes.append(code_str) + for index, sub_layers in enumerate(sub_layers_list): + inputs, outputs = get_inputs_outputs(self.pd_graph, sub_layers) + inputs_dict = dict() + for i, input in enumerate(inputs): + inputs_dict["input_{}".format(i)] = input + mn = module_name.lower() + outputs = ["{}_{}".format(mn, index)] + outputs + node_name = "{}_{}".format(module_name, index) + diff_attrs = dict() + for column in diff_attrs_column: + diff_attrs[column] = attrs_table.get(column).loc[node_name] + new_layer = PaddleLayer(id=list(sub_layers.keys())[-1], + kernel="module", + inputs=inputs_dict, + outputs=outputs, + module=module_name, + **diff_attrs) + + _, nn_param_nodes, _ = rename_layers(sub_layers, self.param_tree) + param_node = PamareterNode(old_name=outputs[0]) + for node in nn_param_nodes: + param_node.add_child(node) + self.param_tree.add_node(param_node) + + for i, (layer_id, layer) in enumerate(sub_layers.items()): + if i == len(sub_layers) - 1: + self.pd_graph.layers[layer_id] = new_layer + else: + if len(layer_id.split(".")) > 1: + continue + self.pd_graph.layers.pop(layer_id) + + self.pd_graph.build() + + def convert_subgraph_to_layer(self, combination, combination_id): + combination_id_set = set(combination_id) + for s in list(combination_id_set): + if s == -1: + continue + module_name = "Block{}".format(s) + sub_layers_list = list() + for i, c in enumerate(combination): + if len(c) > 1 and combination_id[i] == s: + sub_layers = dict() + for layer_id in c: + sub_layers[layer_id] = self.global_layers[layer_id] + sub_layers_list.append(sub_layers) + sub_layers_list_list = self.analyze_graph(sub_layers_list) + for i, sub_layers_list in enumerate(sub_layers_list_list): + if i == 0: + real_module_name = module_name + else: + real_module_name = module_name + "__{}".format(i) + if len(sub_layers_list) > 1: + attrs_table = construct_attrs_table(sub_layers_list, module_name=real_module_name) + self.merge_node(sub_layers_list, attrs_table, real_module_name) + layers, nn_param_nodes, _ = rename_layers(self.pd_graph.layers, self.param_tree, is_rename_module=True) + code_str = gen_layer_code(self.pd_graph, + layers, + self.pd_graph.name) + self.codes.append(code_str) + param_node = PamareterNode(old_name="Module") + for node in nn_param_nodes: + param_node.add_child(node) + self.param_tree.add_node(param_node) + + def update_parameters(self): + """ 更新参数。 + """ + self.param_tree.traverse() + full_old_name_list = copy.deepcopy(list(self.pd_graph.parameters.keys())) + for old_name, new_name in self.param_tree.old2new.items(): + for full_old_name in full_old_name_list: + if full_old_name.startswith("{}.".format(old_name)): + full_new_name = full_old_name.replace("{}.".format(old_name), "{}.".format(new_name)) + params = self.pd_graph.parameters.pop(full_old_name) + self.pd_graph.parameters[full_new_name] = params + if full_old_name == old_name: + full_new_name = full_old_name.replace(old_name, new_name) + params = self.pd_graph.parameters.pop(full_old_name) + self.pd_graph.parameters[full_new_name] = params + + def save_source_files(self, save_dir): + def gen_main_code(): + input_data_name = ', '.join(self.pd_graph.inputs) + run_func_list = list() + run_func_list.append("def main({}):".format(input_data_name)) + run_func_list.append(" # 共{}个输入".format(len(self.pd_graph.inputs_info))) + for k, v in self.pd_graph.inputs_info.items(): + run_func_list.append(" # {}: 形状为{},类型为{}。".format(k, v[0], v[1])) + run_func_list.extend( + [" paddle.disable_static()", + " params = paddle.load('{}/model.pdparams')".format(osp.abspath(save_dir)), + " model = {}()".format(self.pd_graph.name), + " model.set_dict(params)", + " model.eval()", + " out = model({})".format(input_data_name), + " return out"]) + return "\n".join(run_func_list) + combination, combination_id = self.get_updation_information() + self.convert_subgraph_to_layer(combination, combination_id) + self.update_parameters() + import_list = ["import paddle", + "import paddle.fluid as fluid", + "from paddle.fluid.initializer import Constant", + "from paddle.fluid.param_attr import ParamAttr", + "import math", + "from x2paddle.op_mapper.dygraph.pytorch2paddle " + \ + "import pytorch_custom_layer as x2paddle_nn" + "\n",] + import_str = "\n".join(import_list) + if not osp.exists(save_dir): + os.makedirs(save_dir) + f = open(osp.join(save_dir, 'x2paddle_code.py'), 'w') + f.write(import_str) + for code in self.codes: + f.write(code) + f.write("\n") + run_func = gen_main_code() + f.write(run_func) + f.close() + \ No newline at end of file diff --git a/x2paddle/optimizer/code_optimizer/subgraphs_union.py b/x2paddle/optimizer/code_optimizer/subgraphs_union.py index 5f16e66acfc917e02fac8aaa392a9832b5275dea..ee804eb3093caaaf99ef880acd2c65b85585714b 100644 --- a/x2paddle/optimizer/code_optimizer/subgraphs_union.py +++ b/x2paddle/optimizer/code_optimizer/subgraphs_union.py @@ -19,7 +19,7 @@ import pandas as pd from x2paddle.optimizer.code_optimizer.layer_code_generator import rename_layers -def construct_attrs_table(sub_layers_list, node_name2sub_layers): +def construct_attrs_table(sub_layers_list, node_name2sub_layers=None, module_name=None): """ 构造不同属性的表格。 """ def get_node_name(sub_layers): @@ -32,9 +32,12 @@ def construct_attrs_table(sub_layers_list, node_name2sub_layers): _, _, new_names = rename_layers(sub_layers) table = list() node_names = list() - for sub_layers in sub_layers_list: + for i, sub_layers in enumerate(sub_layers_list): attrs = dict() - node_names.append(get_node_name(sub_layers)) + if node_name2sub_layers is not None: + node_names.append(get_node_name(sub_layers)) + else: + node_names.append("{}_{}".format(module_name, i)) for i, (layer_id, layer) in enumerate(sub_layers.items()): for k, v in layer.attrs.items(): attrs[new_names[i] + "_{}".format(k)] = v diff --git a/x2paddle/optimizer/fusion/dygraph/__init__.py b/x2paddle/optimizer/fusion/dygraph/__init__.py index b74f9d29e1879609826c689766731e99582adcde..9309c8eb65b9613f1eddef8c77932e1086e51e16 100644 --- a/x2paddle/optimizer/fusion/dygraph/__init__.py +++ b/x2paddle/optimizer/fusion/dygraph/__init__.py @@ -26,6 +26,8 @@ from .dropout_fuser import DygraphDropoutFuser from .dropout_fuse_pass import DygraphDropoutFusePass from .fc_fuser import DygraphFcFuser from .fc_fuse_pass import DygraphFcFusePass +from .if_fuser import DygraphIfFuser +from .if_fuse_pass import DygraphIfFusePass from .interpolate_bilinear_fuser import DygraphInterpolateBilinearFuser from .interpolate_bilinear_fuse_pass import DygraphInterpolateBilinearFusePass from .prelu_fuser import DygraphPReLUFuser diff --git a/x2paddle/optimizer/fusion/dygraph/if_fuse_pass.py b/x2paddle/optimizer/fusion/dygraph/if_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..820b5e31939ef94f8ffdd27df42883810e83396d --- /dev/null +++ b/x2paddle/optimizer/fusion/dygraph/if_fuse_pass.py @@ -0,0 +1,33 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from x2paddle.optimizer.pass_ import Pass +from x2paddle.optimizer.fusion.dygraph import DygraphIfFuser +from x2paddle.optimizer.pass_manager import pass_register + + +@pass_register +class DygraphIfFusePass(Pass): + name = "dygraph_if_fuse_pass" + + def __init__(self): + Pass.__init__(self) + + def apply(self, graph): + fuser = DygraphIfFuser() + fuser.operate(graph, match_kind="op") + + +# 用于注册 +if_fuse_pass = DygraphIfFuser() diff --git a/x2paddle/optimizer/fusion/dygraph/if_fuser.py b/x2paddle/optimizer/fusion/dygraph/if_fuser.py new file mode 100644 index 0000000000000000000000000000000000000000..70cffa7f0fe0e7b4184407b7aeaf3b224a6a1615 --- /dev/null +++ b/x2paddle/optimizer/fusion/dygraph/if_fuser.py @@ -0,0 +1,58 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from x2paddle.optimizer.pattern_matcher import FuseBase +from x2paddle.core.program import PaddleGraph, PaddleLayer +from x2paddle.core.util import * + + +class DygraphIfFuser(FuseBase): + def __init__(self): + super(DygraphIfFuser, self).__init__(graph_type="dygraph") + + def build_pattern(self): + """ 描述需要替换的if图结构。 + if层模式python实现代码示例: + x81 = 'relu' in {'layer4': 'out', 'layer3': 'aux'} + if x81 : + ... + """ + self.pattern.add_layer( + "prim.if", inputs={"input": "if-input-0"}, outputs=["x0"]) + self.pattern.build(inputs={"input-0": "if-input-0"}) + + def insert_new_layer(self, graph, parameters, matches): + layer_id = list(matches.keys())[0] + layer = list(matches.values())[0] + if "input" not in layer.inputs: + matches.pop(layer_id) + return + for id in graph.edges_in[layer_id]: + input_layer = graph.layers[id] + if input_layer.outputs == [layer.inputs["input"]]: + if input_layer.kernel == "prim.if": + matches.pop(layer_id) + return + input_id = id + break + func_name = input_layer.kernel.replace(".", "_") + from x2paddle.op_mapper.dygraph.pytorch2paddle import prim2code + func = getattr(prim2code, func_name) + line = func(input_layer, is_return_line=True) + layer.attrs["input"] = line + layer.inputs.pop("input") + matches.pop(layer_id) + if len(input_layer.outputs) == 1: + matches[input_id] = input_layer \ No newline at end of file diff --git a/x2paddle/optimizer/optimizer.py b/x2paddle/optimizer/optimizer.py index 30557f65dc7c2822e625e7cd4c4134dc525c30ad..850677b5fcd2b3556762bbec2e4bc611e273d4f0 100644 --- a/x2paddle/optimizer/optimizer.py +++ b/x2paddle/optimizer/optimizer.py @@ -31,7 +31,8 @@ class GraphOptimizer(object): "dygraph_fc_fuse_pass", "dygraph_adaptive_pool2d_fuse_pass", "dygraph_reshape_fuse_pass", - "dygraph_dropout_fuse_pass" + "dygraph_dropout_fuse_pass", + "dygraph_if_fuse_pass" ] elif source_frame == "caffe": if paddle_type == "dygraph": diff --git a/x2paddle/optimizer/pattern_matcher.py b/x2paddle/optimizer/pattern_matcher.py index c92f83d0c75439b2027e3d3a265f2de555b48d04..22e23cae425ebaf36d223a6e812a6899e498a935 100644 --- a/x2paddle/optimizer/pattern_matcher.py +++ b/x2paddle/optimizer/pattern_matcher.py @@ -28,6 +28,8 @@ class PatternMatcher(object): self.detect_patterns_by_topo(graph) elif match_kind == "edge": self.detect_patterns_by_edge(graph) + elif match_kind == "op": + self.detect_patterns_by_op(graph) self.remove_overlapped_match() return self.matches @@ -228,6 +230,42 @@ class PatternMatcher(object): for j, block in enumerate(layer.blocks): if len(block.layers) > 0: self.detect_patterns_by_edge(layer.blocks[j]) + + def detect_patterns_by_op(self, graph): + """ 当只匹配op时使用此方式。 + """ + def get_subgraph(pattern, graph, start_index): + pattern_id2layers = pattern.get_global_layers() + pattern_ids = list(pattern_id2layers.keys()) + pattern_layer_id = pattern_ids[0] + subgraph_id2layers = dict() + layer_id = list(graph.layers.keys())[start_index] + graph_layers = graph.layers + + def update(layer_id, pattern_layer_id): + layer = graph_layers[layer_id] + pattern_layer = pattern_id2layers[pattern_layer_id] + if layer.kernel != pattern_layer.kernel: + return False + subgraph_id2layers[layer_id] = layer + + while len(subgraph_id2layers) != len(pattern_id2layers): + out = update(layer_id, pattern_layer_id) + if out == False: + return False + else: + if len(subgraph_id2layers) == len(pattern_id2layers): + return subgraph_id2layers + else: + return False + for i, (layer_id, layer) in enumerate(graph.layers.items()): + match_info = get_subgraph(self.pattern, graph, i) + if match_info: + self.matches.append(match_info) + for j, block in enumerate(layer.blocks): + if len(block.layers) > 0: + self.detect_patterns_by_op(layer.blocks[j]) + def remove_overlapped_match(self): """ 如果2个子图有重叠,只取前一个子图。 @@ -297,14 +335,11 @@ class FuseBase(object): """ 删除不需要的中间layer及其对应参数。 """ for match in self.matches: + if len(match) == 0: + continue first_layer_id = list(match.keys())[0] subgraph = get_subgraph("", first_layer_id, graph) for layer_id, layer in match.items(): - if layer.kernel == "fluid.dygraph.base.to_variable" and \ - layer.attrs["value"].startswith("params["): - param_name = layer.attrs["value"][8:-2] - if param_name in graph.parameters: - graph.parameters.pop(param_name) if layer_id in subgraph.layers: # layer_id可能是属于子图的,此时删除父layer,即删除整个子图 subgraph.layers.pop(layer_id)