提交 e748c25e 编写于 作者: S SunAhong1993

fix

上级 bf803137
......@@ -89,7 +89,6 @@ class CaffeOpMapper(OpMapper):
input_shape.append(last_node.output_shape[idx])
node.input_shape = input_shape
func_name = 'shape_' + node.layer_type.lower()
if is_fluid_op:
node.output_shape = getattr(caffe_shape, func_name)(node.layer,
......@@ -458,7 +457,6 @@ class CaffeOpMapper(OpMapper):
def ReLU(self, node):
"""
:param node:
:return:
"""
......@@ -974,5 +972,4 @@ class CaffeOpMapper(OpMapper):
self.paddle_graph.add_layer(
kernel=op_info,
inputs={"x": self.get_input_name(input)},
outputs=[node.layer_name])
\ No newline at end of file
outputs=[node.layer_name])
\ No newline at end of file
......@@ -201,7 +201,6 @@ class HierarchicalTree(Tree):
code_str = gen_layer_code(self.pd_graph, sub_layers, module_name,
different_attrs=diff_attrs_column)
# print(code_str)
self.codes.append(code_str)
for sub_layers in sub_layers_list:
inputs, outputs = get_inputs_outputs(self.pd_graph, sub_layers)
......@@ -359,7 +358,7 @@ class HierarchicalTree(Tree):
run_func_list.append(" # {}: 形状为{},类型为{}。".format(k, v[0], v[1]))
run_func_list.extend(
[" paddle.disable_static()",
" params, _ = fluid.load_dygraph('{}/model')".format(save_dir),
" params = paddle.load('{}/model.pdparams')".format(osp.abspath(save_dir)),
" model = {}()".format(self.pd_graph.name),
" model.set_dict(params)",
" model.eval()",
......@@ -371,7 +370,12 @@ class HierarchicalTree(Tree):
self.update_parameters()
import_list = ["import paddle",
"import paddle.fluid as fluid",
"",]
"from paddle.fluid.initializer import Constant",
"from paddle.fluid.param_attr import ParamAttr",
"import math",
"from x2paddle.op_mapper.dygraph.pytorch2paddle " + \
"import pytorch_custom_layer as x2paddle_nn"
"\n",]
import_str = "\n".join(import_list)
if not osp.exists(save_dir):
os.makedirs(save_dir)
......
......@@ -29,9 +29,9 @@ NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn",
"paddle.nn.Tanh": "tanh",
"paddle.nn.AvgPool2D": "pool",
"paddle.nn.MaxPool2D": "pool",
"paddle.nn.Pad1d": "pad",
"paddle.nn.Pad2d": "pad",
"paddle.nn.Pad3d": "pad",
"paddle.nn.Pad1D": "pad",
"paddle.nn.Pad2D": "pad",
"paddle.nn.Pad3D": "pad",
"paddle.nn.Dropout": "dropout",
"paddle.nn.GELU": "gelu",
"paddle.nn.Hardtanh": "tanh",
......@@ -175,9 +175,11 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()):
if layer.kernel.startswith("paddle.nn") and index == 0:
continue
if not output_name.startswith("x") or output_name in outputs \
or layer.kernel == "prim.assert" or \
layer.kernel == "prim.if" or layer.kernel == "prim.loop":
or layer.kernel == "prim.assert":
continue
elif layer.kernel == "prim.if" or layer.kernel == "prim.loop":
if index != 0:
outputs.append(output_name)
elif output_name not in outputs:
outputs.append(output_name)
continue
......@@ -187,15 +189,22 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()):
if layer.kernel.startswith("paddle.nn") and index == 0 and "functional" not in layer.kernel:
continue
if not output_name.startswith("x") or output_name in outputs \
or layer.kernel == "prim.assert" or \
layer.kernel == "prim.if" or layer.kernel == "prim.loop":
or layer.kernel == "prim.assert":
continue
elif layer.kernel == "prim.if" or layer.kernel == "prim.loop":
if index != 0:
outputs.append(output_name)
else:
outputs.append(output_name)
no_output_count = 0
for i, (layer_id, layer) in enumerate(sub_layers.items()):
if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel):
line = "self.{} = {}(".format(layer.outputs[0], layer.kernel)
if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel) or \
layer.kernel.startswith("custom_layer"):
line = "self.{}".format(layer.outputs[0])
if layer.kernel.startswith("custom_layer"):
line += "= x2paddle_nn.{}(".format(layer.kernel.split(":")[-1])
else:
line += " = {}(".format(layer.kernel)
for k, v in layer.attrs.items():
key_name = "{}_{}".format(layer.outputs[0], k)
if key_name in different_attrs:
......@@ -289,7 +298,10 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=list()):
else:
if v not in cur_outputs and v not in inputs:
inputs.append(v)
line += "{}={}, ".format(k, v)
if k == "args":
line += v
else:
line += "{}={}, ".format(k, v)
for k, v in layer.attrs.items():
key_name = "{}_{}".format(layer.outputs[0], k)
if key_name in different_attrs:
......
......@@ -50,21 +50,25 @@ def get_inputs_outputs(pd_graph, layers):
for layer_id, layer in layers.items():
# 获取输出节点名字
if layer_id not in pd_graph.edges_out:
for output_name in layer.outputs:
for index, output_name in enumerate(layer.outputs):
if not output_name.startswith("x") or output_name in outputs \
or layer.kernel == "prim.assert" or \
layer.kernel == "prim.if" or layer.kernel == "prim.loop":
or layer.kernel == "prim.assert":
continue
elif layer.kernel == "prim.if" or layer.kernel == "prim.loop":
if index != 0:
outputs.append(output_name)
elif output_name not in outputs:
outputs.append(output_name)
else:
for out_layer_id in pd_graph.edges_out[layer_id]:
if out_layer_id not in layer_ids:
for output_name in layer.outputs:
for index, output_name in enumerate(layer.outputs):
if not output_name.startswith("x") or output_name in outputs \
or layer.kernel == "prim.assert" or \
layer.kernel == "prim.if" or layer.kernel == "prim.loop":
or layer.kernel == "prim.assert":
continue
elif layer.kernel == "prim.if" or layer.kernel == "prim.loop":
if index != 0:
outputs.append(output_name)
else:
outputs.append(output_name)
# 获取输入节点名字
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册