diff --git a/python/paddle/v2/layer.py b/python/paddle/v2/layer.py index 67b7192bb74b7440567a27df1394a754027f702f..1a0e64ea7701638b83c3a02637e5d0315cf95f11 100644 --- a/python/paddle/v2/layer.py +++ b/python/paddle/v2/layer.py @@ -132,6 +132,13 @@ def __get_used_layers__(output_layers): add_parent(mem.layer_name, mem.boot_layer_name) add_parent(mem.link_name, mem.layer_name) + if sub_model.HasField('generator'): + # according to the implementation of text generation + # in recurrent layer group, the generated word must be + # the first out link + add_parent(sub_model.out_links[0].layer_name, + sub_model.generator.eos_layer_name) + def dfs_travel(layer_name): if layer_name in layer_names: return @@ -175,8 +182,6 @@ def __get_used_submodels__(layer_names): for submodel in cp.g_config.model_config.sub_models: if submodel.name in layer_names: submodel_names.add(submodel.name) - if submodel.is_recurrent_layer_group: - layer_names |= set(submodel.layer_names) return submodel_names