提交 2edb69a7 编写于 作者: Z zchen0211

with in-place option

上级 eb8b84b5
...@@ -32,7 +32,8 @@ def get_numeric_gradient(op, ...@@ -32,7 +32,8 @@ def get_numeric_gradient(op,
output_name, output_name,
input_to_check, input_to_check,
delta=0.005, delta=0.005,
local_scope=None): local_scope=None,
in_place=False):
""" """
Get Numeric Gradient for an operator's input. Get Numeric Gradient for an operator's input.
...@@ -90,9 +91,10 @@ def get_numeric_gradient(op, ...@@ -90,9 +91,10 @@ def get_numeric_gradient(op,
# we only compute gradient of one element each time. # we only compute gradient of one element each time.
# we use a for loop to compute the gradient of every element. # we use a for loop to compute the gradient of every element.
for i in xrange(tensor_size): for i in xrange(tensor_size):
for var_name in input_values: if in_place:
tensor_ = local_scope.find_var(var_name).get_tensor() for var_name in input_values:
tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace()) tensor_ = local_scope.find_var(var_name).get_tensor()
tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace())
# get one input element throw it's index i. # get one input element throw it's index i.
origin = tensor_to_check.get_float_element(i) origin = tensor_to_check.get_float_element(i)
...@@ -102,9 +104,10 @@ def get_numeric_gradient(op, ...@@ -102,9 +104,10 @@ def get_numeric_gradient(op,
y_pos = get_output() y_pos = get_output()
# plus delta to this element, run op and get the sum of the result tensor. # plus delta to this element, run op and get the sum of the result tensor.
for var_name in input_values: if in_place:
tensor_ = local_scope.find_var(var_name).get_tensor() for var_name in input_values:
tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace()) tensor_ = local_scope.find_var(var_name).get_tensor()
tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace())
x_neg = origin - delta x_neg = origin - delta
tensor_to_check.set_float_element(i, x_neg) tensor_to_check.set_float_element(i, x_neg)
y_neg = get_output() y_neg = get_output()
...@@ -257,6 +260,7 @@ class GradientChecker(unittest.TestCase): ...@@ -257,6 +260,7 @@ class GradientChecker(unittest.TestCase):
output_name, output_name,
no_grad_set=None, no_grad_set=None,
only_cpu=False, only_cpu=False,
in_place=False,
max_relative_error=0.005): max_relative_error=0.005):
""" """
:param forward_op: used to create backward_op :param forward_op: used to create backward_op
...@@ -289,7 +293,8 @@ class GradientChecker(unittest.TestCase): ...@@ -289,7 +293,8 @@ class GradientChecker(unittest.TestCase):
# get numerical gradients # get numerical gradients
numeric_grads = [ numeric_grads = [
get_numeric_gradient(forward_op, input_vars, output_name, name) get_numeric_gradient(
forward_op, input_vars, output_name, name, in_place=in_place)
for name in inputs_to_check for name in inputs_to_check
] ]
......
...@@ -31,7 +31,8 @@ class TestScatterGradOp(GradientChecker): ...@@ -31,7 +31,8 @@ class TestScatterGradOp(GradientChecker):
output_np[index_np] += updates_np output_np[index_np] += updates_np
inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np} inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np}
# check gradient # check gradient
self.check_grad(op, inputs, set(["Updates", "Ref"]), "Out") self.check_grad(
op, inputs, set(["Updates", "Ref"]), "Out", in_place=True)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册