提交 fcc28cce 编写于 作者: Y Yu Yang

Add comments

上级 7c42aad4
...@@ -10,8 +10,24 @@ def get_numeric_gradient(op, ...@@ -10,8 +10,24 @@ def get_numeric_gradient(op,
input_to_check, input_to_check,
delta=1e-5, delta=1e-5,
local_scope=None): local_scope=None):
"""
Get Numeric Gradient for an operator's input.
:param op: C++ operator instance, could be an network
:param input_values: The input variables. Should be an dictionary, key is
variable name. Value is numpy array.
:param output_name: The final output variable name.
:param input_to_check: The input variable need to get gradient.
:param delta: The perturbation value for numeric gradient method. The
smaller delta is, the more accurate result will get. But if that delta is
too small, it could occur numerical stability problem.
:param local_scope: The local scope used for get_numeric_gradient.
:return: The gradient array in numpy format.
"""
if local_scope is None: if local_scope is None:
local_scope = core.Scope() local_scope = core.Scope()
# Create all input variable in local_scope
for var_name in input_values: for var_name in input_values:
var = local_scope.new_var(var_name) var = local_scope.new_var(var_name)
tensor = var.get_tensor() tensor = var.get_tensor()
...@@ -19,14 +35,18 @@ def get_numeric_gradient(op, ...@@ -19,14 +35,18 @@ def get_numeric_gradient(op,
tensor.alloc_float() tensor.alloc_float()
tensor.set(input_values[var_name]) tensor.set(input_values[var_name])
# Create all output variable in local_scope
for output in op.outputs(): for output in op.outputs():
local_scope.new_var(output).get_tensor() if local_scope.find_var(output) is None:
local_scope.new_var(output).get_tensor()
op.infer_shape(local_scope) op.infer_shape(local_scope)
# allocate output memory
for output in op.outputs(): for output in op.outputs():
local_scope.find_var(output).get_tensor().alloc_float() local_scope.find_var(output).get_tensor().alloc_float()
# TODO(yuyang18): Only CPU is support now.
cpu_ctx = core.DeviceContext.cpu_context() cpu_ctx = core.DeviceContext.cpu_context()
def get_output(): def get_output():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册