From 7d0355cdd060149ba466920521c2d9aed321bb23 Mon Sep 17 00:00:00 2001 From: xuwei06 Date: Fri, 26 May 2017 02:29:58 -0700 Subject: [PATCH] Fix V2 API --- paddle/parameter/Parameter.h | 1 + python/paddle/trainer/config_parser.py | 31 +- .../config_parser_utils.py | 22 +- .../paddle/trainer_config_helpers/layers.py | 6 + python/paddle/v2/evaluator.py | 14 +- python/paddle/v2/layer.py | 746 +++++------------- python/paddle/v2/networks.py | 15 +- python/paddle/v2/tests/test_layer.py | 18 +- python/paddle/v2/tests/test_rnn_layer.py | 7 + python/paddle/v2/tests/test_topology.py | 12 +- python/paddle/v2/topology.py | 29 +- 11 files changed, 278 insertions(+), 623 deletions(-) diff --git a/paddle/parameter/Parameter.h b/paddle/parameter/Parameter.h index d77486ce42..0bac76f068 100644 --- a/paddle/parameter/Parameter.h +++ b/paddle/parameter/Parameter.h @@ -324,6 +324,7 @@ protected: std::vector> updaterHooks_; public: + void setSharedCount(int cnt) { sharedCount_ = cnt; } int getSharedCount() { return sharedCount_; } bool isSparse() { return config_.is_sparse(); } diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 9fe8794691..56e1ba170d 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3371,7 +3371,7 @@ def make_importer(config_dir, config_args): return Import -settings = dict( +default_settings = dict( batch_size=None, mini_batch_size=None, algorithm='async_sgd', @@ -3404,6 +3404,8 @@ settings = dict( adam_beta2=0.999, adam_epsilon=1e-8, ) +settings = copy.deepcopy(default_settings) + settings_deprecated = dict(usage_ratio=1., ) trainer_settings = dict( @@ -3544,10 +3546,8 @@ def update_g_config(): return g_config -def parse_config(trainer_config, config_arg_str): +def begin_parse(config_arg_str=''): ''' - @param trainer_config: can be a string of config file name or a function name - with config logic @param config_arg_str: a string of the form var1=val1,var2=val2. It will be passed to config script as a dictionary CONFIG_ARGS ''' @@ -3555,12 +3555,23 @@ def parse_config(trainer_config, config_arg_str): for hook in _parse_config_hooks: hook() - config_args = {} - logger.findCaller = find_caller logger.fatal = my_fatal g_config.model_config.type = "nn" + + global g_current_submodel, g_root_submodel + g_root_submodel = g_config.model_config.sub_models.add() + g_root_submodel.name = 'root' + g_root_submodel.is_recurrent_layer_group = False + g_current_submodel = g_root_submodel + + +def parse_config(trainer_config, config_arg_str): + begin_parse(config_arg_str) + + config_args = {} + if config_arg_str: config_args = dict([f.split('=') for f in config_arg_str.split(',')]) @@ -3573,14 +3584,6 @@ def parse_config(trainer_config, config_arg_str): extension_module = importlib(extension_module_name) g_extended_config_funcs = extension_module.get_config_funcs(g_config) - g_config.model_config.type = 'nn' - - global g_current_submodel, g_root_submodel - g_root_submodel = g_config.model_config.sub_models.add() - g_root_submodel.name = 'root' - g_root_submodel.is_recurrent_layer_group = False - g_current_submodel = g_root_submodel - if hasattr(trainer_config, '__call__'): trainer_config.func_globals.update( make_config_environment("", config_args)) diff --git a/python/paddle/trainer_config_helpers/config_parser_utils.py b/python/paddle/trainer_config_helpers/config_parser_utils.py index 681b177a55..9e56556d0a 100644 --- a/python/paddle/trainer_config_helpers/config_parser_utils.py +++ b/python/paddle/trainer_config_helpers/config_parser_utils.py @@ -12,15 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import paddle.trainer.config_parser as config_parser +from paddle.proto.TrainerConfig_pb2 import OptimizationConfig + ''' -This file is a wrapper of formal config_parser. The main idea of this file is to +This file is a wrapper of formal config_parser. The main idea of this file is to separete different config logic into different function, such as network configuration and optimizer configuration. ''' __all__ = [ - "parse_trainer_config", "parse_network_config", "parse_optimizer_config" + "parse_trainer_config", "parse_network_config", "parse_optimizer_config", + "reset_parser" ] @@ -34,5 +38,15 @@ def parse_network_config(network_conf, config_arg_str=''): def parse_optimizer_config(optimizer_conf, config_arg_str=''): - config = config_parser.parse_config(optimizer_conf, config_arg_str) - return config.opt_config + config_parser.settings = copy.deepcopy(config_parser.default_settings) + optimizer_conf() + opt_config = OptimizationConfig() + for k, v in config_parser.settings.iteritems(): + if v is None: + continue + opt_config.__setattr__(k, v) + return opt_config + + +def reset_parser(): + config_parser.begin_parse() diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index ec81e1dc3d..08d80b3b52 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -285,6 +285,7 @@ class LayerOutput(object): assert size is not None assert LayerType.is_layer_type(layer_type) self.name = name + self.full_name = MakeLayerNameInSubmodel(name) self.layer_type = layer_type if parents is not None and type(parents) != list: parents = [parents] @@ -3489,6 +3490,11 @@ def recurrent_group(step, RecurrentLayerGroupEnd(name=name) + for layer_out in layer_outs: + # Thee previous full_name is the name is the rnn group + # We need a full_name outside the rnn group + layer_out.full_name = MakeLayerNameInSubmodel(layer_out.name) + if len(layer_outs) == 1: return layer_outs[0] else: diff --git a/python/paddle/v2/evaluator.py b/python/paddle/v2/evaluator.py index 588eefa391..c474f74235 100644 --- a/python/paddle/v2/evaluator.py +++ b/python/paddle/v2/evaluator.py @@ -25,21 +25,9 @@ def initialize(): for __ev_name__ in filter(lambda x: x.endswith('_evaluator'), evs.__all__): __ev__ = getattr(evs, __ev_name__) - if hasattr(__ev__, 'argspec'): - argspec = __ev__.argspec - else: - argspec = inspect.getargspec(__ev__) - parent_names = filter(lambda x: x in ['input', 'label', 'weight'], - argspec.args) - v2_ev = __convert_to_v2__( - __ev_name__, - parent_names=parent_names, - is_default_name='name' in argspec.args, - attach_parent=True) - __new_name__ = convert_to_new_name(__ev_name__) - globals()[__new_name__] = v2_ev + globals()[__new_name__] = __ev__ globals()[__new_name__].__name__ = __new_name__ __all__.append(__new_name__) diff --git a/python/paddle/v2/layer.py b/python/paddle/v2/layer.py index 919c531d18..ad36364ca8 100644 --- a/python/paddle/v2/layer.py +++ b/python/paddle/v2/layer.py @@ -32,392 +32,39 @@ The primary usage shows below. """ import collections -import inspect -import re - -import paddle.trainer_config_helpers as conf_helps -from paddle.trainer.config_parser import \ - RecurrentLayerGroupWithoutOutLinksBegin, RecurrentLayerGroupSetOutLink, \ - RecurrentLayerGroupEnd, model_type -from paddle.trainer_config_helpers.config_parser_utils import \ - parse_network_config as __parse__ -from paddle.trainer_config_helpers.default_decorators import wrap_act_default -from paddle.trainer_config_helpers.default_decorators import \ - wrap_bias_attr_default -from paddle.trainer_config_helpers.default_decorators import wrap_name_default -from paddle.trainer_config_helpers.layers import RecurrentLayerGroupSetGenerator, Generator -from paddle.trainer_config_helpers.layers import layer_support - -import activation -import attr -import data_type -from config_base import Layer, __convert_to_v2__ - -__all__ = ['parse_network', 'data'] +import copy +import paddle.trainer_config_helpers.layers as v1_layers +import paddle.trainer.config_parser as cp +from paddle.proto.ModelConfig_pb2 import ModelConfig, SubModelConfig +__all__ = ['data', 'parse_network'] +__layer_map__ = {} -def parse_network(output_layers, extra_layers=None): - """ - Parse all layers in the neural network graph and - then generate a ModelConfig object. - - .. note:: - - This function is used internally in paddle.v2 module. User should never - invoke this method. - - :param output_layers: Output layers. - :type output_layers: Layer - :param extra_layers: Some layers in the neural network graph are not in the - path of output_layers. - :type extra_layers: Layer - :return: A ModelConfig object instance. - :rtype: ModelConfig - """ - if not isinstance(output_layers, collections.Sequence): - output_layers = [output_layers] - if extra_layers is not None and not isinstance(extra_layers, - collections.Sequence): - extra_layers = [extra_layers] +def __wrap__(f): + def wrapped(*args, **xargs): + out = f(*args, **xargs) + outs = out + if not isinstance(out, collections.Sequence): + outs = [out] + for l in outs: + if isinstance(l, v1_layers.LayerOutput): + __layer_map__[l.full_name] = l + return out - def __real_func__(): - """ - __real_func__ is the function that config_parser.parse invoked. It is - the plain old paddle configuration function. - """ - context = dict() - real_output = [each.to_proto(context=context) for each in output_layers] - if extra_layers is not None: - extra_output = [ - each.to_proto(context=context) for each in extra_layers - ] - conf_helps.outputs(real_output) + return wrapped - return __parse__(__real_func__) +def __need_to_keep__(name): + if name in ['StaticInput', 'LayerType', 'layer_support']: + return False + return True -""" -Some layer may need some special config, and can not use __convert_to_v2__ to convert. -So we also need to implement some special LayerV2. -""" +def __need_to_wrap__(name): + return name not in ['AggregateLevel', 'ExpandLevel'] -class DataLayerV2(Layer): - METHOD_NAME = 'data_layer' - - def __init__(self, name, type, **kwargs): - assert isinstance(type, data_type.InputType) - - self.type = type - self.__method_name__ = 'data_layer' - self.__kwargs__ = kwargs - - super(DataLayerV2, self).__init__(name=name, parent_layers=dict()) - - def to_proto_impl(self, **kwargs): - args = dict() - args['size'] = self.type.dim - for each in kwargs: - args[each] = kwargs[each] - for each in self.__kwargs__: - args[each] = self.__kwargs__[each] - return getattr(conf_helps, self.__method_name__)(name=self.name, **args) - - def __map_docstr__(doc): - doc = re.sub(r'(data = [^\)]+)\).*', - "data = paddle.layer.data(name=\"input\", " - "type=paddle.data_type.dense_vector(1000))", doc) - - doc = re.sub(r':param size:.*', - ':param type: Data type of this data layer', doc) - doc = re.sub(r':type size:.*', - ":type size: paddle.v2.data_type.InputType", doc) - return doc - - -class MemoryV2(Layer): - def __init__(self, name, extra_input=None, **kwargs): - """ - Init memory object, if memory is inited inside recurrent_group step - function, it may depend on a boot_layer that should be initialized - outside recurrent_group, so we: - 1. add RecurrentLayerInput to extra_parent of self. - 2. add boot_layer to the extra_parent of RecurrentLayerInput. - - :param extra_input: list of RecurrentLayerInput - :type extra_input: [RecurrentLayerInput] - """ - self.name = name - super(MemoryV2, self).__init__(name=name, parent_layers=dict()) - self.__kwargs__ = kwargs - self.__boot_layer_name__ = None - - if 'boot_layer' in kwargs: - begin_of_current_rnn = [] - # TODO(yuyang18): Fix inspect, it could be wrong when user invoke a - # function inside step. - st = inspect.stack() - for i in xrange(len(st)): - locs = inspect.stack()[i][0].f_locals - keys = locs.keys() - for key in keys: - val = locs[key] - if isinstance(val, RecurrentLayerInput): - begin_of_current_rnn.append(val) - elif isinstance(val, collections.Sequence): - for v in val: - if isinstance(v, RecurrentLayerInput): - begin_of_current_rnn.append(v) - - if begin_of_current_rnn: - break - assert begin_of_current_rnn is not None - for extra in begin_of_current_rnn: - self.append_extra_parent(extra) - extra.append_extra_parent(kwargs['boot_layer']) - self.__boot_layer_name__ = kwargs['boot_layer'].name - - def to_proto_impl(self, **kwargs): - args = dict() - for each in kwargs: - args[each] = kwargs[each] - for each in self.__kwargs__: - args[each] = self.__kwargs__[each] - - if self.__boot_layer_name__ is not None: - args['boot_layer'] = self.__context__[self.__boot_layer_name__] - - size = args.get('size', None) - if size is not None: - if callable(size): - real_size = size() - else: - real_size = size - args['size'] = real_size - return conf_helps.memory(name=self.name, **args) - - def context_name(self): - return self.name + "#memory" - - def use_context_name(self): - """ - memory layer will have the same name with some layer - :return: - """ - return True - - -class StaticInputV2(object): - def __init__(self, input, is_seq=False, size=None): - assert isinstance(input, LayerV2) - self.name = input.name - self.input = input - self.is_seq = is_seq - self.size = size - # TODO(add size check) - # assert input.size is not None or size is not None - - -class BaseGeneratedInputV2(object): - def __init__(self): - self.bos_id = None - self.eos_id = None - - def before_real_step(self): - raise NotImplementedError() - - def after_real_step(self, *args): - raise NotImplementedError() - - -class GeneratedInputV2(BaseGeneratedInputV2): - def __init__(self, size, embedding_name, embedding_size): - super(GeneratedInputV2, self).__init__() - self.size = size - self.embedding_name = embedding_name - self.embedding_size = embedding_size - - def after_real_step(self, input): - return max_id(input=input, name='__beam_search_predict__') - - def before_real_step(self): - predict_id = memory( - name='__beam_search_predict__', - size=self.size, - boot_with_const_id=self.bos_id) - - trg_emb = embedding( - input=predict_id, - size=self.embedding_size, - param_attr=attr.ParamAttr(name=self.embedding_name)) - return trg_emb - - -class RecurrentLayerGroupSetGeneratorV2(Layer): - def __init__(self, eos_name, max_length, beam_size, num_results_per_sample): - self.eos_name = eos_name - self.max_length = max_length - self.beam_size = beam_size - self.num_results_per_sample = num_results_per_sample - super(RecurrentLayerGroupSetGeneratorV2, self).__init__( - name=eos_name, parent_layers={}) - - def to_proto_impl(self, **kwargs): - RecurrentLayerGroupSetGenerator( - Generator( - eos_layer_name=self.eos_name, - max_num_frames=self.max_length, - beam_size=self.beam_size, - num_results_per_sample=self.num_results_per_sample)) - return self - - def context_name(self): - return self.eos_name + ".fake" - - def use_context_name(self): - return True - - -class MixedLayerV2(Layer): - """ - This class is use to support `with` grammar. If not, the following code - could convert mixed_layer simply. - - mixed = __convert_to_v2__( - 'mixed_layer', name_prefix='mixed', parent_names=['input']) - """ - - class AddToSealedMixedLayerExceptionV2(Exception): - pass - - def __init__(self, - size=0, - input=None, - name=None, - act=None, - bias_attr=None, - layer_attr=None): - self.__method_name__ = 'mixed_layer' - self.finalized = False - self.__inputs__ = [] - if input is not None: - self.__inputs__ = input - - other_kwargs = dict() - other_kwargs['name'] = name - other_kwargs['size'] = size - other_kwargs['act'] = act - other_kwargs['bias_attr'] = bias_attr - other_kwargs['layer_attr'] = layer_attr - parent_layers = {"input": self.__inputs__} - super(MixedLayerV2, self).__init__(name, parent_layers) - self.__other_kwargs__ = other_kwargs - - def __iadd__(self, other): - if not self.finalized: - self.__inputs__.append(other) - return self - else: - raise MixedLayerV2.AddToSealedMixedLayerExceptionV2() - - def __enter__(self): - assert len(self.__inputs__) == 0 - return self - - def __exit__(self, *args, **kwargs): - self.finalized = True - - def to_proto_impl(self, **kwargs): - args = dict() - for each in kwargs: - args[each] = kwargs[each] - for each in self.__other_kwargs__: - args[each] = self.__other_kwargs__[each] - size = args.get('size', None) - if size is not None: - if callable(size): - real_size = size() - else: - real_size = size - args['size'] = real_size - return getattr(conf_helps, self.__method_name__)(**args) - - -@wrap_name_default("mixed") -@wrap_act_default(act=activation.Linear()) -@wrap_bias_attr_default(has_bias=False) -@layer_support(conf_helps.layers.ERROR_CLIPPING, conf_helps.layers.DROPOUT) -def mixed(size=0, - name=None, - input=None, - act=None, - bias_attr=False, - layer_attr=None): - return MixedLayerV2(size, input, name, act, bias_attr, layer_attr) - - -mixed.__doc__ = conf_helps.mixed_layer.__doc__ - - -class RecurrentLayerInput(Layer): - def __init__(self, recurrent_name, index, parent_layers, reverse): - parents_len = len(parent_layers) - assert parents_len <= 1 - if parents_len == 0: - self.__parents__ = [] - else: - self.__parents__ = parent_layers.values()[0] - self.__recurrent_name__ = recurrent_name - self.__reverse__ = reverse - name = self.__parents__[ - index].name if index >= 0 else self.context_name() - super(RecurrentLayerInput, self).__init__( - name=name, parent_layers=parent_layers) - - def context_name(self): - return self.__recurrent_name__ + ".begin" - - def to_proto_impl(self, **kwargs): - model_type('recurrent_nn') - RecurrentLayerGroupWithoutOutLinksBegin( - name=self.__recurrent_name__, - in_links=map(lambda x: x.name, self.__parents__), - seq_reversed=self.__reverse__) - return self - - -class RecurrentLayerOutput(Layer): - def __init__(self, recurrent_name, index, parent_layers): - assert len(parent_layers) == 1 - self.__parents__ = parent_layers.values()[0] - super(RecurrentLayerOutput, self).__init__( - name=self.__parents__[index].name, parent_layers=parent_layers) - self.__recurrent_name__ = recurrent_name - - def context_name(self): - return self.__recurrent_name__ + ".end" - - def to_proto_impl(self, **kwargs): - for l in self.__parents__: - RecurrentLayerGroupSetOutLink(l.name) - RecurrentLayerGroupEnd(name=self.__recurrent_name__) - - -LayerV2 = Layer -data = DataLayerV2 -data.__name__ = 'data' -AggregateLevel = conf_helps.AggregateLevel -ExpandLevel = conf_helps.ExpandLevel -memory = MemoryV2 -memory.__name__ = 'memory' -memory.__doc__ = conf_helps.memory.__doc__ - - -def __layer_name_mapping__(inname): - if inname in ['data_layer', 'memory', 'mixed_layer', 'recurrent_group']: - # Do Not handle these layers - return - elif inname == 'maxid_layer': +def __convert_name__(inname): + if inname == 'maxid_layer': return 'max_id' elif inname.endswith('memory') or inname.endswith( '_seq') or inname.endswith('_sim') or inname == 'hsigmoid': @@ -431,187 +78,202 @@ def __layer_name_mapping__(inname): return inname elif inname.endswith("_layer"): return inname[:-len("_layer")] + else: + return inname -def __layer_name_mapping_parent_names__(inname): - all_args = getattr(conf_helps, inname).argspec.args - return filter( - lambda x: x in ['input1', 'input2', 'label', 'input', 'a', 'b', - 'expand_as', - 'weights', 'vectors', 'weight', 'score', 'left', - 'right', 'output_mem'], - all_args) - - -def __convert_layer__(_new_name_, _old_name_, _parent_names_): - global __all__ - __all__.append(_new_name_) - globals()[new_name] = __convert_to_v2__(_old_name_, _parent_names_) - globals()[new_name].__name__ = new_name - - -for each_layer_name in dir(conf_helps): - new_name = __layer_name_mapping__(each_layer_name) - if new_name is not None: - parent_names = __layer_name_mapping_parent_names__(each_layer_name) - assert len(parent_names) != 0, each_layer_name - __convert_layer__(new_name, each_layer_name, parent_names) - -del parent_names -del new_name -del each_layer_name - - -@wrap_name_default() -def recurrent_group(step, input, reverse=False, name=None): - if not isinstance(input, collections.Sequence): - input = [input] - - non_static_inputs = filter(lambda x: not isinstance(x, StaticInputV2), - input) - actual_input = [ - RecurrentLayerInput( - recurrent_name=name, - index=i, - parent_layers={'recurrent_inputs': non_static_inputs}, - reverse=reverse) for i in xrange(len(non_static_inputs)) - ] - - extra_input = None - if len(non_static_inputs) == 0: - extra_input = RecurrentLayerInput( - recurrent_name=name, index=-1, parent_layers={}, reverse=reverse) - - def __real_step__(*args): - rnn_input = list(args) - static_inputs = filter(lambda x: isinstance(x, StaticInputV2), input) - for static_input in static_inputs: - mem_name = "__%s_memory__" % static_input.input.name - mem = memory( - name=mem_name, - extra_input=extra_input, - is_seq=static_input.is_seq, - size=static_input.input.calculate_size, - boot_layer=static_input.input) - with mixed( - name=mem_name, - size=static_input.input.calculate_size, - act=activation.Identity()) as mix: - mix += identity_projection(input=mem) - rnn_input.insert(input.index(static_input), mix) - return step(*rnn_input) - - actual_output = __real_step__(*actual_input) - - if not isinstance(actual_output, collections.Sequence): - actual_output = [actual_output] - - retv = [ - RecurrentLayerOutput( - recurrent_name=name, - index=i, - parent_layers={'recurrent_outputs': actual_output}) - for i in xrange(len(actual_output)) - ] - if len(retv) == 1: - return retv[0] +for name in v1_layers.__all__: + obj = getattr(v1_layers, name) + if not __need_to_keep__(name): + continue + new_name = __convert_name__(name) + if callable(obj) and __need_to_wrap__(name): + globals()[new_name] = __wrap__(obj) else: - return retv - - -recurrent_group.__doc__ = conf_helps.recurrent_group.__doc__ - - -@wrap_name_default() -def beam_search(step, - input, - bos_id, - eos_id, - beam_size, - max_length=500, - name=None, - num_results_per_sample=None): - if num_results_per_sample is None: - num_results_per_sample = beam_size - assert num_results_per_sample <= beam_size - # logger.warning("num_results_per_sample should be less than beam_size") - - if isinstance(input, StaticInputV2) or isinstance(input, - BaseGeneratedInputV2): - input = [input] - - generated_input_index = -1 - - real_input = [] - for i, each_input in enumerate(input): - assert isinstance(each_input, StaticInputV2) or isinstance( - each_input, BaseGeneratedInputV2) - if isinstance(each_input, BaseGeneratedInputV2): - assert generated_input_index == -1 - generated_input_index = i - else: - real_input.append(each_input) + globals()[new_name] = obj + __all__.append(new_name) + + +def __data_layer__(name, type, **kwargs): + l = v1_layers.data_layer(name, type.dim, **kwargs) + l.data_type = type + return l - assert generated_input_index != -1 +data = __wrap__(__data_layer__) - gipt = input[generated_input_index] - assert isinstance(gipt, BaseGeneratedInputV2) +LayerV2 = v1_layers.LayerOutput - gipt.bos_id = bos_id - gipt.eos_id = eos_id - def __real_step__(*args): - eos_name = "__%s_eos_layer__" % name - generator = RecurrentLayerGroupSetGeneratorV2( - eos_name, max_length, beam_size, num_results_per_sample) +def __get_used_layers__(output_layers, extra_layers=None): + layer_names = set() + parents = {} + def add_parent(child, parent): + if child in parents: + parents[child].append(parent) + else: + parents[child] = [parent] + + def add_additional_parents(): + for sub_model in cp.g_config.model_config.sub_models: + if sub_model.name == 'root': + continue + for link in sub_model.in_links: + add_parent(link.link_name, link.layer_name) + add_parent(sub_model.name, link.layer_name) + for link in sub_model.out_links: + add_parent(link.link_name, link.layer_name) + add_parent(link.link_name, sub_model.name) + for mem in sub_model.memories: + if mem.boot_layer_name: + add_parent(mem.layer_name, mem.boot_layer_name) + add_parent(mem.link_name, mem.layer_name) + + def dfs_travel(layer_name): + if layer_name in layer_names: + return + layer_names.add(layer_name) + layer = cp.g_layer_map[layer_name] + + for inp in layer.inputs: + dfs_travel(inp.input_layer_name) + if layer.name in parents: + for p in parents[layer.name]: + dfs_travel(p) + + add_additional_parents() + + for layer in output_layers: + dfs_travel(layer.full_name) + + return layer_names + + +def __get_used_parameters__(layer_names): + parameter_names = set() + for name in layer_names: + l = cp.g_layer_map[name] + for inp in l.inputs: + if inp.input_parameter_name: + parameter_names.add(inp.input_parameter_name) + if l.bias_parameter_name: + parameter_names.add(l.bias_parameter_name) + return parameter_names + + +def __get_used_submodels__(layer_names): + submodel_names = set() + for submodel in cp.g_config.model_config.sub_models: + if submodel.name in layer_names: + submodel_names.add(submodel.name) + return submodel_names + + +def __get_used_evaluators__(layer_names): + evaluator_names = set() + for e in cp.g_config.model_config.evaluators: + used = True + for name in e.input_layers: + if name not in layer_names: + used = False + break + if used: + evaluator_names.add(e.name) + return evaluator_names + + +def __trim_submodel__(old_submodel, + layer_names, + input_layer_names, + output_layer_names, + evaluator_names): + + submodel = SubModelConfig() + submodel.name = old_submodel.name + submodel.layer_names.extend(filter(lambda x: x in layer_names, + old_submodel.layer_names)) + submodel.input_layer_names.extend(filter(lambda x: x in input_layer_names, + submodel.layer_names)) + submodel.output_layer_names.extend(filter(lambda x: x in output_layer_names, + submodel.layer_names)) + submodel.evaluator_names.extend(filter(lambda x: x in evaluator_names, + old_submodel.evaluator_names)) + + submodel.is_recurrent_layer_group = old_submodel.is_recurrent_layer_group + submodel.reversed = old_submodel.reversed + + submodel.memories.extend(filter(lambda x: x.link_name in layer_names, + old_submodel.memories)) + target_inlinkid = (old_submodel.target_inlinkid + if old_submodel.HasField('target_inlinkid') else -1) + in_links = [] + for i, link in enumerate(old_submodel.in_links): + if link.link_name in layer_names or i == target_inlinkid: + in_links.append(link) + if i == target_inlinkid: + target_inlinkid = len(in_links) - 1 + submodel.in_links.extend(in_links) + + submodel.out_links.extend(filter(lambda x: x.link_name in layer_names, + old_submodel.out_links)) + if old_submodel.HasField('generator'): + submodel.generator.CopyFrom(old_submodel.generator) + + if old_submodel.HasField('target_inlinkid'): + submodel.target_inlinkid = target_inlinkid + return submodel - args = list(args) - before_step_layer = gipt.before_real_step() - before_step_layer.append_child( - layer=generator, parent_names=[before_step_layer.name]) - args.insert(generated_input_index, before_step_layer) - predict = gipt.after_real_step(step(*args)) +def parse_network(output_layers, extra_layers=None): + if not isinstance(output_layers, collections.Sequence): + output_layers = [output_layers] + if extra_layers is not None and not isinstance(extra_layers, + collections.Sequence): + extra_layers = [extra_layers] + else: + extra_layers = [] - eos_layer = eos(input=predict, eos_id=eos_id, name=eos_name) - predict.append_child(layer=eos_layer, parent_names=[predict.name]) + layer_names = __get_used_layers__(output_layers + extra_layers) + submodel_names = __get_used_submodels__(layer_names) + submodel_names.add('root') + parameter_names = __get_used_parameters__(layer_names) + evaluator_names = __get_used_evaluators__(layer_names) + input_layer_names = set() + output_layer_names = set() - return predict + model_config = ModelConfig() + model_config.type = cp.g_config.model_config.type + for l in cp.g_config.model_config.layers: + if l.name not in layer_names: + continue + model_config.layers.extend([l]) + if l.type == 'data': + model_config.input_layer_names.append(l.name) + input_layer_names.add(l.name) - # tmp = paddle.layer.recurrent_group( - # step=__real_step__, - # input=real_input, - # reverse=False, - # name=name, - # is_generating=True) - tmp = recurrent_group(step=__real_step__, input=real_input, name=name) + for p in cp.g_config.model_config.parameters: + if p.name in parameter_names: + model_config.parameters.extend([p]) - return tmp + for layer in output_layers: + model_config.output_layer_names.append(layer.full_name) + output_layer_names.add(layer.full_name) + for e in cp.g_config.model_config.evaluators: + if e.name in evaluator_names: + model_config.evaluators.extend([e]) -beam_search.__doc__ = conf_helps.beam_search.__doc__ + for s in cp.g_config.model_config.sub_models: + if s.name in submodel_names: + s = __trim_submodel__( + s, layer_names, input_layer_names, output_layer_names, + evaluator_names) + model_config.sub_models.extend([s]) -__projection_names__ = filter(lambda x: x.endswith('_projection'), - dir(conf_helps)) + return model_config -__all__ += __projection_names__ -__operator_names__ = filter(lambda x: x.endswith('_operator'), dir(conf_helps)) -__all__ += __operator_names__ +def get_layer(name): + return __layer_map__.get(name) -# convert projection -for prj in __projection_names__: - globals()[prj] = __convert_to_v2__( - prj, parent_names=['input'], is_default_name=False) - globals()[prj].__name__ = prj -# convert operator -operator_list = [ - # [V1_method_name, parent_names], - ['dotmul_operator', ['a', 'b']], - ['conv_operator', ['img', 'filter']] -] -for op in operator_list: - globals()[op[0]] = __convert_to_v2__( - op[0], parent_names=op[1], is_default_name=False) - globals()[op[0]].__name__ = op[0] +cp.begin_parse() diff --git a/python/paddle/v2/networks.py b/python/paddle/v2/networks.py index 9e6644196c..8ae9f3b202 100644 --- a/python/paddle/v2/networks.py +++ b/python/paddle/v2/networks.py @@ -24,20 +24,7 @@ def __initialize__(): if each_subnetwork in ['inputs', 'outputs']: continue func = getattr(conf_nw, each_subnetwork) - if hasattr(func, 'argspec'): - argspec = func.argspec - else: - argspec = inspect.getargspec(func) - if each_subnetwork == 'simple_attention': - parents = ['encoded_sequence', 'encoded_proj', 'decoder_state'] - else: - parents = filter(lambda x: x.startswith('input'), argspec.args) - assert len(parents) != 0, each_subnetwork - v2_subnet = __convert_to_v2__( - each_subnetwork, - parent_names=parents, - is_default_name='name' in argspec.args) - globals()[each_subnetwork] = v2_subnet + globals()[each_subnetwork] = func globals()[each_subnetwork].__name__ = each_subnetwork global __all__ __all__.append(each_subnetwork) diff --git a/python/paddle/v2/tests/test_layer.py b/python/paddle/v2/tests/test_layer.py index c67f3b84d9..341da1c852 100644 --- a/python/paddle/v2/tests/test_layer.py +++ b/python/paddle/v2/tests/test_layer.py @@ -173,9 +173,9 @@ class OtherLayerTest(unittest.TestCase): class ProjOpTest(unittest.TestCase): def test_projection(self): - input = layer.data(name='data', type=data_type.dense_vector(784)) + input = layer.data(name='data2', type=data_type.dense_vector(784)) word = layer.data( - name='word', type=data_type.integer_value_sequence(10000)) + name='word2', type=data_type.integer_value_sequence(10000)) fc0 = layer.fc(input=input, size=100, act=activation.Sigmoid()) fc1 = layer.fc(input=input, size=200, act=activation.Sigmoid()) mixed0 = layer.mixed( @@ -204,8 +204,8 @@ class ProjOpTest(unittest.TestCase): dotmul1 += dotmul context = layer.context_projection(input=fc0, context_len=5) - context0 = layer.mixed(size=100, input=context) - with layer.mixed(size=100) as context1: + context0 = layer.mixed(size=500, input=context) + with layer.mixed(size=500) as context1: context1 += context conv = layer.conv_projection( @@ -231,8 +231,8 @@ class ProjOpTest(unittest.TestCase): print layer.parse_network(conv1) def test_operator(self): - ipt0 = layer.data(name='data', type=data_type.dense_vector(784)) - ipt1 = layer.data(name='word', type=data_type.dense_vector(128)) + ipt0 = layer.data(name='data1', type=data_type.dense_vector(784)) + ipt1 = layer.data(name='word1', type=data_type.dense_vector(128)) fc0 = layer.fc(input=ipt0, size=100, act=activation.Sigmoid()) fc1 = layer.fc(input=ipt0, size=100, act=activation.Sigmoid()) @@ -261,7 +261,7 @@ class ProjOpTest(unittest.TestCase): class NetworkTests(unittest.TestCase): def test_vgg(self): - img = layer.data(name='pixel', type=data_type.dense_vector(784)) + img = layer.data(name='pixel1', type=data_type.dense_vector(784)) vgg_out = networks.small_vgg( input_image=img, num_channels=1, num_classes=2) print layer.parse_network(vgg_out) @@ -269,12 +269,12 @@ class NetworkTests(unittest.TestCase): class EvaluatorTest(unittest.TestCase): def test_evaluator(self): - img = layer.data(name='pixel', type=data_type.dense_vector(784)) + img = layer.data(name='pixel2', type=data_type.dense_vector(784)) output = layer.fc(input=img, size=10, act=activation.Softmax(), name='fc_here') - lbl = layer.data(name='label', type=data_type.integer_value(10)) + lbl = layer.data(name='label2', type=data_type.integer_value(10)) cost = layer.cross_entropy_cost(input=output, label=lbl) evaluator.classification_error(input=output, label=lbl) diff --git a/python/paddle/v2/tests/test_rnn_layer.py b/python/paddle/v2/tests/test_rnn_layer.py index 845277c012..b334f3b1ff 100644 --- a/python/paddle/v2/tests/test_rnn_layer.py +++ b/python/paddle/v2/tests/test_rnn_layer.py @@ -20,6 +20,8 @@ import paddle.v2.data_type as data_type import paddle.v2.layer as layer from paddle.trainer_config_helpers.config_parser_utils import \ parse_network_config as parse_network +from paddle.trainer_config_helpers.config_parser_utils import \ + reset_parser class RNNTest(unittest.TestCase): @@ -29,6 +31,7 @@ class RNNTest(unittest.TestCase): hidden_dim = 8 def parse_old_rnn(): + reset_parser() def step(y): mem = conf_helps.memory(name="rnn_state", size=hidden_dim) out = conf_helps.fc_layer( @@ -48,6 +51,7 @@ class RNNTest(unittest.TestCase): return str(parse_network(test)) def parse_new_rnn(): + reset_parser() def new_step(y): mem = layer.memory(name="rnn_state", size=hidden_dim) out = layer.fc(input=[y, mem], @@ -68,6 +72,7 @@ class RNNTest(unittest.TestCase): parse_new_rnn().splitlines(1)) print ''.join(diff) + def test_sequence_rnn_multi_input(self): dict_dim = 10 word_dim = 8 @@ -75,6 +80,7 @@ class RNNTest(unittest.TestCase): label_dim = 3 def parse_old_rnn(): + reset_parser() def test(): data = conf_helps.data_layer(name="word", size=dict_dim) label = conf_helps.data_layer(name="label", size=label_dim) @@ -114,6 +120,7 @@ class RNNTest(unittest.TestCase): return str(parse_network(test)) def parse_new_rnn(): + reset_parser() data = layer.data( name="word", type=data_type.dense_vector(dict_dim)) label = layer.data( diff --git a/python/paddle/v2/tests/test_topology.py b/python/paddle/v2/tests/test_topology.py index 5c6dbcdb4f..7fd2ee82fd 100644 --- a/python/paddle/v2/tests/test_topology.py +++ b/python/paddle/v2/tests/test_topology.py @@ -46,8 +46,8 @@ class TestTopology(unittest.TestCase): self.assertEqual(label_data_type[1].dim, 10) def test_get_layer(self): - pixel = layer.data(name='pixel', type=data_type.dense_vector(784)) - label = layer.data(name='label', type=data_type.integer_value(10)) + pixel = layer.data(name='pixel2', type=data_type.dense_vector(784)) + label = layer.data(name='label2', type=data_type.integer_value(10)) hidden = layer.fc(input=pixel, size=100, act=conf_helps.SigmoidActivation()) @@ -56,14 +56,14 @@ class TestTopology(unittest.TestCase): act=conf_helps.SoftmaxActivation()) cost = layer.classification_cost(input=inference, label=label) topo = topology.Topology(cost) - pixel_layer = topo.get_layer("pixel") - label_layer = topo.get_layer("label") + pixel_layer = topo.get_layer("pixel2") + label_layer = topo.get_layer("label2") self.assertEqual(pixel_layer, pixel) self.assertEqual(label_layer, label) def test_parse(self): - pixel = layer.data(name='pixel', type=data_type.dense_vector(784)) - label = layer.data(name='label', type=data_type.integer_value(10)) + pixel = layer.data(name='pixel3', type=data_type.dense_vector(784)) + label = layer.data(name='label3', type=data_type.integer_value(10)) hidden = layer.fc(input=pixel, size=100, act=conf_helps.SigmoidActivation()) diff --git a/python/paddle/v2/topology.py b/python/paddle/v2/topology.py index 1e46e4973f..962d5ab76d 100644 --- a/python/paddle/v2/topology.py +++ b/python/paddle/v2/topology.py @@ -15,7 +15,7 @@ import collections from paddle.proto.ModelConfig_pb2 import ModelConfig - +import paddle.trainer_config_helpers as conf_helps import layer as v2_layer __all__ = ['Topology'] @@ -94,31 +94,18 @@ class Topology(object): :param name: :return: """ - result_layer = [None] - - def __impl__(l): - if l.name == name: - result_layer[0] = l - return True # break - return False - - __bfs_travel__(__impl__, *self.layers) - if result_layer[0] is None: - raise ValueError("No such layer %s" % name) - return result_layer[0] + return v2_layer.get_layer(name) def data_layers(self): """ get all data layer :return: """ - data_layers = dict() - - def __impl__(l): - if isinstance(l, v2_layer.DataLayerV2): - data_layers[l.name] = l - - __bfs_travel__(__impl__, *self.layers) + data_layers = {} + for layer in self.proto().layers: + l = v2_layer.get_layer(layer.name) + if l and l.layer_type == conf_helps.LayerType.DATA: + data_layers[layer.name] = l return data_layers def data_type(self): @@ -127,7 +114,7 @@ class Topology(object): [('image', dense_vector(768)), ('label', integer_value(10))] """ data_layers = self.data_layers() - return [(nm, data_layers[nm].type) + return [(nm, data_layers[nm].data_type) for nm in self.proto().input_layer_names] def get_layer_proto(self, name): -- GitLab