提交 2edb69a7 编写于 作者: Z zchen0211

with in-place option

上级 eb8b84b5
......@@ -32,7 +32,8 @@ def get_numeric_gradient(op,
output_name,
input_to_check,
delta=0.005,
local_scope=None):
local_scope=None,
in_place=False):
"""
Get Numeric Gradient for an operator's input.
......@@ -90,9 +91,10 @@ def get_numeric_gradient(op,
# we only compute gradient of one element each time.
# we use a for loop to compute the gradient of every element.
for i in xrange(tensor_size):
for var_name in input_values:
tensor_ = local_scope.find_var(var_name).get_tensor()
tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace())
if in_place:
for var_name in input_values:
tensor_ = local_scope.find_var(var_name).get_tensor()
tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace())
# get one input element throw it's index i.
origin = tensor_to_check.get_float_element(i)
......@@ -102,9 +104,10 @@ def get_numeric_gradient(op,
y_pos = get_output()
# plus delta to this element, run op and get the sum of the result tensor.
for var_name in input_values:
tensor_ = local_scope.find_var(var_name).get_tensor()
tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace())
if in_place:
for var_name in input_values:
tensor_ = local_scope.find_var(var_name).get_tensor()
tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace())
x_neg = origin - delta
tensor_to_check.set_float_element(i, x_neg)
y_neg = get_output()
......@@ -257,6 +260,7 @@ class GradientChecker(unittest.TestCase):
output_name,
no_grad_set=None,
only_cpu=False,
in_place=False,
max_relative_error=0.005):
"""
:param forward_op: used to create backward_op
......@@ -289,7 +293,8 @@ class GradientChecker(unittest.TestCase):
# get numerical gradients
numeric_grads = [
get_numeric_gradient(forward_op, input_vars, output_name, name)
get_numeric_gradient(
forward_op, input_vars, output_name, name, in_place=in_place)
for name in inputs_to_check
]
......
......@@ -31,7 +31,8 @@ class TestScatterGradOp(GradientChecker):
output_np[index_np] += updates_np
inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np}
# check gradient
self.check_grad(op, inputs, set(["Updates", "Ref"]), "Out")
self.check_grad(
op, inputs, set(["Updates", "Ref"]), "Out", in_place=True)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册