提交 9a0eedf5 编写于 作者: D dangqingqing

fix bug.

上级 01d91340
...@@ -156,6 +156,7 @@ class GradientChecker(unittest.TestCase): ...@@ -156,6 +156,7 @@ class GradientChecker(unittest.TestCase):
def compare_grad(self, forward_op, inputs): def compare_grad(self, forward_op, inputs):
backward_op = core.Operator.backward(forward_op, set()) backward_op = core.Operator.backward(forward_op, set())
# return if not compile with GPU or not implementing GPU kernel
if not (core.is_compile_gpu() and backward_op.support_gpu()): if not (core.is_compile_gpu() and backward_op.support_gpu()):
return return
...@@ -239,7 +240,7 @@ class GradientChecker(unittest.TestCase): ...@@ -239,7 +240,7 @@ class GradientChecker(unittest.TestCase):
for place in places: for place in places:
# get analytical gradients according to different device # get analytical gradients according to different device
analytic_grads = self.get_grad(forward_op, backward_op, input_vars, analytic_grads = self.get_grad(forward_op, backward_op, input_vars,
check_grad_names, place) check_names, place)
self.assert_is_close(numeric_grads, analytic_grads, check_names, self.assert_is_close(numeric_grads, analytic_grads, check_names,
max_relative_error, max_relative_error,
"Gradient Check On %s" % str(place)) "Gradient Check On %s" % str(place))
...@@ -17,14 +17,9 @@ class TestSigmoidGradOp(GradientChecker): ...@@ -17,14 +17,9 @@ class TestSigmoidGradOp(GradientChecker):
def test_compare_grad(self): def test_compare_grad(self):
op = create_op("sigmoid") op = create_op("sigmoid")
inputs = {"X": np.random.random((11, 17)).astype("float32")} inputs = {"X": np.random.random((11, 17)).astype("float32")}
# compare gpu and cpu results for backward op.
# compare gpu and cpu results for backward op # skip this test if only compiling CPU version.
self.compare_grad(op, inputs) self.compare_grad(op, inputs)
def test_check_grad(self):
op = create_op("sigmoid")
inputs = {"X": np.random.random((11, 17)).astype("float32")}
# check gradients # check gradients
self.check_grad(op, inputs, set("X"), "Y") self.check_grad(op, inputs, set("X"), "Y")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册