提交 f2156e4b 编写于 作者: Q QI JUN 提交者: GitHub

Merge pull request #4012 from QiJune/refine_op_test

fix attr bug in op_test and ensure order in duplicate inputs/outputs
......@@ -9,7 +9,7 @@ def grad_var_name(var_name):
return var_name + "@GRAD"
def create_op(scope, op_type, inputs, outputs, attrs=None):
def create_op(scope, op_type, inputs, outputs, attrs):
kwargs = dict()
for in_name, in_dup in Operator.get_op_inputs(op_type):
......@@ -17,7 +17,7 @@ def create_op(scope, op_type, inputs, outputs, attrs=None):
kwargs[in_name] = []
if in_dup:
sub_in = inputs[in_name]
for sub_in_name in sub_in:
for sub_in_name, _ in sub_in:
var = scope.new_var(sub_in_name)
kwargs[in_name].append(sub_in_name)
else:
......@@ -29,15 +29,16 @@ def create_op(scope, op_type, inputs, outputs, attrs=None):
kwargs[out_name] = []
if out_dup:
sub_in = outputs[out_name]
for sun_in_name in sub_in:
var = scope.new_var(sun_in_name)
kwargs[out_name].append(sun_in_name)
for sub_in_name, _ in sub_in:
var = scope.new_var(sub_in_name)
kwargs[out_name].append(sub_in_name)
else:
var = scope.new_var(out_name)
kwargs[out_name].append(out_name)
for attr_name in Operator.get_op_attr_names(op_type):
kwargs[attr_name] = attrs[attr_name]
if attr_name in attrs:
kwargs[attr_name] = attrs[attr_name]
return Operator(op_type, **kwargs)
......@@ -46,12 +47,11 @@ def set_input(scope, op, inputs, place):
if in_name in inputs:
if in_dup:
sub_in = inputs[in_name]
for sub_in_name in sub_in:
for sub_in_name, sub_in_array in sub_in:
var = scope.find_var(sub_in_name)
tensor = var.get_tensor()
arr = sub_in[sub_in_name]
tensor.set_dims(arr.shape)
tensor.set(arr, place)
tensor.set_dims(sub_in_array.shape)
tensor.set(sub_in_array, place)
else:
var = scope.find_var(in_name)
tensor = var.get_tensor()
......@@ -65,7 +65,7 @@ def set_output_grad(scope, op, outputs, place):
if out_name in outputs:
if out_dup:
sub_out = outputs[out_name]
for sub_out_name in sub_out:
for sub_out_name, sub_out_grad in sub_out:
out_tensor = scope.find_var(sub_out_name).get_tensor()
grad_tensor = scope.new_var(grad_var_name(
sub_out_name)).get_tensor()
......@@ -110,7 +110,7 @@ def get_numeric_gradient(scope,
# we use a for loop to compute the gradient of every element.
for i in xrange(tensor_size):
if in_place:
set_input(op, inputs, core.CPUPlace())
set_input(scope, op, inputs, core.CPUPlace())
# get one input element throw it's index i.
origin = tensor_to_check.get_float_element(i)
......@@ -120,7 +120,7 @@ def get_numeric_gradient(scope,
y_pos = get_output()
if in_place:
set_input(op, inputs, core.CPUPlace())
set_input(scope, op, inputs, core.CPUPlace())
x_neg = origin - delta
tensor_to_check.set_float_element(i, x_neg)
......@@ -168,7 +168,10 @@ def get_gradient(scope, op, inputs, outputs, grad_name, place,
class OpTest(unittest.TestCase):
def check_output_with_place(self, place):
self.scope = core.Scope()
self.op = create_op(self.scope, self.op_type, self.inputs, self.outputs)
op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_attrs = self.attrs if hasattr(self, "attrs") else dict()
self.op = create_op(self.scope, self.op_type, op_inputs, self.outputs,
op_attrs)
if isinstance(place, core.GPUPlace) and not self.op.support_gpu():
return
set_input(self.scope, self.op, self.inputs, place)
......@@ -227,7 +230,10 @@ class OpTest(unittest.TestCase):
in_place=False,
max_relative_error=0.005):
self.scope = core.Scope()
self.op = create_op(self.scope, self.op_type, self.inputs, self.outputs)
op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_attrs = self.attrs if hasattr(self, "attrs") else dict()
self.op = create_op(self.scope, self.op_type, op_inputs, self.outputs,
op_attrs)
if no_grad_set is None:
no_grad_set = set()
......
......@@ -9,7 +9,7 @@ class TestSumOp(OpTest):
x0 = np.random.random((3, 4)).astype('float32')
x1 = np.random.random((3, 4)).astype('float32')
x2 = np.random.random((3, 4)).astype('float32')
self.inputs = {"X": {"x0": x0, "x1": x1, "x2": x2}}
self.inputs = {"X": [("x0", x0), ("x1", x1), ("x2", x2)]}
y = x0 + x1 + x2
self.outputs = {'Out': y}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册