提交 f9ea5864 编写于 作者: Y Yu Yang

Add get/set method

上级 9646f2d3
...@@ -28,16 +28,16 @@ def main(): ...@@ -28,16 +28,16 @@ def main():
topology = parse_network_config(network_config) topology = parse_network_config(network_config)
parameters = paddle.parameters.create(topology) parameters = paddle.parameters.create(topology)
for param_name in parameters.keys(): for param_name in parameters.keys():
array = parameters[param_name] array = parameters.get(param_name)
array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape) array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape)
parameters[param_name] = array parameters.set(parameter_name=param_name, value=array)
adam_optimizer = paddle.optimizer.Optimizer( adam_optimizer = paddle.optimizer.Optimizer(
learning_rate=0.01, learning_method=AdamOptimizer()) learning_rate=0.01, learning_method=AdamOptimizer())
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
para = parameters['___fc_layer_2__.w0'] para = parameters.get('___fc_layer_2__.w0')
print "Pass %d, Batch %d, Cost %f, Weight Mean Of Fc 2 is %f" % ( print "Pass %d, Batch %d, Cost %f, Weight Mean Of Fc 2 is %f" % (
event.pass_id, event.batch_id, event.cost, para.mean()) event.pass_id, event.batch_id, event.cost, para.mean())
......
...@@ -26,6 +26,10 @@ def create(*topologies): ...@@ -26,6 +26,10 @@ def create(*topologies):
class Parameters(object): class Parameters(object):
"""
The parameters
"""
def __init__(self): def __init__(self):
self.__param_conf__ = dict() self.__param_conf__ = dict()
self.__gradient_machines__ = [] self.__gradient_machines__ = []
...@@ -66,7 +70,8 @@ class Parameters(object): ...@@ -66,7 +70,8 @@ class Parameters(object):
assert isinstance(param, api.Parameter) assert isinstance(param, api.Parameter)
val = param.getBuf(api.PARAMETER_VALUE) val = param.getBuf(api.PARAMETER_VALUE)
assert isinstance(val, api.Vector) assert isinstance(val, api.Vector)
return val.copyToNumpyArray().reshape(shape=shape) val = val.copyToNumpyArray()
return val
# else continue # else continue
raise RuntimeError("Unexpected branch") raise RuntimeError("Unexpected branch")
...@@ -96,6 +101,12 @@ class Parameters(object): ...@@ -96,6 +101,12 @@ class Parameters(object):
__copy_parameter_to_gradient_machine__(each_gradient_machine, __copy_parameter_to_gradient_machine__(each_gradient_machine,
key, value) key, value)
def get(self, parameter_name):
return self.__getitem__(key=parameter_name)
def set(self, parameter_name, value):
self.__setitem__(key=parameter_name, value=value)
def append_gradient_machine(self, gradient_machine): def append_gradient_machine(self, gradient_machine):
if not isinstance(gradient_machine, api.GradientMachine): if not isinstance(gradient_machine, api.GradientMachine):
raise ValueError("gradient_machine should be api.GradientMachine") raise ValueError("gradient_machine should be api.GradientMachine")
...@@ -108,6 +119,7 @@ class Parameters(object): ...@@ -108,6 +119,7 @@ class Parameters(object):
except ValueError: except ValueError:
# If no such parameter in gradient machine, then don't copy # If no such parameter in gradient machine, then don't copy
pass pass
self.__gradient_machines__.append(gradient_machine)
def __get_parameter_in_gradient_machine__(gradient_machine, name): def __get_parameter_in_gradient_machine__(gradient_machine, name):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册