提交 133541ee 编写于 作者: Y Yu Yang

Merge codes

上级 8544bdbb
...@@ -10,7 +10,7 @@ def get_numeric_gradient(op, ...@@ -10,7 +10,7 @@ def get_numeric_gradient(op,
input_values, input_values,
output_name, output_name,
input_to_check, input_to_check,
delta=1e-5, delta=1e-2,
local_scope=None): local_scope=None):
""" """
Get Numeric Gradient for an operator's input. Get Numeric Gradient for an operator's input.
...@@ -34,8 +34,8 @@ def get_numeric_gradient(op, ...@@ -34,8 +34,8 @@ def get_numeric_gradient(op,
var = local_scope.new_var(var_name) var = local_scope.new_var(var_name)
tensor = var.get_tensor() tensor = var.get_tensor()
tensor.set_dims(input_values[var_name].shape) tensor.set_dims(input_values[var_name].shape)
tensor.alloc_float() tensor.alloc_float(core.CPUPlace())
tensor.set(input_values[var_name]) tensor.set(input_values[var_name], core.CPUPlace())
# Create all output variable in local_scope # Create all output variable in local_scope
for output in op.outputs(): for output in op.outputs():
...@@ -46,10 +46,10 @@ def get_numeric_gradient(op, ...@@ -46,10 +46,10 @@ def get_numeric_gradient(op,
# allocate output memory # allocate output memory
for output in op.outputs(): for output in op.outputs():
local_scope.find_var(output).get_tensor().alloc_float() local_scope.find_var(output).get_tensor().alloc_float(core.CPUPlace())
# TODO(yuyang18): Only CPU is support now. # TODO(yuyang18): Only CPU is support now.
cpu_ctx = core.DeviceContext.cpu_context() cpu_ctx = core.DeviceContext.create(core.CPUPlace())
def get_output(): def get_output():
op.run(local_scope, cpu_ctx) op.run(local_scope, cpu_ctx)
...@@ -85,7 +85,6 @@ if __name__ == '__main__': ...@@ -85,7 +85,6 @@ if __name__ == '__main__':
y = numpy.random.random((10, 1)).astype("float32") y = numpy.random.random((10, 1)).astype("float32")
arr = get_numeric_gradient(add_op, {'X': x, "Y": y}, 'Z', 'X') arr = get_numeric_gradient(add_op, {'X': x, "Y": y}, 'Z', 'X')
self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-2) self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-2)
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册