From 8dc4c0538efa126e771ed807a62be50dc123d7a2 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 16 Feb 2017 16:37:16 +0800 Subject: [PATCH] Add comments --- python/paddle/v2/event.py | 3 + python/paddle/v2/parameters.py | 117 +++++++++++++++++++++++++++++++-- python/paddle/v2/trainer.py | 25 ++++++- 3 files changed, 138 insertions(+), 7 deletions(-) diff --git a/python/paddle/v2/event.py b/python/paddle/v2/event.py index 04158f47652..f6c69574386 100644 --- a/python/paddle/v2/event.py +++ b/python/paddle/v2/event.py @@ -1,3 +1,6 @@ +""" +All training events. +""" __all__ = ['EndIteration'] diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index 55e732a320b..e5b7dabcb8e 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -21,13 +21,29 @@ def create(*topologies): 'create must pass a topologies which type is ModelConfig') for param in topo.parameters: - pool.append_config(param) + pool.__append_config__(param) return pool class Parameters(object): """ - The parameters + Parameters is a dictionary contains Paddle's parameter. The key of + Parameters is the name of parameter. The value of Parameters is a plain + :code:`numpy.ndarry` . + + Basically usage is + + .. code-block:: python + + data = paddle.layers.data(...) + ... + out = paddle.layers.fc(...) + + parameters = paddle.parameters.create(out) + + parameter_names = parameters.names() + fc_mat = parameters.get('fc') + print fc_mat """ def __init__(self): @@ -35,7 +51,16 @@ class Parameters(object): self.__gradient_machines__ = [] self.__tmp_params__ = [] - def append_config(self, param_conf): + def __append_config__(self, param_conf): + """ + Append a parameter configuration. It used to initialize Parameters and + should be invoked only in paddle.parameters.create + + :param param_conf: The parameter configuration in protobuf + :type param_conf: ParameterConfig + :return: Nothing + """ + if not isinstance(param_conf, ParameterConfig): raise ValueError("param_conf must be paddle.proto.ParameterConfig") @@ -45,18 +70,55 @@ class Parameters(object): self.__param_conf__[param_conf.name] = param_conf def keys(self): + """ + keys are the names of each parameter. + :return: list of parameter name + :rtype: list + """ return self.__param_conf__.keys() def names(self): + """ + names of each parameter. + :return: list of parameter name + :rtype: list + """ return self.keys() def has_key(self, key): + """ + has_key return true if there are such parameter name == key + :param key: Parameter name + :type key: basestring + :return: True if contains such key + """ return key in self.__param_conf__.keys() def __iter__(self): + """ + Return an iterator of parameter name. It is used by `for loop` + or `in` operator. + + .. code-block:: python + + parameters = paddle.parameters.create(...) + if "fc_param" in parameters: + print 'OK' + :return: an iterator of parameter name + :rtype: iterator + """ return iter(self.__param_conf__) def __getitem__(self, key): + """ + Get parameter by parameter name. It uses Python dict syntax. + + :note: It will always copy the parameter from C++ side. + :param key: Parameter name + :type key: basestring + :return: parameter value + :rtype: np.ndarray + """ shape = self.get_shape(key) if len(self.__gradient_machines__) == 0: @@ -77,20 +139,37 @@ class Parameters(object): raise RuntimeError("Unexpected branch") def get_shape(self, key): + """ + get shape of the parameter. + :param key: parameter name + :type key: basestring + :return: parameter's shape + :rtype: tuple + """ if not isinstance(key, basestring): raise ValueError("parameter name should be string") if not self.has_key(key): raise ValueError("No such parameter %s" % key) conf = self.__param_conf__[key] - return map(int, conf.dims) + return tuple(map(int, conf.dims)) def __setitem__(self, key, value): + """ + Set parameter by parameter name & value. It use Python dict syntax. + + :note: It will always copy the parameter to C++ side. + :param key: Parameter name + :type key: basestring + :param value: Parameter matrix. + :type value: np.ndarray + :return: Nothing + """ + if not isinstance(value, np.ndarray): raise ValueError("Must return ndarray") value = value.astype(dtype=np.float32) shape = self.get_shape(key) - if not reduce(lambda a, b: a and b, - map(lambda x: x[0] == x[1], zip(value.shape, shape))): + if value.shape != shape: raise ValueError("Value shape mismatch, expect %s, should %s" % (shape, value.shape)) @@ -102,12 +181,38 @@ class Parameters(object): key, value) def get(self, parameter_name): + """ + Get parameter by parameter name. + + :note: It will always copy the parameter from C++ side. + :param parameter_name: parameter name + :type parameter_name: basestring + :return: The parameter matrix. + :rtype: np.ndarray + """ return self.__getitem__(key=parameter_name) def set(self, parameter_name, value): + """ + Set parameter by parameter name & matrix. + :param parameter_name: parameter name + :type parameter_name: basestring + :param value: parameter matrix + :type value: np.ndarray + :return: Nothing. + """ self.__setitem__(key=parameter_name, value=value) def append_gradient_machine(self, gradient_machine): + """ + append gradient machine to parameters. This method is used internally in + Trainer.train. + + :param gradient_machine: Paddle C++ GradientMachine object. + :type gradient_machine: api.GradientMachine + :return: + """ + if not isinstance(gradient_machine, api.GradientMachine): raise ValueError("gradient_machine should be api.GradientMachine") diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index a29c3a05f85..9ba13dc5c8a 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -12,16 +12,38 @@ __all__ = ['ITrainer', 'SGD'] def default_event_handler(event): + """ + Default event handler. It will print some log and save mode. + + TODO(yuyang18): Complete it! + :param event: + :return: + """ pass class ITrainer(object): + """ + The interface of Trainer. The only exposed method is `train`. + """ + def train(self, train_data_reader, topology, parameters, test_data_reader=None, event_handler=None): + """ + train method. + + :param train_data_reader: + :param topology: + :param parameters: + :param test_data_reader: + :param event_handler: + :return: + """ + raise NotImplementedError() @@ -30,7 +52,8 @@ class SGD(ITrainer): """ Simple SGD Trainer. - :param update_equation: Maybe we should give a DSL for update equation? + :param update_equation: The optimizer object. + :type update_equation: v2_optimizer.Optimizer """ if not isinstance(update_equation, v2_optimizer.Optimizer): raise ValueError("update equation parameter must be " -- GitLab