提交 92b5f12b 编写于 作者: S SunAhong1993

add setattr

上级 fb4294b1
...@@ -101,6 +101,11 @@ class PaddleGraph(object): ...@@ -101,6 +101,11 @@ class PaddleGraph(object):
self.clear_edges() self.clear_edges()
outputs_from_nodes = dict() outputs_from_nodes = dict()
for layer_id, layer in self.layers.items(): for layer_id, layer in self.layers.items():
# if "x5109" in layer.outputs or "x5110" in layer.outputs:
# print(layer.kernel)
# print(layer.inputs)
# print(layer.outputs)
# print(layer.attrs)
for input_key, input_var in layer.inputs.items(): for input_key, input_var in layer.inputs.items():
vs = input_var vs = input_var
if not isinstance(vs, list): if not isinstance(vs, list):
......
...@@ -44,7 +44,7 @@ def prim_GetAttr(mapper, graph, node): ...@@ -44,7 +44,7 @@ def prim_GetAttr(mapper, graph, node):
%7 (Tensor): 输入Tensor。 %7 (Tensor): 输入Tensor。
%27 (Tensor): 输入Tensor。 %27 (Tensor): 输入Tensor。
""" """
output_name = mapper._get_outputs_name(node)[0] current_node = node
field_name_list = [node.s('name')] field_name_list = [node.s('name')]
while True: while True:
input_node = list(node.inputs())[0].node() input_node = list(node.inputs())[0].node()
...@@ -53,10 +53,8 @@ def prim_GetAttr(mapper, graph, node): ...@@ -53,10 +53,8 @@ def prim_GetAttr(mapper, graph, node):
node = input_node node = input_node
except Exception: except Exception:
break break
if ".".join(field_name_list) in mapper.pytorch_params: attr_name = ".".join(field_name_list)
mapper.pytorch_params[output_name] = mapper.pytorch_params[".".join( output_name = mapper._get_outputs_name(current_node, attr_name)[0]
field_name_list)]
else:
part_script = mapper.script part_script = mapper.script
for field_name in field_name_list: for field_name in field_name_list:
if hasattr(part_script, field_name): if hasattr(part_script, field_name):
...@@ -295,8 +293,15 @@ def prim_SetAttr(mapper, graph, node): ...@@ -295,8 +293,15 @@ def prim_SetAttr(mapper, graph, node):
field_name_list.append(node.s('name')) field_name_list.append(node.s('name'))
inputs_name, inputs_node = mapper._get_inputs_name(node) inputs_name, inputs_node = mapper._get_inputs_name(node)
param = {"Tensor": inputs_name[1]} param = {
"Tensor": "self." + ".".join(field_name_list).replace(".", "_"),
"parent_layer_id": graph.parent_layer.id
}
mapper.pytorch_params[".".join(field_name_list)] = param mapper.pytorch_params[".".join(field_name_list)] = param
graph.add_layer(
"prim.set_attr",
inputs={"input": inputs_name[1]},
outputs=["self." + ".".join(field_name_list).replace(".", "_")])
return [], [output_name] return [], [output_name]
......
...@@ -208,6 +208,11 @@ def prim_select(layer, indent=1, init_func=[], forward_func=[]): ...@@ -208,6 +208,11 @@ def prim_select(layer, indent=1, init_func=[], forward_func=[]):
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_set_attr(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {}".format(layer.outputs[0], get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_shape(layer, indent=1, init_func=[], forward_func=[]): def prim_shape(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {}.shape".format(layer.outputs[0], get_value(layer, "input")) line = "{} = {}.shape".format(layer.outputs[0], get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
......
...@@ -113,15 +113,19 @@ class PyTorchOpMapper(OpMapper): ...@@ -113,15 +113,19 @@ class PyTorchOpMapper(OpMapper):
graph.set_parameters(self.paddle_params) graph.set_parameters(self.paddle_params)
return graph, graph_inputs return graph, graph_inputs
def _get_outputs_name(self, node): def _get_outputs_name(self, node, attr_name=None):
outputs_name = [] outputs_name = []
for output_ivalue in node.outputs(): for output_ivalue in node.outputs():
output_name = 'x' + str(self.output_index)
script_unique_id = output_ivalue.unique() script_unique_id = output_ivalue.unique()
if attr_name is None:
output_name = 'x' + str(self.output_index)
if script_unique_id in self.outputs_info: if script_unique_id in self.outputs_info:
output_name = self.outputs_info[script_unique_id] output_name = self.outputs_info[script_unique_id]
else:
output_name = attr_name.replace(".", "_")
self.outputs_info[script_unique_id] = output_name self.outputs_info[script_unique_id] = output_name
self.output_index += 1 self.output_index += 1
outputs_name.append(output_name) outputs_name.append(output_name)
# if或loop节点没有输出的情况 # if或loop节点没有输出的情况
if len(list(node.outputs())) == 0: if len(list(node.outputs())) == 0:
...@@ -148,7 +152,33 @@ class PyTorchOpMapper(OpMapper): ...@@ -148,7 +152,33 @@ class PyTorchOpMapper(OpMapper):
outputs=[output_name], outputs=[output_name],
value="params[{}]".format(string(output_name))) value="params[{}]".format(string(output_name)))
else: else:
if isinstance(param, dict) and "Tensor" in param: if isinstance(param, dict) and "Tensor" in param and \
"parent_layer_id" in param:
if graph.parent_layer is not None:
# 当某个param被2个控制流(if-else)赋值时,else不可以引用if中的赋值结果
id1 = param["parent_layer_id"]
id2 = graph.parent_layer.id
id1_part = id1.split(".")
id2_part = id2.split(".")
if len(id1_part) >= len(id2_part):
for i in range(len(id1_part)):
if id1_part[i] == id2_part[i]:
continue
else:
if id1_part[i] == "0" and id2_part[
i] == "1":
if add_dim:
param = param[np.newaxis, :]
self.paddle_params[output_name] = param
graph.add_layer(
"fluid.dygraph.base.to_variable",
inputs={},
outputs=[output_name],
value="params[{}]".format(
string(output_name)))
node_outputs.append(output_name)
return
# 若if-else外,则可直接引用if-else中的赋值结果
graph.add_layer( graph.add_layer(
"prim.constant", "prim.constant",
inputs={}, inputs={},
......
...@@ -20,3 +20,5 @@ from .functional_adaptive_pool2d_fuser import FunctionalAdaptivePool2dFuser ...@@ -20,3 +20,5 @@ from .functional_adaptive_pool2d_fuser import FunctionalAdaptivePool2dFuser
from .functional_adaptive_pool2d_fuse_pass import FunctionalAdaptivePool2dFusePass from .functional_adaptive_pool2d_fuse_pass import FunctionalAdaptivePool2dFusePass
from .constant_fuser import ConstantFuser from .constant_fuser import ConstantFuser
from .constant_fuse_pass import ConstantFusePass from .constant_fuse_pass import ConstantFusePass
from .batchnorm2d_fuser import BatchNorm2dFuser
from .batchnorm2d_fuse_pass import BatchNorm2dFusePass
...@@ -20,7 +20,8 @@ class GraphOptimizer(object): ...@@ -20,7 +20,8 @@ class GraphOptimizer(object):
def __init__(self): def __init__(self):
self.passes = [ self.passes = [
"fc_fuse_pass", "nn_adaptive_pool2d_fuse_pass", "fc_fuse_pass", "nn_adaptive_pool2d_fuse_pass",
"functional_adaptive_pool2d_fuse_pass", "constant_fuse_pass" "functional_adaptive_pool2d_fuse_pass", "batchnorm2d_fuse_pass",
"constant_fuse_pass"
] ]
def optimize(self, graph): def optimize(self, graph):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册