提交 477b23c3 编写于 作者: Q qijun

follow comments

上级 436fbb0d
......@@ -17,7 +17,7 @@ def create_op(scope, op_type, inputs, outputs, attrs):
kwargs[in_name] = []
if in_dup:
sub_in = inputs[in_name]
for sub_in_name, arr in sub_in:
for sub_in_name, _ in sub_in:
var = scope.new_var(sub_in_name)
kwargs[in_name].append(sub_in_name)
else:
......@@ -29,7 +29,7 @@ def create_op(scope, op_type, inputs, outputs, attrs):
kwargs[out_name] = []
if out_dup:
sub_in = outputs[out_name]
for sub_in_name, arr in sub_in:
for sub_in_name, _ in sub_in:
var = scope.new_var(sub_in_name)
kwargs[out_name].append(sub_in_name)
else:
......@@ -47,11 +47,11 @@ def set_input(scope, op, inputs, place):
if in_name in inputs:
if in_dup:
sub_in = inputs[in_name]
for sub_in_name, arr in sub_in:
for sub_in_name, sub_in_array in sub_in:
var = scope.find_var(sub_in_name)
tensor = var.get_tensor()
tensor.set_dims(arr.shape)
tensor.set(arr, place)
tensor.set_dims(sub_in_array.shape)
tensor.set(sub_in_array, place)
else:
var = scope.find_var(in_name)
tensor = var.get_tensor()
......@@ -65,7 +65,7 @@ def set_output_grad(scope, op, outputs, place):
if out_name in outputs:
if out_dup:
sub_out = outputs[out_name]
for sub_out_name, arr in sub_out:
for sub_out_name, sub_out_grad in sub_out:
out_tensor = scope.find_var(sub_out_name).get_tensor()
grad_tensor = scope.new_var(grad_var_name(
sub_out_name)).get_tensor()
......@@ -169,9 +169,8 @@ class OpTest(unittest.TestCase):
def check_output_with_place(self, place):
self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_outputs = self.outputs if hasattr(self, "outputs") else dict()
op_attrs = self.attrs if hasattr(self, "attrs") else dict()
self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs,
self.op = create_op(self.scope, self.op_type, op_inputs, self.outputs,
op_attrs)
if isinstance(place, core.GPUPlace) and not self.op.support_gpu():
return
......@@ -232,9 +231,8 @@ class OpTest(unittest.TestCase):
max_relative_error=0.005):
self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_outputs = self.outputs if hasattr(self, "outputs") else dict()
op_attrs = self.attrs if hasattr(self, "attrs") else dict()
self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs,
self.op = create_op(self.scope, self.op_type, op_inputs, self.outputs,
op_attrs)
if no_grad_set is None:
no_grad_set = set()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册