From 965c4c9a4404534c118348c816cd5dd4e3bf02e9 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Wed, 1 Dec 2021 22:05:30 +0800 Subject: [PATCH] update rm to_tensor --- .../layer_code_generator.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/x2paddle/optimizer/pytorch_code_optimizer/layer_code_generator.py b/x2paddle/optimizer/pytorch_code_optimizer/layer_code_generator.py index 4c302d7..7ee9559 100644 --- a/x2paddle/optimizer/pytorch_code_optimizer/layer_code_generator.py +++ b/x2paddle/optimizer/pytorch_code_optimizer/layer_code_generator.py @@ -257,13 +257,6 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()): if is_set_item: outputs.append(layer.outputs[0]) no_output_count = 0 - # remove to_tensor layer - invalid_list = list() - for layer_id, layer in sub_layers.items(): - if layer.kernel == "paddle.to_tensor": - invalid_list.append(layer_id) - for layer_id in invalid_list: - sub_layers.pop(layer_id) for i, (layer_id, layer) in enumerate(sub_layers.items()): _update_attrs(layer, different_attrs) if ("paddle.nn" in layer.kernel and "functional" not in layer.kernel) or \ @@ -401,12 +394,16 @@ def gen_layer_code(graph, sub_layers, sub_layers_name, different_attrs=dict()): forward_func.extend(gen_codes([line], indent=2)) cur_outputs.extend(layer.outputs) - if inputs is not None: - inputs.sort() head, init_func_head, forward_func_head = gen_head(inputs, different_attrs) output_data_name = ", ".join(outputs) + # remove to_tensor op + forward_func_new = list() + for line in forward_func: + if "paddle.to_tensor" in line: + continue + forward_func_new.append(line) code_list = head + init_func_head + init_func + \ - forward_func_head + forward_func + \ + forward_func_head + forward_func_new + \ gen_codes(["return {}".format(output_data_name)], indent=2) code_str = "".join(code_list) return code_str -- GitLab