From 8e0d1d8be1cd1fb9dd7671e27beb60509735c968 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 13 Feb 2017 17:55:10 +0800 Subject: [PATCH] Complete update equation. --- demo/mnist/api_train_v2.py | 2 +- python/paddle/v2/trainer.py | 33 +++++++++++++++++++++++++++------ 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 6ceff8284d..ab1581325a 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -49,7 +49,7 @@ def main(): :return: """ mu = 0.09 - e = 0.0001 + e = 0.00001 vel_t = mu * vel_t_1 - e * g diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 0375caa552..9620e75420 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -25,7 +25,7 @@ class CompleteTrainOneBatch(BaseEvent): self.pass_id = pass_id self.batch_id = batch_id self.cost = cost - self.paramters = parameters + self.parameters = parameters def default_event_handler(event): @@ -44,6 +44,17 @@ class ITrainer(object): class LazyParameterPool(v2_parameters.IParameterPool): """ + Lazy Parameter Pool stores a reference to GradientMachine. User could invoke + `get_parameter` if needed, but the operation is lazy. It means the parameter + will only fetched from GPU or Parameter Server if `get_parameter` is + invoked. Also, set flag = writable will make a extra host2device copy after + reading/modifying parameter. + + This class is not exposed to User. User should treat this class as a normal + IParameterPool. + + See IParameterPool for usage documentation. + :type __gradient_machine__: api.GradientMachine """ @@ -130,12 +141,22 @@ class CustomizeUpdateEquation(object): shape) g = param.getBuf(api.PARAMETER_GRADIENT).toNumpyArrayInplace( ).reshape(shape) - args = [v, g] - for arg in self.local_params[conf.name]: - args.append(arg) - self.__callback__(*args) + else: - raise NotImplementedError() + v = param.getBuf(api.PARAMETER_VALUE).copyToNumpyArray().reshape( + shape) + g = param.getBuf(api.PARAMETER_GRADIENT).copyToNumpyArray().reshape( + shape) + + args = [v, g] + for arg in self.local_params[conf.name]: + args.append(arg) + self.__callback__(*args) + + if api.isUsingGpu(): + param.getBuf(api.PARAMETER_VALUE).copyFromNumpyArray(v.flatten( + ).astype('float32')) + # discard gradient changed. class SGDTrainer(ITrainer): -- GitLab