未验证 提交 c05a67a4 编写于 作者: J Jason 提交者: GitHub

Merge pull request #435 from SunAhong1993/paddle-2.0

add pytorch2paddle
...@@ -16,6 +16,7 @@ paddlepaddle >= 1.8.0 ...@@ -16,6 +16,7 @@ paddlepaddle >= 1.8.0
tensorflow : tensorflow == 1.14.0 tensorflow : tensorflow == 1.14.0
caffe : 无 caffe : 无
onnx : onnx >= 1.6.0 onnx : onnx >= 1.6.0
pytorch:torch >=1.5.0 (script方式暂不支持1.7.0)
## 安装 ## 安装
### 安装方式一(推荐) ### 安装方式一(推荐)
...@@ -45,11 +46,10 @@ x2paddle --framework=caffe --prototxt=deploy.prototxt --weight=deploy.caffemodel ...@@ -45,11 +46,10 @@ x2paddle --framework=caffe --prototxt=deploy.prototxt --weight=deploy.caffemodel
x2paddle --framework=onnx --model=onnx_model.onnx --save_dir=pd_model x2paddle --framework=onnx --model=onnx_model.onnx --save_dir=pd_model
``` ```
### Paddle2ONNX ### PyTorch
``` > PyTorch不支持命令行使用方式,详见[PyTorch2Paddle](pytorch2paddle.md)
# 注意:paddle_infer_model_dir下需包含__model__和__params__两个文件
x2paddle --framework=paddle2onnx --model=paddle_infer_model_dir --save_dir=onnx_model
```
### 参数选项 ### 参数选项
| 参数 | | | 参数 | |
|----------|--------------| |----------|--------------|
......
# PyTorch2Paddle
PyTorch2Paddle支持trace和script两种方式的转换,均是PyTorch动态图到Paddle动态图的转换,转换后的Paddle动态图运用动转静可转换为静态图模型。trace方式生成的代码可读性较强,较为接近原版PyTorch代码的组织结构;script方式不需要知道输入数据的类型和大小即可转换,使用上较为方便,但目前PyTorch支持的script代码方式有所限制,所以支持转换的代码也有所限制。用户可根据自身需求,选择转换方式。
## 环境依赖
python == 2.7 | python >= 3.5
paddlepaddle >= 1.8.0
pytorch:torch >=1.5.0 (script方式暂不支持1.7.0)
**使用trace方式需安装以下依赖**
pandas
treelib
## 使用方式
```
from x2paddle.convert import pytorch2paddle
pytorch2paddle(module=torch_module,
save_dir="./pd_model",
jit_type="trace",
input_examples=[torch_input])
# module (torch.nn.Module): PyTorch的Module。
# save_dir (str): 转换后模型的保存路径。
# jit_type (str): 转换方式。默认为"trace"。
# input_examples (list[torch.tensor]): torch.nn.Module的输入示例,list的长度必须与输入的长度一致。默认为None。
```
**注意:** 当jit_type为"trace"时,input_examples不可为None,转换后自动进行动转静;
当jit_type为"script"时",input_examples不为None时,才可以进行动转静。
## 使用示例
```
import torch
import numpy as np
from torchvision.models import AlexNet
from torchvision.models.utils import load_state_dict_from_url
# 构建输入
input_data = np.random.rand(1, 3, 224, 224).astype("float32")
# 获取PyTorch Module
torch_module = AlexNet()
torch_state_dict = load_state_dict_from_url('https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth')
torch_module.load_state_dict(torch_state_dict)
# 设置为eval模式
torch_module.eval()
# 进行转换
from x2paddle.convert import pytorch2paddle
pytorch2paddle(torch_model,
save_dir="pd_model_trace",
jit_type="trace",
input_examples=[torch.tensor(input_data)])
```
\ No newline at end of file
...@@ -209,7 +209,7 @@ def onnx2paddle(model_path, save_dir, paddle_type, params_merge=False): ...@@ -209,7 +209,7 @@ def onnx2paddle(model_path, save_dir, paddle_type, params_merge=False):
mapper.save_inference_model(save_dir, params_merge) mapper.save_inference_model(save_dir, params_merge)
def pytorch2paddle(model_path, save_dir, jit_type, input_files): def pytorch2paddle(module, save_dir, jit_type="trace", input_examples=None):
# check pytorch installation and version # check pytorch installation and version
try: try:
import torch import torch
...@@ -225,21 +225,22 @@ def pytorch2paddle(model_path, save_dir, jit_type, input_files): ...@@ -225,21 +225,22 @@ def pytorch2paddle(model_path, save_dir, jit_type, input_files):
) )
return return
print("Now translating model from pytorch to paddle.") print("Now translating model from pytorch to paddle.")
from x2paddle.decoder.pytorch_decoder import ScriptDecoder, TraceDecoder from x2paddle.decoder.pytorch_decoder import ScriptDecoder, TraceDecoder
from x2paddle.op_mapper.pytorch2paddle import pytorch_op_mapper from x2paddle.op_mapper.dygraph.pytorch2paddle.pytorch_op_mapper import PyTorchOpMapper
if jit_type == "trace": if jit_type == "trace":
model = TraceDecoder(model_path, input_files) model = TraceDecoder(module, input_examples)
else: else:
model = ScriptDecoder(model_path) model = ScriptDecoder(module, input_examples)
mapper = pytorch_op_mapper.PyTorchOpMapper(model) mapper = PyTorchOpMapper(model)
mapper.graph.build() mapper.paddle_graph.build()
print("Model optimizing ...") print("Model optimizing ...")
from x2paddle.optimizer.pytorch_optimizer.optimizer import GraphOptimizer from x2paddle.optimizer.optimizer import GraphOptimizer
graph_opt = GraphOptimizer() graph_opt = GraphOptimizer(source_frame="pytorch", paddle_type="dygraph", jit_type=jit_type)
graph_opt.optimize(mapper.graph) graph_opt.optimize(mapper.paddle_graph)
print("Model optimized.") print("Model optimized.")
mapper.graph.gen_model(save_dir, jit_type, input_files) mapper.paddle_graph.gen_model(save_dir, jit_type=jit_type)
def paddle2onnx(model_path, save_dir, opset_version=10): def paddle2onnx(model_path, save_dir, opset_version=10):
...@@ -323,10 +324,6 @@ def main(): ...@@ -323,10 +324,6 @@ def main():
if args.params_merge: if args.params_merge:
params_merge = True params_merge = True
onnx2paddle(args.model, args.save_dir, args.paddle_type, params_merge) onnx2paddle(args.model, args.save_dir, args.paddle_type, params_merge)
elif args.framework == "pytorch":
assert args.model is not None, "--model should be defined while translating pytorch model"
pytorch2paddle(args.model, args.save_dir, args.jit_type, args.input_files)
elif args.framework == "paddle2onnx": elif args.framework == "paddle2onnx":
assert args.model is not None, "--model should be defined while translating paddle model to onnx" assert args.model is not None, "--model should be defined while translating paddle model to onnx"
paddle2onnx(args.model, args.save_dir, opset_version=args.onnx_opset) paddle2onnx(args.model, args.save_dir, opset_version=args.onnx_opset)
......
...@@ -281,13 +281,14 @@ class PaddleGraph(object): ...@@ -281,13 +281,14 @@ class PaddleGraph(object):
else: else:
self.gen_dygraph_code(save_dir) self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir) self.dump_dygraph_parameter(save_dir)
input_shapes = list() # 动转静
input_types = list() if len(self.inputs_info) > 0:
for input_name in self.inputs: input_shapes = list()
input_shapes.append(self.inputs_info[input_name][0]) input_types = list()
input_types.append(self.inputs_info[input_name][1]) for input_name in self.inputs:
# 如果input_files非空,则导出推理模型;其值类似[[None, 3, 224, 224]] input_shapes.append(self.inputs_info[input_name][0])
self.dygraph2static(save_dir, input_shapes, input_types) input_types.append(self.inputs_info[input_name][1])
self.dygraph2static(save_dir, input_shapes, input_types)
def gen_static_code(self, code_dir): def gen_static_code(self, code_dir):
def write_code(f, code_list, indent=0): def write_code(f, code_list, indent=0):
...@@ -424,9 +425,7 @@ class PaddleGraph(object): ...@@ -424,9 +425,7 @@ class PaddleGraph(object):
if self.edges_out.get(layer_id, 0) == 0: if self.edges_out.get(layer_id, 0) == 0:
for i, output_name in enumerate(layer.outputs): for i, output_name in enumerate(layer.outputs):
if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel) or \ if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel):
(layer.kernel == "paddle.to_tensor" and layer.attrs["data"].startswith("params["))or \
"paddle.fluid.dygraph" in layer.kernel:
if i == 0: if i == 0:
continue continue
if output_name not in self.outputs: if output_name not in self.outputs:
...@@ -512,6 +511,8 @@ class PaddleGraph(object): ...@@ -512,6 +511,8 @@ class PaddleGraph(object):
return_code = "return {}".format(", ".join(self.outputs)) return_code = "return {}".format(", ".join(self.outputs))
self.forward_func.extend(gen_codes([return_code], indent=2)) self.forward_func.extend(gen_codes([return_code], indent=2))
for code_line in self.forward_func: for code_line in self.forward_func:
if "assert [1, 1] == 1 or [1, 1] == [1, 1], 'The [1, 1] must be [1, [1, 1]]!'" in code_line:
continue
f.write(code_line) f.write(code_line)
for code_line in self.run_func: for code_line in self.run_func:
f.write(code_line) f.write(code_line)
...@@ -525,7 +526,6 @@ class PaddleGraph(object): ...@@ -525,7 +526,6 @@ class PaddleGraph(object):
for layer_id, layer in self.layers.items(): for layer_id, layer in self.layers.items():
if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel
) or layer.kernel == "paddle.to_tensor" or \ ) or layer.kernel == "paddle.to_tensor" or \
"paddle.fluid.dygraph" in layer.kernel or \
layer.kernel.startswith("custom_layer"): layer.kernel.startswith("custom_layer"):
line = "{}".format( line = "{}".format(
layer.outputs[0] layer.outputs[0]
...@@ -566,7 +566,7 @@ class PaddleGraph(object): ...@@ -566,7 +566,7 @@ class PaddleGraph(object):
self.forward_func.extend(gen_codes([line], indent=indent)) self.forward_func.extend(gen_codes([line], indent=indent))
elif "prim" in layer.kernel: elif "prim" in layer.kernel:
func_name = layer.kernel.replace(".", "_") func_name = layer.kernel.replace(".", "_")
from x2paddle.op_mapper.dygraph import prim2code from x2paddle.op_mapper.dygraph.pytorch2paddle import prim2code
if hasattr(prim2code, func_name): if hasattr(prim2code, func_name):
func = getattr(prim2code, func_name) func = getattr(prim2code, func_name)
func( func(
...@@ -594,7 +594,7 @@ class PaddleGraph(object): ...@@ -594,7 +594,7 @@ class PaddleGraph(object):
line = line.strip(", ") line = line.strip(", ")
line += ")" line += ")"
if layer.kernel == "self.create_parameter": if layer.kernel == "self.create_parameter":
self.init_func.extend(gen_codes(["self." + line], indent=indent)) self.init_func.extend(gen_codes(["self." + line], indent=2))
self.forward_func.extend(gen_codes(["{} = self.{}".format(layer.outputs[0], self.forward_func.extend(gen_codes(["{} = self.{}".format(layer.outputs[0],
layer.outputs[0])], indent=indent)) layer.outputs[0])], indent=indent))
else: else:
...@@ -614,7 +614,6 @@ class PaddleGraph(object): ...@@ -614,7 +614,6 @@ class PaddleGraph(object):
from paddle.fluid.dygraph.jit import declarative from paddle.fluid.dygraph.jit import declarative
sepc_list = list() sepc_list = list()
for i, name in enumerate(self.inputs): for i, name in enumerate(self.inputs):
input_shapes[i][0] = -1
sepc_list.append( sepc_list.append(
paddle.static.InputSpec( paddle.static.InputSpec(
shape=input_shapes[i], name=name, dtype=input_types[i])) shape=input_shapes[i], name=name, dtype=input_types[i]))
...@@ -631,4 +630,11 @@ class PaddleGraph(object): ...@@ -631,4 +630,11 @@ class PaddleGraph(object):
model.set_dict(restore) model.set_dict(restore)
model.eval() model.eval()
static_model = paddle.jit.to_static(model, input_spec=sepc_list) static_model = paddle.jit.to_static(model, input_spec=sepc_list)
paddle.jit.save(static_model, osp.join(save_dir, "inference_model/model")) try:
\ No newline at end of file paddle.jit.save(static_model, osp.join(save_dir, "inference_model/model"))
except ValueError as e:
if str(e) == "'target_vars' should be a list of Variable.":
print("[DyGraph2StaticGraph Error] Can not convert the dygraph to static! The output of PyTorch mustbe Variable or a list of Variable.")
else:
print(e)
exit(0)
...@@ -41,9 +41,10 @@ class ScriptDecoder(Decoder): ...@@ -41,9 +41,10 @@ class ScriptDecoder(Decoder):
script_path (str): ScriptModule保存路径。 script_path (str): ScriptModule保存路径。
model_path (str): PyTorchModule保存路径。 model_path (str): PyTorchModule保存路径。
""" """
def __init__(self, script_path=None): def __init__(self, module, input_examples=None):
self.script = torch.jit.load(script_path) self.script = torch.jit.script(module)
self.graph = self._optimize_graph(self.script.inlined_graph) self.graph = self._optimize_graph(self.script.inlined_graph)
self.input_examples = input_examples
class TraceDecoder(Decoder): class TraceDecoder(Decoder):
""" PyTorchModule后使用trace方式转换为ScriptModule。 """ PyTorchModule后使用trace方式转换为ScriptModule。
...@@ -53,14 +54,16 @@ class TraceDecoder(Decoder): ...@@ -53,14 +54,16 @@ class TraceDecoder(Decoder):
input_files (list): 输入网络的numpy,每个numpy保存成.npy文件, input_files (list): 输入网络的numpy,每个numpy保存成.npy文件,
文件路径存储在input_files中。 文件路径存储在input_files中。
""" """
def __init__(self, model_path, input_files=list()): def __init__(self, module, input_examples):
# TODO(syf): 传入pytorch的Module(即import),否则出错 try:
model = torch.load(model_path) self.script = torch.jit.trace(module, input_examples)
model.eval() except RuntimeError as e:
input_list = list() if "strict" in str(e):
for npy_file in input_files: self.script = torch.jit.trace(module, input_examples, strict=False)
input_list.append(torch.tensor(np.load(npy_file))) else:
self.script = torch.jit.trace(model, input_list, strict=False) print(e)
exit(0)
self.graph = self._optimize_graph(self.script.inlined_graph) self.graph = self._optimize_graph(self.script.inlined_graph)
# print(self.graph) self.input_examples = input_examples
# print(getattr(getattr(self.script.decoder.block, "5").layer, "2"))
...@@ -1180,7 +1180,7 @@ class OpSet9(): ...@@ -1180,7 +1180,7 @@ class OpSet9():
scale=beta) scale=beta)
add_inputs = {"x": val_mm, "y": var_beta} add_inputs = {"x": val_mm, "y": var_beta}
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.addd", "paddle.add",
inputs=add_inputs, inputs=add_inputs,
outputs=[node.layer_name]) outputs=[node.layer_name])
......
...@@ -25,6 +25,7 @@ def prim_Constant(mapper, graph, node): ...@@ -25,6 +25,7 @@ def prim_Constant(mapper, graph, node):
参数含义: 参数含义:
%2 (常量类型由赋值类型定义,该示例中为int型): 常量赋值结果输出。 %2 (常量类型由赋值类型定义,该示例中为int型): 常量赋值结果输出。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
output = list(node.outputs())[0] output = list(node.outputs())[0]
value = output.toIValue() value = output.toIValue()
...@@ -32,7 +33,10 @@ def prim_Constant(mapper, graph, node): ...@@ -32,7 +33,10 @@ def prim_Constant(mapper, graph, node):
if isinstance(value, str): if isinstance(value, str):
value = string(value) value = string(value)
if str(output_type) == "Tensor": if str(output_type) == "Tensor":
tensor_value = value
value = "{}".format(value) value = "{}".format(value)
if "tensor" in value:
mapper.pytorch_params[output_name] = tensor_value.cpu().detach().numpy()
if "inf" in str(value): if "inf" in str(value):
t = str(type(value)).split("'")[1] t = str(type(value)).split("'")[1]
...@@ -45,7 +49,7 @@ def prim_Constant(mapper, graph, node): ...@@ -45,7 +49,7 @@ def prim_Constant(mapper, graph, node):
value = int(math.pow(2, 31) - 1) value = int(math.pow(2, 31) - 1)
mapper.attrs[output_name] = value mapper.attrs[output_name] = value
graph.add_layer( graph.add_layer(
"prim.constant", inputs={}, outputs=[output_name], value=value) "prim.constant", inputs={}, outputs=[output_name], scope_name=scope_name, value=value)
return [], [output_name] return [], [output_name]
...@@ -60,6 +64,7 @@ def prim_data(mapper, graph, node): ...@@ -60,6 +64,7 @@ def prim_data(mapper, graph, node):
【注意】Paddle中无此用法,所以此处翻译成赋值。 【注意】Paddle中无此用法,所以此处翻译成赋值。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
...@@ -68,15 +73,53 @@ def prim_data(mapper, graph, node): ...@@ -68,15 +73,53 @@ def prim_data(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%4336 # 处理输入0,即%4336
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.equal", inputs=layer_inputs, outputs=layer_outputs) graph.add_layer("prim.equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
def prim_DictConstruct(mapper, graph, node):
""" 构建dict。
TorchScript示例:
%32 : Dict(str, Tensor) = prim::DictConstruct(%30, %23, %31, %29)
参数含义:
%32 (dict): 组成的字典。
%30 (str): key。
%23 (-): value。
%31 (str): key。
%29 (-): value。
"""
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理每个输入
for i, input_name in enumerate(inputs_name):
if i%2 == 0:
layer_attrs["key{}".format(int(i/2))] = mapper.attrs[input_name]
else:
layer_inputs["value{}".format(int(i/2))] = input_name
# 获取当前节点输入的list
current_inputs = list(layer_inputs.values())
graph.add_layer("prim.dict_construct",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs
def prim_GetAttr(mapper, graph, node): def prim_GetAttr(mapper, graph, node):
""" 获取attribute信息。 """ 获取attribute信息。
...@@ -86,6 +129,7 @@ def prim_GetAttr(mapper, graph, node): ...@@ -86,6 +129,7 @@ def prim_GetAttr(mapper, graph, node):
%7 (Tensor): 输入Tensor。 %7 (Tensor): 输入Tensor。
%27 (Tensor): 输入Tensor。 %27 (Tensor): 输入Tensor。
""" """
scope_name = mapper.normalize_scope_name(node)
current_node = node current_node = node
field_name_list = [node.s('name')] field_name_list = [node.s('name')]
while True: while True:
...@@ -102,7 +146,7 @@ def prim_GetAttr(mapper, graph, node): ...@@ -102,7 +146,7 @@ def prim_GetAttr(mapper, graph, node):
if hasattr(part_script, field_name): if hasattr(part_script, field_name):
param = getattr(part_script, field_name) param = getattr(part_script, field_name)
if isinstance(param, torch.Tensor): if isinstance(param, torch.Tensor):
param = param.detach().numpy() param = param.cpu().detach().numpy()
if len(param.shape) == 0: if len(param.shape) == 0:
param = np.reshape(param, 1) param = np.reshape(param, 1)
if str(param.dtype) == "uint8": if str(param.dtype) == "uint8":
...@@ -129,14 +173,15 @@ def prim_If(mapper, graph, node): ...@@ -129,14 +173,15 @@ def prim_If(mapper, graph, node):
%107 (bool): if判断条件。 %107 (bool): if判断条件。
%input.5 (Tensor): if控制流的输出,与%output.4对应。 %input.5 (Tensor): if控制流的输出,与%output.4对应。
""" """
scope_name = mapper.normalize_scope_name(node)
outputs_name = mapper._get_outputs_name(node) outputs_name = mapper._get_outputs_name(node)
node_outputs = outputs_name.copy() node_outputs = outputs_name.copy()
current_outputs = outputs_name.copy() current_outputs = outputs_name.copy()
input_node = list(node.inputs())[0].node() input_node = list(node.inputs())[0].node()
script_input_unique_id = list(node.inputs())[0].unique() script_input_unique_id = list(node.inputs())[0].unique()
input_node_name = mapper.outputs_info[script_input_unique_id] input_node_name = mapper.outputs_info[script_input_unique_id]
mapper._check_input(graph, input_node, input_node_name, current_outputs) mapper._check_input(graph, input_node, input_node_name, current_outputs, scope_name)
graph.add_layer("prim.if", {'input': input_node_name}, node_outputs) graph.add_layer("prim.if", inputs={'input': input_node_name}, outputs=node_outputs, scope_name=scope_name)
current_layer = list(graph.layers.values())[-1] current_layer = list(graph.layers.values())[-1]
block0 = list(node.blocks())[0] block0 = list(node.blocks())[0]
block0_graph, graph_inputs0 = mapper.traverse(block0, current_layer) block0_graph, graph_inputs0 = mapper.traverse(block0, current_layer)
...@@ -163,6 +208,7 @@ def prim_ListConstruct(mapper, graph, node): ...@@ -163,6 +208,7 @@ def prim_ListConstruct(mapper, graph, node):
%84 (int/其他): list第一个元素信息。 %84 (int/其他): list第一个元素信息。
%85 (int/其他): list第二个元素信息。 %85 (int/其他): list第二个元素信息。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
...@@ -175,7 +221,7 @@ def prim_ListConstruct(mapper, graph, node): ...@@ -175,7 +221,7 @@ def prim_ListConstruct(mapper, graph, node):
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.list", inputs=layer_inputs, outputs=layer_outputs) graph.add_layer("prim.list", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -189,6 +235,7 @@ def prim_ListUnpack(mapper, graph, node): ...@@ -189,6 +235,7 @@ def prim_ListUnpack(mapper, graph, node):
%x2.4 (Tensor): 输出,list的第二个元素。 %x2.4 (Tensor): 输出,list的第二个元素。
%4354 (list): 列表。 %4354 (list): 列表。
""" """
scope_name = mapper.normalize_scope_name(node)
outputs_name = mapper._get_outputs_name(node) outputs_name = mapper._get_outputs_name(node)
layer_outputs = outputs_name.copy() layer_outputs = outputs_name.copy()
layer_inputs = {} layer_inputs = {}
...@@ -196,13 +243,13 @@ def prim_ListUnpack(mapper, graph, node): ...@@ -196,13 +243,13 @@ def prim_ListUnpack(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = layer_outputs.copy() current_outputs = layer_outputs.copy()
# 处理输入0,即%4354 # 处理输入0,即%4354
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.list_unpack", inputs=layer_inputs, outputs=layer_outputs) "prim.list_unpack", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
mapper.split_len[list(layer_inputs.values())[0]] = len(layer_outputs) mapper.split_len[list(layer_inputs.values())[0]] = len(layer_outputs)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -223,6 +270,7 @@ def prim_Loop(mapper, graph, node): ...@@ -223,6 +270,7 @@ def prim_Loop(mapper, graph, node):
%x.3 (Tensor): 循环中修改的Tensor。 %x.3 (Tensor): 循环中修改的Tensor。
%x (Tensor): loop循环的输出,与%x.5对应。 %x (Tensor): loop循环的输出,与%x.5对应。
""" """
scope_name = mapper.normalize_scope_name(node)
node_outputs = mapper._get_outputs_name(node) node_outputs = mapper._get_outputs_name(node)
loop_inputs = {} loop_inputs = {}
block = list(node.blocks())[0] block = list(node.blocks())[0]
...@@ -242,7 +290,7 @@ def prim_Loop(mapper, graph, node): ...@@ -242,7 +290,7 @@ def prim_Loop(mapper, graph, node):
loop_input_node_name = mapper.outputs_info[ loop_input_node_name = mapper.outputs_info[
script_loop_input_unique_id] script_loop_input_unique_id]
mapper._check_input(graph, loop_input_node, loop_input_node_name, mapper._check_input(graph, loop_input_node, loop_input_node_name,
node_outputs) node_outputs, scope_name)
loop_inputs['input'] = loop_input_node_name loop_inputs['input'] = loop_input_node_name
loop_outputs.append(block_input_node_name) loop_outputs.append(block_input_node_name)
node_outputs.append(block_input_node_name) node_outputs.append(block_input_node_name)
...@@ -252,14 +300,15 @@ def prim_Loop(mapper, graph, node): ...@@ -252,14 +300,15 @@ def prim_Loop(mapper, graph, node):
loop_input_node_name = mapper.outputs_info[ loop_input_node_name = mapper.outputs_info[
script_loop_input_unique_id] script_loop_input_unique_id]
mapper._check_input(graph, loop_input_node, loop_input_node_name, mapper._check_input(graph, loop_input_node, loop_input_node_name,
node_outputs) node_outputs, scope_name)
graph.add_layer( graph.add_layer(
"prim.equal", "prim.equal",
inputs={'input': loop_input_node_name}, inputs={'input': loop_input_node_name},
outputs=[block_input_node_name]) outputs=[block_input_node_name],
scope_name=scope_name)
node_outputs.append(block_input_node_name) node_outputs.append(block_input_node_name)
graph.add_layer("prim.loop", inputs=loop_inputs, outputs=loop_outputs) graph.add_layer("prim.loop", inputs=loop_inputs, outputs=loop_outputs, scope_name=scope_name)
current_layer = list(graph.layers.values())[-1] current_layer = list(graph.layers.values())[-1]
block_graph, graph_inputs = mapper.traverse(block, current_layer) block_graph, graph_inputs = mapper.traverse(block, current_layer)
for i, input_name in enumerate(graph_inputs): for i, input_name in enumerate(graph_inputs):
...@@ -279,6 +328,7 @@ def prim_min(mapper, graph, node): ...@@ -279,6 +328,7 @@ def prim_min(mapper, graph, node):
%86 (list): 输入。 %86 (list): 输入。
%87 (int): 输出。 %87 (int): 输出。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
...@@ -286,12 +336,12 @@ def prim_min(mapper, graph, node): ...@@ -286,12 +336,12 @@ def prim_min(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%86 # 处理输入0,即%86
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.min", inputs=layer_inputs, outputs=layer_outputs) graph.add_layer("prim.min", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -304,6 +354,7 @@ def prim_NumToTensor(mapper, graph, node): ...@@ -304,6 +354,7 @@ def prim_NumToTensor(mapper, graph, node):
%other.2 (Tensor): 输出。 %other.2 (Tensor): 输出。
%1736 (-): 输入。 %1736 (-): 输入。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
...@@ -312,25 +363,26 @@ def prim_NumToTensor(mapper, graph, node): ...@@ -312,25 +363,26 @@ def prim_NumToTensor(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%86 # 处理输入0,即%86
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
if inputs_node[0].kind() == "aten::size": inputs_inputs_name, inputs_inputs_node = mapper._get_inputs_name(inputs_node[0])
if inputs_node[0].kind() == "aten::size" and len(inputs_inputs_name) > 1:
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim_equal", inputs=layer_inputs, outputs=layer_outputs) "prim_equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
else: else:
layer_inputs["value"] = inputs_name[0] layer_inputs["fill_value"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
input_type = list(node.inputs())[0].type() input_type = list(node.inputs())[0].type()
layer_attrs["dtype"] = input_type layer_attrs["dtype"] = input_type
layer_attrs["persistable"] = True
layer_attrs["shape"] = [1] layer_attrs["shape"] = [1]
graph.add_layer( graph.add_layer(
"fluid.layers.create_global_var", "paddle.full",
inputs=layer_inputs, inputs=layer_inputs,
outputs=layer_outputs, outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs) **layer_attrs)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -343,6 +395,7 @@ def prim_RaiseException(mapper, graph, node): ...@@ -343,6 +395,7 @@ def prim_RaiseException(mapper, graph, node):
参数含义: 参数含义:
%76 (str): 异常信息。 %76 (str): 异常信息。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
...@@ -350,13 +403,13 @@ def prim_RaiseException(mapper, graph, node): ...@@ -350,13 +403,13 @@ def prim_RaiseException(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%76 # 处理输入0,即%76
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.exception", inputs=layer_inputs, outputs=layer_outputs) "prim.exception", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -369,6 +422,7 @@ def prim_requires_grad(mapper, graph, node): ...@@ -369,6 +422,7 @@ def prim_requires_grad(mapper, graph, node):
%356 (bool): 输出,当前Tensor是否计算梯度。 %356 (bool): 输出,当前Tensor是否计算梯度。
%tensor.31 (Tensor): 输入的Tensor。 %tensor.31 (Tensor): 输入的Tensor。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
...@@ -376,13 +430,13 @@ def prim_requires_grad(mapper, graph, node): ...@@ -376,13 +430,13 @@ def prim_requires_grad(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%86 # 处理输入0,即%86
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.requires_grad", inputs=layer_inputs, outputs=layer_outputs) "prim.requires_grad", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -395,6 +449,7 @@ def prim_SetAttr(mapper, graph, node): ...@@ -395,6 +449,7 @@ def prim_SetAttr(mapper, graph, node):
%260 (-): 属性名前缀。 %260 (-): 属性名前缀。
%277 (-): 需要设置的值。 %277 (-): 需要设置的值。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
field_name_list = [] field_name_list = []
tmp_node = node tmp_node = node
...@@ -416,7 +471,8 @@ def prim_SetAttr(mapper, graph, node): ...@@ -416,7 +471,8 @@ def prim_SetAttr(mapper, graph, node):
graph.add_layer( graph.add_layer(
"prim.set_attr", "prim.set_attr",
inputs={"input": inputs_name[1]}, inputs={"input": inputs_name[1]},
outputs=["self." + ".".join(field_name_list).replace(".", "_")]) outputs=["self." + ".".join(field_name_list).replace(".", "_")],
scope_name=scope_name)
return [], [output_name] return [], [output_name]
...@@ -429,6 +485,7 @@ def prim_shape(mapper, graph, node): ...@@ -429,6 +485,7 @@ def prim_shape(mapper, graph, node):
%4701 (list): 输出,shape信息。 %4701 (list): 输出,shape信息。
%result.1 (Tensor): 需要获取shape的值。 %result.1 (Tensor): 需要获取shape的值。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
...@@ -436,13 +493,13 @@ def prim_shape(mapper, graph, node): ...@@ -436,13 +493,13 @@ def prim_shape(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%input.8 # 处理输入0,即%input.8
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"fluid.layers.shape", inputs=layer_inputs, outputs=layer_outputs) "paddle.shape", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -456,6 +513,7 @@ def prim_TupleConstruct(mapper, graph, node): ...@@ -456,6 +513,7 @@ def prim_TupleConstruct(mapper, graph, node):
%x.46 (Tensor/其他): tuple第一个元素信息。 %x.46 (Tensor/其他): tuple第一个元素信息。
%aux (Tensor/其他): tuple第二个元素信息。 %aux (Tensor/其他): tuple第二个元素信息。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
...@@ -468,7 +526,7 @@ def prim_TupleConstruct(mapper, graph, node): ...@@ -468,7 +526,7 @@ def prim_TupleConstruct(mapper, graph, node):
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.tuple", inputs=layer_inputs, outputs=layer_outputs) graph.add_layer("prim.tuple", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -482,6 +540,7 @@ def prim_TupleUnpack(mapper, graph, node): ...@@ -482,6 +540,7 @@ def prim_TupleUnpack(mapper, graph, node):
%aux.3 (Tensor/其他): 输出,tuple第二个元素信息。 %aux.3 (Tensor/其他): 输出,tuple第二个元素信息。
%4492 (tuple): 需要获取元素的tuple。 %4492 (tuple): 需要获取元素的tuple。
""" """
scope_name = mapper.normalize_scope_name(node)
outputs_name = mapper._get_outputs_name(node) outputs_name = mapper._get_outputs_name(node)
layer_outputs = outputs_name layer_outputs = outputs_name
layer_inputs = {} layer_inputs = {}
...@@ -493,7 +552,7 @@ def prim_TupleUnpack(mapper, graph, node): ...@@ -493,7 +552,7 @@ def prim_TupleUnpack(mapper, graph, node):
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.tuple_unpack", inputs=layer_inputs, outputs=layer_outputs) "prim.tuple_unpack", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -508,6 +567,7 @@ def prim_unchecked_cast(mapper, graph, node): ...@@ -508,6 +567,7 @@ def prim_unchecked_cast(mapper, graph, node):
【注意】Paddle中无此用法,所以此处翻译成赋值。 【注意】Paddle中无此用法,所以此处翻译成赋值。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
...@@ -516,12 +576,12 @@ def prim_unchecked_cast(mapper, graph, node): ...@@ -516,12 +576,12 @@ def prim_unchecked_cast(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%size.63 # 处理输入0,即%size.63
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.equal", inputs=layer_inputs, outputs=layer_outputs) graph.add_layer("prim.equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -533,9 +593,10 @@ def prim_Uninitialized(mapper, graph, node): ...@@ -533,9 +593,10 @@ def prim_Uninitialized(mapper, graph, node):
参数含义: 参数含义:
%345 (bool): 输出,为赋值的bool。 %345 (bool): 输出,为赋值的bool。
""" """
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
output = list(node.outputs())[0] output = list(node.outputs())[0]
mapper.attrs[output_name] = None mapper.attrs[output_name] = None
graph.add_layer( graph.add_layer(
"prim.constant", inputs={}, outputs=[output_name], value=None) "prim.constant", inputs={}, outputs=[output_name], scope_name=scope_name, value=None)
return [], [output_name] return [], [output_name]
...@@ -17,24 +17,29 @@ import numpy as np ...@@ -17,24 +17,29 @@ import numpy as np
from x2paddle.core.op_mapper import OpMapper from x2paddle.core.op_mapper import OpMapper
from x2paddle.core.util import * from x2paddle.core.util import *
from x2paddle.core.program import PaddleGraph from x2paddle.core.program import PaddleGraph
from x2paddle.op_mapper.pytorch2paddle import prim from x2paddle.op_mapper.dygraph.pytorch2paddle import prim
from x2paddle.op_mapper.pytorch2paddle import aten from x2paddle.op_mapper.dygraph.pytorch2paddle import aten
class PyTorchOpMapper(OpMapper): class PyTorchOpMapper(OpMapper):
def __init__(self, decoder): def __init__(self, decoder):
super(PyTorchOpMapper, self).__init__() super(PyTorchOpMapper, self).__init__()
self.script = decoder.script self.script = decoder.script
self.input_examples = decoder.input_examples
self.paddle_params = dict() self.paddle_params = dict()
self.outputs_info = {} # key为output unique id,value为当前节点的输出名字 self.outputs_info = {} # key为output unique id,value为当前节点的输出名字
self.pytorch_params = {} # key为节点名,value为参数 self.pytorch_params = {} # key为节点名,value为参数
self.attrs = {} # key为节点名,value为属性值 self.attrs = {} # key为节点名,value为属性值
self.output_index = 0 self.output_index = 0
self.dygraph_name_id = {} # 动态图__init__输出名字中的id,key为kernel类型,value为id self.nn_name2id = {} # 动态图__init__输出名字中的id,key为kernel类型,value为id
self.split_len = {} # split的长度 self.split_len = {} # split的长度
self.scope_name_list = list()
self.scope_name2id = dict()
self.inputs_info = dict()
# 转换 # 转换
self.check_op(decoder.graph) self.check_op(decoder.graph)
self.graph, _ = self.traverse(decoder.graph) self.paddle_graph, _ = self.traverse(decoder.graph)
self.paddle_graph.set_inputs_info(self.inputs_info)
def check_op(self, script_graph): def check_op(self, script_graph):
def _update_op_list(graph): def _update_op_list(graph):
...@@ -72,17 +77,21 @@ class PyTorchOpMapper(OpMapper): ...@@ -72,17 +77,21 @@ class PyTorchOpMapper(OpMapper):
current_node_outputs.extend(outputs) current_node_outputs.extend(outputs)
# 初始化 # 初始化
graph = PaddleGraph(parent_layer, graph_type="dygraph") graph = PaddleGraph(parent_layer=parent_layer, graph_type="dygraph")
if "TopLevelTracedModule" in str(type(self.script)):
graph.set_script(self.script)
current_node_outputs = [] current_node_outputs = []
graph_inputs = [] graph_inputs = []
# 转换输入节点 # 转换输入节点
if isinstance(script_graph, torch._C.Graph): if isinstance(script_graph, torch._C.Graph):
input_ct = 0
for i, ivalue in enumerate(script_graph.inputs()): for i, ivalue in enumerate(script_graph.inputs()):
node = ivalue.node() node = ivalue.node()
if str(ivalue.type()) != "Tensor": if str(ivalue.type()) not in ["Tensor", "Dict[str, Tensor]"]:
graph.set_name(str(ivalue.type()).split(".")[-1]) graph.set_name(str(ivalue.type()).split(".")[-1])
continue continue
inputs, outputs = self.data(graph, node, ivalue.unique()) inputs, outputs = self.data(graph, node, ivalue.unique(), input_ct)
input_ct += 1
# 转换中间节点 # 转换中间节点
for node in script_graph.nodes(): for node in script_graph.nodes():
kind = node.kind() kind = node.kind()
...@@ -120,7 +129,7 @@ class PyTorchOpMapper(OpMapper): ...@@ -120,7 +129,7 @@ class PyTorchOpMapper(OpMapper):
graph.outputs = inputs_name graph.outputs = inputs_name
# 更新split参数 # 更新split参数
for layer in graph.layers.values(): for layer in graph.layers.values():
if layer.kernel == "fluid.layers.split" and "num_or_sections" in layer.attrs: if layer.kernel == "paddle.split" and "num_or_sections" in layer.attrs:
layer.attrs["num_or_sections"] = self.split_len[layer.outputs[ layer.attrs["num_or_sections"] = self.split_len[layer.outputs[
0]] 0]]
return graph, graph_inputs return graph, graph_inputs
...@@ -151,6 +160,7 @@ class PyTorchOpMapper(OpMapper): ...@@ -151,6 +160,7 @@ class PyTorchOpMapper(OpMapper):
node, node,
output_name, output_name,
node_outputs, node_outputs,
scope_name,
add_dim=False): add_dim=False):
if node.kind() == "prim::GetAttr": if node.kind() == "prim::GetAttr":
param = self.pytorch_params[output_name] param = self.pytorch_params[output_name]
...@@ -159,10 +169,13 @@ class PyTorchOpMapper(OpMapper): ...@@ -159,10 +169,13 @@ class PyTorchOpMapper(OpMapper):
param = param[np.newaxis, :] param = param[np.newaxis, :]
self.paddle_params[output_name] = param self.paddle_params[output_name] = param
graph.add_layer( graph.add_layer(
"fluid.dygraph.base.to_variable", "self.create_parameter",
inputs={}, inputs={},
outputs=[output_name], outputs=[output_name],
value="params[{}]".format(string(output_name))) scope_name=scope_name,
dtype=string(str(param.dtype)),
shape = param.shape,
default_initializer="paddle.nn.initializer.Constant(value=0.0)")
else: else:
if isinstance(param, dict) and "Tensor" in param and \ if isinstance(param, dict) and "Tensor" in param and \
"parent_layer_id" in param: "parent_layer_id" in param:
...@@ -183,11 +196,13 @@ class PyTorchOpMapper(OpMapper): ...@@ -183,11 +196,13 @@ class PyTorchOpMapper(OpMapper):
param = param[np.newaxis, :] param = param[np.newaxis, :]
self.paddle_params[output_name] = param self.paddle_params[output_name] = param
graph.add_layer( graph.add_layer(
"fluid.dygraph.base.to_variable", "self.create_parameter",
inputs={}, inputs={},
outputs=[output_name], outputs=[output_name],
value="params[{}]".format( scope_name=scope_name,
string(output_name))) dtype=string(str(param.dtype)),
shape = param.shape,
default_initializer="paddle.nn.initializer.Constant(value=0.0)")
node_outputs.append(output_name) node_outputs.append(output_name)
return return
# 若if-else外,则可直接引用if-else中的赋值结果 # 若if-else外,则可直接引用if-else中的赋值结果
...@@ -195,16 +210,30 @@ class PyTorchOpMapper(OpMapper): ...@@ -195,16 +210,30 @@ class PyTorchOpMapper(OpMapper):
"prim.constant", "prim.constant",
inputs={}, inputs={},
outputs=[output_name], outputs=[output_name],
scope_name=scope_name,
value=param["Tensor"]) value=param["Tensor"])
else: else:
graph.add_layer( graph.add_layer(
"prim.constant", "prim.constant",
inputs={}, inputs={},
outputs=[output_name], outputs=[output_name],
scope_name=scope_name,
value=string(param) value=string(param)
if isinstance(param, str) else param) if isinstance(param, str) else param)
node_outputs.append(output_name) node_outputs.append(output_name)
elif node.kind() == "prim::Constant" and output_name in self.pytorch_params:
param = self.pytorch_params[output_name]
self.paddle_params[output_name] = param
graph.add_layer(
"self.create_parameter",
inputs={},
outputs=[output_name],
scope_name=scope_name,
dtype=string(str(param.dtype)),
shape = param.shape,
default_initializer="paddle.nn.initializer.Constant(value=0.0)")
def _get_inputs_name(self, node): def _get_inputs_name(self, node):
inputs_name = [] inputs_name = []
inputs_node = [] inputs_node = []
...@@ -215,8 +244,10 @@ class PyTorchOpMapper(OpMapper): ...@@ -215,8 +244,10 @@ class PyTorchOpMapper(OpMapper):
inputs_node.append(script_input_node) inputs_node.append(script_input_node)
inputs_name.append(input_name) inputs_name.append(input_name)
return inputs_name, inputs_node return inputs_name, inputs_node
def data(self, graph, node, uid): def data(self, graph, node, uid, input_ct):
scope_name = self.normalize_scope_name(node)
for output_ivalue in node.outputs(): for output_ivalue in node.outputs():
script_unique_id = output_ivalue.unique() script_unique_id = output_ivalue.unique()
if script_unique_id in self.outputs_info or script_unique_id != uid: if script_unique_id in self.outputs_info or script_unique_id != uid:
...@@ -226,13 +257,18 @@ class PyTorchOpMapper(OpMapper): ...@@ -226,13 +257,18 @@ class PyTorchOpMapper(OpMapper):
self.output_index += 1 self.output_index += 1
output_name = self.outputs_info[uid] output_name = self.outputs_info[uid]
graph.add_layer( graph.add_layer(
"fluid.dygraph.base.to_variable", "paddle.to_tensor",
inputs={}, inputs={},
outputs=[node_name], outputs=[node_name],
value=output_name) scope_name=scope_name,
data=output_name)
if self.input_examples is not None:
input_np = self.input_examples[input_ct].detach().numpy()
self.inputs_info[output_name] = [list(input_np.shape), str(input_np.dtype)]
return [], [output_name] return [], [output_name]
def equal(self, graph, node, uid=None, parent_layer=None, index=None): def equal(self, graph, node, uid=None, parent_layer=None, index=None):
scope_name = self.normalize_scope_name(node)
if parent_layer is not None and index is not None: if parent_layer is not None and index is not None:
# block的输出 # block的输出
input_node_name = self.outputs_info[uid] input_node_name = self.outputs_info[uid]
...@@ -241,9 +277,61 @@ class PyTorchOpMapper(OpMapper): ...@@ -241,9 +277,61 @@ class PyTorchOpMapper(OpMapper):
control_output_id = index - 1 control_output_id = index - 1
output_node_name = parent_layer.outputs[control_output_id] output_node_name = parent_layer.outputs[control_output_id]
current_outputs = [output_node_name] current_outputs = [output_node_name]
self._check_input(graph, node, input_node_name, current_outputs) self._check_input(graph, node, input_node_name, current_outputs, scope_name)
graph.add_layer( graph.add_layer(
"prim.equal", "prim.equal",
inputs={'input': input_node_name}, inputs={'input': input_node_name},
outputs=[output_node_name]) outputs=[output_node_name],
scope_name=scope_name)
return [input_node_name], current_outputs return [input_node_name], current_outputs
def normalize_scope_name(self, node):
""" 对scope的名字进行标准化。
"""
scope_name = node.scopeName()
if scope_name == "":
return scope_name
scope_name_part = scope_name.split("/")
for index in range(len(scope_name_part) - 1):
if scope_name_part[index] in scope_name_part[index + 1]:
continue
last_name_segments = scope_name_part[index].split(".")
name_segments = scope_name_part[index + 1].split(".")
for j, name in enumerate(last_name_segments):
name_segments[j] = name
scope_name_part[index + 1] = ".".join(name_segments)
last_name = scope_name_part[-1]
name_segments = last_name.split(".")
for i, ns in enumerate(name_segments):
if i not in self.scope_name2id:
self.scope_name2id[i] = dict()
if ns not in self.scope_name2id[i]:
self.scope_name2id[i][ns] = 0
real_scope_name = "/".join(name_segments[1:])
real_father_scope_name = "/".join(name_segments[1:-1])
for i, ns in enumerate(name_segments):
if i == 0:
continue
if self.scope_name2id[i][ns] != 0:
name_segments[i] = name_segments[i] + \
"__{}".format(self.scope_name2id[i][ns])
prefix_scope_name = "/".join(name_segments[1 :i + 1])
is_found = False
for j in range(len(self.scope_name_list)):
last_scope_name = self.scope_name_list[-1-j]
if last_scope_name.startswith(prefix_scope_name + "/") \
or last_scope_name == prefix_scope_name:
if j != 0: # and i != len(name_segments) - 1:
is_found = True
origin_name_segment_i = name_segments[i].split("__")[0]
self.scope_name2id[i][origin_name_segment_i] += 1
name_segments[i] = origin_name_segment_i + \
"__" + str(self.scope_name2id[i][origin_name_segment_i])
break
if is_found:
break
real_scope_name = "/".join(name_segments[1:])
self.scope_name_list.append(real_scope_name)
return real_scope_name
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.code_optimizer.hierachical_tree import HierarchicalTree
\ No newline at end of file
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class PamareterNode(object):
def __init__(self, old_name=None, new_name=None):
self.old_name = old_name
self.new_name = new_name
self.childs = list()
def add_child(self, child):
self.childs.append(child)
def has_child(self):
if len(self.childs) == 0:
return False
else:
return True
class PamareterTree(object):
def __init__(self):
self.nodes = list()
self.old2new = dict()
def add_node(self, node):
self.nodes.append(node)
def traverse(self):
tmp = list()
def recurs(node, prefix_name):
for child in node.childs:
child_prefix_name = prefix_name + "." + child.new_name
if child.has_child():
recurs(child, child_prefix_name)
else:
self.old2new[child.old_name] = child_prefix_name[1:]
recurs(self.nodes[-1], "")
def get_node(self, old_name):
for node in self.nodes:
if node.old_name == old_name:
return node
\ No newline at end of file
此差异已折叠。
...@@ -32,3 +32,5 @@ from .reshape_fuser import DygraphReshapeFuser ...@@ -32,3 +32,5 @@ from .reshape_fuser import DygraphReshapeFuser
from .reshape_fuse_pass import DygraphReshapeFusePass from .reshape_fuse_pass import DygraphReshapeFusePass
from .tf_batchnorm_fuser import DygraphTFBatchNormFuser from .tf_batchnorm_fuser import DygraphTFBatchNormFuser
from .tf_batchnorm_fuse_pass import DygraphTFBatchNormFusePass from .tf_batchnorm_fuse_pass import DygraphTFBatchNormFusePass
from .trace_fc_fuser import TraceFcFuser
from .trace_fc_fuse_pass import TraceFcFusePass
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.optimizer.pass_ import Pass
from x2paddle.optimizer.fusion.dygraph import TraceFcFuser
from x2paddle.optimizer.pass_manager import pass_register
@pass_register
class TraceFcFusePass(Pass):
name = "trace_fc_fuse_pass"
def __init__(self):
Pass.__init__(self)
def apply(self, graph):
fuser = TraceFcFuser()
fuser.operate(graph, match_kind="topo")
# 用于注册
trace_fc_fuse_pass = TraceFcFusePass()
\ No newline at end of file
此差异已折叠。
...@@ -18,14 +18,20 @@ from x2paddle.optimizer.fusion.static import * ...@@ -18,14 +18,20 @@ from x2paddle.optimizer.fusion.static import *
from x2paddle.optimizer.elimination.dygraph import * from x2paddle.optimizer.elimination.dygraph import *
class GraphOptimizer(object): class GraphOptimizer(object):
def __init__(self, source_frame, paddle_type="dygraph"): def __init__(self, source_frame, paddle_type="dygraph", jit_type="trace"):
if source_frame == "pytorch": if source_frame == "pytorch":
self.passes = [ if jit_type == "trace":
"dygraph_constant_fuse_pass", "dygraph_batchnorm2d_fuse_pass", self.passes = ["trace_fc_fuse_pass"]
"dygraph_interpolate_bilinear_fuse_pass", "dygraph_fc_fuse_pass", else:
"dygraph_adaptive_pool2d_fuse_pass", "dygraph_reshape_fuse_pass", self.passes = [
"dygraph_dropout_fuse_pass" "dygraph_constant_fuse_pass",
] "dygraph_batchnorm2d_fuse_pass",
"dygraph_interpolate_bilinear_fuse_pass",
"dygraph_fc_fuse_pass",
"dygraph_adaptive_pool2d_fuse_pass",
"dygraph_reshape_fuse_pass",
"dygraph_dropout_fuse_pass"
]
elif source_frame == "caffe": elif source_frame == "caffe":
if paddle_type == "dygraph": if paddle_type == "dygraph":
self.passes = ["dygraph_bn_scale_fuse_pass"] self.passes = ["dygraph_bn_scale_fuse_pass"]
...@@ -38,8 +44,7 @@ class GraphOptimizer(object): ...@@ -38,8 +44,7 @@ class GraphOptimizer(object):
"transpose_eliminate_pass" "transpose_eliminate_pass"
] ]
else: else:
# TODO self.passes = []
pass
def optimize(self, graph): def optimize(self, graph):
for pass_name in self.passes: for pass_name in self.passes:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册