From fcc28ccea220ab2be166ea824dca3504dd3fc2c6 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 2 Aug 2017 16:18:59 +0800 Subject: [PATCH] Add comments --- .../v2/framework/tests/gradient_checker.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/python/paddle/v2/framework/tests/gradient_checker.py b/python/paddle/v2/framework/tests/gradient_checker.py index d7e5de82523..e7fca05d6f7 100644 --- a/python/paddle/v2/framework/tests/gradient_checker.py +++ b/python/paddle/v2/framework/tests/gradient_checker.py @@ -10,8 +10,24 @@ def get_numeric_gradient(op, input_to_check, delta=1e-5, local_scope=None): + """ + Get Numeric Gradient for an operator's input. + + :param op: C++ operator instance, could be an network + :param input_values: The input variables. Should be an dictionary, key is + variable name. Value is numpy array. + :param output_name: The final output variable name. + :param input_to_check: The input variable need to get gradient. + :param delta: The perturbation value for numeric gradient method. The + smaller delta is, the more accurate result will get. But if that delta is + too small, it could occur numerical stability problem. + :param local_scope: The local scope used for get_numeric_gradient. + :return: The gradient array in numpy format. + """ if local_scope is None: local_scope = core.Scope() + + # Create all input variable in local_scope for var_name in input_values: var = local_scope.new_var(var_name) tensor = var.get_tensor() @@ -19,14 +35,18 @@ def get_numeric_gradient(op, tensor.alloc_float() tensor.set(input_values[var_name]) + # Create all output variable in local_scope for output in op.outputs(): - local_scope.new_var(output).get_tensor() + if local_scope.find_var(output) is None: + local_scope.new_var(output).get_tensor() op.infer_shape(local_scope) + # allocate output memory for output in op.outputs(): local_scope.find_var(output).get_tensor().alloc_float() + # TODO(yuyang18): Only CPU is support now. cpu_ctx = core.DeviceContext.cpu_context() def get_output(): -- GitLab