提交 d0ba89e9 编写于 作者: S SunAhong1993

modify util

上级 c71b77ec
...@@ -275,7 +275,7 @@ class PaddleGraph(object): ...@@ -275,7 +275,7 @@ class PaddleGraph(object):
def gen_dygraph_model(self, save_dir, jit_type=None): def gen_dygraph_model(self, save_dir, jit_type=None):
if jit_type == "trace": if jit_type == "trace":
from x2paddle.optimizer.code_optimizer import HierarchicalTree from x2paddle.optimizer.pytorch_code_optimizer import HierarchicalTree
hierarchical_tree = HierarchicalTree(self) hierarchical_tree = HierarchicalTree(self)
for layer_id, layer in self.layers.items(): for layer_id, layer in self.layers.items():
hierarchical_tree.insert(layer) hierarchical_tree.insert(layer)
...@@ -283,7 +283,7 @@ class PaddleGraph(object): ...@@ -283,7 +283,7 @@ class PaddleGraph(object):
self.dump_dygraph_parameter(save_dir) self.dump_dygraph_parameter(save_dir)
else: else:
if self.source_type == "pytorch": if self.source_type == "pytorch":
from x2paddle.optimizer.code_optimizer import ModuleGraph from x2paddle.optimizer.pytorch_code_optimizer import ModuleGraph
module_graph = ModuleGraph(self) module_graph = ModuleGraph(self)
module_graph.save_source_files(save_dir) module_graph.save_source_files(save_dir)
self.dump_dygraph_parameter(save_dir) self.dump_dygraph_parameter(save_dir)
...@@ -347,7 +347,8 @@ class PaddleGraph(object): ...@@ -347,7 +347,8 @@ class PaddleGraph(object):
], ],
indent=1) indent=1)
for layer_id, layer in self.layers.items(): for layer_id, layer in self.layers.items():
remove_default_attrs(layer) if layer.kernel.startswith("paddle"):
remove_default_attrs(layer.kernel, layer.attrs)
edges_in = self.edges_in.get(layer_id, []) edges_in = self.edges_in.get(layer_id, [])
edges_out = self.edges_out.get(layer_id, []) edges_out = self.edges_out.get(layer_id, [])
if len(edges_in) == 0 and len(edges_out) == 0: if len(edges_in) == 0 and len(edges_out) == 0:
...@@ -546,7 +547,8 @@ class PaddleGraph(object): ...@@ -546,7 +547,8 @@ class PaddleGraph(object):
gen_head() gen_head()
for layer_id, layer in self.layers.items(): for layer_id, layer in self.layers.items():
remove_default_attrs(layer) if layer.kernel.startswith("paddle"):
remove_default_attrs(layer.kernel, layer.attrs)
if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel
) or layer.kernel == "paddle.to_tensor" or \ ) or layer.kernel == "paddle.to_tensor" or \
layer.kernel.startswith("custom_layer") or \ layer.kernel.startswith("custom_layer") or \
......
...@@ -28,7 +28,7 @@ def name_generator(nn_name, nn_name2id): ...@@ -28,7 +28,7 @@ def name_generator(nn_name, nn_name2id):
real_nn_name = nn_name + str(nn_name2id[nn_name]) real_nn_name = nn_name + str(nn_name2id[nn_name])
return real_nn_name return real_nn_name
def remove_default_attrs(layer, diff_attrs=None): def remove_default_attrs(kernel, attrs):
def get_default_args(func): def get_default_args(func):
signature = inspect.signature(func) signature = inspect.signature(func)
return { return {
...@@ -36,10 +36,6 @@ def remove_default_attrs(layer, diff_attrs=None): ...@@ -36,10 +36,6 @@ def remove_default_attrs(layer, diff_attrs=None):
for k, v in signature.parameters.items() for k, v in signature.parameters.items()
if v.default is not inspect.Parameter.empty if v.default is not inspect.Parameter.empty
} }
kernel = layer.kernel
attrs = layer.attrs
if ":" in kernel or "prim" in kernel or "module" in kernel:
return
is_func = True is_func = True
if "paddle.nn" in kernel and "functional"not in kernel: if "paddle.nn" in kernel and "functional"not in kernel:
is_func = False is_func = False
...@@ -61,9 +57,4 @@ def remove_default_attrs(layer, diff_attrs=None): ...@@ -61,9 +57,4 @@ def remove_default_attrs(layer, diff_attrs=None):
if len(set(attrs[default_k])) == 1: if len(set(attrs[default_k])) == 1:
attrs[default_k] = attrs[default_k][0] attrs[default_k] = attrs[default_k][0]
if default_v == attrs[default_k]: if default_v == attrs[default_k]:
if diff_attrs is None: attrs.pop(default_k)
attrs.pop(default_k) \ No newline at end of file
else:
key_name = "{}_{}".format(layer.outputs[0], default_k)
if key_name not in diff_attrs:
attrs.pop(default_k)
...@@ -13,5 +13,5 @@ ...@@ -13,5 +13,5 @@
# limitations under the License. # limitations under the License.
from x2paddle.optimizer.code_optimizer.hierachical_tree import HierarchicalTree from x2paddle.optimizer.pytorch_code_optimizer.hierachical_tree import HierarchicalTree
from x2paddle.optimizer.code_optimizer.module_graph import ModuleGraph from x2paddle.optimizer.pytorch_code_optimizer.module_graph import ModuleGraph
\ No newline at end of file \ No newline at end of file
...@@ -18,10 +18,10 @@ import copy ...@@ -18,10 +18,10 @@ import copy
import os.path as osp import os.path as osp
from treelib import Tree from treelib import Tree
from queue import Queue from queue import Queue
from x2paddle.optimizer.code_optimizer.layer_code_generator import gen_layer_code, rename_layers, NN_KERNEL_WITH_PARAMS, NN_KERNEL_NAME from x2paddle.optimizer.pytorch_code_optimizer.layer_code_generator import gen_layer_code, rename_layers, NN_KERNEL_WITH_PARAMS, NN_KERNEL_NAME
from x2paddle.optimizer.code_optimizer.subgraphs_union import distinguish_sequential, get_inputs_outputs from x2paddle.optimizer.pytorch_code_optimizer.subgraphs_union import distinguish_sequential, get_inputs_outputs
from x2paddle.core.program import PaddleLayer from x2paddle.core.program import PaddleLayer
from x2paddle.optimizer.code_optimizer.parameter_tree import PamareterNode, PamareterTree from x2paddle.optimizer.pytorch_code_optimizer.parameter_tree import PamareterNode, PamareterTree
SEPARATOR_IN_SCOPE = "/" SEPARATOR_IN_SCOPE = "/"
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import copy import copy
import os.path as osp import os.path as osp
import x2paddle import x2paddle
from x2paddle.optimizer.code_optimizer.parameter_tree import PamareterNode from x2paddle.optimizer.pytorch_code_optimizer.parameter_tree import PamareterNode
from x2paddle.core.util import * from x2paddle.core.util import *
...@@ -128,7 +128,23 @@ def rename_layers(layers, param_tree=None, is_rename_module=False): ...@@ -128,7 +128,23 @@ def rename_layers(layers, param_tree=None, is_rename_module=False):
return count return count
rename_sub_layers(layers_cp, count) rename_sub_layers(layers_cp, count)
return layers_cp, nn_param_nodes, new_names return layers_cp, nn_param_nodes, new_names
def _update_attrs(layer, different_attrs):
if "module" in layer.kernel or "prim" in layer.kernel:
return
common_attrs = copy.deepcopy(layer.attrs)
special_attrs = dict()
for k, v in layer.attrs.items():
if len(layer.outputs) < 1:
break
key_name = "{}_{}".format(layer.outputs[0], k)
if key_name in different_attrs:
common_attrs.pop(k)
special_attrs[k] = v
remove_default_attrs(layer.kernel, common_attrs)
common_attrs.update(special_attrs)
layer.attrs = common_attrs
def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()): def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
""" 根据sub_layers生成对应的Module代码。 """ 根据sub_layers生成对应的Module代码。
...@@ -224,7 +240,7 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()): ...@@ -224,7 +240,7 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
outputs.append(layer.outputs[0]) outputs.append(layer.outputs[0])
no_output_count = 0 no_output_count = 0
for i, (layer_id, layer) in enumerate(sub_layers.items()): for i, (layer_id, layer) in enumerate(sub_layers.items()):
remove_default_attrs(layer, different_attrs) _update_attrs(layer, different_attrs)
if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel) or \ if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel) or \
layer.kernel.startswith("custom_layer"): layer.kernel.startswith("custom_layer"):
line = "self.{}".format(layer.outputs[0]) line = "self.{}".format(layer.outputs[0])
......
...@@ -17,9 +17,9 @@ import copy ...@@ -17,9 +17,9 @@ import copy
import os import os
import os.path as osp import os.path as osp
from x2paddle.core.program import PaddleLayer from x2paddle.core.program import PaddleLayer
from x2paddle.optimizer.code_optimizer.subgraphs_union import construct_attrs_table, get_inputs_outputs from x2paddle.optimizer.pytorch_code_optimizer.subgraphs_union import construct_attrs_table, get_inputs_outputs
from x2paddle.optimizer.code_optimizer.layer_code_generator import gen_layer_code, rename_layers from x2paddle.optimizer.pytorch_code_optimizer.layer_code_generator import gen_layer_code, rename_layers
from x2paddle.optimizer.code_optimizer.parameter_tree import PamareterNode, PamareterTree from x2paddle.optimizer.pytorch_code_optimizer.parameter_tree import PamareterNode, PamareterTree
NoModuleStart = ["paddle.nn.ReLU"] NoModuleStart = ["paddle.nn.ReLU"]
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import copy import copy
import pandas as pd import pandas as pd
from x2paddle.optimizer.code_optimizer.layer_code_generator import rename_layers from x2paddle.optimizer.pytorch_code_optimizer.layer_code_generator import rename_layers
def construct_attrs_table(sub_layers_list, node_name2sub_layers=None, module_name=None): def construct_attrs_table(sub_layers_list, node_name2sub_layers=None, module_name=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册