提交 a28a5564 编写于 作者: D dangqingqing

add more comments and fix code style.

上级 9a0eedf5
...@@ -110,7 +110,24 @@ def get_numeric_gradient(op, ...@@ -110,7 +110,24 @@ def get_numeric_gradient(op,
class GradientChecker(unittest.TestCase): class GradientChecker(unittest.TestCase):
def get_grad(self, forward_op, backward_op, input_vars, grad_names, place): def __get_gradient(self, forward_op, backward_op, input_value, grad_names,
place):
"""Get the input gradients after running forward and backward operators
on the given places.
:param forward_op: forward operator
:type forward_op: Operator
:param backward_op: backward operator
:type backward_op: Operator
:param input_value: input values.
:type input_value: dict{string:numpy.array}
:param grad_names: the names of returned input gradients.
:type input_value: a list of string
:param place: the device type.
:type place: CPUPlace or GPUPlace
:return: the input grdients of given grad_names.
:rtype: a list of numpy.array
"""
scope = core.Scope() scope = core.Scope()
ctx = core.DeviceContext.create(place) ctx = core.DeviceContext.create(place)
...@@ -120,7 +137,7 @@ class GradientChecker(unittest.TestCase): ...@@ -120,7 +137,7 @@ class GradientChecker(unittest.TestCase):
out_names = [item for k in outputs for item in outputs[k]] out_names = [item for k in outputs for item in outputs[k]]
# create input var and set value # create input var and set value
for name, value in input_vars.iteritems(): for name, value in input_value.iteritems():
if name not in in_names: if name not in in_names:
raise ValueError(name + "does not exist in Op's inputs.") raise ValueError(name + "does not exist in Op's inputs.")
var = scope.new_var(name).get_tensor() var = scope.new_var(name).get_tensor()
...@@ -154,7 +171,16 @@ class GradientChecker(unittest.TestCase): ...@@ -154,7 +171,16 @@ class GradientChecker(unittest.TestCase):
] ]
return outs return outs
def compare_grad(self, forward_op, inputs): def compare_grad(self, forward_op, input_value):
""" Compare the input gradients between CPU and GPU for the given forward
operator.
:param forward_op: forward operator
:type forward_op: Operator
:param input_value: input values.
:type input_value: dict{string:numpy.array}
:raises: AssertionError, there is different gradient value.
"""
backward_op = core.Operator.backward(forward_op, set()) backward_op = core.Operator.backward(forward_op, set())
# return if not compile with GPU or not implementing GPU kernel # return if not compile with GPU or not implementing GPU kernel
if not (core.is_compile_gpu() and backward_op.support_gpu()): if not (core.is_compile_gpu() and backward_op.support_gpu()):
...@@ -162,19 +188,31 @@ class GradientChecker(unittest.TestCase): ...@@ -162,19 +188,31 @@ class GradientChecker(unittest.TestCase):
outputs = backward_op.outputs() outputs = backward_op.outputs()
out_names = [item for k in outputs for item in outputs[k]] out_names = [item for k in outputs for item in outputs[k]]
cpu_grads = self.get_grad(forward_op, backward_op, inputs, out_names, cpu_grads = self.get_grad(forward_op, backward_op, input_value,
core.CPUPlace()) out_names, core.CPUPlace())
gpu_grads = self.get_grad(forward_op, backward_op, inputs, out_names, gpu_grads = self.get_grad(forward_op, backward_op, input_value,
core.GPUPlace(0)) out_names, core.GPUPlace(0))
for c_grad, g_grad, name in itertools.izip(cpu_grads, gpu_grads, for c_grad, g_grad, name in itertools.izip(cpu_grads, gpu_grads,
out_names): out_names):
self.assertTrue( self.assertTrue(
numpy.allclose(c_grad, g_grad), numpy.allclose(
c_grad, g_grad, atol=1e-4),
"output name: " + name + " has diff") "output name: " + name + " has diff")
def assert_is_close(self, numeric_grads, analytic_grads, names, def __assert_is_close(self, numeric_grads, analytic_grads, names,
max_relative_error, msg_prefix): max_relative_error, msg_prefix):
"""Use relative error for the comparison.
:param numeric_grads: the numerical graidents.
:type numeric_grads: a list of numpy.array
:param analytic_grads: the analytical graidents.
:type analytic_grads: a list of numpy.array
:param name: the names of gradients, used to print for debug.
:type names: a list of string
:param msg_prefix: string info, used to print for debug.
:type msf_prefix: string
"""
for a, b, name in itertools.izip(numeric_grads, analytic_grads, names): for a, b, name in itertools.izip(numeric_grads, analytic_grads, names):
abs_a = numpy.abs(a) abs_a = numpy.abs(a)
# if abs_a is nearly zero, then use abs error for a, not relative # if abs_a is nearly zero, then use abs error for a, not relative
...@@ -241,6 +279,6 @@ class GradientChecker(unittest.TestCase): ...@@ -241,6 +279,6 @@ class GradientChecker(unittest.TestCase):
# get analytical gradients according to different device # get analytical gradients according to different device
analytic_grads = self.get_grad(forward_op, backward_op, input_vars, analytic_grads = self.get_grad(forward_op, backward_op, input_vars,
check_names, place) check_names, place)
self.assert_is_close(numeric_grads, analytic_grads, check_names, self.__assert_is_close(numeric_grads, analytic_grads, check_names,
max_relative_error, max_relative_error,
"Gradient Check On %s" % str(place)) "Gradient Check On %s" % str(place))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册