diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index 364306d6741e21c1c2724f873d2f3e3c3f92ec72..8d8012e5d5a02f0388b2f88c2eb86a6ef78ccf7b 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -127,16 +127,7 @@ class Parameters(object): """ return iter(self.__param_conf__) - def __getitem__(self, key): - """ - Get parameter by parameter name. It uses Python dict syntax. - - :note: It will always copy the parameter from C++ side. - :param key: Parameter name - :type key: basestring - :return: parameter value - :rtype: np.ndarray - """ + def __getter_inner(self, key, param_type): import py_paddle.swig_paddle as api shape = self.get_shape(key) @@ -152,7 +143,7 @@ class Parameters(object): each_gradient_machine, key) # for simplify implementation now, we always copy from C++ assert isinstance(param, api.Parameter) - val = param.getBuf(api.PARAMETER_VALUE) + val = param.getBuf(param_type) assert isinstance(val, api.Vector) val = val.copyToNumpyArray() return val @@ -160,6 +151,19 @@ class Parameters(object): raise RuntimeError("Unexpected branch") + def __getitem__(self, key): + """ + Get parameter by parameter name. It uses Python dict syntax. + + :note: It will always copy the parameter from C++ side. + :param key: Parameter name + :type key: basestring + :return: parameter value + :rtype: np.ndarray + """ + import py_paddle.swig_paddle as api + return self.__getter_inner(key, api.PARAMETER_VALUE) + def get_shape(self, key): """ get shape of the parameter. @@ -216,6 +220,19 @@ class Parameters(object): """ return self.__getitem__(key=parameter_name) + def get_grad(self, key): + """ + Get grandient by parameter name. + + :note: It will always copy the parameter from C++ side. + :param key: parameter name + :type key: basestring + :return: The grandient matrix. + :rtype: np.ndarray + """ + import py_paddle.swig_paddle as api + return self.__getter_inner(key, api.PARAMETER_GRADIENT) + def set(self, parameter_name, value): """ Set parameter by parameter name & matrix. diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 76bae0bb12b6c33f88530386f9cc19ae9b59f457..9c4dd5f25083d210bbd218a85d8dbb3cce2c3d0e 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -161,14 +161,14 @@ class SGD(object): self.__parameter_updater__.update(each_param) cost_sum = out_args.sum() cost = cost_sum / len(data_batch) - self.__parameter_updater__.finishBatch(cost) - batch_evaluator.finish() event_handler( v2_event.EndIteration( pass_id=pass_id, batch_id=batch_id, cost=cost, evaluator=batch_evaluator)) + self.__parameter_updater__.finishBatch(cost) + batch_evaluator.finish() self.__parameter_updater__.finishPass() pass_evaluator.finish()