diff --git a/x2paddle/convert.py b/x2paddle/convert.py index 358986911c64ba39864d7f4eb6f023ed1926e938..94eab65c2395f8aa41add379b2f8823f8633bef8 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -174,6 +174,36 @@ def onnx2paddle(model_path, save_dir, params_merge=False): print("Paddle model and code generated.") +def pytorch2paddle(model_path, save_dir): + # check pytorch installation and version + try: + import torch + version = torch.__version__ + ver_part = version.split('.') + print(ver_part) + if int(ver_part[1]) < 5: + print("[ERROR] pytorch>=1.5.0 is required") + return + except: + print( + "[ERROR] Pytorch is not installed, use \"pip install torch==1.5.0 torchvision\"." + ) + return + print("Now translating model from pytorch to paddle.") + + from x2paddle.decoder.pytorch_decoder import PyTorchDecoder + from x2paddle.op_mapper.pytorch2paddle import pytorch_op_mapper + model = PyTorchDecoder(model_path) + mapper = pytorch_op_mapper.PyTorchOpMapper(model) + mapper.graph.build() + print("Model optimizing ...") + from x2paddle.optimizer.optimizer import GraphOptimizer + graph_opt = GraphOptimizer() + graph_opt.optimize(mapper.graph) + print("Model optimized.") + mapper.graph.gen_model(save_dir) + + def paddle2onnx(model_path, save_dir, opset_version=10): from x2paddle.decoder.paddle_decoder import PaddleDecoder from x2paddle.op_mapper.paddle2onnx.paddle_op_mapper import PaddleOpMapper @@ -243,6 +273,9 @@ def main(): if args.params_merge: params_merge = True onnx2paddle(args.model, args.save_dir, params_merge) + elif args.framework == "pytorch": + assert args.model is not None, "--model should be defined while translating pytorch model" + pytorch2paddle(args.model, args.save_dir) elif args.framework == "paddle2onnx": assert args.model is not None, "--model should be defined while translating paddle model to onnx" diff --git a/x2paddle/core/program.py b/x2paddle/core/program.py index 75c4be504f77f7bbfb22f25a001d30f88da4c691..5358c80c7a9e56e7690179366e5064ee67b216b9 100644 --- a/x2paddle/core/program.py +++ b/x2paddle/core/program.py @@ -132,7 +132,8 @@ class PaddleGraph(object): if self.graph_type == "dygraph": self.get_dygraph_inputs() - self.get_dygraph_outputs() + if len(self.outputs) == 0: + self.get_dygraph_outputs() def get_global_layers(self): # 该全局layers的信息是按照拓扑排序组成的 @@ -164,8 +165,8 @@ class PaddleGraph(object): f, [ "from paddle.fluid.initializer import Constant", "from paddle.fluid.param_attr import ParamAttr", - "import paddle.fluid as fluid" - "", "def x2paddle_net():" + "import paddle.fluid as fluid", "import math", "", + "def x2paddle_net():" ], indent=0) for layer_id, layer in self.layers.items(): @@ -204,6 +205,8 @@ class PaddleGraph(object): f.close() def gen_model(self, save_dir): + if not os.path.exists(save_dir): + os.makedirs(save_dir) if self.graph_type == "static": code_dir = os.path.join(save_dir, 'model_with_code') infer_dir = os.path.join(save_dir, 'inference_model') diff --git a/x2paddle/op_mapper/pytorch2paddle/aten.py b/x2paddle/op_mapper/pytorch2paddle/aten.py index f4c4a21ac8023592e1bfa4d8a2795e538d547c95..97f1685a73a23aaed56b26291c66e7be02be020b 100644 --- a/x2paddle/op_mapper/pytorch2paddle/aten.py +++ b/x2paddle/op_mapper/pytorch2paddle/aten.py @@ -451,6 +451,35 @@ def aten_chunk(mapper, graph, node): return current_inputs, current_outputs +def aten___contains__(mapper, graph, node): + """ 构造in的PaddleLayer。 + + TorchScript示例: + %51 : bool = aten::__contains__(%50, %name.1) + 参数含义: + %51 (bool): 输出,第一个元素是否包含第二个元素。 + %50 (-): 需对比的输入1。 + %name.1 (-): 需对比的输入2。 + """ + output_name = mapper._get_outputs_name(node)[0] + layer_outputs = [output_name] + layer_inputs = {} + inputs_name, inputs_node = mapper._get_inputs_name(node) + # 获取当前节点输出的list + current_outputs = [output_name] + # 处理输入0,即%50 + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) + layer_inputs["input"] = inputs_name[0] + # 处理输入1,即%name.1 + mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs) + layer_inputs["element"] = inputs_name[1] + # 获取当前节点输入的list + current_inputs = list(layer_inputs.values()) + + graph.add_layer("prim.contain", inputs=layer_inputs, outputs=layer_outputs) + return current_inputs, current_outputs + + def aten_contiguous(mapper, graph, node): """ 构造在内存中连续存储的PaddleLayer。 @@ -545,6 +574,25 @@ def aten_conv2d(mapper, graph, node): return current_inputs, current_outputs +def aten_dict(mapper, graph, node): + """ 构造初始化dict的PaddleLayer。 + + TorchScript示例: + %features.1 : Dict(str, Tensor) = aten::dict() + 参数含义: + %features.1: 输出,初始化的dict。 + """ + output_name = mapper._get_outputs_name(node)[0] + layer_outputs = [output_name] + layer_inputs = {} + current_inputs = {} + # 获取当前节点输出的list + current_outputs = [output_name] + + graph.add_layer("prim.dict", inputs=layer_inputs, outputs=layer_outputs) + return current_inputs, current_outputs + + def aten_dim(mapper, graph, node): """ 构造获取维度的PaddleLayer。 @@ -720,6 +768,56 @@ def aten_flatten(mapper, graph, node): return current_inputs, current_outputs +def aten_Float(mapper, graph, node): + """ 构造取浮点型的PaddleLayer。 + + TorchScript示例: + %3992 : float = aten::Float(%3991) + 参数含义: + %3992 (int): 向上取整后的整数。 + %3991 (float): 需要取整的浮点数。 + """ + output_name = mapper._get_outputs_name(node)[0] + layer_outputs = [output_name] + layer_inputs = {} + inputs_name, inputs_node = mapper._get_inputs_name(node) + # 获取当前节点输出的list + current_outputs = [output_name] + # 处理输入0,即%3991 + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) + layer_inputs["input"] = inputs_name[0] + # 获取当前节点输入的list + current_inputs = list(layer_inputs.values()) + + graph.add_layer("prim.float", inputs=layer_inputs, outputs=layer_outputs) + return current_inputs, current_outputs + + +def aten_floor(mapper, graph, node): + """ 构造向上取整的PaddleLayer。 + + TorchScript示例: + %3978 : int = aten::floor(%scale.18) + 参数含义: + %3978 (int): 向上取整后的整数。 + %scale.18 (float): 需要取整的浮点数。 + """ + output_name = mapper._get_outputs_name(node)[0] + layer_outputs = [output_name] + layer_inputs = {} + inputs_name, inputs_node = mapper._get_inputs_name(node) + # 获取当前节点输出的list + current_outputs = [output_name] + # 处理输入0,即%scale.18 + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) + layer_inputs["input"] = inputs_name[0] + # 获取当前节点输入的list + current_inputs = list(layer_inputs.values()) + + graph.add_layer("prim.floor", inputs=layer_inputs, outputs=layer_outputs) + return current_inputs, current_outputs + + def aten_floordiv(mapper, graph, node): """ 构造向上取整除法的PaddleLayer。 @@ -727,7 +825,7 @@ def aten_floordiv(mapper, graph, node): %channels_per_group.2 : int = aten::floordiv(%num_channels.2, %3690) 参数含义: %channels_per_group.2 (-): 除后的结果。 - %%num_channels.2 (-): 被除数。 + %num_channels.2 (-): 被除数。 %2 (int): 除数。 """ output_name = mapper._get_outputs_name(node)[0] @@ -854,6 +952,64 @@ def aten_hardtanh_(mapper, graph, node): return current_inputs, current_outputs +def aten___is__(mapper, graph, node): + """ 构造is not的PaddleLayer。 + + TorchScript示例: + %3949 : bool = aten::__isnot__(%size.122, %3931) + 参数含义: + %3949 (bool): 输出,第一个元素是否不是第二个元素。 + %size.122 (-): 需对比的输入1。 + %3931 (-): 需对比的输入2。 + """ + output_name = mapper._get_outputs_name(node)[0] + layer_outputs = [output_name] + layer_inputs = {} + inputs_name, inputs_node = mapper._get_inputs_name(node) + # 获取当前节点输出的list + current_outputs = [output_name] + # 处理输入0,即%size.122 + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) + layer_inputs["x"] = inputs_name[0] + # 处理输入1,即%3931 + mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs) + layer_inputs["y"] = inputs_name[1] + # 获取当前节点输入的list + current_inputs = list(layer_inputs.values()) + + graph.add_layer("prim.is", inputs=layer_inputs, outputs=layer_outputs) + return current_inputs, current_outputs + + +def aten___isnot__(mapper, graph, node): + """ 构造is not的PaddleLayer。 + + TorchScript示例: + %3949 : bool = aten::__isnot__(%size.122, %3931) + 参数含义: + %3949 (bool): 输出,第一个元素是否不是第二个元素。 + %size.122 (-): 需对比的输入1。 + %3931 (-): 需对比的输入2。 + """ + output_name = mapper._get_outputs_name(node)[0] + layer_outputs = [output_name] + layer_inputs = {} + inputs_name, inputs_node = mapper._get_inputs_name(node) + # 获取当前节点输出的list + current_outputs = [output_name] + # 处理输入0,即%size.122 + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) + layer_inputs["x"] = inputs_name[0] + # 处理输入1,即%3931 + mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs) + layer_inputs["y"] = inputs_name[1] + # 获取当前节点输入的list + current_inputs = list(layer_inputs.values()) + + graph.add_layer("prim.isnot", inputs=layer_inputs, outputs=layer_outputs) + return current_inputs, current_outputs + + def aten_le(mapper, graph, node): """ 构造对比大小的PaddleLayer。 @@ -1344,6 +1500,36 @@ def aten_select(mapper, graph, node): return current_inputs, current_outputs +def aten__set_item(mapper, graph, node): + """ 构造对dict加入元素的PaddleLayer。 + + TorchScript示例: + = aten::_set_item(%features.1, %out_name.1, %x.3) + 参数含义: + %features.1 (list): dict。 + %out_name.1 (-): dict的key。 + %x.3 (-): dict的value。 + """ + layer_inputs = {} + inputs_name, inputs_node = mapper._get_inputs_name(node) + # 获取当前节点输出的list + current_outputs = [] + # 处理输入0,即%features.1 + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) + layer_inputs["dict"] = inputs_name[0] + # 处理输入1,即%out_name.1 + mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs) + layer_inputs["key"] = inputs_name[1] + # 处理输入2,即%x.3 + mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs) + layer_inputs["value"] = inputs_name[2] + # 获取当前节点输入的list + current_inputs = list(layer_inputs.values()) + + graph.add_layer("prim.set_item", inputs=layer_inputs, outputs=[]) + return current_inputs, current_outputs + + def aten_size(mapper, graph, node): """ 构造获取shape的PaddleLayer。 @@ -1569,6 +1755,70 @@ def aten_unsqueeze(mapper, graph, node): return current_inputs, current_outputs +def aten_upsample_bilinear2d(mapper, graph, node): + """ 构造使用bilinear上采样的PaddleLayer。 + + TorchScript示例: + %4997 : Tensor = aten::upsample_bilinear2d(%x.13, %4963, %5421, %4995, %4996) + 参数含义: + %4997 (Tensor): 输出,上采样后的Tensor。 + %x.13 (Tensor): 需要上采样的Tensor。 + %4963 (list): 上采样后的大小。 + %5421 (bool): 若为True,则将输入和输出张量的4个角落像素的中心对齐,并保留角点像素的值。 + %4995 (float): 高度的乘数因子。 + %4995 (float): 宽度的乘数因子。 + """ + output_name = mapper._get_outputs_name(node)[0] + layer_outputs = [output_name] + layer_inputs = {} + layer_attrs = {} + inputs_name, inputs_node = mapper._get_inputs_name(node) + # 获取当前节点输出的list + current_outputs = [output_name] + # 处理输入0,即%x.13 + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) + layer_inputs["input"] = inputs_name[0] + # 获取当前节点输入的list + current_inputs = list(layer_inputs.values()) + # 处理输入1,即%4963 + if inputs_name[1] in mapper.attrs: + layer_attrs["out_shape"] = mapper.attrs[inputs_name[1]] + else: + mapper._check_input(graph, inputs_node[1], inputs_name[1], + current_outputs) + layer_inputs["out_shape"] = inputs_name[1] + current_inputs.append(inputs_name[1]) + # 处理输入2,即%5421 + if inputs_name[2] in mapper.attrs: + layer_attrs["align_corners"] = mapper.attrs[inputs_name[2]] + else: + mapper._check_input(graph, inputs_node[2], inputs_name[2], + current_outputs) + layer_inputs["align_corners"] = inputs_name[2] + current_inputs.append(inputs_name[2]) + # 处理输入3和4,构造assert + list_layer_inputs = {} + mapper._check_input(graph, inputs_node[3], inputs_name[3], current_outputs) + list_layer_inputs["key"] = inputs_name[3] + current_inputs.append(inputs_name[3]) + mapper._check_input(graph, inputs_node[4], inputs_name[4], current_outputs) + list_layer_inputs["value"] = inputs_name[4] + current_inputs.append(inputs_name[4]) + graph.add_layer( + "prim.assert", + inputs=list_layer_inputs, + outputs=[output_name + "_assert"], + type="eq") + layer_inputs["scale"] = inputs_name[3] + layer_attrs["align_mode"] = 0 + graph.add_layer( + "fluid.layers.interpolate", + inputs=layer_inputs, + outputs=layer_outputs, + **layer_attrs) + return current_inputs, current_outputs + + def aten_view(mapper, graph, node): """ 构造调整大小的PaddleLayer。 diff --git a/x2paddle/op_mapper/pytorch2paddle/prim.py b/x2paddle/op_mapper/pytorch2paddle/prim.py index 6d1ef467c9f4aa05e3f73b057c28e597ed944d57..2305f2f5a2a9dbc019cbdf0e10a571d447e0d7ca 100644 --- a/x2paddle/op_mapper/pytorch2paddle/prim.py +++ b/x2paddle/op_mapper/pytorch2paddle/prim.py @@ -111,13 +111,14 @@ def prim_If(mapper, graph, node): %107 (bool): if判断条件。 %input.5 (Tensor): if控制流的输出,与%output.4对应。 """ - output_name = mapper._get_outputs_name(node)[0] - node_outputs = [output_name] + outputs_name = mapper._get_outputs_name(node) + node_outputs = outputs_name.copy() + current_outputs = outputs_name.copy() input_node = list(node.inputs())[0].node() script_input_unique_id = list(node.inputs())[0].unique() input_node_name = mapper.outputs_info[script_input_unique_id] - mapper._check_input(graph, input_node, input_node_name, node_outputs) - graph.add_layer("prim.if", {'input': input_node_name}, [output_name]) + mapper._check_input(graph, input_node, input_node_name, current_outputs) + graph.add_layer("prim.if", {'input': input_node_name}, node_outputs) current_layer = list(graph.layers.values())[-1] block0 = list(node.blocks())[0] block0_graph, graph_inputs0 = mapper.traverse(block0, current_layer) @@ -131,7 +132,7 @@ def prim_If(mapper, graph, node): for i, input_name in enumerate(graph_inputs1): current_layer.inputs['input-{}'.format(len0 + 1 + i)] = input_name current_layer.add_block(block1_graph) - return list(current_layer.inputs.values()), node_outputs + return list(current_layer.inputs.values()), current_outputs def prim_ListConstruct(mapper, graph, node): @@ -436,6 +437,34 @@ def prim_TupleUnpack(mapper, graph, node): return current_inputs, current_outputs +def prim_unchecked_cast(mapper, graph, node): + """ 构造确认类型的PaddleLayer。 + + TorchScript示例: + %size.64 : int[] = prim::unchecked_cast(%size.63) + 参数含义: + %size.64 (-): 输出。 + %size.63 (-): 输入。 + + 【注意】Paddle中无此用法,所以此处翻译成赋值。 + """ + output_name = mapper._get_outputs_name(node)[0] + layer_outputs = [output_name] + layer_inputs = {} + layer_attrs = {} + inputs_name, inputs_node = mapper._get_inputs_name(node) + # 获取当前节点输出的list + current_outputs = [output_name] + # 处理输入0,即%size.63 + mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) + layer_inputs["input"] = inputs_name[0] + # 获取当前节点输入的list + current_inputs = list(layer_inputs.values()) + + graph.add_layer("prim.equal", inputs=layer_inputs, outputs=layer_outputs) + return current_inputs, current_outputs + + def prim_Uninitialized(mapper, graph, node): """ 构造表示编译器永远不会使用的值的PaddleLayer,该节点转换为None。 diff --git a/x2paddle/op_mapper/pytorch2paddle/prim2code.py b/x2paddle/op_mapper/pytorch2paddle/prim2code.py index f178fa0fb9321b4dafd55eee9528a11083b0630c..c1a342e5c2e243aa6b0b84f2c49efa61bd5db403 100644 --- a/x2paddle/op_mapper/pytorch2paddle/prim2code.py +++ b/x2paddle/op_mapper/pytorch2paddle/prim2code.py @@ -62,18 +62,22 @@ def prim_append(layer, indent=1, init_func=[], forward_func=[]): def prim_assert(layer, indent=1, init_func=[], forward_func=[]): if layer.attrs["type"] == "eq": - if isinstance(layer.attrs["value"], list): + values = get_value(layer, "key") + if "value" in layer.attrs: + values = layer.attrs["value"] + if isinstance(values, list): s = "" - for v in layer.attrs["value"]: - s += "{} == {} or ".format(layer.attrs["key"], v) + for v in values: + s += "{} == {} or ".format(get_value(layer, "key"), v) if len(s) > 0: s = s[:-4] line = "assert {}, \'The {} must be {}!\'".format( - s, layer.attrs["key"], layer.attrs["value"]) + s, get_value(layer, "key"), get_value(layer, "value")) else: line = "assert {} == {}, \'The {} must be {}!\'".format( - layer.attrs["key"], layer.attrs["value"], layer.attrs["key"], - layer.attrs["value"]) + get_value(layer, "key"), + get_value(layer, "value"), + get_value(layer, "key"), get_value(layer, "value")) else: raise Exception("Not implement yet!") forward_func.extend(gen_codes([line], indent=indent)) @@ -84,6 +88,18 @@ def prim_constant(layer, indent=1, init_func=[], forward_func=[]): forward_func.extend(gen_codes([line], indent=indent)) +def prim_contain(layer, indent=1, init_func=[], forward_func=[]): + line = "{} = {} in {}".format(layer.outputs[0], + get_value(layer, "element"), + get_value(layer, "input")) + forward_func.extend(gen_codes([line], indent=indent)) + + +def prim_dict(layer, indent=1, init_func=[], forward_func=[]): + line = "{} = dict()".format(layer.outputs[0]) + forward_func.extend(gen_codes([line], indent=indent)) + + def prim_eq(layer, indent=1, init_func=[], forward_func=[]): line = "{} = {} == {}".format(layer.outputs[0], get_value(layer, "x"), get_value(layer, "y")) @@ -100,12 +116,36 @@ def prim_exception(layer, indent=1, init_func=[], forward_func=[]): forward_func.extend(gen_codes([line], indent=indent)) +def prim_float(layer, indent=1, init_func=[], forward_func=[]): + line = "{} = float({})".format(layer.outputs[0], get_value(layer, "input")) + forward_func.extend(gen_codes([line], indent=indent)) + + +def prim_floor(layer, indent=1, init_func=[], forward_func=[]): + line = "{} = math.floor({})".format(layer.outputs[0], + get_value(layer, "input")) + forward_func.extend(gen_codes([line], indent=indent)) + + def prim_floordiv(layer, indent=1, init_func=[], forward_func=[]): line = "{} = {} // {}".format(layer.outputs[0], get_value(layer, "x"), get_value(layer, "y")) forward_func.extend(gen_codes([line], indent=indent)) +def prim_getitem(layer, indent=1, init_func=[], forward_func=[]): + line = "{} = {}[{}]".format(layer.outputs[0], + get_value(layer, "list"), + get_value(layer, "index")) + forward_func.extend(gen_codes([line], indent=indent)) + + +def prim_gt(layer, indent=1, init_func=[], forward_func=[]): + line = "{} = {} > {}".format(layer.outputs[0], + get_value(layer, "x"), get_value(layer, "y")) + forward_func.extend(gen_codes([line], indent=indent)) + + def prim_if(layer, indent=1, init_func=[], forward_func=[]): line = "if {} :".format(get_value(layer, "input")) forward_func.extend(gen_codes([line], indent=indent)) @@ -123,16 +163,16 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[]): forward_func.extend(b_forward_lines) -def prim_getitem(layer, indent=1, init_func=[], forward_func=[]): - line = "{} = {}[{}]".format(layer.outputs[0], - get_value(layer, "list"), - get_value(layer, "index")) +def prim_is(layer, indent=1, init_func=[], forward_func=[]): + line = "{} = {} is {}".format(layer.outputs[0], + get_value(layer, "x"), get_value(layer, "y")) forward_func.extend(gen_codes([line], indent=indent)) -def prim_gt(layer, indent=1, init_func=[], forward_func=[]): - line = "{} = {} > {}".format(layer.outputs[0], - get_value(layer, "x"), get_value(layer, "y")) +def prim_isnot(layer, indent=1, init_func=[], forward_func=[]): + line = "{} = {} is not {}".format(layer.outputs[0], + get_value(layer, "x"), + get_value(layer, "y")) forward_func.extend(gen_codes([line], indent=indent)) @@ -239,6 +279,13 @@ def prim_set_attr(layer, indent=1, init_func=[], forward_func=[]): forward_func.extend(gen_codes([line], indent=indent)) +def prim_set_item(layer, indent=1, init_func=[], forward_func=[]): + line = "{}[{}] = {}".format( + get_value(layer, "dict"), + get_value(layer, "key"), get_value(layer, "value")) + forward_func.extend(gen_codes([line], indent=indent)) + + def prim_shape(layer, indent=1, init_func=[], forward_func=[]): line = "{} = {}.shape".format(layer.outputs[0], get_value(layer, "input")) forward_func.extend(gen_codes([line], indent=indent)) diff --git a/x2paddle/op_mapper/pytorch2paddle/pytorch_op_mapper.py b/x2paddle/op_mapper/pytorch2paddle/pytorch_op_mapper.py index 95396eb3efad544daea6ae1b2d5fecdbce4656c3..2b95218142790a1f181e46ba0c547403492f8813 100644 --- a/x2paddle/op_mapper/pytorch2paddle/pytorch_op_mapper.py +++ b/x2paddle/op_mapper/pytorch2paddle/pytorch_op_mapper.py @@ -108,9 +108,14 @@ class PyTorchOpMapper(OpMapper): parent_layer=parent_layer, index=i) _update_graph_inputs("equal", inputs, outputs) - # 设置graph的参数 + + # 设置graph的参数和输出节点 if isinstance(script_graph, torch._C.Graph): graph.set_parameters(self.paddle_params) + if hasattr(script_graph, 'return_node'): + inputs_name, inputs_node = self._get_inputs_name( + script_graph.return_node()) + graph.outputs = inputs_name return graph, graph_inputs def _get_outputs_name(self, node, attr_name=None): diff --git a/x2paddle/optimizer/fusion/__init__.py b/x2paddle/optimizer/fusion/__init__.py index 13ef304f5d927b80a5643a3f39b8bf1a7e303fc1..96303eab3d075bef5eb87fda9fc4199a5c81be2a 100644 --- a/x2paddle/optimizer/fusion/__init__.py +++ b/x2paddle/optimizer/fusion/__init__.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .fc_fuser import FcFuser -from .fc_fuse_pass import FcFusePass from .adaptive_pool2d_fuser import AdaptivePool2dFuser from .adaptive_pool2d_fuse_pass import AdaptivePool2dFusePass -from .constant_fuser import ConstantFuser -from .constant_fuse_pass import ConstantFusePass from .batchnorm2d_fuser import BatchNorm2dFuser from .batchnorm2d_fuse_pass import BatchNorm2dFusePass +from .constant_fuser import ConstantFuser +from .constant_fuse_pass import ConstantFusePass +from .fc_fuser import FcFuser +from .fc_fuse_pass import FcFusePass +from .interpolate_bilinear_fuser import InterpolateBilinearFuser +from .interpolate_bilinear_fuse_pass import InterpolateBilinearFusePass diff --git a/x2paddle/optimizer/fusion/interpolate_bilinear_fuse_pass.py b/x2paddle/optimizer/fusion/interpolate_bilinear_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..79947d6aebf7229208ae10ab9163807275a667c0 --- /dev/null +++ b/x2paddle/optimizer/fusion/interpolate_bilinear_fuse_pass.py @@ -0,0 +1,33 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from x2paddle.optimizer.pass_ import Pass +from x2paddle.optimizer.fusion import InterpolateBilinearFuser +from x2paddle.optimizer.pass_manager import pass_register + + +@pass_register +class InterpolateBilinearFusePass(Pass): + name = "interpolate_bilinear_fuse_pass" + + def __init__(self): + Pass.__init__(self) + + def apply(self, graph): + fuser = InterpolateBilinearFuser() + fuser.operate(graph, match_kind="topo") + + +# 用于注册 +interpolate_bilinear_fuse_pass = InterpolateBilinearFusePass() diff --git a/x2paddle/optimizer/fusion/interpolate_bilinear_fuser.py b/x2paddle/optimizer/fusion/interpolate_bilinear_fuser.py new file mode 100644 index 0000000000000000000000000000000000000000..f368670c8a79dd03ae0603bdda553b063906fab8 --- /dev/null +++ b/x2paddle/optimizer/fusion/interpolate_bilinear_fuser.py @@ -0,0 +1,978 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from x2paddle.optimizer.pattern_matcher import FuseBase +from x2paddle.core.program import PaddleGraph, PaddleLayer +from x2paddle.core.util import * + + +class InterpolateBilinearFuser(FuseBase): + def __init__(self): + super(InterpolateBilinearFuser, self).__init__(graph_type="dygraph") + + def build_pattern(self): + """ 描述需要替换的双线性插值图结构。 + interpolate_bilinear层模式python实现代码示例: + x2834 = 'Exception' + x2835 = None + x2836 = 2 + x2837 = 3 + x2838 = 1 + x2839 = 0 + x2840 = 4 + x2841 = 5 + x2842 = None + x2843 = x2832.shape + x2843 = len(x2843) + x2844 = x2843 - x2836 + x2845 = [] + for _x2847 in range(x2844): + x2845.append(x2835) + x2848 = (x2832, x9, x3, x3) + x2849 = x2832.shape + x2849 = len(x2849) + x2850 = x2849 == x2837 + if x2850 : + raise RaiseException(x2834) + x2851 = x2842 + else: + x2853 = x2832.shape + x2853 = len(x2853) + x2854 = x2853 == x2840 + if x2854 : + x2857 = True + x2858 = 'Exception' + x2859 = False + x2860 = None + x2861 = 'The default behavior for interpolate/upsample with float scale_factor will change in 1.6.0 to align with other frameworks/libraries, and use scale_factor directly, instead of relying on the computed output size. If you wish to keep the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details. ' + x2862 = 0 + x2863 = 9223372036854775807 + x2864 = 1 + x2865 = 2 + x2866 = None + x2867 = None + x2868 = None + x2869 = None + x2870 = None + x2871, x2872, x2873, x2874 = x2848 + x2875 = x2872 is x2860 + if x2875 : + x2878 = x2873 is x2860 + x2876 = x2878 + x2877 = x2872 + else: + x2879 = x2872 + x2876 = x2859 + x2877 = x2879 + if x2876 : + raise RaiseException(x2858) + x2882 = x2877 is not x2860 + if x2882 : + x2885 = x2877 + x2886 = x2873 is not x2860 + x2883 = x2886 + x2884 = x2885 + else: + x2883 = x2859 + x2884 = x2877 + if x2883 : + raise RaiseException(x2858) + x2887 = x2868 + x2888 = x2869 + else: + x2887 = x2873 + x2888 = x2884 + x2890 = x2887 is not x2860 + if x2890 : + x2892 = x2887 + x2893 = len(x2892) + x2894 = x2893 != x2836 + if x2894 : + raise RaiseException(x2858) + x2891 = x2892 + else: + x2891 = x2887 + x2897 = x2888 is not x2860 + if x2897 : + x2899 = x2888 + x2898 = x2899 + else: + x2898 = x2866 + if x2897 : + x2900 = x2898 + else: + x2901 = x2891 is not x2860 + if x2901 : + x2903 = x2891 + x2902 = x2903 + else: + raise RaiseException(x2858) + x2902 = x2867 + x2905 = x2874 is x2860 + if x2905 : + x2907 = len(x2902) + x2908 = x2907 > x2862 + x2912 = x2859 + x2913 = x2862 + for x2910 in range(x2863): + x2914 = x2902[x2913] + x2915 = math.floor(x2914) + x2916 = x2915 != x2914 + if x2916 : + x2917 = x2859 + x2918 = x2916 + else: + x2917 = x2870 + x2918 = x2870 + if x2916 : + x2919 = x2917 + x2920 = x2918 + else: + x2919 = x2857 + x2920 = x2916 + x2921 = x2913 + x2864 + x2922 = x2921 < x2907 + x2923 = x2922 and x2919 + x2909 = x2920 + x2910 = x2921 + if x2909 : + import warnings + warnings.warn(x2861, stacklevel=2) + x2926 = [] + for _x2928 in range(x2836): + x2929 = _x2928 + x2865 + x2930 = x2871.shape + x2931 = float(x2930) + x2932 = x2902[_x2928] + x2933 = x2931 * x2932 + x2934 = math.floor(x2933) + x2926.append(x2934) + x2900 = x2926 + x2935 = x2845[x2839] + x2936 = x2845[x2838] + assert x2935 == x2936, 'The x2935 must be x2936!' + x2937 = fluid.layers.interpolate( + input=x2832, out_shape=x2900, scale=x2935, align_corners=False, align_mode=0) + x2855 = x2937 + else: + x2938 = x2832.shape + x2938 = len(x2938) + x2939 = x2938 == x2841 + if x2939 : + raise RaiseException(x2834) + else: + raise RaiseException(x2834) + x2855 = x2842 + x2851 = x2855 + """ + + def gen_name(id): + return "x" + str(id) + + self.pattern.add_layer( + "prim.constant", + inputs={}, + outputs=[gen_name(0)], + value="Exception") + self.pattern.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(1)], value=None) + self.pattern.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(2)], value=2) + self.pattern.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(3)], value=3) + self.pattern.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(4)], value=1) + self.pattern.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(5)], value=0) + self.pattern.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(6)], value=4) + self.pattern.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(7)], value=5) + self.pattern.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(8)], value=None) + self.pattern.add_layer( + "prim.shape", + inputs={"input": "interpolate-input-0"}, + outputs=[gen_name(9)]) + self.pattern.add_layer( + "prim.len", inputs={"input": gen_name(9)}, outputs=[gen_name(9)]) + self.pattern.add_layer( + "prim.sub", + inputs={"x": gen_name(9), + "y": gen_name(2)}, + outputs=[gen_name(10)]) + self.pattern.add_layer("prim.list", inputs={}, outputs=[gen_name(11)]) + self.pattern.add_layer( + "prim.loop", + inputs={"input": gen_name(10)}, + outputs=[gen_name(12.1), gen_name(12.2)]) + loop_layer = self.pattern.layers[list(self.pattern.layers.keys())[-1]] + pattern_block = PaddleGraph(loop_layer, graph_type="dygraph") + pattern_block.add_layer( + "prim.append", + inputs={"list": gen_name(11), + "element": gen_name(1)}, + outputs=[]) + loop_layer.inputs["input-0"] = gen_name(11) + loop_layer.inputs["input-1"] = gen_name(1) + loop_layer.add_block(pattern_block) + self.pattern.add_layer( + "prim.tuple", + inputs={ + "input0": "interpolate-input-0", + "input1": "interpolate-input-1", + "input2": "interpolate-input-2", + "input3": "interpolate-input-2" + }, + outputs=[gen_name(13)]) + self.pattern.add_layer( + "prim.shape", + inputs={"input": "interpolate-input-0"}, + outputs=[gen_name(14)]) + self.pattern.add_layer( + "prim.len", inputs={"input": gen_name(14)}, outputs=[gen_name(14)]) + self.pattern.add_layer( + "prim.eq", + inputs={"x": gen_name(14), + "y": gen_name(3)}, + outputs=[gen_name(15)]) + self.pattern.add_layer( + "prim.if", inputs={"input": gen_name(15)}, outputs=[gen_name(16)]) + if_layer1 = self.pattern.layers[list(self.pattern.layers.keys())[-1]] + pattern_block = PaddleGraph(if_layer1, graph_type="dygraph") + pattern_block.add_layer( + "prim.exception", + inputs={"input": gen_name(0)}, + outputs=[gen_name(17)]) + pattern_block.add_layer( + "prim.equal", inputs={"input": gen_name(8)}, + outputs=[gen_name(16)]) + if_layer1.inputs["input-0"] = gen_name(0) + if_layer1.inputs["input-1"] = gen_name(8) + if_layer1.add_block(pattern_block) + pattern_block = PaddleGraph(if_layer1, graph_type="dygraph") + pattern_block.add_layer( + "prim.shape", + inputs={"input": "interpolate-input-0"}, + outputs=[gen_name(18)]) + pattern_block.add_layer( + "prim.len", inputs={"input": gen_name(18)}, outputs=[gen_name(18)]) + pattern_block.add_layer( + "prim.eq", + inputs={"x": gen_name(18), + "y": gen_name(6)}, + outputs=[gen_name(19)]) + pattern_block.add_layer( + "prim.if", inputs={"input": gen_name(19)}, outputs=[gen_name(20)]) + if_layer2 = pattern_block.layers[list(pattern_block.layers.keys())[-1]] + pattern_block_block = PaddleGraph(if_layer2, graph_type="dygraph") + pattern_block_block.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(21)], value=False) + pattern_block_block.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(22)], value=True) + pattern_block_block.add_layer( + "prim.constant", + inputs={}, + outputs=[gen_name(23)], + value="Exception") + pattern_block_block.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(24)], value=False) + pattern_block_block.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(25)], value=None) + pattern_block_block.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(26)], value="") + pattern_block_block.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(26.1)], value=0) + pattern_block_block.add_layer( + "prim.constant", + inputs={}, + outputs=[gen_name(26.2)], + value=9223372036854775807) + pattern_block_block.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(27)], value=1) + pattern_block_block.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(28)], value=2) + pattern_block_block.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(29)], value=None) + pattern_block_block.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(30)], value=None) + pattern_block_block.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(31)], value=None) + pattern_block_block.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(32)], value=None) + pattern_block_block.add_layer( + "prim.constant", inputs={}, outputs=[gen_name(33)], value=None) + pattern_block_block.add_layer( + "prim.tuple_unpack", + inputs={"input": gen_name(13)}, + outputs=[gen_name(34), gen_name(35), gen_name(36), gen_name(37)]) + pattern_block_block.add_layer( + "prim.is", + inputs={"x": gen_name(35), + "y": gen_name(25)}, + outputs=[gen_name(38)]) + pattern_block_block.add_layer( + "prim.if", + inputs={"input": gen_name(38)}, + outputs=[gen_name(39), gen_name(40)]) + if_layer3 = pattern_block_block.layers[list( + pattern_block_block.layers.keys())[-1]] + pattern_block_block_block = PaddleGraph(if_layer3, graph_type="dygraph") + pattern_block_block_block.add_layer( + "prim.is", + inputs={"x": gen_name(36), + "y": gen_name(25)}, + outputs=[gen_name(41)]) + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(41)}, + outputs=[gen_name(39)]) + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(35)}, + outputs=[gen_name(40)]) + if_layer3.add_block(pattern_block_block_block) + pattern_block_block_block = PaddleGraph(if_layer3, graph_type="dygraph") + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(35)}, + outputs=[gen_name(42)]) + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(24)}, + outputs=[gen_name(39)]) + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(35)}, + outputs=[gen_name(40)]) + if_layer3.add_block(pattern_block_block_block) + if_layer3.inputs.update({ + "input-0": gen_name(36), + 'input-1': gen_name(25), + 'input-2': gen_name(35), + 'input-3': gen_name(35), + 'input-4': gen_name(24) + }) + pattern_block_block.add_layer( + "prim.if", inputs={"input": gen_name(39)}, outputs=[gen_name(43)]) + if_layer4 = pattern_block_block.layers[list( + pattern_block_block.layers.keys())[-1]] + pattern_block_block_block = PaddleGraph(if_layer4, graph_type="dygraph") + pattern_block_block_block.add_layer( + "prim.exception", + inputs={"input": gen_name(23)}, + outputs=[gen_name(44)]) + if_layer4.add_block(pattern_block_block_block) + pattern_block_block_block = PaddleGraph(if_layer4, graph_type="dygraph") + if_layer4.add_block(pattern_block_block_block) + if_layer4.inputs["input-0"] = gen_name(23) + pattern_block_block.add_layer( + "prim.isnot", + inputs={"x": gen_name(40), + "y": gen_name(25)}, + outputs=[gen_name(45)]) + pattern_block_block.add_layer( + "prim.if", + inputs={"input": gen_name(45)}, + outputs=[gen_name(46), gen_name(47)]) + if_layer5 = pattern_block_block.layers[list( + pattern_block_block.layers.keys())[-1]] + pattern_block_block_block = PaddleGraph(if_layer5, graph_type="dygraph") + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(40)}, + outputs=[gen_name(48)]) + pattern_block_block_block.add_layer( + "prim.isnot", + inputs={"x": gen_name(36), + "y": gen_name(25)}, + outputs=[gen_name(49)]) + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(49)}, + outputs=[gen_name(46)]) + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(48)}, + outputs=[gen_name(47)]) + if_layer5.add_block(pattern_block_block_block) + pattern_block_block_block = PaddleGraph(if_layer5, graph_type="dygraph") + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(24)}, + outputs=[gen_name(46)]) + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(40)}, + outputs=[gen_name(47)]) + if_layer5.add_block(pattern_block_block_block) + if_layer5.inputs.update({ + "input-0": gen_name(40), + "input-1": gen_name(36), + "input-2": gen_name(25), + "input-3": gen_name(24), + "input-4": gen_name(40) + }) + pattern_block_block.add_layer( + "prim.if", + inputs={"input": gen_name(46)}, + outputs=[gen_name(50), gen_name(51)]) + if_layer6 = pattern_block_block.layers[list( + pattern_block_block.layers.keys())[-1]] + pattern_block_block_block = PaddleGraph(if_layer6, graph_type="dygraph") + pattern_block_block_block.add_layer( + "prim.exception", + inputs={"input": gen_name(23)}, + outputs=[gen_name(52)]) + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(31)}, + outputs=[gen_name(50)]) + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(32)}, + outputs=[gen_name(51)]) + if_layer6.add_block(pattern_block_block_block) + pattern_block_block_block = PaddleGraph(if_layer6, graph_type="dygraph") + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(36)}, + outputs=[gen_name(50)]) + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(47)}, + outputs=[gen_name(51)]) + if_layer6.add_block(pattern_block_block_block) + if_layer6.inputs.update({ + "input-0": gen_name(23), + "input-1": gen_name(31), + "input-2": gen_name(32), + "input-3": gen_name(36), + "input-4": gen_name(47) + }) + pattern_block_block.add_layer( + "prim.isnot", + inputs={"x": gen_name(50), + "y": gen_name(25)}, + outputs=[gen_name(53)]) + pattern_block_block.add_layer( + "prim.if", inputs={"input": gen_name(53)}, outputs=[gen_name(54)]) + if_layer7 = pattern_block_block.layers[list( + pattern_block_block.layers.keys())[-1]] + pattern_block_block_block = PaddleGraph(if_layer7, graph_type="dygraph") + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(50)}, + outputs=[gen_name(55)]) + pattern_block_block_block.add_layer( + "prim.len", inputs={"input": gen_name(55)}, outputs=[gen_name(56)]) + pattern_block_block_block.add_layer( + "prim.ne", + inputs={"x": gen_name(56), + "y": gen_name(2)}, + outputs=[gen_name(57)]) + pattern_block_block_block.add_layer( + "prim.if", inputs={"input": gen_name(57)}, outputs=[gen_name(58)]) + if_layer8 = pattern_block_block_block.layers[list( + pattern_block_block_block.layers.keys())[-1]] + pattern_block_block_block_block = PaddleGraph( + if_layer8, graph_type="dygraph") + pattern_block_block_block_block.add_layer( + "prim.exception", + inputs={"input": gen_name(23)}, + outputs=[gen_name(59)]) + if_layer8.add_block(pattern_block_block_block_block) + pattern_block_block_block_block = PaddleGraph( + if_layer8, graph_type="dygraph") + if_layer8.add_block(pattern_block_block_block_block) + if_layer8.inputs["input-0"] = gen_name(23) + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(55)}, + outputs=[gen_name(54)]) + if_layer7.add_block(pattern_block_block_block) + pattern_block_block_block = PaddleGraph(if_layer7, graph_type="dygraph") + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(50)}, + outputs=[gen_name(54)]) + if_layer7.add_block(pattern_block_block_block) + if_layer7.inputs.update({ + "input-0": gen_name(50), + "input-1": gen_name(2), + "input-2": gen_name(23), + "input-3": gen_name(50) + }) + pattern_block_block.add_layer( + "prim.isnot", + inputs={"x": gen_name(51), + "y": gen_name(25)}, + outputs=[gen_name(60)]) + pattern_block_block.add_layer( + "prim.if", inputs={"input": gen_name(60)}, outputs=[gen_name(61)]) + if_layer9 = pattern_block_block.layers[list( + pattern_block_block.layers.keys())[-1]] + pattern_block_block_block = PaddleGraph(if_layer9, graph_type="dygraph") + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(51)}, + outputs=[gen_name(62)]) + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(62)}, + outputs=[gen_name(61)]) + if_layer9.add_block(pattern_block_block_block) + pattern_block_block_block = PaddleGraph(if_layer9, graph_type="dygraph") + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(29)}, + outputs=[gen_name(61)]) + if_layer9.add_block(pattern_block_block_block) + if_layer9.inputs.update({ + "input-0": gen_name(51), + "input-1": gen_name(29) + }) + pattern_block_block.add_layer( + "prim.if", inputs={"input": gen_name(60)}, outputs=[gen_name(63)]) + if_layer10 = pattern_block_block.layers[list( + pattern_block_block.layers.keys())[-1]] + pattern_block_block_block = PaddleGraph( + if_layer10, graph_type="dygraph") + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(61)}, + outputs=[gen_name(63)]) + if_layer10.add_block(pattern_block_block_block) + pattern_block_block_block = PaddleGraph( + if_layer10, graph_type="dygraph") + pattern_block_block_block.add_layer( + "prim.isnot", + inputs={"x": gen_name(54), + "y": gen_name(25)}, + outputs=[gen_name(64)]) + pattern_block_block_block.add_layer( + "prim.if", inputs={"input": gen_name(64)}, outputs=[gen_name(65)]) + if_layer11 = pattern_block_block_block.layers[list( + pattern_block_block_block.layers.keys())[-1]] + pattern_block_block_block_block = PaddleGraph( + if_layer11, graph_type="dygraph") + pattern_block_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(54)}, + outputs=[gen_name(66)]) + pattern_block_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(66)}, + outputs=[gen_name(65)]) + if_layer11.add_block(pattern_block_block_block_block) + pattern_block_block_block_block = PaddleGraph( + if_layer11, graph_type="dygraph") + pattern_block_block_block_block.add_layer( + "prim.exception", + inputs={"input": gen_name(23)}, + outputs=[gen_name(67)]) + pattern_block_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(30)}, + outputs=[gen_name(65)]) + if_layer11.add_block(pattern_block_block_block_block) + if_layer11.inputs.update({ + "input-0": gen_name(54), + "input-1": gen_name(23), + "input-2": gen_name(30) + }) + pattern_block_block_block.add_layer( + "prim.is", + inputs={"x": gen_name(37), + "y": gen_name(25)}, + outputs=[gen_name(68)]) + pattern_block_block_block.add_layer( + "prim.if", inputs={"input": gen_name(68)}, outputs=[gen_name(69)]) + if_layer12 = pattern_block_block_block.layers[list( + pattern_block_block_block.layers.keys())[-1]] + pattern_block_block_block_block = PaddleGraph( + if_layer12, graph_type="dygraph") + pattern_block_block_block_block.add_layer( + "prim.len", inputs={"input": gen_name(65)}, outputs=[gen_name(70)]) + pattern_block_block_block_block.add_layer( + "prim.gt", + inputs={"x": gen_name(70), + "y": gen_name(26.1)}, + outputs=[gen_name(71)]) + pattern_block_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(24)}, + outputs=[gen_name(72)]) + pattern_block_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(26.1)}, + outputs=[gen_name(73)]) + pattern_block_block_block_block.add_layer( + "prim.loop", + inputs={"input": gen_name(26.2)}, + outputs=[gen_name(74), gen_name(75), gen_name(76)]) + loop_layer = pattern_block_block_block_block.layers[list( + pattern_block_block_block_block.layers.keys())[-1]] + pattern_loop_block = PaddleGraph(loop_layer, graph_type="dygraph") + pattern_loop_block.add_layer( + "prim.getitem", + inputs={"list": gen_name(65), + "element": gen_name(73)}, + outputs=[gen_name(74.1)]) + pattern_loop_block.add_layer( + "prim.floor", + inputs={"input": gen_name(74.1)}, + outputs=[gen_name(75.1)]) + pattern_loop_block.add_layer( + "prim.ne", + inputs={"x": gen_name(75.1), + "y": gen_name(74.1)}, + outputs=[gen_name(76)]) + pattern_loop_block.add_layer( + "prim.if", + inputs={"input": gen_name(76)}, + outputs=[gen_name(77), gen_name(78)]) + if_layer13 = pattern_loop_block.layers[list( + pattern_loop_block.layers.keys())[-1]] + pattern_loop_block_block = PaddleGraph(if_layer13, graph_type="dygraph") + pattern_loop_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(24)}, + outputs=[gen_name(77)]) + pattern_loop_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(76)}, + outputs=[gen_name(78)]) + if_layer13.add_block(pattern_loop_block_block) + pattern_loop_block_block = PaddleGraph(if_layer13, graph_type="dygraph") + pattern_loop_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(33)}, + outputs=[gen_name(77)]) + pattern_loop_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(33)}, + outputs=[gen_name(78)]) + if_layer13.add_block(pattern_loop_block_block) + if_layer13.inputs.update({ + "input-0": gen_name(24), + "input-1": gen_name(76), + "input-2": gen_name(33), + "input-3": gen_name(33) + }) + pattern_loop_block.add_layer( + "prim.if", + inputs={"input": gen_name(76)}, + outputs=[gen_name(79), gen_name(80)]) + if_layer14 = pattern_loop_block.layers[list( + pattern_loop_block.layers.keys())[-1]] + pattern_loop_block_block = PaddleGraph(if_layer14, graph_type="dygraph") + pattern_loop_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(77)}, + outputs=[gen_name(79)]) + pattern_loop_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(78)}, + outputs=[gen_name(80)]) + if_layer14.add_block(pattern_loop_block_block) + pattern_loop_block_block = PaddleGraph(if_layer14, graph_type="dygraph") + pattern_loop_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(22)}, + outputs=[gen_name(79)]) + pattern_loop_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(76)}, + outputs=[gen_name(80)]) + if_layer14.add_block(pattern_loop_block_block) + if_layer14.inputs.update({ + "input-0": gen_name(77), + "input-1": gen_name(78), + "input-2": gen_name(22), + "input-3": gen_name(76) + }) + pattern_loop_block.add_layer( + "prim.add", + inputs={"x": gen_name(73), + "y": gen_name(27)}, + outputs=[gen_name(81)]) + pattern_loop_block.add_layer( + "prim.lt", + inputs={"x": gen_name(81), + "y": gen_name(70)}, + outputs=[gen_name(82)]) + pattern_loop_block.add_layer( + "prim.and", + inputs={"x": gen_name(82), + "y": gen_name(79)}, + outputs=[gen_name(83)]) + pattern_loop_block.add_layer( + "prim.equal", + inputs={"input": gen_name(80)}, + outputs=[gen_name(74)]) + pattern_loop_block.add_layer( + "prim.equal", + inputs={"input": gen_name(81)}, + outputs=[gen_name(75)]) + loop_layer.add_block(pattern_loop_block) + loop_layer.inputs.update({ + "input-0": gen_name(65), + "input-1": gen_name(73), + "input-2": gen_name(24), + "input-3": gen_name(33), + "input-4": gen_name(33), + "input-5": gen_name(22), + "input-6": gen_name(73), + "input-7": gen_name(27), + "input-8": gen_name(70) + }) + pattern_block_block_block_block.add_layer( + "prim.if", inputs={"input": gen_name(74)}, outputs=[gen_name(84)]) + if_layer15 = pattern_block_block_block_block.layers[list( + pattern_block_block_block_block.layers.keys())[-1]] + pattern_block_block_block_block_block = PaddleGraph( + if_layer15, graph_type="dygraph") + pattern_block_block_block_block_block.add_layer( + "prim.warnings", + inputs={"input": gen_name(26)}, + outputs=[gen_name(85)], + stacklevel=2) + if_layer15.add_block(pattern_block_block_block_block_block) + pattern_block_block_block_block_block = PaddleGraph( + if_layer15, graph_type="dygraph") + if_layer15.add_block(pattern_block_block_block_block_block) + if_layer15.inputs["input-0"] = gen_name(26) + if_layer12.add_block(pattern_block_block_block_block) + pattern_block_block_block_block = PaddleGraph( + if_layer12, graph_type="dygraph") + if_layer12.add_block(pattern_block_block_block_block) + if_layer12.inputs.update({ + "input-0": gen_name(65), + "input-1": gen_name(26.1), + "input-2": gen_name(26.2), + "input-3": gen_name(65), + "input-4": gen_name(24), + "input-5": gen_name(33), + "input-6": gen_name(33), + "input-7": gen_name(22), + "input-8": gen_name(27), + "input-9": gen_name(26) + }) + pattern_block_block_block.add_layer( + "prim.list", inputs={}, outputs=[gen_name(86)]) + pattern_block_block_block.add_layer( + "prim.loop", + inputs={"input": gen_name(2)}, + outputs=[gen_name(87), gen_name(88)]) + loop_layer = pattern_block_block_block.layers[list( + pattern_block_block_block.layers.keys())[-1]] + pattern_loop_block = PaddleGraph(loop_layer, graph_type="dygraph") + pattern_loop_block.add_layer( + "prim.add", + inputs={"x": gen_name(88), + "y": gen_name(28)}, + outputs=[gen_name(89)]) + pattern_loop_block.add_layer( + "prim.shape", + inputs={"input": gen_name(34)}, + outputs=[gen_name(90)]) + pattern_loop_block.add_layer( + "prim.float", + inputs={"input": gen_name(90)}, + outputs=[gen_name(91)]) + pattern_loop_block.add_layer( + "prim.getitem", + inputs={"list": gen_name(65), + "element": gen_name(88)}, + outputs=[gen_name(92)]) + pattern_loop_block.add_layer( + "prim.mul", + inputs={"x": gen_name(91), + "y": gen_name(92)}, + outputs=[gen_name(93)]) + pattern_loop_block.add_layer( + "prim.floor", + inputs={"input": gen_name(93)}, + outputs=[gen_name(94)]) + pattern_loop_block.add_layer( + "prim.append", + inputs={"list": gen_name(86), + "element": gen_name(94)}, + outputs=[]) + loop_layer.add_block(pattern_loop_block) + loop_layer.inputs.update({ + "input-1": gen_name(28), + "input-2": gen_name(34), + "input-3": gen_name(65), + "input-5": gen_name(86) + }) + pattern_block_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(86)}, + outputs=[gen_name(63)]) + if_layer10.add_block(pattern_block_block_block) + if_layer10.inputs.update({ + "input-0": gen_name(61), + "input-1": gen_name(54), + "input-2": gen_name(25), + "input-3": gen_name(54), + "input-4": gen_name(23), + "input-5": gen_name(30), + "input-6": gen_name(37), + "input-7": gen_name(25), + "input-8": gen_name(26.1), + "input-9": gen_name(26.2), + "input-10": gen_name(24), + "input-11": gen_name(33), + "input-12": gen_name(33), + "input-13": gen_name(22), + "input-14": gen_name(27), + "input-15": gen_name(26), + "input-16": gen_name(2), + "input-17": gen_name(28), + "input-18": gen_name(34) + }) + pattern_block_block.add_layer( + "prim.getitem", + inputs={"list": gen_name(11), + "element": gen_name(5)}, + outputs=[gen_name(95)]) + pattern_block_block.add_layer( + "prim.getitem", + inputs={"list": gen_name(11), + "element": gen_name(4)}, + outputs=[gen_name(96)]) + pattern_block_block.add_layer( + "prim.assert", + inputs={"key": gen_name(95), + "value": gen_name(96)}, + outputs=[gen_name(97) + "_assert"], + type="eq") + pattern_block_block.add_layer( + "fluid.layers.interpolate", + inputs={ + "input": "interpolate-input-0", + "out_shape": gen_name(63), + "scale": gen_name(95) + }, + outputs=[gen_name(97)], + align_corners=False, + align_mode=0) + pattern_block_block.add_layer( + "prim.equal", + inputs={"input": gen_name(97)}, + outputs=[gen_name(20)]) + if_layer2.add_block(pattern_block_block) + pattern_block_block = PaddleGraph(if_layer2, graph_type="dygraph") + pattern_block_block.add_layer( + "prim.shape", + inputs={"input": "interpolate-input-0"}, + outputs=[gen_name(98)]) + pattern_block_block.add_layer( + "prim.len", inputs={"input": gen_name(98)}, outputs=[gen_name(98)]) + pattern_block_block.add_layer( + "prim.eq", + inputs={"x": gen_name(98), + "y": gen_name(7)}, + outputs=[gen_name(99)]) + pattern_block_block.add_layer( + "prim.if", inputs={"input": gen_name(99)}, outputs=[gen_name(100)]) + if_layer16 = pattern_block_block.layers[list( + pattern_block_block.layers.keys())[-1]] + pattern_block_block_block = PaddleGraph( + if_layer16, graph_type="dygraph") + pattern_block_block_block.add_layer( + "prim.exception", + inputs={"input": gen_name(0)}, + outputs=[gen_name(101)]) + if_layer16.add_block(pattern_block_block_block) + pattern_block_block_block = PaddleGraph( + if_layer16, graph_type="dygraph") + pattern_block_block_block.add_layer( + "prim.exception", + inputs={"input": gen_name(0)}, + outputs=[gen_name(102)]) + if_layer16.add_block(pattern_block_block_block) + if_layer16.inputs.update({ + "input-0": gen_name(0), + "input-1": gen_name(0) + }) + pattern_block_block.add_layer( + "prim.equal", inputs={"input": gen_name(8)}, + outputs=[gen_name(20)]) + if_layer2.add_block(pattern_block_block) + if_layer2.inputs.update({ + "input-0": gen_name(13), + "input-1": gen_name(2), + "input-2": gen_name(2), + "input-3": gen_name(11), + "input-4": gen_name(5), + "input-5": gen_name(11), + "input-6": gen_name(4), + "input-7": "interpolate-input-0", + "input-8": "interpolate-input-0", + "input-9": gen_name(7), + "input-10": gen_name(0), + "input-11": gen_name(0), + "input-12": gen_name(8) + }) + pattern_block.add_layer( + "prim.equal", + inputs={"input": gen_name(20)}, + outputs=[gen_name(16)]) + if_layer1.add_block(pattern_block) + if_layer1.inputs.update({ + "input-0": gen_name(0), + "input-1": gen_name(8), + "input-2": "interpolate-input-0", + "input-3": gen_name(6), + "input-4": gen_name(13), + "input-5": gen_name(2), + "input-6": gen_name(2), + "input-7": gen_name(11), + "input-8": gen_name(5), + "input-9": gen_name(11), + "input-10": gen_name(4), + "input-11": "interpolate-input-0", + "input-12": "interpolate-input-0", + "input-13": gen_name(7), + "input-14": gen_name(0), + "input-15": gen_name(0), + "input-16": gen_name(8) + }) + self.pattern.build(inputs={ + "input-0": "interpolate-input-0", + "input-1": "interpolate-input-1", + "input-2": "interpolate-input-2", + }) + + def insert_new_layer(self, graph, parameters, matches): + new_layer = self.gen_new_layer(parameters, matches) + new_layer_id = list(matches.keys())[0] + graph.layers[new_layer_id] = new_layer + matches.pop(new_layer_id) + + def gen_new_layer(self, parameters, matches): + layers_id = list(matches.keys()) + layer = matches[layers_id[15]] + out_shape = layer.inputs["input1"] + layer = matches[layers_id[21]] + outputs = layer.outputs + layer = matches[layers_id[128]] + layer.inputs.pop("scale") + layer.inputs["out_shape"] = out_shape + layer.outputs = outputs + return layer diff --git a/x2paddle/optimizer/optimizer.py b/x2paddle/optimizer/optimizer.py index b6211d980b84ef0008f6d0dc8af505ce459d591a..6a69ec66445e401aa40222ff2f644f42e2aa54bd 100644 --- a/x2paddle/optimizer/optimizer.py +++ b/x2paddle/optimizer/optimizer.py @@ -19,8 +19,9 @@ from x2paddle.optimizer.pass_manager import PassManager class GraphOptimizer(object): def __init__(self): self.passes = [ - "fc_fuse_pass", "adaptive_pool2d_fuse_pass", - "batchnorm2d_fuse_pass", "constant_fuse_pass" + "interpolate_bilinear_fuse_pass", "fc_fuse_pass", + "adaptive_pool2d_fuse_pass", "batchnorm2d_fuse_pass", + "constant_fuse_pass" ] def optimize(self, graph):