提交 965c4c9a 编写于 作者: W wjj19950828

update rm to_tensor

上级 22726057
...@@ -257,13 +257,6 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()): ...@@ -257,13 +257,6 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
if is_set_item: if is_set_item:
outputs.append(layer.outputs[0]) outputs.append(layer.outputs[0])
no_output_count = 0 no_output_count = 0
# remove to_tensor layer
invalid_list = list()
for layer_id, layer in sub_layers.items():
if layer.kernel == "paddle.to_tensor":
invalid_list.append(layer_id)
for layer_id in invalid_list:
sub_layers.pop(layer_id)
for i, (layer_id, layer) in enumerate(sub_layers.items()): for i, (layer_id, layer) in enumerate(sub_layers.items()):
_update_attrs(layer, different_attrs) _update_attrs(layer, different_attrs)
if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel) or \ if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel) or \
...@@ -401,12 +394,16 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()): ...@@ -401,12 +394,16 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()):
forward_func.extend(gen_codes([line], indent=2)) forward_func.extend(gen_codes([line], indent=2))
cur_outputs.extend(layer.outputs) cur_outputs.extend(layer.outputs)
if inputs is not None:
inputs.sort()
head, init_func_head, forward_func_head = gen_head(inputs, different_attrs) head, init_func_head, forward_func_head = gen_head(inputs, different_attrs)
output_data_name = ", ".join(outputs) output_data_name = ", ".join(outputs)
# remove to_tensor op
forward_func_new = list()
for line in forward_func:
if "paddle.to_tensor" in line:
continue
forward_func_new.append(line)
code_list = head + init_func_head + init_func + \ code_list = head + init_func_head + init_func + \
forward_func_head + forward_func + \ forward_func_head + forward_func_new + \
gen_codes(["return {}".format(output_data_name)], indent=2) gen_codes(["return {}".format(output_data_name)], indent=2)
code_str = "".join(code_list) code_str = "".join(code_list)
return code_str return code_str
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册