提交 8dc4c053 编写于 作者: Y Yu Yang

Add comments

上级 f9ea5864
"""
All training events.
"""
__all__ = ['EndIteration'] __all__ = ['EndIteration']
......
...@@ -21,13 +21,29 @@ def create(*topologies): ...@@ -21,13 +21,29 @@ def create(*topologies):
'create must pass a topologies which type is ModelConfig') 'create must pass a topologies which type is ModelConfig')
for param in topo.parameters: for param in topo.parameters:
pool.append_config(param) pool.__append_config__(param)
return pool return pool
class Parameters(object): class Parameters(object):
""" """
The parameters Parameters is a dictionary contains Paddle's parameter. The key of
Parameters is the name of parameter. The value of Parameters is a plain
:code:`numpy.ndarry` .
Basically usage is
.. code-block:: python
data = paddle.layers.data(...)
...
out = paddle.layers.fc(...)
parameters = paddle.parameters.create(out)
parameter_names = parameters.names()
fc_mat = parameters.get('fc')
print fc_mat
""" """
def __init__(self): def __init__(self):
...@@ -35,7 +51,16 @@ class Parameters(object): ...@@ -35,7 +51,16 @@ class Parameters(object):
self.__gradient_machines__ = [] self.__gradient_machines__ = []
self.__tmp_params__ = [] self.__tmp_params__ = []
def append_config(self, param_conf): def __append_config__(self, param_conf):
"""
Append a parameter configuration. It used to initialize Parameters and
should be invoked only in paddle.parameters.create
:param param_conf: The parameter configuration in protobuf
:type param_conf: ParameterConfig
:return: Nothing
"""
if not isinstance(param_conf, ParameterConfig): if not isinstance(param_conf, ParameterConfig):
raise ValueError("param_conf must be paddle.proto.ParameterConfig") raise ValueError("param_conf must be paddle.proto.ParameterConfig")
...@@ -45,18 +70,55 @@ class Parameters(object): ...@@ -45,18 +70,55 @@ class Parameters(object):
self.__param_conf__[param_conf.name] = param_conf self.__param_conf__[param_conf.name] = param_conf
def keys(self): def keys(self):
"""
keys are the names of each parameter.
:return: list of parameter name
:rtype: list
"""
return self.__param_conf__.keys() return self.__param_conf__.keys()
def names(self): def names(self):
"""
names of each parameter.
:return: list of parameter name
:rtype: list
"""
return self.keys() return self.keys()
def has_key(self, key): def has_key(self, key):
"""
has_key return true if there are such parameter name == key
:param key: Parameter name
:type key: basestring
:return: True if contains such key
"""
return key in self.__param_conf__.keys() return key in self.__param_conf__.keys()
def __iter__(self): def __iter__(self):
"""
Return an iterator of parameter name. It is used by `for loop`
or `in` operator.
.. code-block:: python
parameters = paddle.parameters.create(...)
if "fc_param" in parameters:
print 'OK'
:return: an iterator of parameter name
:rtype: iterator
"""
return iter(self.__param_conf__) return iter(self.__param_conf__)
def __getitem__(self, key): def __getitem__(self, key):
"""
Get parameter by parameter name. It uses Python dict syntax.
:note: It will always copy the parameter from C++ side.
:param key: Parameter name
:type key: basestring
:return: parameter value
:rtype: np.ndarray
"""
shape = self.get_shape(key) shape = self.get_shape(key)
if len(self.__gradient_machines__) == 0: if len(self.__gradient_machines__) == 0:
...@@ -77,20 +139,37 @@ class Parameters(object): ...@@ -77,20 +139,37 @@ class Parameters(object):
raise RuntimeError("Unexpected branch") raise RuntimeError("Unexpected branch")
def get_shape(self, key): def get_shape(self, key):
"""
get shape of the parameter.
:param key: parameter name
:type key: basestring
:return: parameter's shape
:rtype: tuple
"""
if not isinstance(key, basestring): if not isinstance(key, basestring):
raise ValueError("parameter name should be string") raise ValueError("parameter name should be string")
if not self.has_key(key): if not self.has_key(key):
raise ValueError("No such parameter %s" % key) raise ValueError("No such parameter %s" % key)
conf = self.__param_conf__[key] conf = self.__param_conf__[key]
return map(int, conf.dims) return tuple(map(int, conf.dims))
def __setitem__(self, key, value): def __setitem__(self, key, value):
"""
Set parameter by parameter name & value. It use Python dict syntax.
:note: It will always copy the parameter to C++ side.
:param key: Parameter name
:type key: basestring
:param value: Parameter matrix.
:type value: np.ndarray
:return: Nothing
"""
if not isinstance(value, np.ndarray): if not isinstance(value, np.ndarray):
raise ValueError("Must return ndarray") raise ValueError("Must return ndarray")
value = value.astype(dtype=np.float32) value = value.astype(dtype=np.float32)
shape = self.get_shape(key) shape = self.get_shape(key)
if not reduce(lambda a, b: a and b, if value.shape != shape:
map(lambda x: x[0] == x[1], zip(value.shape, shape))):
raise ValueError("Value shape mismatch, expect %s, should %s" % raise ValueError("Value shape mismatch, expect %s, should %s" %
(shape, value.shape)) (shape, value.shape))
...@@ -102,12 +181,38 @@ class Parameters(object): ...@@ -102,12 +181,38 @@ class Parameters(object):
key, value) key, value)
def get(self, parameter_name): def get(self, parameter_name):
"""
Get parameter by parameter name.
:note: It will always copy the parameter from C++ side.
:param parameter_name: parameter name
:type parameter_name: basestring
:return: The parameter matrix.
:rtype: np.ndarray
"""
return self.__getitem__(key=parameter_name) return self.__getitem__(key=parameter_name)
def set(self, parameter_name, value): def set(self, parameter_name, value):
"""
Set parameter by parameter name & matrix.
:param parameter_name: parameter name
:type parameter_name: basestring
:param value: parameter matrix
:type value: np.ndarray
:return: Nothing.
"""
self.__setitem__(key=parameter_name, value=value) self.__setitem__(key=parameter_name, value=value)
def append_gradient_machine(self, gradient_machine): def append_gradient_machine(self, gradient_machine):
"""
append gradient machine to parameters. This method is used internally in
Trainer.train.
:param gradient_machine: Paddle C++ GradientMachine object.
:type gradient_machine: api.GradientMachine
:return:
"""
if not isinstance(gradient_machine, api.GradientMachine): if not isinstance(gradient_machine, api.GradientMachine):
raise ValueError("gradient_machine should be api.GradientMachine") raise ValueError("gradient_machine should be api.GradientMachine")
......
...@@ -12,16 +12,38 @@ __all__ = ['ITrainer', 'SGD'] ...@@ -12,16 +12,38 @@ __all__ = ['ITrainer', 'SGD']
def default_event_handler(event): def default_event_handler(event):
"""
Default event handler. It will print some log and save mode.
TODO(yuyang18): Complete it!
:param event:
:return:
"""
pass pass
class ITrainer(object): class ITrainer(object):
"""
The interface of Trainer. The only exposed method is `train`.
"""
def train(self, def train(self,
train_data_reader, train_data_reader,
topology, topology,
parameters, parameters,
test_data_reader=None, test_data_reader=None,
event_handler=None): event_handler=None):
"""
train method.
:param train_data_reader:
:param topology:
:param parameters:
:param test_data_reader:
:param event_handler:
:return:
"""
raise NotImplementedError() raise NotImplementedError()
...@@ -30,7 +52,8 @@ class SGD(ITrainer): ...@@ -30,7 +52,8 @@ class SGD(ITrainer):
""" """
Simple SGD Trainer. Simple SGD Trainer.
:param update_equation: Maybe we should give a DSL for update equation? :param update_equation: The optimizer object.
:type update_equation: v2_optimizer.Optimizer
""" """
if not isinstance(update_equation, v2_optimizer.Optimizer): if not isinstance(update_equation, v2_optimizer.Optimizer):
raise ValueError("update equation parameter must be " raise ValueError("update equation parameter must be "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册