diff --git a/x2paddle/convert.py b/x2paddle/convert.py index 926f0df7f96670621e86ddf3decd6289a81293d6..e905ef0e9c511385c2120ea272dc9cebb40ed909 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -88,7 +88,7 @@ def arg_parser(): default=None, help="pretrain model file of pytorch model") parser.add_argument( - "--code_optimizer", + "--enable_code_optim", "-co", default=True, help="Turn on code optimization") @@ -225,7 +225,7 @@ def pytorch2paddle(module, save_dir, jit_type="trace", input_examples=None, - code_optimizer=True, + enable_code_optim=True, convert_to_lite=False, lite_valid_places="arm", lite_model_type="naive_buffer"): @@ -260,7 +260,7 @@ def pytorch2paddle(module, graph_opt.optimize(mapper.paddle_graph) print("Model optimized.") mapper.paddle_graph.gen_model( - save_dir, jit_type=jit_type, code_optimizer=code_optimizer) + save_dir, jit_type=jit_type, enable_code_optim=enable_code_optim) if convert_to_lite: convert2lite(save_dir, lite_valid_places, lite_model_type) diff --git a/x2paddle/core/program.py b/x2paddle/core/program.py index d8b32a48316737e8b6ac2d07a03e7dbb0c83403e..daf552762c36ca70a96a40f6e46adfc9a28fdf73 100644 --- a/x2paddle/core/program.py +++ b/x2paddle/core/program.py @@ -237,11 +237,11 @@ class PaddleGraph(object): return update(self.layers) - def gen_model(self, save_dir, jit_type=None, code_optimizer=True): + def gen_model(self, save_dir, jit_type=None, enable_code_optim=True): if not osp.exists(save_dir): os.makedirs(save_dir) if jit_type == "trace": - if not self.has_unpack and code_optimizer: + if not self.has_unpack and enable_code_optim: from x2paddle.optimizer.pytorch_code_optimizer import HierarchicalTree hierarchical_tree = HierarchicalTree(self) for layer_id, layer in self.layers.items(): @@ -252,7 +252,7 @@ class PaddleGraph(object): self.gen_code(save_dir) self.dump_parameter(save_dir) else: - if self.source_type == "pytorch" and code_optimizer: + if self.source_type == "pytorch" and enable_code_optim: from x2paddle.optimizer.pytorch_code_optimizer import ModuleGraph module_graph = ModuleGraph(self) module_graph.save_source_files(save_dir)