提交 f9ea5864 编写于 作者: Y Yu Yang

Add get/set method

上级 9646f2d3
......@@ -28,16 +28,16 @@ def main():
topology = parse_network_config(network_config)
parameters = paddle.parameters.create(topology)
for param_name in parameters.keys():
array = parameters[param_name]
array = parameters.get(param_name)
array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape)
parameters[param_name] = array
parameters.set(parameter_name=param_name, value=array)
adam_optimizer = paddle.optimizer.Optimizer(
learning_rate=0.01, learning_method=AdamOptimizer())
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
para = parameters['___fc_layer_2__.w0']
para = parameters.get('___fc_layer_2__.w0')
print "Pass %d, Batch %d, Cost %f, Weight Mean Of Fc 2 is %f" % (
event.pass_id, event.batch_id, event.cost, para.mean())
......
......@@ -26,6 +26,10 @@ def create(*topologies):
class Parameters(object):
"""
The parameters
"""
def __init__(self):
self.__param_conf__ = dict()
self.__gradient_machines__ = []
......@@ -66,7 +70,8 @@ class Parameters(object):
assert isinstance(param, api.Parameter)
val = param.getBuf(api.PARAMETER_VALUE)
assert isinstance(val, api.Vector)
return val.copyToNumpyArray().reshape(shape=shape)
val = val.copyToNumpyArray()
return val
# else continue
raise RuntimeError("Unexpected branch")
......@@ -96,6 +101,12 @@ class Parameters(object):
__copy_parameter_to_gradient_machine__(each_gradient_machine,
key, value)
def get(self, parameter_name):
return self.__getitem__(key=parameter_name)
def set(self, parameter_name, value):
self.__setitem__(key=parameter_name, value=value)
def append_gradient_machine(self, gradient_machine):
if not isinstance(gradient_machine, api.GradientMachine):
raise ValueError("gradient_machine should be api.GradientMachine")
......@@ -108,6 +119,7 @@ class Parameters(object):
except ValueError:
# If no such parameter in gradient machine, then don't copy
pass
self.__gradient_machines__.append(gradient_machine)
def __get_parameter_in_gradient_machine__(gradient_machine, name):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册