提交 8dc4c053 编写于 作者: Y Yu Yang

Add comments

上级 f9ea5864
"""
All training events.
"""
__all__ = ['EndIteration']
......
......@@ -21,13 +21,29 @@ def create(*topologies):
'create must pass a topologies which type is ModelConfig')
for param in topo.parameters:
pool.append_config(param)
pool.__append_config__(param)
return pool
class Parameters(object):
"""
The parameters
Parameters is a dictionary contains Paddle's parameter. The key of
Parameters is the name of parameter. The value of Parameters is a plain
:code:`numpy.ndarry` .
Basically usage is
.. code-block:: python
data = paddle.layers.data(...)
...
out = paddle.layers.fc(...)
parameters = paddle.parameters.create(out)
parameter_names = parameters.names()
fc_mat = parameters.get('fc')
print fc_mat
"""
def __init__(self):
......@@ -35,7 +51,16 @@ class Parameters(object):
self.__gradient_machines__ = []
self.__tmp_params__ = []
def append_config(self, param_conf):
def __append_config__(self, param_conf):
"""
Append a parameter configuration. It used to initialize Parameters and
should be invoked only in paddle.parameters.create
:param param_conf: The parameter configuration in protobuf
:type param_conf: ParameterConfig
:return: Nothing
"""
if not isinstance(param_conf, ParameterConfig):
raise ValueError("param_conf must be paddle.proto.ParameterConfig")
......@@ -45,18 +70,55 @@ class Parameters(object):
self.__param_conf__[param_conf.name] = param_conf
def keys(self):
"""
keys are the names of each parameter.
:return: list of parameter name
:rtype: list
"""
return self.__param_conf__.keys()
def names(self):
"""
names of each parameter.
:return: list of parameter name
:rtype: list
"""
return self.keys()
def has_key(self, key):
"""
has_key return true if there are such parameter name == key
:param key: Parameter name
:type key: basestring
:return: True if contains such key
"""
return key in self.__param_conf__.keys()
def __iter__(self):
"""
Return an iterator of parameter name. It is used by `for loop`
or `in` operator.
.. code-block:: python
parameters = paddle.parameters.create(...)
if "fc_param" in parameters:
print 'OK'
:return: an iterator of parameter name
:rtype: iterator
"""
return iter(self.__param_conf__)
def __getitem__(self, key):
"""
Get parameter by parameter name. It uses Python dict syntax.
:note: It will always copy the parameter from C++ side.
:param key: Parameter name
:type key: basestring
:return: parameter value
:rtype: np.ndarray
"""
shape = self.get_shape(key)
if len(self.__gradient_machines__) == 0:
......@@ -77,20 +139,37 @@ class Parameters(object):
raise RuntimeError("Unexpected branch")
def get_shape(self, key):
"""
get shape of the parameter.
:param key: parameter name
:type key: basestring
:return: parameter's shape
:rtype: tuple
"""
if not isinstance(key, basestring):
raise ValueError("parameter name should be string")
if not self.has_key(key):
raise ValueError("No such parameter %s" % key)
conf = self.__param_conf__[key]
return map(int, conf.dims)
return tuple(map(int, conf.dims))
def __setitem__(self, key, value):
"""
Set parameter by parameter name & value. It use Python dict syntax.
:note: It will always copy the parameter to C++ side.
:param key: Parameter name
:type key: basestring
:param value: Parameter matrix.
:type value: np.ndarray
:return: Nothing
"""
if not isinstance(value, np.ndarray):
raise ValueError("Must return ndarray")
value = value.astype(dtype=np.float32)
shape = self.get_shape(key)
if not reduce(lambda a, b: a and b,
map(lambda x: x[0] == x[1], zip(value.shape, shape))):
if value.shape != shape:
raise ValueError("Value shape mismatch, expect %s, should %s" %
(shape, value.shape))
......@@ -102,12 +181,38 @@ class Parameters(object):
key, value)
def get(self, parameter_name):
"""
Get parameter by parameter name.
:note: It will always copy the parameter from C++ side.
:param parameter_name: parameter name
:type parameter_name: basestring
:return: The parameter matrix.
:rtype: np.ndarray
"""
return self.__getitem__(key=parameter_name)
def set(self, parameter_name, value):
"""
Set parameter by parameter name & matrix.
:param parameter_name: parameter name
:type parameter_name: basestring
:param value: parameter matrix
:type value: np.ndarray
:return: Nothing.
"""
self.__setitem__(key=parameter_name, value=value)
def append_gradient_machine(self, gradient_machine):
"""
append gradient machine to parameters. This method is used internally in
Trainer.train.
:param gradient_machine: Paddle C++ GradientMachine object.
:type gradient_machine: api.GradientMachine
:return:
"""
if not isinstance(gradient_machine, api.GradientMachine):
raise ValueError("gradient_machine should be api.GradientMachine")
......
......@@ -12,16 +12,38 @@ __all__ = ['ITrainer', 'SGD']
def default_event_handler(event):
"""
Default event handler. It will print some log and save mode.
TODO(yuyang18): Complete it!
:param event:
:return:
"""
pass
class ITrainer(object):
"""
The interface of Trainer. The only exposed method is `train`.
"""
def train(self,
train_data_reader,
topology,
parameters,
test_data_reader=None,
event_handler=None):
"""
train method.
:param train_data_reader:
:param topology:
:param parameters:
:param test_data_reader:
:param event_handler:
:return:
"""
raise NotImplementedError()
......@@ -30,7 +52,8 @@ class SGD(ITrainer):
"""
Simple SGD Trainer.
:param update_equation: Maybe we should give a DSL for update equation?
:param update_equation: The optimizer object.
:type update_equation: v2_optimizer.Optimizer
"""
if not isinstance(update_equation, v2_optimizer.Optimizer):
raise ValueError("update equation parameter must be "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册