From 54892c079735aaffafc7388486482e06ff139439 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 28 Sep 2017 11:59:17 -0700 Subject: [PATCH] Simplify op_test --- python/paddle/v2/framework/tests/op_test.py | 42 +++++++++------------ 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index 70ae50d401c..23794151bdb 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -12,17 +12,19 @@ def grad_var_name(var_name): def create_op(scope, op_type, inputs, outputs, attrs): kwargs = dict() + def __create_var__(name, var_name): + scope.new_var(var_name) + kwargs[name].append(var_name) + for in_name, in_dup in Operator.get_op_inputs(op_type): if in_name in inputs: kwargs[in_name] = [] if in_dup: sub_in = inputs[in_name] for sub_in_name, _ in sub_in: - var = scope.new_var(sub_in_name) - kwargs[in_name].append(sub_in_name) + __create_var__(in_name, sub_in_name) else: - var = scope.new_var(in_name) - kwargs[in_name].append(in_name) + __create_var__(in_name, in_name) for out_name, out_dup in Operator.get_op_outputs(op_type): if out_name in outputs: @@ -30,11 +32,9 @@ def create_op(scope, op_type, inputs, outputs, attrs): if out_dup: sub_out = outputs[out_name] for sub_out_name, _ in sub_out: - var = scope.new_var(sub_out_name) - kwargs[out_name].append(sub_out_name) + __create_var__(out_name, sub_out_name) else: - var = scope.new_var(out_name) - kwargs[out_name].append(out_name) + __create_var__(out_name, out_name) for attr_name in Operator.get_op_attr_names(op_type): if attr_name in attrs: @@ -44,28 +44,22 @@ def create_op(scope, op_type, inputs, outputs, attrs): def set_input(scope, op, inputs, place): + def __set_input__(var_name, var): + tensor = scope.find_var(var_name).get_tensor() + if isinstance(var, tuple): + tensor.set_lod(var[1]) + var = var[0] + tensor.set_dims(var.shape) + tensor.set(var, place) + for in_name, in_dup in Operator.get_op_inputs(op.type()): if in_name in inputs: if in_dup: sub_in = inputs[in_name] for sub_in_name, sub_in_val in sub_in: - var = scope.find_var(sub_in_name) - tensor = var.get_tensor() - sub_in_array = sub_in_val[0] \ - if isinstance(sub_in_val, tuple) else sub_in_val - tensor.set_dims(sub_in_array.shape) - tensor.set(sub_in_array, place) - if isinstance(sub_in_val, tuple): - tensor.set_lod(sub_in_val[1]) + __set_input__(sub_in_name, sub_in_val) else: - var = scope.find_var(in_name) - tensor = var.get_tensor() - in_val = inputs[in_name] - in_array = in_val[0] if isinstance(in_val, tuple) else in_val - tensor.set_dims(in_array.shape) - tensor.set(in_array, place) - if isinstance(in_val, tuple): - tensor.set_lod(in_val[1]) + __set_input__(in_name, inputs[in_name]) def set_output_grad(scope, op, outputs, place): -- GitLab