提交 01e9e443 编写于 作者: 武毅 提交者: GitHub

able to print gradients in event_handler (#3085)

上级 886e66a5
...@@ -127,16 +127,7 @@ class Parameters(object): ...@@ -127,16 +127,7 @@ class Parameters(object):
""" """
return iter(self.__param_conf__) return iter(self.__param_conf__)
def __getitem__(self, key): def __getter_inner(self, key, param_type):
"""
Get parameter by parameter name. It uses Python dict syntax.
:note: It will always copy the parameter from C++ side.
:param key: Parameter name
:type key: basestring
:return: parameter value
:rtype: np.ndarray
"""
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
shape = self.get_shape(key) shape = self.get_shape(key)
...@@ -152,7 +143,7 @@ class Parameters(object): ...@@ -152,7 +143,7 @@ class Parameters(object):
each_gradient_machine, key) each_gradient_machine, key)
# for simplify implementation now, we always copy from C++ # for simplify implementation now, we always copy from C++
assert isinstance(param, api.Parameter) assert isinstance(param, api.Parameter)
val = param.getBuf(api.PARAMETER_VALUE) val = param.getBuf(param_type)
assert isinstance(val, api.Vector) assert isinstance(val, api.Vector)
val = val.copyToNumpyArray() val = val.copyToNumpyArray()
return val return val
...@@ -160,6 +151,19 @@ class Parameters(object): ...@@ -160,6 +151,19 @@ class Parameters(object):
raise RuntimeError("Unexpected branch") raise RuntimeError("Unexpected branch")
def __getitem__(self, key):
"""
Get parameter by parameter name. It uses Python dict syntax.
:note: It will always copy the parameter from C++ side.
:param key: Parameter name
:type key: basestring
:return: parameter value
:rtype: np.ndarray
"""
import py_paddle.swig_paddle as api
return self.__getter_inner(key, api.PARAMETER_VALUE)
def get_shape(self, key): def get_shape(self, key):
""" """
get shape of the parameter. get shape of the parameter.
...@@ -216,6 +220,19 @@ class Parameters(object): ...@@ -216,6 +220,19 @@ class Parameters(object):
""" """
return self.__getitem__(key=parameter_name) return self.__getitem__(key=parameter_name)
def get_grad(self, key):
"""
Get grandient by parameter name.
:note: It will always copy the parameter from C++ side.
:param key: parameter name
:type key: basestring
:return: The grandient matrix.
:rtype: np.ndarray
"""
import py_paddle.swig_paddle as api
return self.__getter_inner(key, api.PARAMETER_GRADIENT)
def set(self, parameter_name, value): def set(self, parameter_name, value):
""" """
Set parameter by parameter name & matrix. Set parameter by parameter name & matrix.
......
...@@ -161,14 +161,14 @@ class SGD(object): ...@@ -161,14 +161,14 @@ class SGD(object):
self.__parameter_updater__.update(each_param) self.__parameter_updater__.update(each_param)
cost_sum = out_args.sum() cost_sum = out_args.sum()
cost = cost_sum / len(data_batch) cost = cost_sum / len(data_batch)
self.__parameter_updater__.finishBatch(cost)
batch_evaluator.finish()
event_handler( event_handler(
v2_event.EndIteration( v2_event.EndIteration(
pass_id=pass_id, pass_id=pass_id,
batch_id=batch_id, batch_id=batch_id,
cost=cost, cost=cost,
evaluator=batch_evaluator)) evaluator=batch_evaluator))
self.__parameter_updater__.finishBatch(cost)
batch_evaluator.finish()
self.__parameter_updater__.finishPass() self.__parameter_updater__.finishPass()
pass_evaluator.finish() pass_evaluator.finish()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册