提交 e748c25e 编写于 作者: S SunAhong1993

fix

上级 bf803137
...@@ -89,7 +89,6 @@ class CaffeOpMapper(OpMapper): ...@@ -89,7 +89,6 @@ class CaffeOpMapper(OpMapper):
input_shape.append(last_node.output_shape[idx]) input_shape.append(last_node.output_shape[idx])
node.input_shape = input_shape node.input_shape = input_shape
func_name = 'shape_' + node.layer_type.lower() func_name = 'shape_' + node.layer_type.lower()
if is_fluid_op: if is_fluid_op:
node.output_shape = getattr(caffe_shape, func_name)(node.layer, node.output_shape = getattr(caffe_shape, func_name)(node.layer,
...@@ -458,7 +457,6 @@ class CaffeOpMapper(OpMapper): ...@@ -458,7 +457,6 @@ class CaffeOpMapper(OpMapper):
def ReLU(self, node): def ReLU(self, node):
""" """
:param node: :param node:
:return: :return:
""" """
...@@ -974,5 +972,4 @@ class CaffeOpMapper(OpMapper): ...@@ -974,5 +972,4 @@ class CaffeOpMapper(OpMapper):
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel=op_info, kernel=op_info,
inputs={"x": self.get_input_name(input)}, inputs={"x": self.get_input_name(input)},
outputs=[node.layer_name]) outputs=[node.layer_name])
\ No newline at end of file
\ No newline at end of file
...@@ -201,7 +201,6 @@ class HierarchicalTree(Tree): ...@@ -201,7 +201,6 @@ class HierarchicalTree(Tree):
code_str = gen_layer_code(self.pd_graph, sub_layers, module_name, code_str = gen_layer_code(self.pd_graph, sub_layers, module_name,
different_attrs=diff_attrs_column) different_attrs=diff_attrs_column)
# print(code_str)
self.codes.append(code_str) self.codes.append(code_str)
for sub_layers in sub_layers_list: for sub_layers in sub_layers_list:
inputs, outputs = get_inputs_outputs(self.pd_graph, sub_layers) inputs, outputs = get_inputs_outputs(self.pd_graph, sub_layers)
...@@ -359,7 +358,7 @@ class HierarchicalTree(Tree): ...@@ -359,7 +358,7 @@ class HierarchicalTree(Tree):
run_func_list.append(" # {}: 形状为{},类型为{}。".format(k, v[0], v[1])) run_func_list.append(" # {}: 形状为{},类型为{}。".format(k, v[0], v[1]))
run_func_list.extend( run_func_list.extend(
[" paddle.disable_static()", [" paddle.disable_static()",
" params, _ = fluid.load_dygraph('{}/model')".format(save_dir), " params = paddle.load('{}/model.pdparams')".format(osp.abspath(save_dir)),
" model = {}()".format(self.pd_graph.name), " model = {}()".format(self.pd_graph.name),
" model.set_dict(params)", " model.set_dict(params)",
" model.eval()", " model.eval()",
...@@ -371,7 +370,12 @@ class HierarchicalTree(Tree): ...@@ -371,7 +370,12 @@ class HierarchicalTree(Tree):
self.update_parameters() self.update_parameters()
import_list = ["import paddle", import_list = ["import paddle",
"import paddle.fluid as fluid", "import paddle.fluid as fluid",
"",] "from paddle.fluid.initializer import Constant",
"from paddle.fluid.param_attr import ParamAttr",
"import math",
"from x2paddle.op_mapper.dygraph.pytorch2paddle " + \
"import pytorch_custom_layer as x2paddle_nn"
"\n",]
import_str = "\n".join(import_list) import_str = "\n".join(import_list)
if not osp.exists(save_dir): if not osp.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
......
...@@ -29,9 +29,9 @@ NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn", ...@@ -29,9 +29,9 @@ NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn",
"paddle.nn.Tanh": "tanh", "paddle.nn.Tanh": "tanh",
"paddle.nn.AvgPool2D": "pool", "paddle.nn.AvgPool2D": "pool",
"paddle.nn.MaxPool2D": "pool", "paddle.nn.MaxPool2D": "pool",
"paddle.nn.Pad1d": "pad", "paddle.nn.Pad1D": "pad",
"paddle.nn.Pad2d": "pad", "paddle.nn.Pad2D": "pad",
"paddle.nn.Pad3d": "pad", "paddle.nn.Pad3D": "pad",
"paddle.nn.Dropout": "dropout", "paddle.nn.Dropout": "dropout",
"paddle.nn.GELU": "gelu", "paddle.nn.GELU": "gelu",
"paddle.nn.Hardtanh": "tanh", "paddle.nn.Hardtanh": "tanh",
...@@ -175,9 +175,11 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()): ...@@ -175,9 +175,11 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()):
if layer.kernel.startswith("paddle.nn") and index == 0: if layer.kernel.startswith("paddle.nn") and index == 0:
continue continue
if not output_name.startswith("x") or output_name in outputs \ if not output_name.startswith("x") or output_name in outputs \
or layer.kernel == "prim.assert" or \ or layer.kernel == "prim.assert":
layer.kernel == "prim.if" or layer.kernel == "prim.loop":
continue continue
elif layer.kernel == "prim.if" or layer.kernel == "prim.loop":
if index != 0:
outputs.append(output_name)
elif output_name not in outputs: elif output_name not in outputs:
outputs.append(output_name) outputs.append(output_name)
continue continue
...@@ -187,15 +189,22 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()): ...@@ -187,15 +189,22 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()):
if layer.kernel.startswith("paddle.nn") and index == 0 and "functional" not in layer.kernel: if layer.kernel.startswith("paddle.nn") and index == 0 and "functional" not in layer.kernel:
continue continue
if not output_name.startswith("x") or output_name in outputs \ if not output_name.startswith("x") or output_name in outputs \
or layer.kernel == "prim.assert" or \ or layer.kernel == "prim.assert":
layer.kernel == "prim.if" or layer.kernel == "prim.loop":
continue continue
elif layer.kernel == "prim.if" or layer.kernel == "prim.loop":
if index != 0:
outputs.append(output_name)
else: else:
outputs.append(output_name) outputs.append(output_name)
no_output_count = 0 no_output_count = 0
for i, (layer_id, layer) in enumerate(sub_layers.items()): for i, (layer_id, layer) in enumerate(sub_layers.items()):
if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel): if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel) or \
line = "self.{} = {}(".format(layer.outputs[0], layer.kernel) layer.kernel.startswith("custom_layer"):
line = "self.{}".format(layer.outputs[0])
if layer.kernel.startswith("custom_layer"):
line += "= x2paddle_nn.{}(".format(layer.kernel.split(":")[-1])
else:
line += " = {}(".format(layer.kernel)
for k, v in layer.attrs.items(): for k, v in layer.attrs.items():
key_name = "{}_{}".format(layer.outputs[0], k) key_name = "{}_{}".format(layer.outputs[0], k)
if key_name in different_attrs: if key_name in different_attrs:
...@@ -289,7 +298,10 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()): ...@@ -289,7 +298,10 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()):
else: else:
if v not in cur_outputs and v not in inputs: if v not in cur_outputs and v not in inputs:
inputs.append(v) inputs.append(v)
line += "{}={}, ".format(k, v) if k == "args":
line += v
else:
line += "{}={}, ".format(k, v)
for k, v in layer.attrs.items(): for k, v in layer.attrs.items():
key_name = "{}_{}".format(layer.outputs[0], k) key_name = "{}_{}".format(layer.outputs[0], k)
if key_name in different_attrs: if key_name in different_attrs:
......
...@@ -50,21 +50,25 @@ def get_inputs_outputs(pd_graph, layers): ...@@ -50,21 +50,25 @@ def get_inputs_outputs(pd_graph, layers):
for layer_id, layer in layers.items(): for layer_id, layer in layers.items():
# 获取输出节点名字 # 获取输出节点名字
if layer_id not in pd_graph.edges_out: if layer_id not in pd_graph.edges_out:
for output_name in layer.outputs: for index, output_name in enumerate(layer.outputs):
if not output_name.startswith("x") or output_name in outputs \ if not output_name.startswith("x") or output_name in outputs \
or layer.kernel == "prim.assert" or \ or layer.kernel == "prim.assert":
layer.kernel == "prim.if" or layer.kernel == "prim.loop":
continue continue
elif layer.kernel == "prim.if" or layer.kernel == "prim.loop":
if index != 0:
outputs.append(output_name)
elif output_name not in outputs: elif output_name not in outputs:
outputs.append(output_name) outputs.append(output_name)
else: else:
for out_layer_id in pd_graph.edges_out[layer_id]: for out_layer_id in pd_graph.edges_out[layer_id]:
if out_layer_id not in layer_ids: if out_layer_id not in layer_ids:
for output_name in layer.outputs: for index, output_name in enumerate(layer.outputs):
if not output_name.startswith("x") or output_name in outputs \ if not output_name.startswith("x") or output_name in outputs \
or layer.kernel == "prim.assert" or \ or layer.kernel == "prim.assert":
layer.kernel == "prim.if" or layer.kernel == "prim.loop":
continue continue
elif layer.kernel == "prim.if" or layer.kernel == "prim.loop":
if index != 0:
outputs.append(output_name)
else: else:
outputs.append(output_name) outputs.append(output_name)
# 获取输入节点名字 # 获取输入节点名字
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册