From 477b23c3f5c123b446cec48321105e1a471c1212 Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 11 Sep 2017 18:37:19 +0800 Subject: [PATCH] follow comments --- python/paddle/v2/framework/tests/op_test.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index 1daa6fa277..489358ba85 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -17,7 +17,7 @@ def create_op(scope, op_type, inputs, outputs, attrs): kwargs[in_name] = [] if in_dup: sub_in = inputs[in_name] - for sub_in_name, arr in sub_in: + for sub_in_name, _ in sub_in: var = scope.new_var(sub_in_name) kwargs[in_name].append(sub_in_name) else: @@ -29,7 +29,7 @@ def create_op(scope, op_type, inputs, outputs, attrs): kwargs[out_name] = [] if out_dup: sub_in = outputs[out_name] - for sub_in_name, arr in sub_in: + for sub_in_name, _ in sub_in: var = scope.new_var(sub_in_name) kwargs[out_name].append(sub_in_name) else: @@ -47,11 +47,11 @@ def set_input(scope, op, inputs, place): if in_name in inputs: if in_dup: sub_in = inputs[in_name] - for sub_in_name, arr in sub_in: + for sub_in_name, sub_in_array in sub_in: var = scope.find_var(sub_in_name) tensor = var.get_tensor() - tensor.set_dims(arr.shape) - tensor.set(arr, place) + tensor.set_dims(sub_in_array.shape) + tensor.set(sub_in_array, place) else: var = scope.find_var(in_name) tensor = var.get_tensor() @@ -65,7 +65,7 @@ def set_output_grad(scope, op, outputs, place): if out_name in outputs: if out_dup: sub_out = outputs[out_name] - for sub_out_name, arr in sub_out: + for sub_out_name, sub_out_grad in sub_out: out_tensor = scope.find_var(sub_out_name).get_tensor() grad_tensor = scope.new_var(grad_var_name( sub_out_name)).get_tensor() @@ -169,9 +169,8 @@ class OpTest(unittest.TestCase): def check_output_with_place(self, place): self.scope = core.Scope() op_inputs = self.inputs if hasattr(self, "inputs") else dict() - op_outputs = self.outputs if hasattr(self, "outputs") else dict() op_attrs = self.attrs if hasattr(self, "attrs") else dict() - self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs, + self.op = create_op(self.scope, self.op_type, op_inputs, self.outputs, op_attrs) if isinstance(place, core.GPUPlace) and not self.op.support_gpu(): return @@ -232,9 +231,8 @@ class OpTest(unittest.TestCase): max_relative_error=0.005): self.scope = core.Scope() op_inputs = self.inputs if hasattr(self, "inputs") else dict() - op_outputs = self.outputs if hasattr(self, "outputs") else dict() op_attrs = self.attrs if hasattr(self, "attrs") else dict() - self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs, + self.op = create_op(self.scope, self.op_type, op_inputs, self.outputs, op_attrs) if no_grad_set is None: no_grad_set = set() -- GitLab