提交 d0ba89e9 编写于 作者: S SunAhong1993

modify util

上级 c71b77ec
......@@ -275,7 +275,7 @@ class PaddleGraph(object):
def gen_dygraph_model(self, save_dir, jit_type=None):
if jit_type == "trace":
from x2paddle.optimizer.code_optimizer import HierarchicalTree
from x2paddle.optimizer.pytorch_code_optimizer import HierarchicalTree
hierarchical_tree = HierarchicalTree(self)
for layer_id, layer in self.layers.items():
hierarchical_tree.insert(layer)
......@@ -283,7 +283,7 @@ class PaddleGraph(object):
self.dump_dygraph_parameter(save_dir)
else:
if self.source_type == "pytorch":
from x2paddle.optimizer.code_optimizer import ModuleGraph
from x2paddle.optimizer.pytorch_code_optimizer import ModuleGraph
module_graph = ModuleGraph(self)
module_graph.save_source_files(save_dir)
self.dump_dygraph_parameter(save_dir)
......@@ -347,7 +347,8 @@ class PaddleGraph(object):
],
indent=1)
for layer_id, layer in self.layers.items():
remove_default_attrs(layer)
if layer.kernel.startswith("paddle"):
remove_default_attrs(layer.kernel, layer.attrs)
edges_in = self.edges_in.get(layer_id, [])
edges_out = self.edges_out.get(layer_id, [])
if len(edges_in) == 0 and len(edges_out) == 0:
......@@ -546,7 +547,8 @@ class PaddleGraph(object):
gen_head()
for layer_id, layer in self.layers.items():
remove_default_attrs(layer)
if layer.kernel.startswith("paddle"):
remove_default_attrs(layer.kernel, layer.attrs)
if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel
) or layer.kernel == "paddle.to_tensor" or \
layer.kernel.startswith("custom_layer") or \
......
......@@ -28,7 +28,7 @@ def name_generator(nn_name, nn_name2id):
real_nn_name = nn_name + str(nn_name2id[nn_name])
return real_nn_name
def remove_default_attrs(layer, diff_attrs=None):
def remove_default_attrs(kernel, attrs):
def get_default_args(func):
signature = inspect.signature(func)
return {
......@@ -36,10 +36,6 @@ def remove_default_attrs(layer, diff_attrs=None):
for k, v in signature.parameters.items()
if v.default is not inspect.Parameter.empty
}
kernel = layer.kernel
attrs = layer.attrs
if ":" in kernel or "prim" in kernel or "module" in kernel:
return
is_func = True
if "paddle.nn" in kernel and "functional"not in kernel:
is_func = False
......@@ -61,9 +57,4 @@ def remove_default_attrs(layer, diff_attrs=None):
if len(set(attrs[default_k])) == 1:
attrs[default_k] = attrs[default_k][0]
if default_v == attrs[default_k]:
if diff_attrs is None:
attrs.pop(default_k)
else:
key_name = "{}_{}".format(layer.outputs[0], default_k)
if key_name not in diff_attrs:
attrs.pop(default_k)
attrs.pop(default_k)
\ No newline at end of file
......@@ -13,5 +13,5 @@
# limitations under the License.
from x2paddle.optimizer.code_optimizer.hierachical_tree import HierarchicalTree
from x2paddle.optimizer.code_optimizer.module_graph import ModuleGraph
\ No newline at end of file
from x2paddle.optimizer.pytorch_code_optimizer.hierachical_tree import HierarchicalTree
from x2paddle.optimizer.pytorch_code_optimizer.module_graph import ModuleGraph
\ No newline at end of file
......@@ -18,10 +18,10 @@ import copy
import os.path as osp
from treelib import Tree
from queue import Queue
from x2paddle.optimizer.code_optimizer.layer_code_generator import gen_layer_code, rename_layers, NN_KERNEL_WITH_PARAMS, NN_KERNEL_NAME
from x2paddle.optimizer.code_optimizer.subgraphs_union import distinguish_sequential, get_inputs_outputs
from x2paddle.optimizer.pytorch_code_optimizer.layer_code_generator import gen_layer_code, rename_layers, NN_KERNEL_WITH_PARAMS, NN_KERNEL_NAME
from x2paddle.optimizer.pytorch_code_optimizer.subgraphs_union import distinguish_sequential, get_inputs_outputs
from x2paddle.core.program import PaddleLayer
from x2paddle.optimizer.code_optimizer.parameter_tree import PamareterNode, PamareterTree
from x2paddle.optimizer.pytorch_code_optimizer.parameter_tree import PamareterNode, PamareterTree
SEPARATOR_IN_SCOPE = "/"
......
......@@ -16,7 +16,7 @@
import copy
import os.path as osp
import x2paddle
from x2paddle.optimizer.code_optimizer.parameter_tree import PamareterNode
from x2paddle.optimizer.pytorch_code_optimizer.parameter_tree import PamareterNode
from x2paddle.core.util import *
......@@ -128,7 +128,23 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
return count
rename_sub_layers(layers_cp, count)
return layers_cp, nn_param_nodes, new_names
def _update_attrs(layer, different_attrs):
if "module" in layer.kernel or "prim" in layer.kernel:
return
common_attrs = copy.deepcopy(layer.attrs)
special_attrs = dict()
for k, v in layer.attrs.items():
if len(layer.outputs) < 1:
break
key_name = "{}_{}".format(layer.outputs[0], k)
if key_name in different_attrs:
common_attrs.pop(k)
special_attrs[k] = v
remove_default_attrs(layer.kernel, common_attrs)
common_attrs.update(special_attrs)
layer.attrs = common_attrs
def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
""" 根据sub_layers生成对应的Module代码。
......@@ -224,7 +240,7 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
outputs.append(layer.outputs[0])
no_output_count = 0
for i, (layer_id, layer) in enumerate(sub_layers.items()):
remove_default_attrs(layer, different_attrs)
_update_attrs(layer, different_attrs)
if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel) or \
layer.kernel.startswith("custom_layer"):
line = "self.{}".format(layer.outputs[0])
......
......@@ -17,9 +17,9 @@ import copy
import os
import os.path as osp
from x2paddle.core.program import PaddleLayer
from x2paddle.optimizer.code_optimizer.subgraphs_union import construct_attrs_table, get_inputs_outputs
from x2paddle.optimizer.code_optimizer.layer_code_generator import gen_layer_code, rename_layers
from x2paddle.optimizer.code_optimizer.parameter_tree import PamareterNode, PamareterTree
from x2paddle.optimizer.pytorch_code_optimizer.subgraphs_union import construct_attrs_table, get_inputs_outputs
from x2paddle.optimizer.pytorch_code_optimizer.layer_code_generator import gen_layer_code, rename_layers
from x2paddle.optimizer.pytorch_code_optimizer.parameter_tree import PamareterNode, PamareterTree
NoModuleStart = ["paddle.nn.ReLU"]
......
......@@ -16,7 +16,7 @@
import copy
import pandas as pd
from x2paddle.optimizer.code_optimizer.layer_code_generator import rename_layers
from x2paddle.optimizer.pytorch_code_optimizer.layer_code_generator import rename_layers
def construct_attrs_table(sub_layers_list, node_name2sub_layers=None, module_name=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册