提交 139554da 编写于 作者: W wjj19950828

deal with comments

上级 573b7b14
...@@ -88,7 +88,7 @@ def arg_parser(): ...@@ -88,7 +88,7 @@ def arg_parser():
default=None, default=None,
help="pretrain model file of pytorch model") help="pretrain model file of pytorch model")
parser.add_argument( parser.add_argument(
"--code_optimizer", "--enable_code_optim",
"-co", "-co",
default=True, default=True,
help="Turn on code optimization") help="Turn on code optimization")
...@@ -225,7 +225,7 @@ def pytorch2paddle(module, ...@@ -225,7 +225,7 @@ def pytorch2paddle(module,
save_dir, save_dir,
jit_type="trace", jit_type="trace",
input_examples=None, input_examples=None,
code_optimizer=True, enable_code_optim=True,
convert_to_lite=False, convert_to_lite=False,
lite_valid_places="arm", lite_valid_places="arm",
lite_model_type="naive_buffer"): lite_model_type="naive_buffer"):
...@@ -260,7 +260,7 @@ def pytorch2paddle(module, ...@@ -260,7 +260,7 @@ def pytorch2paddle(module,
graph_opt.optimize(mapper.paddle_graph) graph_opt.optimize(mapper.paddle_graph)
print("Model optimized.") print("Model optimized.")
mapper.paddle_graph.gen_model( mapper.paddle_graph.gen_model(
save_dir, jit_type=jit_type, code_optimizer=code_optimizer) save_dir, jit_type=jit_type, enable_code_optim=enable_code_optim)
if convert_to_lite: if convert_to_lite:
convert2lite(save_dir, lite_valid_places, lite_model_type) convert2lite(save_dir, lite_valid_places, lite_model_type)
......
...@@ -237,11 +237,11 @@ class PaddleGraph(object): ...@@ -237,11 +237,11 @@ class PaddleGraph(object):
return update(self.layers) return update(self.layers)
def gen_model(self, save_dir, jit_type=None, code_optimizer=True): def gen_model(self, save_dir, jit_type=None, enable_code_optim=True):
if not osp.exists(save_dir): if not osp.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
if jit_type == "trace": if jit_type == "trace":
if not self.has_unpack and code_optimizer: if not self.has_unpack and enable_code_optim:
from x2paddle.optimizer.pytorch_code_optimizer import HierarchicalTree from x2paddle.optimizer.pytorch_code_optimizer import HierarchicalTree
hierarchical_tree = HierarchicalTree(self) hierarchical_tree = HierarchicalTree(self)
for layer_id, layer in self.layers.items(): for layer_id, layer in self.layers.items():
...@@ -252,7 +252,7 @@ class PaddleGraph(object): ...@@ -252,7 +252,7 @@ class PaddleGraph(object):
self.gen_code(save_dir) self.gen_code(save_dir)
self.dump_parameter(save_dir) self.dump_parameter(save_dir)
else: else:
if self.source_type == "pytorch" and code_optimizer: if self.source_type == "pytorch" and enable_code_optim:
from x2paddle.optimizer.pytorch_code_optimizer import ModuleGraph from x2paddle.optimizer.pytorch_code_optimizer import ModuleGraph
module_graph = ModuleGraph(self) module_graph = ModuleGraph(self)
module_graph.save_source_files(save_dir) module_graph.save_source_files(save_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册