提交 161a15f0 编写于 作者: Z zchen0211

gradient check

上级 ab6b3c48
...@@ -86,6 +86,9 @@ def get_numeric_gradient(op, ...@@ -86,6 +86,9 @@ def get_numeric_gradient(op,
# we only compute gradient of one element each time. # we only compute gradient of one element each time.
# we use a for loop to compute the gradient of every element. # we use a for loop to compute the gradient of every element.
for i in xrange(tensor_size): for i in xrange(tensor_size):
for var_name in input_values:
tensor_ = local_scope.find_var(var_name).get_tensor()
tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace())
# get one input element throw it's index i. # get one input element throw it's index i.
origin = tensor_to_check.get_float_element(i) origin = tensor_to_check.get_float_element(i)
...@@ -95,6 +98,9 @@ def get_numeric_gradient(op, ...@@ -95,6 +98,9 @@ def get_numeric_gradient(op,
y_pos = get_output() y_pos = get_output()
# plus delta to this element, run op and get the sum of the result tensor. # plus delta to this element, run op and get the sum of the result tensor.
for var_name in input_values:
tensor_ = local_scope.find_var(var_name).get_tensor()
tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace())
x_neg = origin - delta x_neg = origin - delta
tensor_to_check.set_float_element(i, x_neg) tensor_to_check.set_float_element(i, x_neg)
y_neg = get_output() y_neg = get_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册