From 9023248c6fa82ef38a2b99bb8e4d892067441cc1 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 23 Oct 2017 16:52:05 -0700 Subject: [PATCH] Simplize Gradient Check (#5024) --- python/paddle/v2/framework/tests/op_test.py | 29 ++++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index 1c6dce9634..0fdc21ef51 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -179,7 +179,12 @@ def get_backward_op(scope, op, no_grad_set): return backward_op -def get_gradient(scope, op, inputs, outputs, grad_name, place, +def get_gradient(scope, + op, + inputs, + outputs, + grad_names, + place, no_grad_set=None): ctx = core.DeviceContext.create(place) @@ -195,8 +200,10 @@ def get_gradient(scope, op, inputs, outputs, grad_name, place, backward_op.run(scope, ctx) - out = np.array(scope.find_var(grad_name).get_tensor()) - return out + return [ + np.array(scope.find_var(grad_name).get_tensor()) + for grad_name in grad_names + ] def append_input_output(block, op_proto, np_list, is_input): @@ -399,11 +406,9 @@ class OpTest(unittest.TestCase): ] cpu_place = core.CPUPlace() - cpu_analytic_grads = [ - get_gradient(self.scope, self.op, self.inputs, self.outputs, - grad_name, cpu_place, no_grad_set) - for grad_name in grad_names - ] + cpu_analytic_grads = get_gradient(self.scope, self.op, self.inputs, + self.outputs, grad_names, cpu_place, + no_grad_set) self.__assert_is_close(numeric_grads, cpu_analytic_grads, grad_names, max_relative_error, @@ -411,11 +416,9 @@ class OpTest(unittest.TestCase): if core.is_compile_gpu() and self.op.support_gpu(): gpu_place = core.GPUPlace(0) - gpu_analytic_grads = [ - get_gradient(self.scope, self.op, self.inputs, self.outputs, - grad_name, gpu_place, no_grad_set) - for grad_name in grad_names - ] + gpu_analytic_grads = get_gradient(self.scope, self.op, self.inputs, + self.outputs, grad_names, + gpu_place, no_grad_set) self.__assert_is_close(numeric_grads, gpu_analytic_grads, grad_names, max_relative_error, -- GitLab