提交 3d9d32a1 编写于 作者: Y Yu Yang

Invoke check_grad many times for no_grad_set

上级 44703329
...@@ -286,7 +286,7 @@ class GradientChecker(unittest.TestCase): ...@@ -286,7 +286,7 @@ class GradientChecker(unittest.TestCase):
for no_grad in no_grad_set: for no_grad in no_grad_set:
if no_grad not in in_names: if no_grad not in in_names:
raise ValueError("no_grad should be in in_names") raise ValueError("no_grad should be in in_names")
if name in inputs_to_check: if no_grad in inputs_to_check:
raise ValueError("no_grad should not be in inputs_to_check") raise ValueError("no_grad should not be in inputs_to_check")
backward_op = core.Operator.backward(forward_op, no_grad_set) backward_op = core.Operator.backward(forward_op, no_grad_set)
...@@ -304,25 +304,8 @@ class GradientChecker(unittest.TestCase): ...@@ -304,25 +304,8 @@ class GradientChecker(unittest.TestCase):
check_names = [grad_var_name(name) for name in inputs_to_check] check_names = [grad_var_name(name) for name in inputs_to_check]
for place in places: for place in places:
# analytic_grads = self.__get_gradient(forward_op, backward_op, analytic_grads = self.__get_gradient(forward_op, backward_op,
# input_vars, check_names, place) input_vars, check_names, place)
# In fact, the above two lines can be used to replace following
# codes. But most of the gradient operators need to handle the case
# where one of more of the gradient of the input is not needed.
# We change the unit test framework to explicitly test whether
# the operator correctly handles this through follow codes.
# In addtion, if all the inputs have no gradients, the NOP operator
# will be returned by core.Operator.backward(). The following codes
# do not test this case.
analytic_grads = []
for name in inputs_to_check:
no_grads = [name for name in no_grad_set]
no_grads.extend(filter(lambda x: x != name, inputs_to_check))
backward_op = core.Operator.backward(forward_op, set(no_grads))
# get analytical gradients according to different device
analytic_grads.extend(
self.__get_gradient(forward_op, backward_op, input_vars,
[grad_var_name(name)], place))
self.__assert_is_close(numeric_grads, analytic_grads, check_names, self.__assert_is_close(numeric_grads, analytic_grads, check_names,
max_relative_error, max_relative_error,
"Gradient Check On %s" % str(place)) "Gradient Check On %s" % str(place))
...@@ -17,16 +17,33 @@ class TestMulOp(unittest.TestCase): ...@@ -17,16 +17,33 @@ class TestMulOp(unittest.TestCase):
class TestMulGradOp(GradientChecker): class TestMulGradOp(GradientChecker):
def test_mul(self): def setUp(self):
op = create_op("mul") self.op = create_op("mul")
inputs = { self.inputs = {
'X': np.random.random((32, 84)).astype("float32"), 'X': np.random.random((32, 84)).astype("float32"),
'Y': np.random.random((84, 100)).astype("float32") 'Y': np.random.random((84, 100)).astype("float32")
} }
self.compare_grad(op, inputs)
def test_normal(self):
# mul op will enlarge the relative error # mul op will enlarge the relative error
self.check_grad( self.check_grad(
op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.5) self.op, self.inputs, ["X", "Y"], "Out", max_relative_error=0.5)
def test_ignore_x(self):
self.check_grad(
self.op,
self.inputs, ["Y"],
"Out",
max_relative_error=0.5,
no_grad_set={"X"})
def test_ignore_y(self):
self.check_grad(
self.op,
self.inputs, ["X"],
"Out",
max_relative_error=0.5,
no_grad_set={"Y"})
# TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library # TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library
......
...@@ -17,13 +17,21 @@ class TestRowwiseAddOp(unittest.TestCase): ...@@ -17,13 +17,21 @@ class TestRowwiseAddOp(unittest.TestCase):
class RowwiseAddGradOpTest(GradientChecker): class RowwiseAddGradOpTest(GradientChecker):
def test_rowwise_add(self): def setUp(self):
op = create_op("rowwise_add") self.op = create_op("rowwise_add")
inputs = { self.inputs = {
"X": np.random.uniform(0.1, 1, [5, 10]).astype("float32"), "X": np.random.uniform(0.1, 1, [5, 10]).astype("float32"),
"b": np.random.uniform(0.1, 1, [10]).astype("float32") "b": np.random.uniform(0.1, 1, [10]).astype("float32")
} }
self.check_grad(op, inputs, set(["X", "b"]), "Out")
def test_normal(self):
self.check_grad(self.op, self.inputs, ["X", "b"], "Out")
def test_ignore_b(self):
self.check_grad(self.op, self.inputs, ["X"], "Out", no_grad_set={"b"})
def test_ignore_x(self):
self.check_grad(self.op, self.inputs, ["b"], "Out", no_grad_set={"X"})
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册