提交 54892c07 编写于 作者: Y Yu Yang 提交者: zchen0211

Simplify op_test

上级 61cc3ae4
...@@ -12,17 +12,19 @@ def grad_var_name(var_name): ...@@ -12,17 +12,19 @@ def grad_var_name(var_name):
def create_op(scope, op_type, inputs, outputs, attrs): def create_op(scope, op_type, inputs, outputs, attrs):
kwargs = dict() kwargs = dict()
def __create_var__(name, var_name):
scope.new_var(var_name)
kwargs[name].append(var_name)
for in_name, in_dup in Operator.get_op_inputs(op_type): for in_name, in_dup in Operator.get_op_inputs(op_type):
if in_name in inputs: if in_name in inputs:
kwargs[in_name] = [] kwargs[in_name] = []
if in_dup: if in_dup:
sub_in = inputs[in_name] sub_in = inputs[in_name]
for sub_in_name, _ in sub_in: for sub_in_name, _ in sub_in:
var = scope.new_var(sub_in_name) __create_var__(in_name, sub_in_name)
kwargs[in_name].append(sub_in_name)
else: else:
var = scope.new_var(in_name) __create_var__(in_name, in_name)
kwargs[in_name].append(in_name)
for out_name, out_dup in Operator.get_op_outputs(op_type): for out_name, out_dup in Operator.get_op_outputs(op_type):
if out_name in outputs: if out_name in outputs:
...@@ -30,11 +32,9 @@ def create_op(scope, op_type, inputs, outputs, attrs): ...@@ -30,11 +32,9 @@ def create_op(scope, op_type, inputs, outputs, attrs):
if out_dup: if out_dup:
sub_out = outputs[out_name] sub_out = outputs[out_name]
for sub_out_name, _ in sub_out: for sub_out_name, _ in sub_out:
var = scope.new_var(sub_out_name) __create_var__(out_name, sub_out_name)
kwargs[out_name].append(sub_out_name)
else: else:
var = scope.new_var(out_name) __create_var__(out_name, out_name)
kwargs[out_name].append(out_name)
for attr_name in Operator.get_op_attr_names(op_type): for attr_name in Operator.get_op_attr_names(op_type):
if attr_name in attrs: if attr_name in attrs:
...@@ -44,28 +44,22 @@ def create_op(scope, op_type, inputs, outputs, attrs): ...@@ -44,28 +44,22 @@ def create_op(scope, op_type, inputs, outputs, attrs):
def set_input(scope, op, inputs, place): def set_input(scope, op, inputs, place):
def __set_input__(var_name, var):
tensor = scope.find_var(var_name).get_tensor()
if isinstance(var, tuple):
tensor.set_lod(var[1])
var = var[0]
tensor.set_dims(var.shape)
tensor.set(var, place)
for in_name, in_dup in Operator.get_op_inputs(op.type()): for in_name, in_dup in Operator.get_op_inputs(op.type()):
if in_name in inputs: if in_name in inputs:
if in_dup: if in_dup:
sub_in = inputs[in_name] sub_in = inputs[in_name]
for sub_in_name, sub_in_val in sub_in: for sub_in_name, sub_in_val in sub_in:
var = scope.find_var(sub_in_name) __set_input__(sub_in_name, sub_in_val)
tensor = var.get_tensor()
sub_in_array = sub_in_val[0] \
if isinstance(sub_in_val, tuple) else sub_in_val
tensor.set_dims(sub_in_array.shape)
tensor.set(sub_in_array, place)
if isinstance(sub_in_val, tuple):
tensor.set_lod(sub_in_val[1])
else: else:
var = scope.find_var(in_name) __set_input__(in_name, inputs[in_name])
tensor = var.get_tensor()
in_val = inputs[in_name]
in_array = in_val[0] if isinstance(in_val, tuple) else in_val
tensor.set_dims(in_array.shape)
tensor.set(in_array, place)
if isinstance(in_val, tuple):
tensor.set_lod(in_val[1])
def set_output_grad(scope, op, outputs, place): def set_output_grad(scope, op, outputs, place):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册