提交 9d46f443 编写于 作者: Q qijun

fix attr bug in op_test and ensure order in duplicate inputs/outputs

上级 faf827ba
...@@ -9,7 +9,7 @@ def grad_var_name(var_name): ...@@ -9,7 +9,7 @@ def grad_var_name(var_name):
return var_name + "@GRAD" return var_name + "@GRAD"
def create_op(scope, op_type, inputs, outputs, attrs=None): def create_op(scope, op_type, inputs, outputs, attrs):
kwargs = dict() kwargs = dict()
for in_name, in_dup in Operator.get_op_inputs(op_type): for in_name, in_dup in Operator.get_op_inputs(op_type):
...@@ -17,7 +17,7 @@ def create_op(scope, op_type, inputs, outputs, attrs=None): ...@@ -17,7 +17,7 @@ def create_op(scope, op_type, inputs, outputs, attrs=None):
kwargs[in_name] = [] kwargs[in_name] = []
if in_dup: if in_dup:
sub_in = inputs[in_name] sub_in = inputs[in_name]
for sub_in_name in sub_in: for sub_in_name, arr in sub_in:
var = scope.new_var(sub_in_name) var = scope.new_var(sub_in_name)
kwargs[in_name].append(sub_in_name) kwargs[in_name].append(sub_in_name)
else: else:
...@@ -29,15 +29,16 @@ def create_op(scope, op_type, inputs, outputs, attrs=None): ...@@ -29,15 +29,16 @@ def create_op(scope, op_type, inputs, outputs, attrs=None):
kwargs[out_name] = [] kwargs[out_name] = []
if out_dup: if out_dup:
sub_in = outputs[out_name] sub_in = outputs[out_name]
for sun_in_name in sub_in: for sub_in_name, arr in sub_in:
var = scope.new_var(sun_in_name) var = scope.new_var(sub_in_name)
kwargs[out_name].append(sun_in_name) kwargs[out_name].append(sub_in_name)
else: else:
var = scope.new_var(out_name) var = scope.new_var(out_name)
kwargs[out_name].append(out_name) kwargs[out_name].append(out_name)
for attr_name in Operator.get_op_attr_names(op_type): for attr_name in Operator.get_op_attr_names(op_type):
kwargs[attr_name] = attrs[attr_name] if attr_name in attrs:
kwargs[attr_name] = attrs[attr_name]
return Operator(op_type, **kwargs) return Operator(op_type, **kwargs)
...@@ -46,10 +47,9 @@ def set_input(scope, op, inputs, place): ...@@ -46,10 +47,9 @@ def set_input(scope, op, inputs, place):
if in_name in inputs: if in_name in inputs:
if in_dup: if in_dup:
sub_in = inputs[in_name] sub_in = inputs[in_name]
for sub_in_name in sub_in: for sub_in_name, arr in sub_in:
var = scope.find_var(sub_in_name) var = scope.find_var(sub_in_name)
tensor = var.get_tensor() tensor = var.get_tensor()
arr = sub_in[sub_in_name]
tensor.set_dims(arr.shape) tensor.set_dims(arr.shape)
tensor.set(arr, place) tensor.set(arr, place)
else: else:
...@@ -65,7 +65,7 @@ def set_output_grad(scope, op, outputs, place): ...@@ -65,7 +65,7 @@ def set_output_grad(scope, op, outputs, place):
if out_name in outputs: if out_name in outputs:
if out_dup: if out_dup:
sub_out = outputs[out_name] sub_out = outputs[out_name]
for sub_out_name in sub_out: for sub_out_name, arr in sub_out:
out_tensor = scope.find_var(sub_out_name).get_tensor() out_tensor = scope.find_var(sub_out_name).get_tensor()
grad_tensor = scope.new_var(grad_var_name( grad_tensor = scope.new_var(grad_var_name(
sub_out_name)).get_tensor() sub_out_name)).get_tensor()
...@@ -110,7 +110,7 @@ def get_numeric_gradient(scope, ...@@ -110,7 +110,7 @@ def get_numeric_gradient(scope,
# we use a for loop to compute the gradient of every element. # we use a for loop to compute the gradient of every element.
for i in xrange(tensor_size): for i in xrange(tensor_size):
if in_place: if in_place:
set_input(op, inputs, core.CPUPlace()) set_input(scope, op, inputs, core.CPUPlace())
# get one input element throw it's index i. # get one input element throw it's index i.
origin = tensor_to_check.get_float_element(i) origin = tensor_to_check.get_float_element(i)
...@@ -120,7 +120,7 @@ def get_numeric_gradient(scope, ...@@ -120,7 +120,7 @@ def get_numeric_gradient(scope,
y_pos = get_output() y_pos = get_output()
if in_place: if in_place:
set_input(op, inputs, core.CPUPlace()) set_input(scope, op, inputs, core.CPUPlace())
x_neg = origin - delta x_neg = origin - delta
tensor_to_check.set_float_element(i, x_neg) tensor_to_check.set_float_element(i, x_neg)
...@@ -168,7 +168,11 @@ def get_gradient(scope, op, inputs, outputs, grad_name, place, ...@@ -168,7 +168,11 @@ def get_gradient(scope, op, inputs, outputs, grad_name, place,
class OpTest(unittest.TestCase): class OpTest(unittest.TestCase):
def check_output_with_place(self, place): def check_output_with_place(self, place):
self.scope = core.Scope() self.scope = core.Scope()
self.op = create_op(self.scope, self.op_type, self.inputs, self.outputs) op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_outputs = self.outputs if hasattr(self, "outputs") else dict()
op_attrs = self.attrs if hasattr(self, "attrs") else dict()
self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs,
op_attrs)
if isinstance(place, core.GPUPlace) and not self.op.support_gpu(): if isinstance(place, core.GPUPlace) and not self.op.support_gpu():
return return
set_input(self.scope, self.op, self.inputs, place) set_input(self.scope, self.op, self.inputs, place)
...@@ -227,7 +231,11 @@ class OpTest(unittest.TestCase): ...@@ -227,7 +231,11 @@ class OpTest(unittest.TestCase):
in_place=False, in_place=False,
max_relative_error=0.005): max_relative_error=0.005):
self.scope = core.Scope() self.scope = core.Scope()
self.op = create_op(self.scope, self.op_type, self.inputs, self.outputs) op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_outputs = self.outputs if hasattr(self, "outputs") else dict()
op_attrs = self.attrs if hasattr(self, "attrs") else dict()
self.op = create_op(self.scope, self.op_type, op_inputs, op_outputs,
op_attrs)
if no_grad_set is None: if no_grad_set is None:
no_grad_set = set() no_grad_set = set()
......
...@@ -9,7 +9,7 @@ class TestSumOp(OpTest): ...@@ -9,7 +9,7 @@ class TestSumOp(OpTest):
x0 = np.random.random((3, 4)).astype('float32') x0 = np.random.random((3, 4)).astype('float32')
x1 = np.random.random((3, 4)).astype('float32') x1 = np.random.random((3, 4)).astype('float32')
x2 = np.random.random((3, 4)).astype('float32') x2 = np.random.random((3, 4)).astype('float32')
self.inputs = {"X": {"x0": x0, "x1": x1, "x2": x2}} self.inputs = {"X": [("x0", x0), ("x1", x1), ("x2", x2)]}
y = x0 + x1 + x2 y = x0 + x1 + x2
self.outputs = {'Out': y} self.outputs = {'Out': y}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册