You need to sign in or sign up before continuing.
提交 c2a437fd 编写于 作者: S Sylwester Fraczek

review fixes

上级 dcbc4284
......@@ -103,37 +103,20 @@ class TestConv2dOp(OpTest):
return core.is_compiled_with_cuda() and self.use_cudnn
def test_check_output(self):
if self.testcudnn():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
place = core.CPUPlace()
place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
self.check_output_with_place(place, atol=1e-5)
def test_check_grad(self):
if self.dtype == np.float16:
return
if self.testcudnn():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
set(['Input', 'Filter']),
'Output',
max_relative_error=0.02)
place = core.CPUPlace()
place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
self.check_grad_with_place(
place, set(['Input', 'Filter']), 'Output', max_relative_error=0.02)
def test_check_grad_no_filter(self):
if self.dtype == np.float16:
return
if self.testcudnn():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Input'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Filter']))
place = core.CPUPlace()
place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
self.check_grad_with_place(
place, ['Input'],
'Output',
......@@ -143,14 +126,7 @@ class TestConv2dOp(OpTest):
def test_check_grad_no_input(self):
if self.dtype == np.float16:
return
if self.testcudnn():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']))
place = core.CPUPlace()
place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
self.check_grad_with_place(
place, ['Filter'],
'Output',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册