提交 47975870 编写于 作者: Y Yancey 提交者: GitHub

Fix check grad with multioutput (#4067)

Fix check grad with multi outputs
上级 e4bab9a4
...@@ -85,7 +85,7 @@ def get_numeric_gradient(scope, ...@@ -85,7 +85,7 @@ def get_numeric_gradient(scope,
op, op,
inputs, inputs,
input_to_check, input_to_check,
output_name, output_names,
delta=0.005, delta=0.005,
in_place=False): in_place=False):
...@@ -100,8 +100,11 @@ def get_numeric_gradient(scope, ...@@ -100,8 +100,11 @@ def get_numeric_gradient(scope,
ctx = core.DeviceContext.create(core.CPUPlace()) ctx = core.DeviceContext.create(core.CPUPlace())
def get_output(): def get_output():
sum = 0.0
for output_name in output_names:
op.run(scope, ctx) op.run(scope, ctx)
return np.array(scope.find_var(output_name).get_tensor()).sum() sum += np.array(scope.find_var(output_name).get_tensor()).sum()
return sum
tensor_to_check = scope.find_var(input_to_check).get_tensor() tensor_to_check = scope.find_var(input_to_check).get_tensor()
tensor_size = product(tensor_to_check.get_dims()) tensor_size = product(tensor_to_check.get_dims())
...@@ -225,7 +228,7 @@ class OpTest(unittest.TestCase): ...@@ -225,7 +228,7 @@ class OpTest(unittest.TestCase):
def check_grad(self, def check_grad(self,
inputs_to_check, inputs_to_check,
output_name, output_names,
no_grad_set=None, no_grad_set=None,
in_place=False, in_place=False,
max_relative_error=0.005): max_relative_error=0.005):
...@@ -237,13 +240,16 @@ class OpTest(unittest.TestCase): ...@@ -237,13 +240,16 @@ class OpTest(unittest.TestCase):
if no_grad_set is None: if no_grad_set is None:
no_grad_set = set() no_grad_set = set()
if not type(output_names) is list:
output_names = [output_names]
numeric_grads = [ numeric_grads = [
get_numeric_gradient( get_numeric_gradient(
self.scope, self.scope,
self.op, self.op,
self.inputs, self.inputs,
input_to_check, input_to_check,
output_name, output_names,
in_place=in_place) for input_to_check in inputs_to_check in_place=in_place) for input_to_check in inputs_to_check
] ]
grad_names = [ grad_names = [
......
...@@ -12,7 +12,8 @@ class GetNumericGradientTest(unittest.TestCase): ...@@ -12,7 +12,8 @@ class GetNumericGradientTest(unittest.TestCase):
z = x + y z = x + y
scope = core.Scope() scope = core.Scope()
add_op = create_op(scope, "add", {'X': x, 'Y': y}, {'Out': z}, dict()) add_op = create_op(scope, "add", {'X': x, 'Y': y}, {'Out': z}, dict())
arr = get_numeric_gradient(scope, add_op, {'X': x, 'Y': y}, 'X', 'Out') arr = get_numeric_gradient(scope, add_op, {'X': x,
'Y': y}, 'X', ['Out'])
self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-4) self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-4)
def test_softmax_op(self): def test_softmax_op(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册