提交 1368050f 编写于 作者: S SunAhong1993

fix the bug

上级 458b551d
...@@ -447,6 +447,9 @@ class PaddleGraph(object): ...@@ -447,6 +447,9 @@ class PaddleGraph(object):
if self.source_type == "caffe": if self.source_type == "caffe":
custom_import = "from x2paddle.op_mapper.dygraph.caffe2paddle " + \ custom_import = "from x2paddle.op_mapper.dygraph.caffe2paddle " + \
"import caffe_custom_layer as x2paddle_nn" "import caffe_custom_layer as x2paddle_nn"
elif self.source_type == "pytorch":
custom_import = "from x2paddle.op_mapper.dygraph.pytorch2paddle " + \
"import pytorch_custom_layer as x2paddle_nn"
else: else:
custom_import = "" custom_import = ""
self.head = gen_codes( self.head = gen_codes(
...@@ -455,6 +458,7 @@ class PaddleGraph(object): ...@@ -455,6 +458,7 @@ class PaddleGraph(object):
"from paddle.fluid.param_attr import ParamAttr", "from paddle.fluid.param_attr import ParamAttr",
"import paddle", "import paddle",
"import paddle.fluid as fluid", "import paddle.fluid as fluid",
"import math",
custom_import, custom_import,
"", "",
"class {}(paddle.nn.Layer):".format(self.name), "class {}(paddle.nn.Layer):".format(self.name),
...@@ -590,7 +594,10 @@ class PaddleGraph(object): ...@@ -590,7 +594,10 @@ class PaddleGraph(object):
if isinstance(v, list): if isinstance(v, list):
line += "{}=[{}], ".format(k, ", ".join(v)) line += "{}=[{}], ".format(k, ", ".join(v))
else: else:
line += "{}={}, ".format(k, v) if k == "args":
line += v
else:
line += "{}={}, ".format(k, v)
for k, v in layer.attrs.items(): for k, v in layer.attrs.items():
line += "{}={}, ".format(k, v) line += "{}={}, ".format(k, v)
line = line.strip(", ") line = line.strip(", ")
......
...@@ -21,14 +21,14 @@ import numpy as np ...@@ -21,14 +21,14 @@ import numpy as np
class Decoder(object): class Decoder(object):
def _optimize_graph(self, graph): def _optimize_graph(self, graph):
torch._C._jit_pass_constant_propagation(graph) torch._C._jit_pass_constant_propagation(graph)
torch._C._jit_pass_dce(graph) # torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph) # torch._C._jit_pass_lint(graph)
torch._C._jit_pass_peephole(graph) # torch._C._jit_pass_peephole(graph)
torch._C._jit_pass_lint(graph) # torch._C._jit_pass_lint(graph)
torch._C._jit_pass_dce(graph) # torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph) # torch._C._jit_pass_lint(graph)
torch._C._jit_pass_canonicalize(graph) # torch._C._jit_pass_canonicalize(graph)
torch._C._jit_pass_lint(graph) # torch._C._jit_pass_lint(graph)
torch._C._jit_pass_constant_propagation(graph) torch._C._jit_pass_constant_propagation(graph)
return graph return graph
......
...@@ -752,6 +752,56 @@ def aten_chunk(mapper, graph, node): ...@@ -752,6 +752,56 @@ def aten_chunk(mapper, graph, node):
return current_inputs, current_outputs return current_inputs, current_outputs
def aten_clamp(mapper, graph, node):
""" 构造元素剪裁的PaddleLayer。
TorchScript示例:
%56 : Tensor = aten::clamp(%input.1, %46, %48, %49)
参数含义:
%56 (Tensor): 输出,累加后的结果。
%input.1 (Tensor): 输入,需要剪裁的Tensor。
%46 (float/Tensor): 最小值。
%48 (float/Tensor): 最大值。
"""
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%input.1
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values())
# 处理输入1,即%46
if inputs_name[1] in mapper.attrs:
layer_attrs["min"] = mapper.attrs[inputs_name[1]]
else:
mapper._check_input(graph, inputs_node[1], inputs_name[1],
current_outputs, scope_name)
layer_inputs["min"] = inputs_name[1]
current_inputs.append(inputs_name[1])
# 处理输入2,即%48,代表dtype
if inputs_name[2] in mapper.attrs:
layer_attrs["max"] = mapper.attrs[inputs_name[2]]
else:
mapper._check_input(graph, inputs_node[2], inputs_name[2],
current_outputs, scope_name)
layer_inputs["max"] = inputs_name[2]
current_inputs.append(inputs_name[2])
graph.add_layer(
"paddle.clip",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs
def aten___contains__(mapper, graph, node): def aten___contains__(mapper, graph, node):
""" 构造in的PaddleLayer。 """ 构造in的PaddleLayer。
...@@ -810,7 +860,7 @@ def aten_constant_pad_nd(mapper, graph, node): ...@@ -810,7 +860,7 @@ def aten_constant_pad_nd(mapper, graph, node):
# 处理输入1,即%4876 # 处理输入1,即%4876
layer_attrs["padding"] = mapper.attrs[inputs_name[1]] layer_attrs["padding"] = mapper.attrs[inputs_name[1]]
# 处理输入2,即%42 # 处理输入2,即%42
layer_attrs["pad_value"] = mapper.attrs[inputs_name[2]] layer_attrs["value"] = mapper.attrs[inputs_name[2]]
graph.add_layer( graph.add_layer(
"prim.shape", "prim.shape",
...@@ -856,7 +906,7 @@ def aten_constant_pad_nd(mapper, graph, node): ...@@ -856,7 +906,7 @@ def aten_constant_pad_nd(mapper, graph, node):
block.add_layer( block.add_layer(
kernel, kernel,
inputs={"input": inputs_name[0] + "_var"}, inputs={"input": inputs_name[0] + "_var"},
outputs=layer_outputs, outputs=copy.deepcopy(layer_outputs),
scope_name=scope_name, scope_name=scope_name,
**layer_attrs) **layer_attrs)
block.add_layer( block.add_layer(
...@@ -1517,76 +1567,88 @@ def aten_expand(mapper, graph, node): ...@@ -1517,76 +1567,88 @@ def aten_expand(mapper, graph, node):
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name] layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node) inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%1875 # 处理输入0,即%1875
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%1888 current_inputs = list(layer_inputs.values())
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) # 处理输入1,即%51
if inputs_name[1] in mapper.attrs:
graph.add_layer( layer_attrs["shape"] = mapper.attrs[inputs_name[1]]
"prim.type", else:
inputs={"input": inputs_name[0]}, mapper._check_input(graph, inputs_node[1], inputs_name[1],
outputs=[inputs_name[0] + "_type"], current_outputs, scope_name)
scope_name=scope_name) layer_inputs["shape"] = inputs_name[1]
graph.add_layer( current_inputs.append(inputs_name[1])
"prim.str",
inputs={"input": inputs_name[0] + "_type"},
outputs=[inputs_name[0] + "_type"],
scope_name=scope_name)
graph.add_layer(
"prim.eq",
inputs={"x": inputs_name[0] + "_type"},
outputs=[inputs_name[0] + "_cond"],
scope_name=scope_name,
y=string("VarType.BOOL"))
graph.add_layer( graph.add_layer(
"prim.if", {'input': inputs_name[0] + "_cond"}, "paddle.expand",
outputs=[inputs_name[0] + "_if1", inputs_name[1] + "_var"], inputs=layer_inputs,
scope_name=scope_name) outputs=layer_outputs,
if_layer = graph.layers[list(graph.layers.keys())[-1]]
block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph")
block.add_layer(
"paddle.cast",
inputs={"x": inputs_name[0]},
outputs=[inputs_name[0]],
scope_name=scope_name,
dtype=string("int64"))
block.add_layer(
"self.create_parameter",
inputs={"shape": inputs_name[1]},
outputs=[inputs_name[1] + "_var"],
scope_name=scope_name, scope_name=scope_name,
dtype=string("int64"), **layer_attrs)
default_initializer="paddle.nn.initializer.Constant(value=0.0)")
if_layer.add_block(block)
block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph")
block.add_layer(
"prim.type",
inputs={"input": inputs_name[0]},
outputs=[inputs_name[0] + "_type"],
scope_name=scope_name)
block.add_layer(
"self.create_parameter",
inputs={"shape": inputs_name[1]},
outputs=[inputs_name[1] + "_var"],
scope_name=scope_name,
dtype=inputs_name[0] + "_type",
default_initializer="paddle.nn.initializer.Constant(value=0.0)")
if_layer.add_block(block)
if_layer.inputs["input-0"] = inputs_name[0]
if_layer.inputs["input-1"] = inputs_name[1]
layer_inputs["y"] = inputs_name[1] + "_var" # graph.add_layer(
current_outputs.append(inputs_name[1] + "_var") # "prim.type",
# 获取当前节点输入的list # inputs={"input": inputs_name[0]},
current_inputs = list(layer_inputs.values()) # outputs=[inputs_name[0] + "_type"],
current_inputs.append(inputs_name[1]) # scope_name=scope_name)
# graph.add_layer(
# "prim.str",
# inputs={"input": inputs_name[0] + "_type"},
# outputs=[inputs_name[0] + "_type"],
# scope_name=scope_name)
# graph.add_layer(
# "prim.eq",
# inputs={"x": inputs_name[0] + "_type"},
# outputs=[inputs_name[0] + "_cond"],
# scope_name=scope_name,
# y=string("VarType.BOOL"))
# graph.add_layer(
# "prim.if", {'input': inputs_name[0] + "_cond"},
# outputs=[inputs_name[0] + "_if1", inputs_name[1] + "_var"],
# scope_name=scope_name)
# if_layer = graph.layers[list(graph.layers.keys())[-1]]
# block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph")
# block.add_layer(
# "paddle.cast",
# inputs={"x": inputs_name[0]},
# outputs=[inputs_name[0]],
# scope_name=scope_name,
# dtype=string("int64"))
# block.add_layer(
# "paddle.zeros",
# inputs={"shape": inputs_name[1]},
# outputs=[inputs_name[1] + "_var"],
# scope_name=scope_name,
# dtype=string("int64"))
# if_layer.add_block(block)
# block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph")
# block.add_layer(
# "prim.type",
# inputs={"input": inputs_name[0]},
# outputs=[inputs_name[0] + "_type"],
# scope_name=scope_name)
# block.add_layer(
# "paddle.zeros",
# inputs={"shape": inputs_name[1]},
# outputs=[inputs_name[1] + "_var"],
# scope_name=scope_name,
# dtype=inputs_name[0] + "_type")
# if_layer.add_block(block)
# if_layer.inputs["input-0"] = inputs_name[0]
# if_layer.inputs["input-1"] = inputs_name[1]
graph.add_layer( # layer_inputs["y"] = inputs_name[1] + "_var"
"paddle.expand_as", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) # current_outputs.append(inputs_name[1] + "_var")
# # 获取当前节点输入的list
# current_inputs = list(layer_inputs.values())
# current_inputs.append(inputs_name[1])
# graph.add_layer(
# "paddle.expand_as", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -1841,11 +1903,39 @@ def aten_floor(mapper, graph, node): ...@@ -1841,11 +1903,39 @@ def aten_floor(mapper, graph, node):
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%scale.18 # 处理输入0,即%scale.18
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer(
graph.add_layer("prim.floor", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "prim.type",
{'input': inputs_name[0]},
outputs=[inputs_name[0] + "_type"],
scope_name=scope_name)
graph.add_layer(
"prim.str",
{'input': inputs_name[0] + "_type"},
outputs=[inputs_name[0] + "_type"],
scope_name=scope_name)
graph.add_layer(
"prim.startswith",
{'input': inputs_name[0] + "_type"},
outputs=[inputs_name[0] + "_cond"],
scope_name=scope_name,
start_str=string("VarType"))
graph.add_layer(
"prim.if",
{'input': inputs_name[0] + "_cond"},
outputs=[inputs_name[0] + "_if"],
scope_name=scope_name)
if_layer = graph.layers[list(graph.layers.keys())[-1]]
block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph")
block.add_layer("paddle.floor", inputs=copy.deepcopy(layer_inputs), outputs=copy.deepcopy(layer_outputs), scope_name=scope_name)
if_layer.add_block(block)
block = PaddleGraph(parent_layer=if_layer, graph_type="dygraph")
block.add_layer("prim.floor", inputs=copy.deepcopy(layer_inputs), outputs=copy.deepcopy(layer_outputs), scope_name=scope_name)
if_layer.add_block(block)
if_layer.inputs["input-0"] = inputs_name[0]
if_layer.outputs.append(output_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -1957,6 +2047,46 @@ def aten_full_like(mapper, graph, node): ...@@ -1957,6 +2047,46 @@ def aten_full_like(mapper, graph, node):
return current_inputs, current_outputs return current_inputs, current_outputs
def aten_gather(mapper, graph, node):
""" 构造gather激活的PaddleLayer。
TorchScript示例:
%result.3 : Tensor = aten::gather(%input.5, %18, %19, %20, %21)
参数含义:
%result.3 (Tensor): 输出,gather后的结果。
%result.5 (Tensor): 需要gather的Tensor。
%18 (int): 需要gather的维度。
%19 (Tensor): 需要gather的索引。
"""
scope_name = mapper.normalize_scope_name(node)
op_name = name_generator("gather", mapper.nn_name2id)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [op_name, output_name]
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%result.5
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%18
layer_attrs["dim"] = mapper.attrs[inputs_name[1]]
# 处理输入2,即%19
mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs, scope_name)
layer_inputs["index"] = inputs_name[2]
# 获取当前节点输入的list
current_inputs = list(layer_inputs.values())
graph.add_layer(
"custom_layer:Gather",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs
def aten_gelu(mapper, graph, node): def aten_gelu(mapper, graph, node):
""" 构造GeLU激活的PaddleLayer。 """ 构造GeLU激活的PaddleLayer。
...@@ -2855,6 +2985,33 @@ def aten_mean(mapper, graph, node): ...@@ -2855,6 +2985,33 @@ def aten_mean(mapper, graph, node):
return current_inputs, current_outputs return current_inputs, current_outputs
def aten_meshgrid(mapper, graph, node):
""" 构造对每个张量做扩充操作的PaddleLayer。
TorchScript示例:
%out.39 : int = aten::mshgrid(%input.1)
参数含义:
%out.39 (Tensor): 输出,扩充后的结果。
%input.1 (Tensor): 输入。
"""
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%input.1
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["args"] = inputs_name[0]
# 获取当前节点输入的list
current_inputs = layer_inputs.values()
current_outputs = layer_outputs
graph.add_layer("paddle.meshgrid", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs
def aten_mul(mapper, graph, node): def aten_mul(mapper, graph, node):
""" 构造数值相乘的PaddleLayer。 """ 构造数值相乘的PaddleLayer。
......
...@@ -180,7 +180,7 @@ def prim_float(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di ...@@ -180,7 +180,7 @@ def prim_float(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di
def prim_floor(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_floor(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
line = "{} = math.floor({})".format(layer.outputs[0], line = "{} = math.floor({})".format(layer.outputs[0],
get_value(layer, "input", different_attrs)) get_value(layer, "x", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
...@@ -404,6 +404,13 @@ def prim_slice(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di ...@@ -404,6 +404,13 @@ def prim_slice(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di
get_value(layer, "end", different_attrs), get_value(layer, "end", different_attrs),
get_value(layer, "step", different_attrs)) get_value(layer, "step", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_startswith(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
line = "{} = {}.startswith({})".format(layer.outputs[0],
get_value(layer, "input", different_attrs),
get_value(layer, "start_str", different_attrs))
forward_func.extend(gen_codes([line], indent=indent))
def prim_str(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_str(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
...@@ -451,3 +458,4 @@ def prim_warnings(layer, indent=1, init_func=[], forward_func=[], layer_id=None, ...@@ -451,3 +458,4 @@ def prim_warnings(layer, indent=1, init_func=[], forward_func=[], layer_id=None,
get_value(layer, "input", different_attrs), layer.attrs["stacklevel"]) get_value(layer, "input", different_attrs), layer.attrs["stacklevel"])
lines.append(line) lines.append(line)
forward_func.extend(gen_codes(lines, indent=indent)) forward_func.extend(gen_codes(lines, indent=indent))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .gather import Gather
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
from itertools import product
import numpy as np
class Gather(object):
def __init__(self, dim):
self.dim = dim
def __call__(self, x, index):
out_list = list()
dims = list()
index_shape = index.shape
x_type = x.numpy().dtype
for s in index_shape:
dims.append(list(range(s)))
for id in product(*dims):
id = list(id)
id_tensor = paddle.to_tensor(np.array(id).astype('int32'))
dim_id = paddle.gather_nd(index, id_tensor).numpy()
id[self.dim] = dim_id
id_tensor = paddle.to_tensor(np.array(id).astype('int32'))
data = paddle.gather_nd(x, id_tensor).numpy()
out_list.append(data)
out = paddle.to_tensor(np.array(out_list).astype(x_type))
out = paddle.reshape(out, index_shape)
return out
...@@ -201,7 +201,6 @@ class HierarchicalTree(Tree): ...@@ -201,7 +201,6 @@ class HierarchicalTree(Tree):
code_str = gen_layer_code(self.pd_graph, sub_layers, module_name, code_str = gen_layer_code(self.pd_graph, sub_layers, module_name,
different_attrs=diff_attrs_column) different_attrs=diff_attrs_column)
# print(code_str)
self.codes.append(code_str) self.codes.append(code_str)
for sub_layers in sub_layers_list: for sub_layers in sub_layers_list:
inputs, outputs = get_inputs_outputs(self.pd_graph, sub_layers) inputs, outputs = get_inputs_outputs(self.pd_graph, sub_layers)
...@@ -371,7 +370,12 @@ class HierarchicalTree(Tree): ...@@ -371,7 +370,12 @@ class HierarchicalTree(Tree):
self.update_parameters() self.update_parameters()
import_list = ["import paddle", import_list = ["import paddle",
"import paddle.fluid as fluid", "import paddle.fluid as fluid",
"",] "from paddle.fluid.initializer import Constant",
"from paddle.fluid.param_attr import ParamAttr",
"imort math",
"from x2paddle.op_mapper.dygraph.pytorch2paddle " + \
"import pytorch_custom_layer as x2paddle_nn"
"\n",]
import_str = "\n".join(import_list) import_str = "\n".join(import_list)
if not osp.exists(save_dir): if not osp.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
......
...@@ -29,9 +29,9 @@ NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn", ...@@ -29,9 +29,9 @@ NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn",
"paddle.nn.Tanh": "tanh", "paddle.nn.Tanh": "tanh",
"paddle.nn.AvgPool2D": "pool", "paddle.nn.AvgPool2D": "pool",
"paddle.nn.MaxPool2D": "pool", "paddle.nn.MaxPool2D": "pool",
"paddle.nn.Pad1d": "pad", "paddle.nn.Pad1D": "pad",
"paddle.nn.Pad2d": "pad", "paddle.nn.Pad2D": "pad",
"paddle.nn.Pad3d": "pad", "paddle.nn.Pad3D": "pad",
"paddle.nn.Dropout": "dropout", "paddle.nn.Dropout": "dropout",
"paddle.nn.GELU": "gelu", "paddle.nn.GELU": "gelu",
"paddle.nn.Hardtanh": "tanh", "paddle.nn.Hardtanh": "tanh",
...@@ -175,9 +175,11 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()): ...@@ -175,9 +175,11 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()):
if layer.kernel.startswith("paddle.nn") and index == 0: if layer.kernel.startswith("paddle.nn") and index == 0:
continue continue
if not output_name.startswith("x") or output_name in outputs \ if not output_name.startswith("x") or output_name in outputs \
or layer.kernel == "prim.assert" or \ or layer.kernel == "prim.assert":
layer.kernel == "prim.if" or layer.kernel == "prim.loop":
continue continue
elif layer.kernel == "prim.if" or layer.kernel == "prim.loop":
if index != 0:
outputs.append(output_name)
elif output_name not in outputs: elif output_name not in outputs:
outputs.append(output_name) outputs.append(output_name)
continue continue
...@@ -187,15 +189,22 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()): ...@@ -187,15 +189,22 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()):
if layer.kernel.startswith("paddle.nn") and index == 0 and "functional" not in layer.kernel: if layer.kernel.startswith("paddle.nn") and index == 0 and "functional" not in layer.kernel:
continue continue
if not output_name.startswith("x") or output_name in outputs \ if not output_name.startswith("x") or output_name in outputs \
or layer.kernel == "prim.assert" or \ or layer.kernel == "prim.assert":
layer.kernel == "prim.if" or layer.kernel == "prim.loop":
continue continue
elif layer.kernel == "prim.if" or layer.kernel == "prim.loop":
if index != 0:
outputs.append(output_name)
else: else:
outputs.append(output_name) outputs.append(output_name)
no_output_count = 0 no_output_count = 0
for i, (layer_id, layer) in enumerate(sub_layers.items()): for i, (layer_id, layer) in enumerate(sub_layers.items()):
if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel): if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel) or \
line = "self.{} = {}(".format(layer.outputs[0], layer.kernel) layer.kernel.startswith("custom_layer"):
line = "self.{}".format(layer.outputs[0])
if layer.kernel.startswith("custom_layer"):
line += "= x2paddle_nn.{}(".format(layer.kernel.split(":")[-1])
else:
line += " = {}(".format(layer.kernel)
for k, v in layer.attrs.items(): for k, v in layer.attrs.items():
key_name = "{}_{}".format(layer.outputs[0], k) key_name = "{}_{}".format(layer.outputs[0], k)
if key_name in different_attrs: if key_name in different_attrs:
...@@ -289,7 +298,10 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()): ...@@ -289,7 +298,10 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()):
else: else:
if v not in cur_outputs and v not in inputs: if v not in cur_outputs and v not in inputs:
inputs.append(v) inputs.append(v)
line += "{}={}, ".format(k, v) if k == "args":
line += v
else:
line += "{}={}, ".format(k, v)
for k, v in layer.attrs.items(): for k, v in layer.attrs.items():
key_name = "{}_{}".format(layer.outputs[0], k) key_name = "{}_{}".format(layer.outputs[0], k)
if key_name in different_attrs: if key_name in different_attrs:
......
...@@ -50,21 +50,25 @@ def get_inputs_outputs(pd_graph, layers): ...@@ -50,21 +50,25 @@ def get_inputs_outputs(pd_graph, layers):
for layer_id, layer in layers.items(): for layer_id, layer in layers.items():
# 获取输出节点名字 # 获取输出节点名字
if layer_id not in pd_graph.edges_out: if layer_id not in pd_graph.edges_out:
for output_name in layer.outputs: for index, output_name in enumerate(layer.outputs):
if not output_name.startswith("x") or output_name in outputs \ if not output_name.startswith("x") or output_name in outputs \
or layer.kernel == "prim.assert" or \ or layer.kernel == "prim.assert":
layer.kernel == "prim.if" or layer.kernel == "prim.loop":
continue continue
elif layer.kernel == "prim.if" or layer.kernel == "prim.loop":
if index != 0:
outputs.append(output_name)
elif output_name not in outputs: elif output_name not in outputs:
outputs.append(output_name) outputs.append(output_name)
else: else:
for out_layer_id in pd_graph.edges_out[layer_id]: for out_layer_id in pd_graph.edges_out[layer_id]:
if out_layer_id not in layer_ids: if out_layer_id not in layer_ids:
for output_name in layer.outputs: for index, output_name in enumerate(layer.outputs):
if not output_name.startswith("x") or output_name in outputs \ if not output_name.startswith("x") or output_name in outputs \
or layer.kernel == "prim.assert" or \ or layer.kernel == "prim.assert":
layer.kernel == "prim.if" or layer.kernel == "prim.loop":
continue continue
elif layer.kernel == "prim.if" or layer.kernel == "prim.loop":
if index != 0:
outputs.append(output_name)
else: else:
outputs.append(output_name) outputs.append(output_name)
# 获取输入节点名字 # 获取输入节点名字
......
...@@ -21,7 +21,8 @@ class GraphOptimizer(object): ...@@ -21,7 +21,8 @@ class GraphOptimizer(object):
def __init__(self, source_frame, paddle_type="dygraph", jit_type="trace"): def __init__(self, source_frame, paddle_type="dygraph", jit_type="trace"):
if source_frame == "pytorch": if source_frame == "pytorch":
if jit_type == "trace": if jit_type == "trace":
self.passes = ["trace_fc_fuse_pass"] self.passes = ["dygraph_constant_fuse_pass",
"trace_fc_fuse_pass"]
else: else:
self.passes = [ self.passes = [
"dygraph_constant_fuse_pass", "dygraph_constant_fuse_pass",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册