From f9ea586431fe1c1851d3a59790c9e6197f69b60a Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 16 Feb 2017 15:19:29 +0800 Subject: [PATCH] Add get/set method --- demo/mnist/api_train_v2.py | 6 +++--- python/paddle/v2/parameters.py | 14 +++++++++++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 2c0394aa0b1..5e46d510ad3 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -28,16 +28,16 @@ def main(): topology = parse_network_config(network_config) parameters = paddle.parameters.create(topology) for param_name in parameters.keys(): - array = parameters[param_name] + array = parameters.get(param_name) array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape) - parameters[param_name] = array + parameters.set(parameter_name=param_name, value=array) adam_optimizer = paddle.optimizer.Optimizer( learning_rate=0.01, learning_method=AdamOptimizer()) def event_handler(event): if isinstance(event, paddle.event.EndIteration): - para = parameters['___fc_layer_2__.w0'] + para = parameters.get('___fc_layer_2__.w0') print "Pass %d, Batch %d, Cost %f, Weight Mean Of Fc 2 is %f" % ( event.pass_id, event.batch_id, event.cost, para.mean()) diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index c2e74b8fb12..55e732a320b 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -26,6 +26,10 @@ def create(*topologies): class Parameters(object): + """ + The parameters + """ + def __init__(self): self.__param_conf__ = dict() self.__gradient_machines__ = [] @@ -66,7 +70,8 @@ class Parameters(object): assert isinstance(param, api.Parameter) val = param.getBuf(api.PARAMETER_VALUE) assert isinstance(val, api.Vector) - return val.copyToNumpyArray().reshape(shape=shape) + val = val.copyToNumpyArray() + return val # else continue raise RuntimeError("Unexpected branch") @@ -96,6 +101,12 @@ class Parameters(object): __copy_parameter_to_gradient_machine__(each_gradient_machine, key, value) + def get(self, parameter_name): + return self.__getitem__(key=parameter_name) + + def set(self, parameter_name, value): + self.__setitem__(key=parameter_name, value=value) + def append_gradient_machine(self, gradient_machine): if not isinstance(gradient_machine, api.GradientMachine): raise ValueError("gradient_machine should be api.GradientMachine") @@ -108,6 +119,7 @@ class Parameters(object): except ValueError: # If no such parameter in gradient machine, then don't copy pass + self.__gradient_machines__.append(gradient_machine) def __get_parameter_in_gradient_machine__(gradient_machine, name): -- GitLab