diff --git a/python/paddle/v2/framework/tests/gradient_checker.py b/python/paddle/v2/framework/tests/gradient_checker.py index d251f14b9d15804b0e64fd72ed9f780c321bc816..2c92dfa43e7ee3abf4822268a8204d385c14bfc6 100644 --- a/python/paddle/v2/framework/tests/gradient_checker.py +++ b/python/paddle/v2/framework/tests/gradient_checker.py @@ -110,7 +110,24 @@ def get_numeric_gradient(op, class GradientChecker(unittest.TestCase): - def get_grad(self, forward_op, backward_op, input_vars, grad_names, place): + def __get_gradient(self, forward_op, backward_op, input_value, grad_names, + place): + """Get the input gradients after running forward and backward operators + on the given places. + + :param forward_op: forward operator + :type forward_op: Operator + :param backward_op: backward operator + :type backward_op: Operator + :param input_value: input values. + :type input_value: dict{string:numpy.array} + :param grad_names: the names of returned input gradients. + :type input_value: a list of string + :param place: the device type. + :type place: CPUPlace or GPUPlace + :return: the input grdients of given grad_names. + :rtype: a list of numpy.array + """ scope = core.Scope() ctx = core.DeviceContext.create(place) @@ -120,7 +137,7 @@ class GradientChecker(unittest.TestCase): out_names = [item for k in outputs for item in outputs[k]] # create input var and set value - for name, value in input_vars.iteritems(): + for name, value in input_value.iteritems(): if name not in in_names: raise ValueError(name + "does not exist in Op's inputs.") var = scope.new_var(name).get_tensor() @@ -154,7 +171,16 @@ class GradientChecker(unittest.TestCase): ] return outs - def compare_grad(self, forward_op, inputs): + def compare_grad(self, forward_op, input_value): + """ Compare the input gradients between CPU and GPU for the given forward + operator. + + :param forward_op: forward operator + :type forward_op: Operator + :param input_value: input values. + :type input_value: dict{string:numpy.array} + :raises: AssertionError, there is different gradient value. + """ backward_op = core.Operator.backward(forward_op, set()) # return if not compile with GPU or not implementing GPU kernel if not (core.is_compile_gpu() and backward_op.support_gpu()): @@ -162,19 +188,31 @@ class GradientChecker(unittest.TestCase): outputs = backward_op.outputs() out_names = [item for k in outputs for item in outputs[k]] - cpu_grads = self.get_grad(forward_op, backward_op, inputs, out_names, - core.CPUPlace()) - gpu_grads = self.get_grad(forward_op, backward_op, inputs, out_names, - core.GPUPlace(0)) + cpu_grads = self.get_grad(forward_op, backward_op, input_value, + out_names, core.CPUPlace()) + gpu_grads = self.get_grad(forward_op, backward_op, input_value, + out_names, core.GPUPlace(0)) for c_grad, g_grad, name in itertools.izip(cpu_grads, gpu_grads, out_names): self.assertTrue( - numpy.allclose(c_grad, g_grad), + numpy.allclose( + c_grad, g_grad, atol=1e-4), "output name: " + name + " has diff") - def assert_is_close(self, numeric_grads, analytic_grads, names, - max_relative_error, msg_prefix): + def __assert_is_close(self, numeric_grads, analytic_grads, names, + max_relative_error, msg_prefix): + """Use relative error for the comparison. + + :param numeric_grads: the numerical graidents. + :type numeric_grads: a list of numpy.array + :param analytic_grads: the analytical graidents. + :type analytic_grads: a list of numpy.array + :param name: the names of gradients, used to print for debug. + :type names: a list of string + :param msg_prefix: string info, used to print for debug. + :type msf_prefix: string + """ for a, b, name in itertools.izip(numeric_grads, analytic_grads, names): abs_a = numpy.abs(a) # if abs_a is nearly zero, then use abs error for a, not relative @@ -241,6 +279,6 @@ class GradientChecker(unittest.TestCase): # get analytical gradients according to different device analytic_grads = self.get_grad(forward_op, backward_op, input_vars, check_names, place) - self.assert_is_close(numeric_grads, analytic_grads, check_names, - max_relative_error, - "Gradient Check On %s" % str(place)) + self.__assert_is_close(numeric_grads, analytic_grads, check_names, + max_relative_error, + "Gradient Check On %s" % str(place))