提交 0cb8a666 编写于 作者: X xuwei06

Fix style

上级 7d0355cd
...@@ -40,6 +40,7 @@ from paddle.proto.ModelConfig_pb2 import ModelConfig, SubModelConfig ...@@ -40,6 +40,7 @@ from paddle.proto.ModelConfig_pb2 import ModelConfig, SubModelConfig
__all__ = ['data', 'parse_network'] __all__ = ['data', 'parse_network']
__layer_map__ = {} __layer_map__ = {}
def __wrap__(f): def __wrap__(f):
def wrapped(*args, **xargs): def wrapped(*args, **xargs):
out = f(*args, **xargs) out = f(*args, **xargs)
...@@ -53,6 +54,7 @@ def __wrap__(f): ...@@ -53,6 +54,7 @@ def __wrap__(f):
return wrapped return wrapped
def __need_to_keep__(name): def __need_to_keep__(name):
if name in ['StaticInput', 'LayerType', 'layer_support']: if name in ['StaticInput', 'LayerType', 'layer_support']:
return False return False
...@@ -99,6 +101,7 @@ def __data_layer__(name, type, **kwargs): ...@@ -99,6 +101,7 @@ def __data_layer__(name, type, **kwargs):
l.data_type = type l.data_type = type
return l return l
data = __wrap__(__data_layer__) data = __wrap__(__data_layer__)
LayerV2 = v1_layers.LayerOutput LayerV2 = v1_layers.LayerOutput
...@@ -107,6 +110,7 @@ LayerV2 = v1_layers.LayerOutput ...@@ -107,6 +110,7 @@ LayerV2 = v1_layers.LayerOutput
def __get_used_layers__(output_layers, extra_layers=None): def __get_used_layers__(output_layers, extra_layers=None):
layer_names = set() layer_names = set()
parents = {} parents = {}
def add_parent(child, parent): def add_parent(child, parent):
if child in parents: if child in parents:
parents[child].append(parent) parents[child].append(parent)
...@@ -181,28 +185,25 @@ def __get_used_evaluators__(layer_names): ...@@ -181,28 +185,25 @@ def __get_used_evaluators__(layer_names):
return evaluator_names return evaluator_names
def __trim_submodel__(old_submodel, def __trim_submodel__(old_submodel, layer_names, input_layer_names,
layer_names, output_layer_names, evaluator_names):
input_layer_names,
output_layer_names,
evaluator_names):
submodel = SubModelConfig() submodel = SubModelConfig()
submodel.name = old_submodel.name submodel.name = old_submodel.name
submodel.layer_names.extend(filter(lambda x: x in layer_names, submodel.layer_names.extend(
old_submodel.layer_names)) filter(lambda x: x in layer_names, old_submodel.layer_names))
submodel.input_layer_names.extend(filter(lambda x: x in input_layer_names, submodel.input_layer_names.extend(
submodel.layer_names)) filter(lambda x: x in input_layer_names, submodel.layer_names))
submodel.output_layer_names.extend(filter(lambda x: x in output_layer_names, submodel.output_layer_names.extend(
submodel.layer_names)) filter(lambda x: x in output_layer_names, submodel.layer_names))
submodel.evaluator_names.extend(filter(lambda x: x in evaluator_names, submodel.evaluator_names.extend(
old_submodel.evaluator_names)) filter(lambda x: x in evaluator_names, old_submodel.evaluator_names))
submodel.is_recurrent_layer_group = old_submodel.is_recurrent_layer_group submodel.is_recurrent_layer_group = old_submodel.is_recurrent_layer_group
submodel.reversed = old_submodel.reversed submodel.reversed = old_submodel.reversed
submodel.memories.extend(filter(lambda x: x.link_name in layer_names, submodel.memories.extend(
old_submodel.memories)) filter(lambda x: x.link_name in layer_names, old_submodel.memories))
target_inlinkid = (old_submodel.target_inlinkid target_inlinkid = (old_submodel.target_inlinkid
if old_submodel.HasField('target_inlinkid') else -1) if old_submodel.HasField('target_inlinkid') else -1)
in_links = [] in_links = []
...@@ -213,8 +214,8 @@ def __trim_submodel__(old_submodel, ...@@ -213,8 +214,8 @@ def __trim_submodel__(old_submodel,
target_inlinkid = len(in_links) - 1 target_inlinkid = len(in_links) - 1
submodel.in_links.extend(in_links) submodel.in_links.extend(in_links)
submodel.out_links.extend(filter(lambda x: x.link_name in layer_names, submodel.out_links.extend(
old_submodel.out_links)) filter(lambda x: x.link_name in layer_names, old_submodel.out_links))
if old_submodel.HasField('generator'): if old_submodel.HasField('generator'):
submodel.generator.CopyFrom(old_submodel.generator) submodel.generator.CopyFrom(old_submodel.generator)
...@@ -264,9 +265,8 @@ def parse_network(output_layers, extra_layers=None): ...@@ -264,9 +265,8 @@ def parse_network(output_layers, extra_layers=None):
for s in cp.g_config.model_config.sub_models: for s in cp.g_config.model_config.sub_models:
if s.name in submodel_names: if s.name in submodel_names:
s = __trim_submodel__( s = __trim_submodel__(s, layer_names, input_layer_names,
s, layer_names, input_layer_names, output_layer_names, output_layer_names, evaluator_names)
evaluator_names)
model_config.sub_models.extend([s]) model_config.sub_models.extend([s])
return model_config return model_config
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册