提交 edd42662 编写于 作者: S SunAhong1993

modify the prim2code

上级 f76d0121
......@@ -357,7 +357,8 @@ class PaddleGraph(object):
for layer_id, layer in self.layers.items():
if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get(
layer_id, 0) == 0 and layer.kernel != "prim.assert":
layer_id, 0) == 0 and layer.kernel != "prim.assert" \
and layer.kernel != "prim.exception":
continue
if "dygraph" in layer.kernel:
line = "{}".format(
......@@ -396,9 +397,9 @@ class PaddleGraph(object):
self.forward_func.extend(gen_codes([line], indent=indent))
elif "prim" in layer.kernel:
func_name = layer.kernel.replace(".", "_")
from . import convert_prim
if hasattr(convert_prim, func_name):
func = getattr(convert_prim, func_name)
from x2paddle.op_mapper.pytorch2paddle import prim2code
if hasattr(prim2code, func_name):
func = getattr(prim2code, func_name)
func(
layer,
indent=indent,
......
......@@ -532,6 +532,41 @@ def aten_dropout(mapper, graph, node):
return current_inputs, current_outputs
def aten_dropout_(mapper, graph, node):
""" 构造Dropout的PaddleLayer。
TorchScript示例:
%119 : Tensor = aten::dropout_(%result.3, %117, %118)
参数含义:
%119 (Tensor): Dropout后的Tensor。
%result.3 (Tensor): 输入Tensor。
%118 (bool): 是否是训练阶段。
"""
if "dropout" in mapper.dygraph_name_id:
mapper.dygraph_name_id["dropout"] += 1
else:
mapper.dygraph_name_id["dropout"] = 0
dropout_name = "dropout" + str(mapper.dygraph_name_id["dropout"])
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [dropout_name, output_name]
layer_inputs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%119
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs)
layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values())
graph.add_layer(
"fluid.dygraph.Dropout",
inputs=layer_inputs,
outputs=layer_outputs,
p=0.0)
return current_inputs, current_outputs
def aten_eq(mapper, graph, node):
""" 构造判断数值是否相等的PaddleLayer。
......@@ -994,6 +1029,34 @@ def aten___not__(mapper, graph, node):
return current_inputs, current_outputs
def aten_relu(mapper, graph, node):
""" 构造ReLU激活的PaddleLayer。
TorchScript示例:
%result.3 : Tensor = aten::relu(%input.5)
参数含义:
%result.3 (Tensor): 输出,ReLU后的结果。
%result.5 (Tensor): 需要ReLU的Tensor。
注意: inplace这个参数在paddle中未实现
"""
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%result.5
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs)
layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list
current_inputs = list(layer_inputs.values())
graph.add_layer(
"fluid.layers.relu", inputs=layer_inputs, outputs=layer_outputs)
return current_inputs, current_outputs
def aten_relu_(mapper, graph, node):
""" 构造ReLU激活的PaddleLayer。
......
......@@ -24,26 +24,39 @@ def gen_codes(code_list, indent=0):
return codes
def get_value(layer, key):
""" 进行optimizer后可能把inputs的value直接用数值代替(ConstantFuser),
会把input换成attr,所以需要此处的操作。
"""
if key in layer.inputs:
return layer.inputs[key]
else:
return str(layer.attrs[key])
def prim_add(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} + {}".format(layer.outputs[0], layer.inputs["x"],
layer.inputs["y"])
line = "{} = {} + {}".format(layer.outputs[0],
get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_add_(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} + {} * {}".format(layer.outputs[0], layer.inputs["x"],
layer.attrs["alpha"], layer.inputs["y"])
line = "{} = {} + {} * {}".format(layer.outputs[0],
get_value(layer, "x"),
layer.attrs["alpha"],
get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_and(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} and {}".format(layer.outputs[0], layer.inputs["x"],
layer.inputs["y"])
line = "{} = {} and {}".format(layer.outputs[0],
get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_append(layer, indent=1, init_func=[], forward_func=[]):
line = "{}.append({})".format(layer.inputs["list"], layer.inputs["element"])
line = "{}.append({})".format(
get_value(layer, "list"), get_value(layer, "element"))
forward_func.extend(gen_codes([line], indent=indent))
......@@ -72,23 +85,23 @@ def prim_constant(layer, indent=1, init_func=[], forward_func=[]):
def prim_eq(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} == {}".format(layer.outputs[0], layer.inputs["x"],
layer.inputs["y"])
line = "{} = {} == {}".format(layer.outputs[0],
get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_equal(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {}".format(layer.outputs[0], layer.inputs["input"])
line = "{} = {}".format(layer.outputs[0], get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_exception(layer, indent=1, init_func=[], forward_func=[]):
line = "raise RaiseException({})".format(layer.inputs["input"])
line = "raise RaiseException({})".format(get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_if(layer, indent=1, init_func=[], forward_func=[]):
line = "if {} :".format(list(layer.inputs.values())[0])
line = "if {} :".format(get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent))
block = layer.blocks[0]
b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1)
......@@ -105,45 +118,47 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[]):
def prim_getitem(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {}[{}]".format(layer.outputs[0], layer.inputs["list"],
layer.inputs["index"])
line = "{} = {}[{}]".format(layer.outputs[0],
get_value(layer, "list"),
get_value(layer, "index"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_gt(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} > {}".format(layer.outputs[0], layer.inputs["x"],
layer.inputs["y"])
line = "{} = {} > {}".format(layer.outputs[0],
get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_le(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} <= {}".format(layer.outputs[0], layer.inputs["x"],
layer.inputs["y"])
line = "{} = {} <= {}".format(layer.outputs[0],
get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_len(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = len({})".format(layer.outputs[0], layer.inputs["input"])
line = "{} = len({})".format(layer.outputs[0], get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_lt(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} < {}".format(layer.outputs[0], layer.inputs["x"],
layer.inputs["y"])
line = "{} = {} < {}".format(layer.outputs[0],
get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_list(layer, indent=1, init_func=[], forward_func=[]):
inputs_list = list(layer.inputs.values())
input_len = len(layer.inputs) + len(layer.attrs)
inputs_list = list()
for i in range(input_len):
inputs_list.append(get_value(layer, "input{}".format(i)))
inputs_str = ', '.join(inputs_list)
line = "{} = [{}]".format(layer.outputs[0], inputs_str)
forward_func.extend(gen_codes([line], indent=indent))
def prim_loop(layer, indent=1, init_func=[], forward_func=[]):
loop_range = list(layer.inputs.values())[0]
if list(layer.inputs.values())[0] is None:
loop_range = str(layer.attrs[list(layer.inputs.keys())[0]])
loop_range = get_value(layer, "input")
line = "for {} in range({}):".format(layer.outputs[1], loop_range)
forward_func.extend(gen_codes([line], indent=indent))
block = layer.blocks[0]
......@@ -153,66 +168,71 @@ def prim_loop(layer, indent=1, init_func=[], forward_func=[]):
def prim_min(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = min({})".format(layer.outputs[0], layer.inputs["input"])
line = "{} = min({})".format(layer.outputs[0], get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_mul(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} * {}".format(layer.outputs[0], layer.inputs["x"],
layer.inputs["y"])
line = "{} = {} * {}".format(layer.outputs[0],
get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_ne(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} < {}".format(layer.outputs[0], layer.inputs["x"],
layer.inputs["y"])
line = "{} = {} < {}".format(layer.outputs[0],
get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_neg(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = -{}".format(layer.outputs[0], layer.inputs["input"])
line = "{} = -{}".format(layer.outputs[0], get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_not(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = not {}".format(layer.outputs[0], layer.inputs["input"])
line = "{} = not {}".format(layer.outputs[0], get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_requires_grad(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = not {}.stop_gradient".format(layer.outputs[0],
layer.inputs["input"])
get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_select(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {}[".format(layer.outputs[0], layer.inputs["input"])
line = "{} = {}[".format(layer.outputs[0], get_value(layer, "input"))
for dim in range(layer.attrs["dim"]):
line += ":, "
line += (layer.inputs["index"] + "]")
line += (get_value(layer, "index") + "]")
forward_func.extend(gen_codes([line], indent=indent))
def prim_shape(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {}.shape".format(layer.outputs[0], layer.inputs["input"])
line = "{} = {}.shape".format(layer.outputs[0], get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_slice(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {}[{}: {}: {}]".format(
layer.outputs[0], layer.inputs["input"], layer.inputs["start"],
layer.inputs["end"], layer.inputs["step"])
line = "{} = {}[{}: {}: {}]".format(layer.outputs[0],
get_value(layer, "input"),
get_value(layer, "start"),
get_value(layer, "end"),
get_value(layer, "step"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_sub(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} - {}".format(layer.outputs[0], layer.inputs["x"],
layer.inputs["y"])
line = "{} = {} - {}".format(layer.outputs[0],
get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_tuple(layer, indent=1, init_func=[], forward_func=[]):
inputs_list = list(layer.inputs.values())
input_len = len(layer.inputs) + len(layer.attrs)
inputs_list = list()
for i in range(input_len):
inputs_list.append(get_value(layer, "input{}".format(i)))
inputs_str = ', '.join(inputs_list)
line = "{} = ({})".format(layer.outputs[0], inputs_str)
forward_func.extend(gen_codes([line], indent=indent))
......@@ -220,13 +240,13 @@ def prim_tuple(layer, indent=1, init_func=[], forward_func=[]):
def prim_tuple_unpack(layer, indent=1, init_func=[], forward_func=[]):
outputs_str = ', '.join(layer.outputs)
line = "{} = {}".format(outputs_str, layer.inputs["input"])
line = "{} = {}".format(outputs_str, get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_warnings(layer, indent=1, init_func=[], forward_func=[]):
lines = ["import warnings"]
line = "warnings.warn({}, stacklevel={})".format(layer.inputs["input"],
layer.attrs["stacklevel"])
line = "warnings.warn({}, stacklevel={})".format(
get_value(layer, "input"), layer.attrs["stacklevel"])
lines.append(line)
forward_func.extend(gen_codes(lines, indent=indent))
......@@ -14,3 +14,5 @@
from .fc_fuser import FcFuser
from .fc_fuse_pass import FcFusePass
from .constant_fuser import ConstantFuser
from .constant_fuse_pass import ConstantFusePass
......@@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pass_ import ProgramPass
from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion import FcFuser
from x2paddle.optimizer.pass_manager import pass_register
@pass_register
class FcFusePass(ProgramPass):
class FcFusePass(Pass):
name = "fc_fuse_pass"
def __init__(self):
ProgramPass.__init__(self)
Pass.__init__(self)
def apply(self, graph):
fuser = FcFuser()
fuser.operate(graph)
fuser.operate(graph, match_kind="topo")
# 用于注册
......
......@@ -18,7 +18,7 @@ from x2paddle.optimizer.pass_manager import PassManager
class GraphOptimizer(object):
def __init__(self):
self.passes = ["fc_fuse_pass"]
self.passes = ["fc_fuse_pass", "constant_fuse_pass"]
def optimize(self, graph):
for pass_name in self.passes:
......
......@@ -12,19 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
class Kind(Enum):
Program = 1
Code = 2
class Pass(object):
name = "pass"
def __init__(self, kind):
self.kind = kind
def __init__(self):
pass
def apply(self, graph):
raise NotImplementedError("The apply function must be implemented!")
......@@ -32,13 +23,3 @@ class Pass(object):
@classmethod
def get_name(cls):
return cls.name
class ProgramPass(Pass):
def __init__(self):
super(ProgramPass, self).__init__(Kind.Program)
class CodePass(Pass):
def __init__(self):
super(CodePass, self).__init__(Kind.Code)
......@@ -18,14 +18,18 @@ from x2paddle.core.program import PaddleGraph
class PatternMatcher(object):
def __init__(self, pattern):
self.pattern = pattern
self.subgraphs = list()
# matches的每个match是按照拓扑排序组成layer的dict
self.matches = list()
def operate(self, graph):
self.detect_patterns(graph)
def operate(self, graph, match_kind="topo"):
if match_kind == "topo":
self.detect_patterns_by_topo(graph)
elif match_kind == "edge":
self.detect_patterns_by_edge(graph)
self.remove_overlapped_match()
return self.subgraphs
return self.matches
def detect_patterns(self, graph):
def detect_patterns_by_topo(self, graph):
""" 找到与模式匹配的子图,
并将子图的id以拓扑排序存放到subgraph_id2layers。
"""
......@@ -101,49 +105,79 @@ class PatternMatcher(object):
for i, (layer_id, layer) in enumerate(graph.layers.items()):
match_info = get_subgraph(self.pattern, graph, i)
if match_info:
self.subgraphs.append(match_info)
self.matches.append(match_info)
for j, block in enumerate(layer.blocks):
if len(block.layers) > 0:
self.detect_patterns(layer.blocks[j])
self.detect_patterns_by_topo(layer.blocks[j])
def detect_patterns_by_edge(self, graph):
"""当遇见顺序没有强制规定的pattern时使用该方式
"""
pass
def remove_overlapped_match(self):
""" 如果2个子图有重叠,只取前一个子图。
"""
match_ids = []
for i, subgraph in enumerate(self.subgraphs):
for i, match in enumerate(self.matches):
is_overlapped = False
for id in subgraph.keys():
for id in match.keys():
if id in match_ids:
self.subgraphs.pop(i)
self.matches.pop(i)
is_overlapped = True
break
if not is_overlapped:
match_ids.extend(list(subgraph.keys()))
match_ids.extend(list(match.keys()))
def get_subgraph(prefix_layer_id, suffix_layer_id, graph):
""" 根据prefix_layer_id和suffix_layer_id获取需要子图。
Args:
prefix_layer_id (str): 起初为一个空字符串,之后为suffix_layer_id分割出来的前缀。
suffix_layer_id (str): 起初为以一个layer的id,之后将分割部分给prefix_layer_id;例如”57.0.1“;
graph (x2paddle.core.program.PaddleGraph): 需要进行pass的子图。
"""
id_part = suffix_layer_id.split(".")
if len(id_part) == 1:
return graph
if prefix_layer_id == "":
layer_id = id_part[0]
prefix_layer_id += ".".join(id_part[:2])
else:
layer_id = prefix_layer_id + "." + id_part[0]
prefix_layer_id += ("." + ".".join(id_part[:2]))
subgraph = graph.layers[layer_id].blocks[int(id_part[1])]
suffix_layer_id = ".".join(id_part[2:])
return get_subgraph(prefix_layer_id, suffix_layer_id, subgraph)
class FuseBase(object):
def __init__(self):
self.pattern = PaddleGraph()
def operate(self, graph):
def operate(self, graph, match_kind="topo"):
self.build_pattern()
self.perform_pattern_matcher(graph)
for subgraph in self.subgraphs:
self.insert_new_layer(graph, subgraph)
self.perform_pattern_matcher(graph, match_kind)
for match in self.matches:
first_layer_id = list(match.keys())[0]
subgraph = get_subgraph("", first_layer_id, graph)
self.insert_new_layer(subgraph, match)
self.delete_inter_layer(graph)
graph.build()
def perform_pattern_matcher(self, graph):
def perform_pattern_matcher(self, graph, match_kind="topo"):
""" 执行模式匹配,找到匹配的子图。
"""
pattern_matcher = PatternMatcher(self.pattern)
self.subgraphs = pattern_matcher.operate(graph)
self.matches = pattern_matcher.operate(graph, match_kind)
def delete_inter_layer(self, graph):
""" 删除不需要的中间layer及其对应参数。
"""
for subgraph in self.subgraphs:
for layer_id, layer in subgraph.items():
for match in self.matches:
first_layer_id = list(match.keys())[0]
subgraph = get_subgraph("", first_layer_id, graph)
for layer_id, layer in match.items():
if layer.kernel == "fluid.dygraph.base.to_variable" and \
layer.attrs["value"].startswith("params["):
param_name = layer.attrs["value"][8:-2]
......@@ -151,4 +185,4 @@ class FuseBase(object):
graph.parameters.pop(param_name)
if layer_id in graph.layers:
# layer_id可能是属于子图的,此时删除父layer,即删除整个子图
graph.layers.pop(layer_id)
subgraph.layers.pop(layer_id)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册