提交 0cb8a666 编写于 作者: X xuwei06

Fix style

上级 7d0355cd
......@@ -40,6 +40,7 @@ from paddle.proto.ModelConfig_pb2 import ModelConfig, SubModelConfig
__all__ = ['data', 'parse_network']
__layer_map__ = {}
def __wrap__(f):
def wrapped(*args, **xargs):
out = f(*args, **xargs)
......@@ -53,6 +54,7 @@ def __wrap__(f):
return wrapped
def __need_to_keep__(name):
if name in ['StaticInput', 'LayerType', 'layer_support']:
return False
......@@ -99,6 +101,7 @@ def __data_layer__(name, type, **kwargs):
l.data_type = type
return l
data = __wrap__(__data_layer__)
LayerV2 = v1_layers.LayerOutput
......@@ -107,6 +110,7 @@ LayerV2 = v1_layers.LayerOutput
def __get_used_layers__(output_layers, extra_layers=None):
layer_names = set()
parents = {}
def add_parent(child, parent):
if child in parents:
parents[child].append(parent)
......@@ -181,28 +185,25 @@ def __get_used_evaluators__(layer_names):
return evaluator_names
def __trim_submodel__(old_submodel,
layer_names,
input_layer_names,
output_layer_names,
evaluator_names):
def __trim_submodel__(old_submodel, layer_names, input_layer_names,
output_layer_names, evaluator_names):
submodel = SubModelConfig()
submodel.name = old_submodel.name
submodel.layer_names.extend(filter(lambda x: x in layer_names,
old_submodel.layer_names))
submodel.input_layer_names.extend(filter(lambda x: x in input_layer_names,
submodel.layer_names))
submodel.output_layer_names.extend(filter(lambda x: x in output_layer_names,
submodel.layer_names))
submodel.evaluator_names.extend(filter(lambda x: x in evaluator_names,
old_submodel.evaluator_names))
submodel.layer_names.extend(
filter(lambda x: x in layer_names, old_submodel.layer_names))
submodel.input_layer_names.extend(
filter(lambda x: x in input_layer_names, submodel.layer_names))
submodel.output_layer_names.extend(
filter(lambda x: x in output_layer_names, submodel.layer_names))
submodel.evaluator_names.extend(
filter(lambda x: x in evaluator_names, old_submodel.evaluator_names))
submodel.is_recurrent_layer_group = old_submodel.is_recurrent_layer_group
submodel.reversed = old_submodel.reversed
submodel.memories.extend(filter(lambda x: x.link_name in layer_names,
old_submodel.memories))
submodel.memories.extend(
filter(lambda x: x.link_name in layer_names, old_submodel.memories))
target_inlinkid = (old_submodel.target_inlinkid
if old_submodel.HasField('target_inlinkid') else -1)
in_links = []
......@@ -213,8 +214,8 @@ def __trim_submodel__(old_submodel,
target_inlinkid = len(in_links) - 1
submodel.in_links.extend(in_links)
submodel.out_links.extend(filter(lambda x: x.link_name in layer_names,
old_submodel.out_links))
submodel.out_links.extend(
filter(lambda x: x.link_name in layer_names, old_submodel.out_links))
if old_submodel.HasField('generator'):
submodel.generator.CopyFrom(old_submodel.generator)
......@@ -264,9 +265,8 @@ def parse_network(output_layers, extra_layers=None):
for s in cp.g_config.model_config.sub_models:
if s.name in submodel_names:
s = __trim_submodel__(
s, layer_names, input_layer_names, output_layer_names,
evaluator_names)
s = __trim_submodel__(s, layer_names, input_layer_names,
output_layer_names, evaluator_names)
model_config.sub_models.extend([s])
return model_config
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册