提交 130e7682 编写于 作者: S SunAhong1993

add pytorch

上级 dfb5a46e
...@@ -209,7 +209,7 @@ def onnx2paddle(model_path, save_dir, paddle_type, params_merge=False): ...@@ -209,7 +209,7 @@ def onnx2paddle(model_path, save_dir, paddle_type, params_merge=False):
mapper.save_inference_model(save_dir, params_merge) mapper.save_inference_model(save_dir, params_merge)
def pytorch2paddle(model_path, save_dir, jit_type, input_files): def pytorch2paddle(module, save_dir, jit_type, input_examples):
# check pytorch installation and version # check pytorch installation and version
try: try:
import torch import torch
...@@ -227,19 +227,20 @@ def pytorch2paddle(model_path, save_dir, jit_type, input_files): ...@@ -227,19 +227,20 @@ def pytorch2paddle(model_path, save_dir, jit_type, input_files):
print("Now translating model from pytorch to paddle.") print("Now translating model from pytorch to paddle.")
from x2paddle.decoder.pytorch_decoder import ScriptDecoder, TraceDecoder from x2paddle.decoder.pytorch_decoder import ScriptDecoder, TraceDecoder
from x2paddle.op_mapper.pytorch2paddle import pytorch_op_mapper from x2paddle.op_mapper.dygraph.pytorch2paddle.pytorch_op_mapper import PyTorchOpMapper
if jit_type == "trace": if jit_type == "trace":
model = TraceDecoder(model_path, input_files) model = TraceDecoder(module, input_examples)
else: else:
model = ScriptDecoder(model_path) model = ScriptDecoder(module)
mapper = pytorch_op_mapper.PyTorchOpMapper(model) mapper = PyTorchOpMapper(model)
mapper.graph.build() mapper.paddle_graph.build()
print("Model optimizing ...") print("Model optimizing ...")
from x2paddle.optimizer.pytorch_optimizer.optimizer import GraphOptimizer from x2paddle.optimizer.optimizer import GraphOptimizer
graph_opt = GraphOptimizer() graph_opt = GraphOptimizer(source_frame="pytorch", paddle_type="dygraph", jit_type=jit_type)
graph_opt.optimize(mapper.graph) graph_opt.optimize(mapper.paddle_graph)
print("Model optimized.") print("Model optimized.")
mapper.graph.gen_model(save_dir, jit_type, input_files) mapper.paddle_graph.gen_model(save_dir, jit_type=jit_type)
def paddle2onnx(model_path, save_dir, opset_version=10): def paddle2onnx(model_path, save_dir, opset_version=10):
......
...@@ -272,6 +272,8 @@ class PaddleGraph(object): ...@@ -272,6 +272,8 @@ class PaddleGraph(object):
def gen_dygraph_model(self, save_dir, jit_type=None): def gen_dygraph_model(self, save_dir, jit_type=None):
if jit_type == "trace": if jit_type == "trace":
# self.gen_dygraph_code(save_dir)
# self.dump_dygraph_parameter(save_dir)
from x2paddle.optimizer.code_optimizer import HierarchicalTree from x2paddle.optimizer.code_optimizer import HierarchicalTree
hierarchical_tree = HierarchicalTree(self) hierarchical_tree = HierarchicalTree(self)
for layer_id, layer in self.layers.items(): for layer_id, layer in self.layers.items():
...@@ -488,12 +490,11 @@ class PaddleGraph(object): ...@@ -488,12 +490,11 @@ class PaddleGraph(object):
gen_codes( gen_codes(
comment_list, comment_list,
indent=1)) indent=1))
use_structured_name = False if self.source_type in ["tf", "onnx"] else True
self.run_func.extend( self.run_func.extend(
gen_codes(["paddle.disable_static()", gen_codes(["paddle.disable_static()",
"params, _ = fluid.load_dygraph('{}/model')".format(code_dir), "params, _ = fluid.load_dygraph('{}/model')".format(code_dir),
"model = {}()".format(self.name), "model = {}()".format(self.name),
"model.set_dict(params, use_structured_name={})".format(use_structured_name), "model.set_dict(params)",
"model.eval()", "model.eval()",
"out = model({})".format(input_data_name), "out = model({})".format(input_data_name),
"return out"], indent=1)) "return out"], indent=1))
...@@ -525,7 +526,6 @@ class PaddleGraph(object): ...@@ -525,7 +526,6 @@ class PaddleGraph(object):
for layer_id, layer in self.layers.items(): for layer_id, layer in self.layers.items():
if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel
) or layer.kernel == "paddle.to_tensor" or \ ) or layer.kernel == "paddle.to_tensor" or \
"paddle.fluid.dygraph" in layer.kernel or \
layer.kernel.startswith("custom_layer"): layer.kernel.startswith("custom_layer"):
line = "{}".format( line = "{}".format(
layer.outputs[0] layer.outputs[0]
...@@ -566,7 +566,7 @@ class PaddleGraph(object): ...@@ -566,7 +566,7 @@ class PaddleGraph(object):
self.forward_func.extend(gen_codes([line], indent=indent)) self.forward_func.extend(gen_codes([line], indent=indent))
elif "prim" in layer.kernel: elif "prim" in layer.kernel:
func_name = layer.kernel.replace(".", "_") func_name = layer.kernel.replace(".", "_")
from x2paddle.op_mapper.dygraph import prim2code from x2paddle.op_mapper.dygraph.pytorch2paddle import prim2code
if hasattr(prim2code, func_name): if hasattr(prim2code, func_name):
func = getattr(prim2code, func_name) func = getattr(prim2code, func_name)
func( func(
...@@ -614,7 +614,6 @@ class PaddleGraph(object): ...@@ -614,7 +614,6 @@ class PaddleGraph(object):
from paddle.fluid.dygraph.jit import declarative from paddle.fluid.dygraph.jit import declarative
sepc_list = list() sepc_list = list()
for i, name in enumerate(self.inputs): for i, name in enumerate(self.inputs):
input_shapes[i][0] = -1
sepc_list.append( sepc_list.append(
paddle.static.InputSpec( paddle.static.InputSpec(
shape=input_shapes[i], name=name, dtype=input_types[i])) shape=input_shapes[i], name=name, dtype=input_types[i]))
...@@ -625,10 +624,16 @@ class PaddleGraph(object): ...@@ -625,10 +624,16 @@ class PaddleGraph(object):
paddle.disable_static() paddle.disable_static()
restore, _ = fluid.load_dygraph(osp.join(save_dir, "model")) restore, _ = fluid.load_dygraph(osp.join(save_dir, "model"))
model = getattr(x2paddle_code, self.name)() model = getattr(x2paddle_code, self.name)()
if self.source_type in ["tf", "onnx"]: if self.source_type == "tf":
model.set_dict(restore, use_structured_name=False) model.set_dict(restore, use_structured_name=False)
else: else:
model.set_dict(restore) model.set_dict(restore)
model.eval() model.eval()
static_model = paddle.jit.to_static(model, input_spec=sepc_list) static_model = paddle.jit.to_static(model, input_spec=sepc_list)
try:
paddle.jit.save(static_model, osp.join(save_dir, "inference_model/model")) paddle.jit.save(static_model, osp.join(save_dir, "inference_model/model"))
except ValueError as e:
if str(e) == "'target_vars' should be a list of Variable.":
print("[DyGraph2StaticGraph Error] Can not convert the dygraph to static! The output of PyTorch mustbe Variable or a list of Variable.")
else:
print(e)
\ No newline at end of file
...@@ -41,9 +41,10 @@ class ScriptDecoder(Decoder): ...@@ -41,9 +41,10 @@ class ScriptDecoder(Decoder):
script_path (str): ScriptModule保存路径。 script_path (str): ScriptModule保存路径。
model_path (str): PyTorchModule保存路径。 model_path (str): PyTorchModule保存路径。
""" """
def __init__(self, script_path=None): def __init__(self, module):
self.script = torch.jit.load(script_path) self.script = torch.jit.script(module)
self.graph = self._optimize_graph(self.script.inlined_graph) self.graph = self._optimize_graph(self.script.inlined_graph)
self.input_examples = None
class TraceDecoder(Decoder): class TraceDecoder(Decoder):
""" PyTorchModule后使用trace方式转换为ScriptModule。 """ PyTorchModule后使用trace方式转换为ScriptModule。
...@@ -53,14 +54,15 @@ class TraceDecoder(Decoder): ...@@ -53,14 +54,15 @@ class TraceDecoder(Decoder):
input_files (list): 输入网络的numpy,每个numpy保存成.npy文件, input_files (list): 输入网络的numpy,每个numpy保存成.npy文件,
文件路径存储在input_files中。 文件路径存储在input_files中。
""" """
def __init__(self, model_path, input_files=list()): def __init__(self, module, input_examples):
# TODO(syf): 传入pytorch的Module(即import),否则出错 try:
model = torch.load(model_path) self.script = torch.jit.trace(module, input_examples)
model.eval() except RuntimeError as e:
input_list = list() if "strict" in str(e):
for npy_file in input_files: self.script = torch.jit.trace(module, input_examples, strict=False)
input_list.append(torch.tensor(np.load(npy_file))) else:
self.script = torch.jit.trace(model, input_list, strict=False) print(e)
exit(0)
self.graph = self._optimize_graph(self.script.inlined_graph) self.graph = self._optimize_graph(self.script.inlined_graph)
# print(self.graph) self.input_examples = input_examples
# print(getattr(getattr(self.script.decoder.block, "5").layer, "2"))
...@@ -25,6 +25,7 @@ def prim_Constant(mapper, graph, node): ...@@ -25,6 +25,7 @@ def prim_Constant(mapper, graph, node):
参数含义: 参数含义:
%2 (常量类型由赋值类型定义,该示例中为int型): 常量赋值结果输出。 %2 (常量类型由赋值类型定义,该示例中为int型): 常量赋值结果输出。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
output = list(node.outputs())[0] output = list(node.outputs())[0]
value = output.toIValue() value = output.toIValue()
...@@ -32,7 +33,10 @@ def prim_Constant(mapper, graph, node): ...@@ -32,7 +33,10 @@ def prim_Constant(mapper, graph, node):
if isinstance(value, str): if isinstance(value, str):
value = string(value) value = string(value)
if str(output_type) == "Tensor": if str(output_type) == "Tensor":
tensor_value = value
value = "{}".format(value) value = "{}".format(value)
if "tensor" in value:
mapper.pytorch_params[output_name] = tensor_value.cpu().detach().numpy()
if "inf" in str(value): if "inf" in str(value):
t = str(type(value)).split("'")[1] t = str(type(value)).split("'")[1]
...@@ -45,7 +49,7 @@ def prim_Constant(mapper, graph, node): ...@@ -45,7 +49,7 @@ def prim_Constant(mapper, graph, node):
value = int(math.pow(2, 31) - 1) value = int(math.pow(2, 31) - 1)
mapper.attrs[output_name] = value mapper.attrs[output_name] = value
graph.add_layer( graph.add_layer(
"prim.constant", inputs={}, outputs=[output_name], value=value) "prim.constant", inputs={}, outputs=[output_name], scope_name=scope_name, value=value)
return [], [output_name] return [], [output_name]
...@@ -60,6 +64,7 @@ def prim_data(mapper, graph, node): ...@@ -60,6 +64,7 @@ def prim_data(mapper, graph, node):
【注意】Paddle中无此用法,所以此处翻译成赋值。 【注意】Paddle中无此用法,所以此处翻译成赋值。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
...@@ -68,15 +73,53 @@ def prim_data(mapper, graph, node): ...@@ -68,15 +73,53 @@ def prim_data(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%4336 # 处理输入0,即%4336
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.equal", inputs=layer_inputs, outputs=layer_outputs) graph.add_layer("prim.equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
def prim_DictConstruct(mapper, graph, node):
""" 构建dict。
TorchScript示例:
%32 : Dict(str, Tensor) = prim::DictConstruct(%30, %23, %31, %29)
参数含义:
%32 (dict): 组成的字典。
%30 (str): key。
%23 (-): value。
%31 (str): key。
%29 (-): value。
"""
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理每个输入
for i, input_name in enumerate(inputs_name):
if i%2 == 0:
layer_attrs["key{}".format(int(i/2))] = mapper.attrs[input_name]
else:
layer_inputs["value{}".format(int(i/2))] = input_name
# 获取当前节点输入的list
current_inputs = list(layer_inputs.values())
graph.add_layer("prim.dict_construct",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs
def prim_GetAttr(mapper, graph, node): def prim_GetAttr(mapper, graph, node):
""" 获取attribute信息。 """ 获取attribute信息。
...@@ -86,6 +129,7 @@ def prim_GetAttr(mapper, graph, node): ...@@ -86,6 +129,7 @@ def prim_GetAttr(mapper, graph, node):
%7 (Tensor): 输入Tensor。 %7 (Tensor): 输入Tensor。
%27 (Tensor): 输入Tensor。 %27 (Tensor): 输入Tensor。
""" """
scope_name = mapper.normalize_scope_name(node)
current_node = node current_node = node
field_name_list = [node.s('name')] field_name_list = [node.s('name')]
while True: while True:
...@@ -102,7 +146,7 @@ def prim_GetAttr(mapper, graph, node): ...@@ -102,7 +146,7 @@ def prim_GetAttr(mapper, graph, node):
if hasattr(part_script, field_name): if hasattr(part_script, field_name):
param = getattr(part_script, field_name) param = getattr(part_script, field_name)
if isinstance(param, torch.Tensor): if isinstance(param, torch.Tensor):
param = param.detach().numpy() param = param.cpu().detach().numpy()
if len(param.shape) == 0: if len(param.shape) == 0:
param = np.reshape(param, 1) param = np.reshape(param, 1)
if str(param.dtype) == "uint8": if str(param.dtype) == "uint8":
...@@ -129,14 +173,15 @@ def prim_If(mapper, graph, node): ...@@ -129,14 +173,15 @@ def prim_If(mapper, graph, node):
%107 (bool): if判断条件。 %107 (bool): if判断条件。
%input.5 (Tensor): if控制流的输出,与%output.4对应。 %input.5 (Tensor): if控制流的输出,与%output.4对应。
""" """
scope_name = mapper.normalize_scope_name(node)
outputs_name = mapper._get_outputs_name(node) outputs_name = mapper._get_outputs_name(node)
node_outputs = outputs_name.copy() node_outputs = outputs_name.copy()
current_outputs = outputs_name.copy() current_outputs = outputs_name.copy()
input_node = list(node.inputs())[0].node() input_node = list(node.inputs())[0].node()
script_input_unique_id = list(node.inputs())[0].unique() script_input_unique_id = list(node.inputs())[0].unique()
input_node_name = mapper.outputs_info[script_input_unique_id] input_node_name = mapper.outputs_info[script_input_unique_id]
mapper._check_input(graph, input_node, input_node_name, current_outputs) mapper._check_input(graph, input_node, input_node_name, current_outputs, scope_name)
graph.add_layer("prim.if", {'input': input_node_name}, node_outputs) graph.add_layer("prim.if", inputs={'input': input_node_name}, outputs=node_outputs, scope_name=scope_name)
current_layer = list(graph.layers.values())[-1] current_layer = list(graph.layers.values())[-1]
block0 = list(node.blocks())[0] block0 = list(node.blocks())[0]
block0_graph, graph_inputs0 = mapper.traverse(block0, current_layer) block0_graph, graph_inputs0 = mapper.traverse(block0, current_layer)
...@@ -163,6 +208,7 @@ def prim_ListConstruct(mapper, graph, node): ...@@ -163,6 +208,7 @@ def prim_ListConstruct(mapper, graph, node):
%84 (int/其他): list第一个元素信息。 %84 (int/其他): list第一个元素信息。
%85 (int/其他): list第二个元素信息。 %85 (int/其他): list第二个元素信息。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
...@@ -175,7 +221,7 @@ def prim_ListConstruct(mapper, graph, node): ...@@ -175,7 +221,7 @@ def prim_ListConstruct(mapper, graph, node):
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.list", inputs=layer_inputs, outputs=layer_outputs) graph.add_layer("prim.list", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -189,6 +235,7 @@ def prim_ListUnpack(mapper, graph, node): ...@@ -189,6 +235,7 @@ def prim_ListUnpack(mapper, graph, node):
%x2.4 (Tensor): 输出,list的第二个元素。 %x2.4 (Tensor): 输出,list的第二个元素。
%4354 (list): 列表。 %4354 (list): 列表。
""" """
scope_name = mapper.normalize_scope_name(node)
outputs_name = mapper._get_outputs_name(node) outputs_name = mapper._get_outputs_name(node)
layer_outputs = outputs_name.copy() layer_outputs = outputs_name.copy()
layer_inputs = {} layer_inputs = {}
...@@ -196,13 +243,13 @@ def prim_ListUnpack(mapper, graph, node): ...@@ -196,13 +243,13 @@ def prim_ListUnpack(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = layer_outputs.copy() current_outputs = layer_outputs.copy()
# 处理输入0,即%4354 # 处理输入0,即%4354
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.list_unpack", inputs=layer_inputs, outputs=layer_outputs) "prim.list_unpack", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
mapper.split_len[list(layer_inputs.values())[0]] = len(layer_outputs) mapper.split_len[list(layer_inputs.values())[0]] = len(layer_outputs)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -223,6 +270,7 @@ def prim_Loop(mapper, graph, node): ...@@ -223,6 +270,7 @@ def prim_Loop(mapper, graph, node):
%x.3 (Tensor): 循环中修改的Tensor。 %x.3 (Tensor): 循环中修改的Tensor。
%x (Tensor): loop循环的输出,与%x.5对应。 %x (Tensor): loop循环的输出,与%x.5对应。
""" """
scope_name = mapper.normalize_scope_name(node)
node_outputs = mapper._get_outputs_name(node) node_outputs = mapper._get_outputs_name(node)
loop_inputs = {} loop_inputs = {}
block = list(node.blocks())[0] block = list(node.blocks())[0]
...@@ -242,7 +290,7 @@ def prim_Loop(mapper, graph, node): ...@@ -242,7 +290,7 @@ def prim_Loop(mapper, graph, node):
loop_input_node_name = mapper.outputs_info[ loop_input_node_name = mapper.outputs_info[
script_loop_input_unique_id] script_loop_input_unique_id]
mapper._check_input(graph, loop_input_node, loop_input_node_name, mapper._check_input(graph, loop_input_node, loop_input_node_name,
node_outputs) node_outputs, scope_name)
loop_inputs['input'] = loop_input_node_name loop_inputs['input'] = loop_input_node_name
loop_outputs.append(block_input_node_name) loop_outputs.append(block_input_node_name)
node_outputs.append(block_input_node_name) node_outputs.append(block_input_node_name)
...@@ -252,14 +300,15 @@ def prim_Loop(mapper, graph, node): ...@@ -252,14 +300,15 @@ def prim_Loop(mapper, graph, node):
loop_input_node_name = mapper.outputs_info[ loop_input_node_name = mapper.outputs_info[
script_loop_input_unique_id] script_loop_input_unique_id]
mapper._check_input(graph, loop_input_node, loop_input_node_name, mapper._check_input(graph, loop_input_node, loop_input_node_name,
node_outputs) node_outputs, scope_name)
graph.add_layer( graph.add_layer(
"prim.equal", "prim.equal",
inputs={'input': loop_input_node_name}, inputs={'input': loop_input_node_name},
outputs=[block_input_node_name]) outputs=[block_input_node_name],
scope_name=scope_name)
node_outputs.append(block_input_node_name) node_outputs.append(block_input_node_name)
graph.add_layer("prim.loop", inputs=loop_inputs, outputs=loop_outputs) graph.add_layer("prim.loop", inputs=loop_inputs, outputs=loop_outputs, scope_name=scope_name)
current_layer = list(graph.layers.values())[-1] current_layer = list(graph.layers.values())[-1]
block_graph, graph_inputs = mapper.traverse(block, current_layer) block_graph, graph_inputs = mapper.traverse(block, current_layer)
for i, input_name in enumerate(graph_inputs): for i, input_name in enumerate(graph_inputs):
...@@ -279,6 +328,7 @@ def prim_min(mapper, graph, node): ...@@ -279,6 +328,7 @@ def prim_min(mapper, graph, node):
%86 (list): 输入。 %86 (list): 输入。
%87 (int): 输出。 %87 (int): 输出。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
...@@ -286,12 +336,12 @@ def prim_min(mapper, graph, node): ...@@ -286,12 +336,12 @@ def prim_min(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%86 # 处理输入0,即%86
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.min", inputs=layer_inputs, outputs=layer_outputs) graph.add_layer("prim.min", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -304,6 +354,7 @@ def prim_NumToTensor(mapper, graph, node): ...@@ -304,6 +354,7 @@ def prim_NumToTensor(mapper, graph, node):
%other.2 (Tensor): 输出。 %other.2 (Tensor): 输出。
%1736 (-): 输入。 %1736 (-): 输入。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
...@@ -312,25 +363,26 @@ def prim_NumToTensor(mapper, graph, node): ...@@ -312,25 +363,26 @@ def prim_NumToTensor(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%86 # 处理输入0,即%86
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
if inputs_node[0].kind() == "aten::size": inputs_inputs_name, inputs_inputs_node = mapper._get_inputs_name(inputs_node[0])
if inputs_node[0].kind() == "aten::size" and len(inputs_inputs_name) > 1:
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim_equal", inputs=layer_inputs, outputs=layer_outputs) "prim_equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
else: else:
layer_inputs["value"] = inputs_name[0] layer_inputs["fill_value"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
input_type = list(node.inputs())[0].type() input_type = list(node.inputs())[0].type()
layer_attrs["dtype"] = input_type layer_attrs["dtype"] = input_type
layer_attrs["persistable"] = True
layer_attrs["shape"] = [1] layer_attrs["shape"] = [1]
graph.add_layer( graph.add_layer(
"fluid.layers.create_global_var", "paddle.full",
inputs=layer_inputs, inputs=layer_inputs,
outputs=layer_outputs, outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs) **layer_attrs)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -343,6 +395,7 @@ def prim_RaiseException(mapper, graph, node): ...@@ -343,6 +395,7 @@ def prim_RaiseException(mapper, graph, node):
参数含义: 参数含义:
%76 (str): 异常信息。 %76 (str): 异常信息。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
...@@ -350,13 +403,13 @@ def prim_RaiseException(mapper, graph, node): ...@@ -350,13 +403,13 @@ def prim_RaiseException(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%76 # 处理输入0,即%76
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.exception", inputs=layer_inputs, outputs=layer_outputs) "prim.exception", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -369,6 +422,7 @@ def prim_requires_grad(mapper, graph, node): ...@@ -369,6 +422,7 @@ def prim_requires_grad(mapper, graph, node):
%356 (bool): 输出,当前Tensor是否计算梯度。 %356 (bool): 输出,当前Tensor是否计算梯度。
%tensor.31 (Tensor): 输入的Tensor。 %tensor.31 (Tensor): 输入的Tensor。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
...@@ -376,13 +430,13 @@ def prim_requires_grad(mapper, graph, node): ...@@ -376,13 +430,13 @@ def prim_requires_grad(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%86 # 处理输入0,即%86
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.requires_grad", inputs=layer_inputs, outputs=layer_outputs) "prim.requires_grad", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -395,6 +449,7 @@ def prim_SetAttr(mapper, graph, node): ...@@ -395,6 +449,7 @@ def prim_SetAttr(mapper, graph, node):
%260 (-): 属性名前缀。 %260 (-): 属性名前缀。
%277 (-): 需要设置的值。 %277 (-): 需要设置的值。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
field_name_list = [] field_name_list = []
tmp_node = node tmp_node = node
...@@ -416,7 +471,8 @@ def prim_SetAttr(mapper, graph, node): ...@@ -416,7 +471,8 @@ def prim_SetAttr(mapper, graph, node):
graph.add_layer( graph.add_layer(
"prim.set_attr", "prim.set_attr",
inputs={"input": inputs_name[1]}, inputs={"input": inputs_name[1]},
outputs=["self." + ".".join(field_name_list).replace(".", "_")]) outputs=["self." + ".".join(field_name_list).replace(".", "_")],
scope_name=scope_name)
return [], [output_name] return [], [output_name]
...@@ -429,6 +485,7 @@ def prim_shape(mapper, graph, node): ...@@ -429,6 +485,7 @@ def prim_shape(mapper, graph, node):
%4701 (list): 输出,shape信息。 %4701 (list): 输出,shape信息。
%result.1 (Tensor): 需要获取shape的值。 %result.1 (Tensor): 需要获取shape的值。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
...@@ -436,13 +493,13 @@ def prim_shape(mapper, graph, node): ...@@ -436,13 +493,13 @@ def prim_shape(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%input.8 # 处理输入0,即%input.8
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"fluid.layers.shape", inputs=layer_inputs, outputs=layer_outputs) "paddle.shape", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -456,6 +513,7 @@ def prim_TupleConstruct(mapper, graph, node): ...@@ -456,6 +513,7 @@ def prim_TupleConstruct(mapper, graph, node):
%x.46 (Tensor/其他): tuple第一个元素信息。 %x.46 (Tensor/其他): tuple第一个元素信息。
%aux (Tensor/其他): tuple第二个元素信息。 %aux (Tensor/其他): tuple第二个元素信息。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
...@@ -468,7 +526,7 @@ def prim_TupleConstruct(mapper, graph, node): ...@@ -468,7 +526,7 @@ def prim_TupleConstruct(mapper, graph, node):
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.tuple", inputs=layer_inputs, outputs=layer_outputs) graph.add_layer("prim.tuple", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -482,6 +540,7 @@ def prim_TupleUnpack(mapper, graph, node): ...@@ -482,6 +540,7 @@ def prim_TupleUnpack(mapper, graph, node):
%aux.3 (Tensor/其他): 输出,tuple第二个元素信息。 %aux.3 (Tensor/其他): 输出,tuple第二个元素信息。
%4492 (tuple): 需要获取元素的tuple。 %4492 (tuple): 需要获取元素的tuple。
""" """
scope_name = mapper.normalize_scope_name(node)
outputs_name = mapper._get_outputs_name(node) outputs_name = mapper._get_outputs_name(node)
layer_outputs = outputs_name layer_outputs = outputs_name
layer_inputs = {} layer_inputs = {}
...@@ -493,7 +552,7 @@ def prim_TupleUnpack(mapper, graph, node): ...@@ -493,7 +552,7 @@ def prim_TupleUnpack(mapper, graph, node):
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.tuple_unpack", inputs=layer_inputs, outputs=layer_outputs) "prim.tuple_unpack", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -508,6 +567,7 @@ def prim_unchecked_cast(mapper, graph, node): ...@@ -508,6 +567,7 @@ def prim_unchecked_cast(mapper, graph, node):
【注意】Paddle中无此用法,所以此处翻译成赋值。 【注意】Paddle中无此用法,所以此处翻译成赋值。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
...@@ -516,12 +576,12 @@ def prim_unchecked_cast(mapper, graph, node): ...@@ -516,12 +576,12 @@ def prim_unchecked_cast(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%size.63 # 处理输入0,即%size.63
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.equal", inputs=layer_inputs, outputs=layer_outputs) graph.add_layer("prim.equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -533,9 +593,10 @@ def prim_Uninitialized(mapper, graph, node): ...@@ -533,9 +593,10 @@ def prim_Uninitialized(mapper, graph, node):
参数含义: 参数含义:
%345 (bool): 输出,为赋值的bool。 %345 (bool): 输出,为赋值的bool。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
output = list(node.outputs())[0] output = list(node.outputs())[0]
mapper.attrs[output_name] = None mapper.attrs[output_name] = None
graph.add_layer( graph.add_layer(
"prim.constant", inputs={}, outputs=[output_name], value=None) "prim.constant", inputs={}, outputs=[output_name], scope_name=scope_name, value=None)
return [], [output_name] return [], [output_name]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.code_optimizer.hierachical_tree import HierarchicalTree
\ No newline at end of file
此差异已折叠。
此差异已折叠。
此差异已折叠。
...@@ -32,3 +32,5 @@ from .reshape_fuser import DygraphReshapeFuser ...@@ -32,3 +32,5 @@ from .reshape_fuser import DygraphReshapeFuser
from .reshape_fuse_pass import DygraphReshapeFusePass from .reshape_fuse_pass import DygraphReshapeFusePass
from .tf_batchnorm_fuser import DygraphTFBatchNormFuser from .tf_batchnorm_fuser import DygraphTFBatchNormFuser
from .tf_batchnorm_fuse_pass import DygraphTFBatchNormFusePass from .tf_batchnorm_fuse_pass import DygraphTFBatchNormFusePass
from .trace_fc_fuser import TraceFcFuser
from .trace_fc_fuse_pass import TraceFcFusePass
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册