提交 9023248c 编写于 作者: Y Yu Yang 提交者: GitHub

Simplize Gradient Check (#5024)

上级 cdb5f292
...@@ -179,7 +179,12 @@ def get_backward_op(scope, op, no_grad_set): ...@@ -179,7 +179,12 @@ def get_backward_op(scope, op, no_grad_set):
return backward_op return backward_op
def get_gradient(scope, op, inputs, outputs, grad_name, place, def get_gradient(scope,
op,
inputs,
outputs,
grad_names,
place,
no_grad_set=None): no_grad_set=None):
ctx = core.DeviceContext.create(place) ctx = core.DeviceContext.create(place)
...@@ -195,8 +200,10 @@ def get_gradient(scope, op, inputs, outputs, grad_name, place, ...@@ -195,8 +200,10 @@ def get_gradient(scope, op, inputs, outputs, grad_name, place,
backward_op.run(scope, ctx) backward_op.run(scope, ctx)
out = np.array(scope.find_var(grad_name).get_tensor()) return [
return out np.array(scope.find_var(grad_name).get_tensor())
for grad_name in grad_names
]
def append_input_output(block, op_proto, np_list, is_input): def append_input_output(block, op_proto, np_list, is_input):
...@@ -399,11 +406,9 @@ class OpTest(unittest.TestCase): ...@@ -399,11 +406,9 @@ class OpTest(unittest.TestCase):
] ]
cpu_place = core.CPUPlace() cpu_place = core.CPUPlace()
cpu_analytic_grads = [ cpu_analytic_grads = get_gradient(self.scope, self.op, self.inputs,
get_gradient(self.scope, self.op, self.inputs, self.outputs, self.outputs, grad_names, cpu_place,
grad_name, cpu_place, no_grad_set) no_grad_set)
for grad_name in grad_names
]
self.__assert_is_close(numeric_grads, cpu_analytic_grads, grad_names, self.__assert_is_close(numeric_grads, cpu_analytic_grads, grad_names,
max_relative_error, max_relative_error,
...@@ -411,11 +416,9 @@ class OpTest(unittest.TestCase): ...@@ -411,11 +416,9 @@ class OpTest(unittest.TestCase):
if core.is_compile_gpu() and self.op.support_gpu(): if core.is_compile_gpu() and self.op.support_gpu():
gpu_place = core.GPUPlace(0) gpu_place = core.GPUPlace(0)
gpu_analytic_grads = [ gpu_analytic_grads = get_gradient(self.scope, self.op, self.inputs,
get_gradient(self.scope, self.op, self.inputs, self.outputs, self.outputs, grad_names,
grad_name, gpu_place, no_grad_set) gpu_place, no_grad_set)
for grad_name in grad_names
]
self.__assert_is_close(numeric_grads, gpu_analytic_grads, self.__assert_is_close(numeric_grads, gpu_analytic_grads,
grad_names, max_relative_error, grad_names, max_relative_error,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册