提交 4a59b171 编写于 作者: S SunAhong1993

fix3

上级 3c4537f1
...@@ -285,14 +285,14 @@ class PaddleGraph(object): ...@@ -285,14 +285,14 @@ class PaddleGraph(object):
hierarchical_tree.save_source_files(save_dir) hierarchical_tree.save_source_files(save_dir)
self.dump_dygraph_parameter(save_dir) self.dump_dygraph_parameter(save_dir)
else: else:
if self.source_type == "pytorch": # if self.source_type == "pytorch":
from x2paddle.optimizer.pytorch_code_optimizer import ModuleGraph # from x2paddle.optimizer.pytorch_code_optimizer import ModuleGraph
module_graph = ModuleGraph(self) # module_graph = ModuleGraph(self)
module_graph.save_source_files(save_dir) # module_graph.save_source_files(save_dir)
self.dump_dygraph_parameter(save_dir) # self.dump_dygraph_parameter(save_dir)
else: # else:
self.gen_dygraph_code(save_dir) self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir) self.dump_dygraph_parameter(save_dir)
# 动转静 # 动转静
code_path = osp.join(osp.abspath(save_dir), "x2paddle_code.py") code_path = osp.join(osp.abspath(save_dir), "x2paddle_code.py")
print("Exporting inference model from python code ('{}')... \n".format(code_path)) print("Exporting inference model from python code ('{}')... \n".format(code_path))
......
...@@ -182,7 +182,7 @@ def prim_equal(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di ...@@ -182,7 +182,7 @@ def prim_equal(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di
def prim_exception(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_exception(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
line = "raise RaiseException({})".format(get_value(layer, "input", different_attrs)) line = "raise Exception({})".format(get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
...@@ -458,10 +458,12 @@ def prim_slice(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di ...@@ -458,10 +458,12 @@ def prim_slice(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_startswith(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_startswith(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False):
line = "{} = {}.startswith({})".format(layer.outputs[0], line = "{} = {}.startswith({})".format(layer.outputs[0],
get_value(layer, "input", different_attrs), get_value(layer, "input", different_attrs),
get_value(layer, "start_str", different_attrs)) get_value(layer, "start_str", different_attrs))
if is_return_line:
return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
......
...@@ -41,13 +41,21 @@ class DygraphIfFuser(FuseBase): ...@@ -41,13 +41,21 @@ class DygraphIfFuser(FuseBase):
return return
for id in graph.edges_in[layer_id]: for id in graph.edges_in[layer_id]:
input_layer = graph.layers[id] input_layer = graph.layers[id]
input_layer_id = id
if input_layer.outputs == [layer.inputs["input"]]: if input_layer.outputs == [layer.inputs["input"]]:
if input_layer.kernel == "prim.if": if input_layer.kernel == "prim.if":
matches.pop(layer_id) matches.pop(layer_id)
return return
input_id = id input_id = id
break break
if list(layer.inputs.values()).count(input_layer.outputs[0]) > 1 or \
(input_layer_id in graph.edges_out and len(graph.edges_out[input_layer_id]) > 1):
matches.pop(layer_id)
return
func_name = input_layer.kernel.replace(".", "_") func_name = input_layer.kernel.replace(".", "_")
if func_name in ["prim_if", "prim_loop"]:
matches.pop(layer_id)
return
from x2paddle.op_mapper.dygraph.pytorch2paddle import prim2code from x2paddle.op_mapper.dygraph.pytorch2paddle import prim2code
func = getattr(prim2code, func_name) func = getattr(prim2code, func_name)
line = func(input_layer, is_return_line=True) line = func(input_layer, is_return_line=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册