提交 3c4537f1 编写于 作者: S SunAhong1993

fix2

上级 bee47161
......@@ -285,14 +285,14 @@ class PaddleGraph(object):
hierarchical_tree.save_source_files(save_dir)
self.dump_dygraph_parameter(save_dir)
else:
# if self.source_type == "pytorch":
# from x2paddle.optimizer.pytorch_code_optimizer import ModuleGraph
# module_graph = ModuleGraph(self)
# module_graph.save_source_files(save_dir)
# self.dump_dygraph_parameter(save_dir)
# else:
self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir)
if self.source_type == "pytorch":
from x2paddle.optimizer.pytorch_code_optimizer import ModuleGraph
module_graph = ModuleGraph(self)
module_graph.save_source_files(save_dir)
self.dump_dygraph_parameter(save_dir)
else:
self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir)
# 动转静
code_path = osp.join(osp.abspath(save_dir), "x2paddle_code.py")
print("Exporting inference model from python code ('{}')... \n".format(code_path))
......
......@@ -4511,10 +4511,10 @@ def aten_upsample_bilinear2d(mapper, graph, node):
current_outputs, scope_name)
layer_inputs["align_corners"] = inputs_name[2]
current_inputs.append(inputs_name[2])
# if "size" in layer_attrs and layer_attrs["size"] is None:
# mapper._check_input(graph, inputs_node[3], inputs_name[3],
# current_outputs, scope_name)
# layer_inputs["scale_factor"] = inputs_name[3]
if "size" in layer_attrs and layer_attrs["size"] is None:
mapper._check_input(graph, inputs_node[3], inputs_name[3],
current_outputs, scope_name)
layer_inputs["scale_factor"] = inputs_name[3]
layer_attrs["align_mode"] = 0
layer_attrs["mode"] = string("bilinear")
graph.add_layer(
......
......@@ -28,11 +28,11 @@ class GraphOptimizer(object):
"dygraph_constant_fuse_pass",
"dygraph_batchnorm2d_fuse_pass",
"dygraph_interpolate_bilinear_fuse_pass",
# "dygraph_fc_fuse_pass",
# "dygraph_adaptive_pool2d_fuse_pass",
# "dygraph_reshape_fuse_pass",
# "dygraph_dropout_fuse_pass",
# "dygraph_if_fuse_pass"
"dygraph_fc_fuse_pass",
"dygraph_adaptive_pool2d_fuse_pass",
"dygraph_reshape_fuse_pass",
"dygraph_dropout_fuse_pass",
"dygraph_if_fuse_pass"
]
elif source_frame == "caffe":
if paddle_type == "dygraph":
......
......@@ -130,8 +130,6 @@ class PatternMatcher(object):
if is_pop:
subgraph_id2layers.pop(layer_id)
continue
if layer_id not in subgraph_id2layers:
continue
# 当为控制流时的处理
if layer.kernel == "prim.if" or layer.kernel == "prim.loop":
if len(pattern_layer.blocks) != len(layer.blocks):
......@@ -156,7 +154,6 @@ class PatternMatcher(object):
if pattern_index == 0 or is_subblock:
return False
else:
print(subgraph_id2layers.keys())
index = list(subgraph_id2layers.keys()).index(
layer_id)
for key in list(subgraph_id2layers.keys())[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册