From 0cb8a6669ed04afc526c818bb3907645a16c7a34 Mon Sep 17 00:00:00 2001 From: xuwei06 Date: Fri, 26 May 2017 12:14:07 -0700 Subject: [PATCH] Fix style --- python/paddle/v2/layer.py | 40 +++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/python/paddle/v2/layer.py b/python/paddle/v2/layer.py index ad36364ca8e..5500b8b342a 100644 --- a/python/paddle/v2/layer.py +++ b/python/paddle/v2/layer.py @@ -40,6 +40,7 @@ from paddle.proto.ModelConfig_pb2 import ModelConfig, SubModelConfig __all__ = ['data', 'parse_network'] __layer_map__ = {} + def __wrap__(f): def wrapped(*args, **xargs): out = f(*args, **xargs) @@ -53,6 +54,7 @@ def __wrap__(f): return wrapped + def __need_to_keep__(name): if name in ['StaticInput', 'LayerType', 'layer_support']: return False @@ -99,6 +101,7 @@ def __data_layer__(name, type, **kwargs): l.data_type = type return l + data = __wrap__(__data_layer__) LayerV2 = v1_layers.LayerOutput @@ -107,6 +110,7 @@ LayerV2 = v1_layers.LayerOutput def __get_used_layers__(output_layers, extra_layers=None): layer_names = set() parents = {} + def add_parent(child, parent): if child in parents: parents[child].append(parent) @@ -181,28 +185,25 @@ def __get_used_evaluators__(layer_names): return evaluator_names -def __trim_submodel__(old_submodel, - layer_names, - input_layer_names, - output_layer_names, - evaluator_names): +def __trim_submodel__(old_submodel, layer_names, input_layer_names, + output_layer_names, evaluator_names): submodel = SubModelConfig() submodel.name = old_submodel.name - submodel.layer_names.extend(filter(lambda x: x in layer_names, - old_submodel.layer_names)) - submodel.input_layer_names.extend(filter(lambda x: x in input_layer_names, - submodel.layer_names)) - submodel.output_layer_names.extend(filter(lambda x: x in output_layer_names, - submodel.layer_names)) - submodel.evaluator_names.extend(filter(lambda x: x in evaluator_names, - old_submodel.evaluator_names)) + submodel.layer_names.extend( + filter(lambda x: x in layer_names, old_submodel.layer_names)) + submodel.input_layer_names.extend( + filter(lambda x: x in input_layer_names, submodel.layer_names)) + submodel.output_layer_names.extend( + filter(lambda x: x in output_layer_names, submodel.layer_names)) + submodel.evaluator_names.extend( + filter(lambda x: x in evaluator_names, old_submodel.evaluator_names)) submodel.is_recurrent_layer_group = old_submodel.is_recurrent_layer_group submodel.reversed = old_submodel.reversed - submodel.memories.extend(filter(lambda x: x.link_name in layer_names, - old_submodel.memories)) + submodel.memories.extend( + filter(lambda x: x.link_name in layer_names, old_submodel.memories)) target_inlinkid = (old_submodel.target_inlinkid if old_submodel.HasField('target_inlinkid') else -1) in_links = [] @@ -213,8 +214,8 @@ def __trim_submodel__(old_submodel, target_inlinkid = len(in_links) - 1 submodel.in_links.extend(in_links) - submodel.out_links.extend(filter(lambda x: x.link_name in layer_names, - old_submodel.out_links)) + submodel.out_links.extend( + filter(lambda x: x.link_name in layer_names, old_submodel.out_links)) if old_submodel.HasField('generator'): submodel.generator.CopyFrom(old_submodel.generator) @@ -264,9 +265,8 @@ def parse_network(output_layers, extra_layers=None): for s in cp.g_config.model_config.sub_models: if s.name in submodel_names: - s = __trim_submodel__( - s, layer_names, input_layer_names, output_layer_names, - evaluator_names) + s = __trim_submodel__(s, layer_names, input_layer_names, + output_layer_names, evaluator_names) model_config.sub_models.extend([s]) return model_config -- GitLab