提交 133541ee 编写于 作者: Y Yu Yang

Merge codes

上级 8544bdbb
......@@ -10,7 +10,7 @@ def get_numeric_gradient(op,
input_values,
output_name,
input_to_check,
delta=1e-5,
delta=1e-2,
local_scope=None):
"""
Get Numeric Gradient for an operator's input.
......@@ -34,8 +34,8 @@ def get_numeric_gradient(op,
var = local_scope.new_var(var_name)
tensor = var.get_tensor()
tensor.set_dims(input_values[var_name].shape)
tensor.alloc_float()
tensor.set(input_values[var_name])
tensor.alloc_float(core.CPUPlace())
tensor.set(input_values[var_name], core.CPUPlace())
# Create all output variable in local_scope
for output in op.outputs():
......@@ -46,10 +46,10 @@ def get_numeric_gradient(op,
# allocate output memory
for output in op.outputs():
local_scope.find_var(output).get_tensor().alloc_float()
local_scope.find_var(output).get_tensor().alloc_float(core.CPUPlace())
# TODO(yuyang18): Only CPU is support now.
cpu_ctx = core.DeviceContext.cpu_context()
cpu_ctx = core.DeviceContext.create(core.CPUPlace())
def get_output():
op.run(local_scope, cpu_ctx)
......@@ -85,7 +85,6 @@ if __name__ == '__main__':
y = numpy.random.random((10, 1)).astype("float32")
arr = get_numeric_gradient(add_op, {'X': x, "Y": y}, 'Z', 'X')
self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-2)
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册