未验证 提交 69a8316b 编写于 作者: S SunAhong1993 提交者: GitHub

Merge pull request #17 from PaddlePaddle/develop

add
......@@ -278,6 +278,12 @@ class PaddleGraph(object):
hierarchical_tree.insert(layer)
hierarchical_tree.save_source_files(save_dir)
self.dump_dygraph_parameter(save_dir)
else:
if self.source_type == "pytorch":
from x2paddle.optimizer.code_optimizer import ModuleGraph
module_graph = ModuleGraph(self)
module_graph.save_source_files(save_dir)
self.dump_dygraph_parameter(save_dir)
else:
self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir)
......@@ -621,7 +627,7 @@ class PaddleGraph(object):
layer.outputs[0])], indent=indent))
else:
self.forward_func.extend(gen_codes([line], indent=indent))
if indent == 2:
if indent == 2 and code_dir is not None:
gen_main_code(code_dir)
write_code(code_dir)
else:
......
......@@ -205,7 +205,7 @@ class ONNXGraph(Graph):
shape = raw_input(
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: "
)
except:
except NameError:
shape = input(
"Shape of Input(e.g. -1,3,224,224), enter 'N' to skip: "
)
......@@ -302,6 +302,17 @@ class ONNXGraph(Graph):
if opt == in_node:
self.connect(nd.name, layer_name)
flag = 1
if nd.name in node.which_child:
for n_i, n_ipt in enumerate(node.inputs):
if first_i == n_i:
continue
if n_ipt == nd.name:
new_nd_name = "{}/{}".format(nd.name, n_i)
if new_nd_name not in node.which_child:
node.which_child[new_nd_name] = idx
break
else:
first_i = node.inputs.index(nd.name)
node.which_child[nd.name] = idx
self.node_map[nd.name].index = 0
break
......@@ -318,11 +329,15 @@ class ONNXGraph(Graph):
if len(node.which_child) == 0:
ipt_node = super(ONNXGraph, self).get_node(node.inputs[idx], copy)
return ipt_node
else:
ipt_node = super(ONNXGraph, self).get_node(node.inputs[idx], copy)
new_ipt_name = "{}/{}".format(ipt_node.layer_name, idx)
if new_ipt_name in node.which_child:
ipt_node.index = node.which_child[new_ipt_name]
else:
if ipt_node.layer_name in node.which_child:
ipt_node.index = node.which_child[ipt_node.layer_name]
return ipt_node
......
......@@ -250,15 +250,22 @@ class OpSet9():
def _interpolate(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
inputs = {'x': val_x.name}
attrs = dict()
if node.layer_type == 'Resize':
if len(node.layer.input) == 2:
# opset 10
val_scales = self.graph.get_input_node(node, idx=1, copy=True)
inputs['scale_factor'] = val_scales.name
# TODO(syf): paddle.nn.functional.interpolate will support the length
# which is the same as the rank of input.
# inputs['scale_factor'] = val_scales.name
attrs['scale_factor'] = self.weights[val_scales.name].tolist()[2:]
elif len(node.layer.input) == 3:
# opset 11
val_scales = self.graph.get_input_node(node, idx=2, copy=True)
inputs['scale_factor'] = val_scales.name
# TODO(syf): paddle.nn.functional.interpolate will support the length
# which is the same as the rank of input.
# inputs['scale_factor'] = val_scales.name
attrs['scale_factor'] = self.weights[val_scales.name].tolist()[2:]
elif len(node.layer.input) == 4:
# opset 11
val_sizes = self.graph.get_input_node(node, idx=3, copy=True)
......@@ -281,7 +288,7 @@ class OpSet9():
ipt = inputs.pop("x")
inputs["input"] = ipt
mode = node.get_attr('mode', 'nearest')
attrs = {"align_corners": False}
attrs.update({"align_corners": False})
self.paddle_graph.add_layer(
kernel="fluid.layers.resize_nearest",
inputs=inputs,
......@@ -290,12 +297,12 @@ class OpSet9():
return
elif node.layer_type == 'Upsample':
val_scales = self.graph.get_input_node(node, idx=1, copy=True)
inputs['scale'] = val_scales
inputs['scale_factor'] = val_scales
mode = node.get_attr('mode', 'nearest')
attrs = {"align_corners": False,
attrs.update({"align_corners": False,
"mode": string(mode),
"align_mode": 1}
"align_mode": 1})
self.paddle_graph.add_layer(
kernel="paddle.nn.functional.interpolate",
inputs=inputs,
......@@ -926,16 +933,17 @@ class OpSet9():
'max': max_value,
'min': min_value,
}
self.paddle_graph.add_layer(
'paddle.clip',
inputs={"x": val_x.name},
outputs=[node.name],
**layer_attrs)
else:
max_ipt = self.graph.get_input_node(node, idx=1, copy=True)
min_ipt = self.graph.get_input_node(node, idx=2, copy=True)
max_value = _const_weight_or_none(max_ipt)
min_ipt = self.graph.get_input_node(node, idx=1, copy=True)
max_ipt = self.graph.get_input_node(node, idx=2, copy=True)
min_value = _const_weight_or_none(min_ipt)
max_value = _const_weight_or_none(max_ipt)
if max_value.shape == (1, ):
max_value = max_value[0]
if min_value.shape == (1, ):
......@@ -1637,3 +1645,16 @@ class OpSet9():
inputs=inputs_dict,
outputs=[node.name],
**layer_attrs)
@print_mapping_info
def ArgMax(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
axis = node.get_attr('axis')
keepdims = False if node.get_attr('keepdims') == 0 else True
layer_attrs = {'axis': axis,
'keepdim': keepdims}
self.paddle_graph.add_layer(
'paddle.argmax',
inputs={"x": val_x.name},
outputs=[node.name],
**layer_attrs)
......@@ -663,6 +663,37 @@ def aten_batch_norm(mapper, graph, node):
return current_inputs, current_outputs
def aten_bmm(mapper, graph, node):
""" 构造矩阵相乘的PaddleLayer。
TorchScript示例:
%x.222 : Tensor = aten::bmm(%32, %7)
参数含义:
%x.222 (Tensor): 输出,矩阵相乘后的结果。
%i.12 (list): 输入1。
%7 (int): 输入2。
"""
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%i.12
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%288
mapper._check_input(
graph, inputs_node[1], inputs_name[1], current_outputs, scope_name, add_dim=True)
layer_inputs["y"] = inputs_name[1]
# 获取当前节点输入的list
current_inputs = list(layer_inputs.values())
graph.add_layer("paddle.bmm", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs
def aten_cat(mapper, graph, node):
""" 构造连接Tensor的PaddleLayer。
......@@ -885,7 +916,7 @@ def aten_constant_pad_nd(mapper, graph, node):
outputs=[inputs_name[0] + "_if", output_name],
scope_name=scope_name)
if_layer = graph.layers[list(graph.layers.keys())[-1]]
block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph")
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer(
"prim.sub",
inputs={"y": inputs_name[0] + "_len"},
......@@ -916,7 +947,7 @@ def aten_constant_pad_nd(mapper, graph, node):
outputs=[output_name],
scope_name=scope_name)
if_layer.add_block(block)
block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph")
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
layer_inputs["input"] = inputs_name[0]
block.add_layer(
kernel, inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name, **layer_attrs)
......@@ -1525,6 +1556,32 @@ def aten_eq(mapper, graph, node):
return current_inputs, current_outputs
def aten_erf(mapper, graph, node):
""" 构造逐元素计算 Erf 激活函数的PaddleLayer。
TorchScript示例:
%94 : Tensor = aten::erf(%sinusoid_inp.1)
参数含义:
%94 (Tensor): 输出,erf之后的结果。
%sinusoid_inp.1 (Tensor): 需要进行erf的Tensor。
"""
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%sinusoid_inp.1
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values())
graph.add_layer("paddle.erf", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs
def aten_exp(mapper, graph, node):
""" 构造以自然数e为底指数运算的PaddleLayer。
......@@ -1639,7 +1696,7 @@ def aten_expand_as(mapper, graph, node):
outputs=[inputs_name[0] + "_if1"],
scope_name=scope_name)
if_layer = graph.layers[list(graph.layers.keys())[-1]]
block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph")
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer(
"prim.type",
inputs={"input": inputs_name[1]},
......@@ -1652,7 +1709,7 @@ def aten_expand_as(mapper, graph, node):
scope_name=scope_name,
dtype=inputs_name[1] + "_type")
if_layer.add_block(block)
block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph")
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
if_layer.add_block(block)
if_layer.inputs["input-0"] = inputs_name[0]
if_layer.inputs["input-1"] = inputs_name[1]
......@@ -1663,7 +1720,7 @@ def aten_expand_as(mapper, graph, node):
outputs=[inputs_name[0] + "_if2"],
scope_name=scope_name)
if_layer = graph.layers[list(graph.layers.keys())[-1]]
block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph")
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer(
"fluid.layers.cast",
inputs={"x": layer_outputs[0]},
......@@ -1671,7 +1728,7 @@ def aten_expand_as(mapper, graph, node):
scope_name=scope_name,
dtype=string("bool"))
if_layer.add_block(block)
block = PaddleGraph(if_layer, graph_type="dygraph")
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
if_layer.add_block(block)
if_layer.inputs["input-0"] = layer_outputs[0]
# TODO(syf): check expand_as
......@@ -1868,10 +1925,10 @@ def aten_floor(mapper, graph, node):
outputs=[inputs_name[0] + "_if"],
scope_name=scope_name)
if_layer = graph.layers[list(graph.layers.keys())[-1]]
block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph")
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer("paddle.floor", inputs=copy.deepcopy(layer_inputs), outputs=copy.deepcopy(layer_outputs), scope_name=scope_name)
if_layer.add_block(block)
block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph")
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer("prim.floor", inputs=copy.deepcopy(layer_inputs), outputs=copy.deepcopy(layer_outputs), scope_name=scope_name)
if_layer.add_block(block)
if_layer.inputs["input-0"] = inputs_name[0]
......@@ -2569,14 +2626,14 @@ def aten_masked_fill_(mapper, graph, node):
outputs=[inputs_name[2] + "_if"],
scope_name=scope_name)
if_layer = graph.layers[list(graph.layers.keys())[-1]]
block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph")
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer(
"prim.equal",
inputs={"input": inputs_name[1] + "_mask"},
outputs=[inputs_name[2] + "_1"],
scope_name=scope_name)
if_layer.add_block(block)
block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph")
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer(
"prim.mul",
inputs={"x": inputs_name[1] + "_mask",
......@@ -2677,14 +2734,14 @@ def aten_masked_fill(mapper, graph, node):
outputs=[inputs_name[2] + "_if"],
scope_name=scope_name)
if_layer = graph.layers[list(graph.layers.keys())[-1]]
block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph")
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer(
"prim.equal",
inputs={"input": inputs_name[1] + "_mask"},
outputs=[inputs_name[2] + "_1"],
scope_name=scope_name)
if_layer.add_block(block)
block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph")
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer(
"prim.mul",
inputs={"x": inputs_name[1] + "_mask",
......@@ -3986,16 +4043,18 @@ def aten_sub(mapper, graph, node):
""" 构造数值相减的PaddleLayer。
TorchScript示例:
%840 : int = aten::sub(%839, %836)
%840 : int = aten::sub(%839, %836, %3)
参数含义:
%840 (-): 相减结果。
%839 (-): 输入数值 x。
%836 (-): 输入数值 y。
%3 (-): alpha。
"""
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
......@@ -4006,13 +4065,37 @@ def aten_sub(mapper, graph, node):
mapper._check_input(
graph, inputs_node[1], inputs_name[1], current_outputs, scope_name, add_dim=True)
layer_inputs["y"] = inputs_name[1]
# 处理输入2,即%3
if len(inputs_node) > 2:
if inputs_name[2] in mapper.attrs:
layer_attrs["alpha"] = mapper.attrs[inputs_name[2]]
else:
mapper._check_input(graph, inputs_node[2], inputs_name[2],
current_outputs, scope_name)
layer_inputs["alpha"] = inputs_name[2]
current_inputs.append(inputs_name[2])
else:
layer_attrs["alpha"] = 1.0
# 获取当前节点输入的list
current_inputs = list(layer_inputs.values())
graph.add_layer("prim.sub", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
graph.add_layer("prim.sub", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name, **layer_attrs)
return current_inputs, current_outputs
def aten_sub_(mapper, graph, node):
""" 构造数值相减的PaddleLayer。
TorchScript示例:
%840 : int = aten::sub_(%839, %836, %3)
参数含义:
%840 (-): 相减结果。
%839 (-): 输入数值 x。
%836 (-): 输入数值 y。
%3 (-): alpha。
"""
return aten_sub(mapper, graph, node)
def aten_t(mapper, graph, node):
""" 构造矩阵转置的PaddleLayer。
......@@ -4366,14 +4449,14 @@ def aten_upsample_bilinear2d(mapper, graph, node):
outputs=[inputs_name[0] + "_if1"],
scope_name=scope_name)
if_layer = graph.layers[list(graph.layers.keys())[-1]]
block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph")
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer(
"prim.var2list",
inputs={"input": inputs_name[1]},
outputs=[inputs_name[1]],
scope_name=scope_name)
if_layer.add_block(block)
block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph")
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
if_layer.add_block(block)
if_layer.inputs["input-0"] = inputs_name[1]
# 处理输入2,即%5421
......
......@@ -67,9 +67,11 @@ def prim_add_(layer, indent=1, init_func=[], forward_func=[], layer_id=None, dif
forward_func.extend(gen_codes([line], indent=indent))
def prim_and(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
def prim_and(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False):
line = "{} = {} and {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs))
if is_return_line:
return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent))
......@@ -91,15 +93,22 @@ def prim_assert(layer, indent=1, init_func=[], forward_func=[], layer_id=None, d
s += "{} == {} or ".format(get_value(layer, "key"), v)
if len(s) > 0:
s = s[:-4]
lc=locals()
exec("assert_result = {}".format(s))
assert_result = lc['assert_result']
line = "assert {}, \'The {} must be {}!\'".format(
s, get_value(layer, "key"), get_value(layer, "value"))
else:
line = "assert {} == {}, \'The {} must be {}!\'".format(
get_value(layer, "key"),
get_value(layer, "value"),
get_value(layer, "key"), get_value(layer, "value"))
s = "{} == {}".format(get_value(layer, "key"),
get_value(layer, "value"))
lc=locals()
exec("assert_result = {}".format(s))
assert_result = lc['assert_result']
line = "assert {}, \'The {} must be {}!\'".format(
s, get_value(layer, "key"), get_value(layer, "value"))
else:
raise Exception("Not implement yet!")
if not assert_result:
forward_func.extend(gen_codes([line], indent=indent))
......@@ -119,10 +128,12 @@ def prim_constant(layer, indent=1, init_func=[], forward_func=[], layer_id=None,
forward_func.extend(gen_codes([line], indent=indent))
def prim_contain(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
def prim_contain(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False):
line = "{} = {} in {}".format(layer.outputs[0],
get_value(layer, "element", different_attrs),
get_value(layer, "input", different_attrs))
if is_return_line:
return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent))
......@@ -156,10 +167,12 @@ def prim_div(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diff
forward_func.extend(gen_codes([line], indent=indent))
def prim_eq(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
def prim_eq(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None,is_return_line=False):
line = "{} = {} == {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs))
if is_return_line:
return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent))
......@@ -198,14 +211,37 @@ def prim_getitem(layer, indent=1, init_func=[], forward_func=[], layer_id=None,
forward_func.extend(gen_codes([line], indent=indent))
def prim_gt(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
def prim_gt(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False):
line = "{} = {} > {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs))
if is_return_line:
return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent))
def prim_if(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
try:
exec_s = None
for line in forward_func:
s = line.replace(" ", "")
if s.startswith("{} = ".format(get_value(layer, "input", different_attrs))):
exec_s = s.split(" = ")[1]
lc=locals()
if exec_s is not None:
exec("if_result = {}".format(exec_s))
else:
exec("if_result = {}".format(get_value(layer, "input", different_attrs)))
if_result = lc['if_result']
if if_result:
block = layer.blocks[0]
else:
block = layer.blocks[1]
if len(block.layers) > 0:
b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent)
init_func.extend(b_init_lines)
forward_func.extend(b_forward_lines)
except:
line = "if {} :".format(get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent))
block = layer.blocks[0]
......@@ -232,31 +268,39 @@ def prim_int(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diff
forward_func.extend(gen_codes([line], indent=indent))
def prim_is(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
def prim_is(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False):
line = "{} = {} is {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs))
if is_return_line:
return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent))
def prim_isinstance(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
def prim_isinstance(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False):
line = "{} = isinstance({}, {})".format(layer.outputs[0],
get_value(layer, "input", different_attrs),
layer.attrs["cls"])
if is_return_line:
return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent))
def prim_isnot(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
def prim_isnot(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False):
line = "{} = {} is not {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs))
if is_return_line:
return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent))
def prim_le(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
def prim_le(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False):
line = "{} = {} <= {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs))
if is_return_line:
return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent))
......@@ -273,10 +317,12 @@ def prim_len2list(layer, indent=1, init_func=[], forward_func=[], layer_id=None,
forward_func.extend(gen_codes(lines, indent=indent))
def prim_lt(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
def prim_lt(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False):
line = "{} = {} < {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs))
if is_return_line:
return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent))
......@@ -317,10 +363,12 @@ def prim_mul(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diff
forward_func.extend(gen_codes([line], indent=indent))
def prim_ne(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
def prim_ne(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False):
line = "{} = {} != {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs))
if is_return_line:
return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent))
......@@ -329,15 +377,19 @@ def prim_neg(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diff
forward_func.extend(gen_codes([line], indent=indent))
def prim_not(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
def prim_not(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False):
line = "{} = not {}".format(layer.outputs[0], get_value(layer, "input", different_attrs))
if is_return_line:
return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent))
def prim_or(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
def prim_or(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False):
line = "{} = {} or {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs))
if is_return_line:
return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent))
......@@ -419,9 +471,15 @@ def prim_str(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diff
def prim_sub(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
if int(get_value(layer, "alpha", different_attrs)) == 1:
line = "{} = {} - {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs))
else:
line = "{} = {} - {} * {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs),
get_value(layer, "alpha", different_attrs),
get_value(layer, "y", different_attrs))
forward_func.extend(gen_codes([line], indent=indent))
......
......@@ -84,7 +84,7 @@ class PyTorchOpMapper(OpMapper):
current_node_outputs.extend(outputs)
# 初始化
graph = PaddleGraph(parent_layer=parent_layer, graph_type="dygraph")
graph = PaddleGraph(source_type="pytorch", parent_layer=parent_layer, graph_type="dygraph")
if "TopLevelTracedModule" in str(type(self.script)):
graph.set_script(self.script)
current_node_outputs = []
......
......@@ -240,15 +240,22 @@ class OpSet9():
def _interpolate(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
inputs = {'x': val_x.name}
attrs = dict()
if node.layer_type == 'Resize':
if len(node.layer.input) == 2:
# opset 10
val_scales = self.graph.get_input_node(node, idx=1, copy=True)
inputs['scale_factor'] = val_scales.name
# TODO(syf): paddle.nn.functional.interpolate will support the length
# which is the same as the rank of input.
# inputs['scale_factor'] = val_scales.name
attrs['scale_factor'] = self.params[val_scales.name].tolist()[2:]
elif len(node.layer.input) == 3:
# opset 11
val_scales = self.graph.get_input_node(node, idx=2, copy=True)
inputs['scale_factor'] = val_scales.name
# TODO(syf): paddle.nn.functional.interpolate will support the length
# which is the same as the rank of input.
# inputs['scale_factor'] = val_scales.name
attrs['scale_factor'] = self.params[val_scales.name].tolist()[2:]
elif len(node.layer.input) == 4:
# opset 11
val_sizes = self.graph.get_input_node(node, idx=3, copy=True)
......@@ -271,7 +278,7 @@ class OpSet9():
ipt = inputs.pop("x")
inputs["input"] = ipt
mode = node.get_attr('mode', 'nearest')
attrs = {"align_corners": False}
attrs.update({"align_corners": False})
self.paddle_graph.add_layer(
kernel="fluid.layers.resize_nearest",
inputs=inputs,
......@@ -283,9 +290,9 @@ class OpSet9():
inputs['scale'] = val_scales
mode = node.get_attr('mode', 'nearest')
attrs = {"align_corners": False,
attrs.update({"align_corners": False,
"mode": string(mode),
"align_mode": 1}
"align_mode": 1})
self.paddle_graph.add_layer(
kernel="paddle.nn.functional.interpolate",
inputs=inputs,
......@@ -917,10 +924,10 @@ class OpSet9():
outputs=[node.name],
**layer_attrs)
else:
max_ipt = self.graph.get_input_node(node, idx=1, copy=True)
min_ipt = self.graph.get_input_node(node, idx=2, copy=True)
max_value = _const_weight_or_none(max_ipt)
min_ipt = self.graph.get_input_node(node, idx=1, copy=True)
max_ipt = self.graph.get_input_node(node, idx=2, copy=True)
min_value = _const_weight_or_none(min_ipt)
max_value = _const_weight_or_none(max_ipt)
if max_value.shape == (1, ):
max_value = max_value[0]
if min_value.shape == (1, ):
......@@ -1577,3 +1584,16 @@ class OpSet9():
inputs=layer_inputs,
outputs=[node.name],
**layer_attrs)
@print_mapping_info
def ArgMax(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
axis = node.get_attr('axis')
keepdims = False if node.get_attr('keepdims') == 0 else True
layer_attrs = {'axis': axis,
'keepdim': keepdims}
self.paddle_graph.add_layer(
'paddle.argmax',
inputs={"x": val_x.name},
outputs=[node.name],
**layer_attrs)
\ No newline at end of file
......@@ -14,3 +14,4 @@
from x2paddle.optimizer.code_optimizer.hierachical_tree import HierarchicalTree
from x2paddle.optimizer.code_optimizer.module_graph import ModuleGraph
\ No newline at end of file
......@@ -38,7 +38,7 @@ NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn",
"paddle.nn.LeakyReLU": "leakly_relu"}
NN_KERNEL_WITH_PARAMS = list(NN_KERNEL_NAME.keys())[:6]
def rename_layers(layers, param_tree=None):
def rename_layers(layers, param_tree=None, is_rename_module=False):
""" 对子模块的输入输出等进行重命名。
"""
layers_cp = copy.deepcopy(layers)
......@@ -84,6 +84,12 @@ def rename_layers(layers, param_tree=None):
layer.outputs[0] = new_name
nn_count_dict[layer.kernel] += 1
elif i == 0 and layer.kernel == "module":
if is_rename_module:
if param_tree is not None:
param_node = param_tree.get_node(layer.outputs[0])
nn_param_nodes.append(param_node)
param_node.new_name = layer.outputs[0]
else:
old_name = layer.outputs[0].split("/")[0]
if old_name not in nn_count_dict:
nn_count_dict[old_name] = 0
......@@ -196,6 +202,15 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()):
outputs.append(output_name)
else:
outputs.append(output_name)
if layer.kernel == "prim.dict":
is_set_item = True
for out_layer_id in graph.edges_out[layer_id]:
out_layer = sub_layers[out_layer_id]
if out_layer.kernel != "prim.set_item":
is_set_item = False
break
if is_set_item:
outputs.append(layer.outputs[0])
no_output_count = 0
for i, (layer_id, layer) in enumerate(sub_layers.items()):
if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel) or \
......
# -*- coding:UTF-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import os.path as osp
from x2paddle.core.program import PaddleLayer
from x2paddle.optimizer.code_optimizer.subgraphs_union import construct_attrs_table, get_inputs_outputs
from x2paddle.optimizer.code_optimizer.layer_code_generator import gen_layer_code, rename_layers
from x2paddle.optimizer.code_optimizer.parameter_tree import PamareterNode, PamareterTree
NoModuleStart = ["paddle.nn.ReLU"]
class Apriori(object):
""" 使用Apriori算法挖掘频繁子图
1. 构建频繁1项集
2. 挖掘频繁k项集
3. 最终k项集和节点数满足最少节点数的子图组成集合GS
Args:
min_support (int): 子图出现次数的最小值。
"""
def __init__(self, min_support):
self.min_support = min_support
def is_match(self, item, sublayers):
for i in range(len(item)):
if len(sublayers) <= i or item[i] != sublayers[i].kernel:
return False
return True
def create_C1(self):
# 构建候选1-项集
C1 = list()
for layer_id, layer in self.layers.items():
if layer.kernel == "paddle.to_tensor" or \
layer.kernel == "prim.if" or \
layer.kernel == "prim.loop": #or \
# layer.kernel == "prim.list" or \
# layer.kernel == "prim.tuple" or \
# layer.kernel == "prim.dict_construct":
continue
if self.pd_graph.edges_in.get(layer_id, 0) == 0 and \
self.pd_graph.edges_out.get(layer_id, 0) == 0:
continue
if [layer.kernel] not in C1:
C1.append([layer.kernel])
return C1
def create_Ck(self, Lk_last, C1):
# 构建候选k-项集
Ck = list()
for item in Lk_last:
for item_one in C1:
new_item = copy.deepcopy(item)
new_item.append(item_one[0])
if new_item[0] in NoModuleStart:
continue
Ck.append(new_item)
return Ck
def generate_Lk_by_Ck(self, Ck):
# 生成频繁k-项集
Lk = list()
for item in Ck:
count = 0
for i in range(len(self.layers)):
sublayers = list(self.layers.values())[i:]
if self.is_match(item, sublayers):
count += 1
if count >= self.min_support:
Lk.append(item)
return Lk
def run(self, graph):
self.pd_graph = graph
self.layers = graph.layers
itemset = list()
C1 = self.create_C1()
L1 = self.generate_Lk_by_Ck(C1)
Lk = L1
while len(Lk) > 0:
Ck = self.create_Ck(Lk, C1)
Lk = self.generate_Lk_by_Ck(Ck)
itemset.extend(Lk)
return itemset
class DP(object):
""" 使用动动态规划找到使代码最短的组合方式。
"""
def __init__(self, combination_itemset):
self.combination_itemset = combination_itemset
def get_combination_id(self, combination, layers):
combination_id = list()
for layer_obj in combination:
if len(layer_obj) > 1:
kernel_itemset = list()
for layer_id in layer_obj:
kernel_itemset.append(layers[layer_id].kernel)
id = self.combination_itemset.index(kernel_itemset)
combination_id.append(id)
else:
combination_id.append(-1)
return combination_id
def run(self, graph):
layers = graph.layers
layer_combination_list = list()
for i, (layer_id, layer) in enumerate(layers.items()):
if i == 0:
layer_combination_list.append([[layer_id]])
continue
current_itemset = [layer_id]
kernel_itemset = [layer.kernel]
candidate_itemset = list()
min_count = len(layers)
prefix_ids = list(range(i))
prefix_ids.reverse()
for j in prefix_ids:
current_layer_id = list(layers.keys())[j]
current_layer = list(layers.values())[j]
current_itemset.insert(0, current_layer_id)
kernel_itemset.insert(0, current_layer.kernel)
if kernel_itemset in self.combination_itemset:
current_count = len(layer_combination_list[j - 1])
all_count = current_count + 1
if all_count < min_count:
min_count = all_count
candidate_itemset = copy.deepcopy(current_itemset)
if j - 1 < 0:
last_itemset = list()
else:
last_itemset = copy.deepcopy(layer_combination_list[j - 1])
else:
if j == prefix_ids[0]:
min_count = len(layer_combination_list[j]) + 1
current_itemset.pop(0)
candidate_itemset = copy.deepcopy(current_itemset)
last_itemset = copy.deepcopy(layer_combination_list[j])
break
last_itemset.append(candidate_itemset)
layer_combination_list.append(last_itemset)
final_combination = layer_combination_list[-1]
combination_id = self.get_combination_id(final_combination, layers)
return final_combination, combination_id
class ModuleGraph(object):
""" 更新PaddleGraph,生成代码。
"""
def __init__(self, graph):
self.pd_graph = graph
self.global_layers = graph.get_global_layers()
self.codes = list()
self.param_tree = PamareterTree()
def get_updation_information(self):
aprior = Apriori(3)
combination_itemset = aprior.run(self.pd_graph)
dp = DP(combination_itemset)
combination, combination_id = dp.run(self.pd_graph)
return combination, combination_id
def analyze_attrs_table(self, attrs_table):
""" 分析属性表格,哪些属性取值不一致。
"""
diff_attrs_column = list()
for column in list(attrs_table.columns):
elements = list(attrs_table.get(column))
base = elements[0]
for element in elements[1:]:
if isinstance(base, str) and "'" not in base:
break
if element != base:
diff_attrs_column.append(column)
break
return diff_attrs_column
def analyze_graph(self, sub_layers_list):
def is_same(sub_layers1, sub_layers2, id1, id2):
inputs1, outputs1 = ipt_opt_list[id1]
inputs2, outputs2 = ipt_opt_list[id2]
if len(inputs1) != len(inputs2) or len(outputs1) != len(outputs2):
return False
layer_id_list1 = list(sub_layers1.keys())
layer_id_list2 = list(sub_layers2.keys())
for i, layer_id1 in enumerate(layer_id_list1):
layer_id2 = layer_id_list2[i]
if len(self.pd_graph.edges_in[layer_id1]) != len(self.pd_graph.edges_in[layer_id2]):
return False
for j, ipt_layer_id1 in enumerate(self.pd_graph.edges_in[layer_id1]):
ipt_layer_id2 = self.pd_graph.edges_in[layer_id2][j]
if (ipt_layer_id1 in layer_id_list1) ^ (ipt_layer_id2 in layer_id_list2):
return False
if (layer_id1 in self.pd_graph.edges_out) ^ (layer_id2 in self.pd_graph.edges_out):
return False
if (layer_id1 in self.pd_graph.edges_out) and (layer_id2 in self.pd_graph.edges_out):
if (len(self.pd_graph.edges_out[layer_id1]) > 1 and len(self.pd_graph.edges_out[layer_id2]) == 1) or \
(len(self.pd_graph.edges_out[layer_id1]) == 1 and len(self.pd_graph.edges_out[layer_id2]) > 1):
return False
for j, opt_layer_id1 in enumerate(self.pd_graph.edges_out[layer_id1]):
if len(self.pd_graph.edges_out[layer_id1]) == 1 and len(self.pd_graph.edges_out[layer_id2]) == 1:
opt_layer_id2 = self.pd_graph.edges_out[layer_id2][j]
if (opt_layer_id1 in layer_id_list1) ^ (opt_layer_id2 in layer_id_list2):
return False
return True
sub_layers_list_list = list()
id_list = list()
ipt_opt_list = list()
sub_layers_list_list.append([sub_layers_list[0]])
id_list.append(0)
for i, sub_layer in enumerate(sub_layers_list):
ipt_opt_list.append(get_inputs_outputs(self.pd_graph, sub_layer))
if i == 0:
continue
id_list_cp = copy.deepcopy(id_list)
for j, index in enumerate(id_list_cp):
if is_same(sub_layers_list[index], sub_layer, index, i):
sub_layers_list_list[j].append(sub_layer)
break
if j == len(id_list_cp) - 1:
sub_layers_list_list.append(list())
sub_layers_list_list[j + 1].append(sub_layer)
id_list.append(i)
return sub_layers_list_list
def merge_node(self, sub_layers_list, attrs_table, module_name):
sub_layers = sub_layers_list[0]
diff_attrs_column = self.analyze_attrs_table(attrs_table)
sub_layers, _, _ = rename_layers(sub_layers)
code_str = gen_layer_code(self.pd_graph,
sub_layers,
module_name,
different_attrs=diff_attrs_column)
self.codes.append(code_str)
for index, sub_layers in enumerate(sub_layers_list):
inputs, outputs = get_inputs_outputs(self.pd_graph, sub_layers)
inputs_dict = dict()
for i, input in enumerate(inputs):
inputs_dict["input_{}".format(i)] = input
mn = module_name.lower()
outputs = ["{}_{}".format(mn, index)] + outputs
node_name = "{}_{}".format(module_name, index)
diff_attrs = dict()
for column in diff_attrs_column:
diff_attrs[column] = attrs_table.get(column).loc[node_name]
new_layer = PaddleLayer(id=list(sub_layers.keys())[-1],
kernel="module",
inputs=inputs_dict,
outputs=outputs,
module=module_name,
**diff_attrs)
_, nn_param_nodes, _ = rename_layers(sub_layers, self.param_tree)
param_node = PamareterNode(old_name=outputs[0])
for node in nn_param_nodes:
param_node.add_child(node)
self.param_tree.add_node(param_node)
for i, (layer_id, layer) in enumerate(sub_layers.items()):
if i == len(sub_layers) - 1:
self.pd_graph.layers[layer_id] = new_layer
else:
if len(layer_id.split(".")) > 1:
continue
self.pd_graph.layers.pop(layer_id)
self.pd_graph.build()
def convert_subgraph_to_layer(self, combination, combination_id):
combination_id_set = set(combination_id)
for s in list(combination_id_set):
if s == -1:
continue
module_name = "Block{}".format(s)
sub_layers_list = list()
for i, c in enumerate(combination):
if len(c) > 1 and combination_id[i] == s:
sub_layers = dict()
for layer_id in c:
sub_layers[layer_id] = self.global_layers[layer_id]
sub_layers_list.append(sub_layers)
sub_layers_list_list = self.analyze_graph(sub_layers_list)
for i, sub_layers_list in enumerate(sub_layers_list_list):
if i == 0:
real_module_name = module_name
else:
real_module_name = module_name + "__{}".format(i)
if len(sub_layers_list) > 1:
attrs_table = construct_attrs_table(sub_layers_list, module_name=real_module_name)
self.merge_node(sub_layers_list, attrs_table, real_module_name)
layers, nn_param_nodes, _ = rename_layers(self.pd_graph.layers, self.param_tree, is_rename_module=True)
code_str = gen_layer_code(self.pd_graph,
layers,
self.pd_graph.name)
self.codes.append(code_str)
param_node = PamareterNode(old_name="Module")
for node in nn_param_nodes:
param_node.add_child(node)
self.param_tree.add_node(param_node)
def update_parameters(self):
""" 更新参数。
"""
self.param_tree.traverse()
full_old_name_list = copy.deepcopy(list(self.pd_graph.parameters.keys()))
for old_name, new_name in self.param_tree.old2new.items():
for full_old_name in full_old_name_list:
if full_old_name.startswith("{}.".format(old_name)):
full_new_name = full_old_name.replace("{}.".format(old_name), "{}.".format(new_name))
params = self.pd_graph.parameters.pop(full_old_name)
self.pd_graph.parameters[full_new_name] = params
if full_old_name == old_name:
full_new_name = full_old_name.replace(old_name, new_name)
params = self.pd_graph.parameters.pop(full_old_name)
self.pd_graph.parameters[full_new_name] = params
def save_source_files(self, save_dir):
def gen_main_code():
input_data_name = ', '.join(self.pd_graph.inputs)
run_func_list = list()
run_func_list.append("def main({}):".format(input_data_name))
run_func_list.append(" # 共{}个输入".format(len(self.pd_graph.inputs_info)))
for k, v in self.pd_graph.inputs_info.items():
run_func_list.append(" # {}: 形状为{},类型为{}。".format(k, v[0], v[1]))
run_func_list.extend(
[" paddle.disable_static()",
" params = paddle.load('{}/model.pdparams')".format(osp.abspath(save_dir)),
" model = {}()".format(self.pd_graph.name),
" model.set_dict(params)",
" model.eval()",
" out = model({})".format(input_data_name),
" return out"])
return "\n".join(run_func_list)
combination, combination_id = self.get_updation_information()
self.convert_subgraph_to_layer(combination, combination_id)
self.update_parameters()
import_list = ["import paddle",
"import paddle.fluid as fluid",
"from paddle.fluid.initializer import Constant",
"from paddle.fluid.param_attr import ParamAttr",
"import math",
"from x2paddle.op_mapper.dygraph.pytorch2paddle " + \
"import pytorch_custom_layer as x2paddle_nn"
"\n",]
import_str = "\n".join(import_list)
if not osp.exists(save_dir):
os.makedirs(save_dir)
f = open(osp.join(save_dir, 'x2paddle_code.py'), 'w')
f.write(import_str)
for code in self.codes:
f.write(code)
f.write("\n")
run_func = gen_main_code()
f.write(run_func)
f.close()
\ No newline at end of file
......@@ -19,7 +19,7 @@ import pandas as pd
from x2paddle.optimizer.code_optimizer.layer_code_generator import rename_layers
def construct_attrs_table(sub_layers_list, node_name2sub_layers):
def construct_attrs_table(sub_layers_list, node_name2sub_layers=None, module_name=None):
""" 构造不同属性的表格。
"""
def get_node_name(sub_layers):
......@@ -32,9 +32,12 @@ def construct_attrs_table(sub_layers_list, node_name2sub_layers):
_, _, new_names = rename_layers(sub_layers)
table = list()
node_names = list()
for sub_layers in sub_layers_list:
for i, sub_layers in enumerate(sub_layers_list):
attrs = dict()
if node_name2sub_layers is not None:
node_names.append(get_node_name(sub_layers))
else:
node_names.append("{}_{}".format(module_name, i))
for i, (layer_id, layer) in enumerate(sub_layers.items()):
for k, v in layer.attrs.items():
attrs[new_names[i] + "_{}".format(k)] = v
......
......@@ -26,6 +26,8 @@ from .dropout_fuser import DygraphDropoutFuser
from .dropout_fuse_pass import DygraphDropoutFusePass
from .fc_fuser import DygraphFcFuser
from .fc_fuse_pass import DygraphFcFusePass
from .if_fuser import DygraphIfFuser
from .if_fuse_pass import DygraphIfFusePass
from .interpolate_bilinear_fuser import DygraphInterpolateBilinearFuser
from .interpolate_bilinear_fuse_pass import DygraphInterpolateBilinearFusePass
from .prelu_fuser import DygraphPReLUFuser
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import DygraphIfFuser
from x2paddle.optimizer.pass_manager import pass_register
@pass_register
class DygraphIfFusePass(Pass):
name = "dygraph_if_fuse_pass"
def __init__(self):
Pass.__init__(self)
def apply(self, graph):
fuser = DygraphIfFuser()
fuser.operate(graph, match_kind="op")
# 用于注册
if_fuse_pass = DygraphIfFuser()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from x2paddle.optimizer.pattern_matcher import FuseBase
from x2paddle.core.program import PaddleGraph, PaddleLayer
from x2paddle.core.util import *
class DygraphIfFuser(FuseBase):
def __init__(self):
super(DygraphIfFuser, self).__init__(graph_type="dygraph")
def build_pattern(self):
""" 描述需要替换的if图结构。
if层模式python实现代码示例:
x81 = 'relu' in {'layer4': 'out', 'layer3': 'aux'}
if x81 :
...
"""
self.pattern.add_layer(
"prim.if", inputs={"input": "if-input-0"}, outputs=["x0"])
self.pattern.build(inputs={"input-0": "if-input-0"})
def insert_new_layer(self, graph, parameters, matches):
layer_id = list(matches.keys())[0]
layer = list(matches.values())[0]
if "input" not in layer.inputs:
matches.pop(layer_id)
return
for id in graph.edges_in[layer_id]:
input_layer = graph.layers[id]
if input_layer.outputs == [layer.inputs["input"]]:
if input_layer.kernel == "prim.if":
matches.pop(layer_id)
return
input_id = id
break
func_name = input_layer.kernel.replace(".", "_")
from x2paddle.op_mapper.dygraph.pytorch2paddle import prim2code
func = getattr(prim2code, func_name)
line = func(input_layer, is_return_line=True)
layer.attrs["input"] = line
layer.inputs.pop("input")
matches.pop(layer_id)
if len(input_layer.outputs) == 1:
matches[input_id] = input_layer
\ No newline at end of file
......@@ -31,7 +31,8 @@ class GraphOptimizer(object):
"dygraph_fc_fuse_pass",
"dygraph_adaptive_pool2d_fuse_pass",
"dygraph_reshape_fuse_pass",
"dygraph_dropout_fuse_pass"
"dygraph_dropout_fuse_pass",
"dygraph_if_fuse_pass"
]
elif source_frame == "caffe":
if paddle_type == "dygraph":
......
......@@ -28,6 +28,8 @@ class PatternMatcher(object):
self.detect_patterns_by_topo(graph)
elif match_kind == "edge":
self.detect_patterns_by_edge(graph)
elif match_kind == "op":
self.detect_patterns_by_op(graph)
self.remove_overlapped_match()
return self.matches
......@@ -229,6 +231,42 @@ class PatternMatcher(object):
if len(block.layers) > 0:
self.detect_patterns_by_edge(layer.blocks[j])
def detect_patterns_by_op(self, graph):
""" 当只匹配op时使用此方式。
"""
def get_subgraph(pattern, graph, start_index):
pattern_id2layers = pattern.get_global_layers()
pattern_ids = list(pattern_id2layers.keys())
pattern_layer_id = pattern_ids[0]
subgraph_id2layers = dict()
layer_id = list(graph.layers.keys())[start_index]
graph_layers = graph.layers
def update(layer_id, pattern_layer_id):
layer = graph_layers[layer_id]
pattern_layer = pattern_id2layers[pattern_layer_id]
if layer.kernel != pattern_layer.kernel:
return False
subgraph_id2layers[layer_id] = layer
while len(subgraph_id2layers) != len(pattern_id2layers):
out = update(layer_id, pattern_layer_id)
if out == False:
return False
else:
if len(subgraph_id2layers) == len(pattern_id2layers):
return subgraph_id2layers
else:
return False
for i, (layer_id, layer) in enumerate(graph.layers.items()):
match_info = get_subgraph(self.pattern, graph, i)
if match_info:
self.matches.append(match_info)
for j, block in enumerate(layer.blocks):
if len(block.layers) > 0:
self.detect_patterns_by_op(layer.blocks[j])
def remove_overlapped_match(self):
""" 如果2个子图有重叠,只取前一个子图。
"""
......@@ -297,14 +335,11 @@ class FuseBase(object):
""" 删除不需要的中间layer及其对应参数。
"""
for match in self.matches:
if len(match) == 0:
continue
first_layer_id = list(match.keys())[0]
subgraph = get_subgraph("", first_layer_id, graph)
for layer_id, layer in match.items():
if layer.kernel == "fluid.dygraph.base.to_variable" and \
layer.attrs["value"].startswith("params["):
param_name = layer.attrs["value"][8:-2]
if param_name in graph.parameters:
graph.parameters.pop(param_name)
if layer_id in subgraph.layers:
# layer_id可能是属于子图的,此时删除父layer,即删除整个子图
subgraph.layers.pop(layer_id)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册