diff --git a/python/paddle/v2/framework/tests/gradient_checker.py b/python/paddle/v2/framework/tests/gradient_checker.py index 51a98284bdd17ec2caeed53c231654b309514c8b..ed838b5979d3bd5a3caa954074a692bbd24a5a07 100644 --- a/python/paddle/v2/framework/tests/gradient_checker.py +++ b/python/paddle/v2/framework/tests/gradient_checker.py @@ -11,9 +11,9 @@ __all__ = ['get_numeric_gradient'] def create_op(op_type): # TODO need to set attrs kwargs = dict() - for in_name, _ in Operator.get_op_input_names(op_type): + for in_name in Operator.get_op_input_names(op_type): kwargs[in_name] = in_name - for out_name, _ in Operator.get_op_output_names(op_type): + for out_name in Operator.get_op_output_names(op_type): kwargs[out_name] = out_name return Operator(op_type, **kwargs) diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index 54fe5da4401655ee50af08c18eaee7ad90c2fd8d..99a114e45f1228551375982308b6f2cdc7aabb44 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -27,9 +27,7 @@ class OpTestMeta(type): places.append(core.GPUPlace(0)) for place in places: - for ins in Operator.get_op_input_names(self.type): - in_name = ins[0] - in_dup = ins[1] + for in_name, in_dup in Operator.get_op_inputs(self.type): if hasattr(self, 'inputs') and in_name in self.inputs: kwargs[in_name] = [] if in_dup: @@ -49,8 +47,7 @@ class OpTestMeta(type): else: kwargs[in_name] = "@EMPTY@" - for out_name, out_dup in Operator.get_op_output_names( - self.type): + for out_name, out_dup in Operator.get_op_outputs(self.type): if not hasattr(self, "outputs"): raise ValueError( "The test op must set self.outputs dict.") @@ -73,8 +70,7 @@ class OpTestMeta(type): ctx = core.DeviceContext.create(place) op.run(scope, ctx) - for out_name, out_dup in Operator.get_op_output_names( - self.type): + for out_name, out_dup in Operator.get_op_outputs(self.type): actual = numpy.array(scope.find_var(out_name).get_tensor()) expect = self.outputs[out_name] self.assertTrue(