提交 bee47161 编写于 作者: S SunAhong1993

fix

上级 2186c6d6
......@@ -285,14 +285,14 @@ class PaddleGraph(object):
hierarchical_tree.save_source_files(save_dir)
self.dump_dygraph_parameter(save_dir)
else:
if self.source_type == "pytorch":
from x2paddle.optimizer.pytorch_code_optimizer import ModuleGraph
module_graph = ModuleGraph(self)
module_graph.save_source_files(save_dir)
self.dump_dygraph_parameter(save_dir)
else:
self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir)
# if self.source_type == "pytorch":
# from x2paddle.optimizer.pytorch_code_optimizer import ModuleGraph
# module_graph = ModuleGraph(self)
# module_graph.save_source_files(save_dir)
# self.dump_dygraph_parameter(save_dir)
# else:
self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir)
# 动转静
code_path = osp.join(osp.abspath(save_dir), "x2paddle_code.py")
print("Exporting inference model from python code ('{}')... \n".format(code_path))
......
......@@ -1310,7 +1310,6 @@ def aten_dim(mapper, graph, node):
"""
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
......@@ -1322,9 +1321,9 @@ def aten_dim(mapper, graph, node):
current_inputs = list(layer_inputs.values())
graph.add_layer(
"prim.shape", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
"prim.shape", inputs=layer_inputs, outputs=[output_name], scope_name=scope_name)
graph.add_layer(
"prim.len", inputs={"input": output_name}, outputs=layer_outputs, scope_name=scope_name)
"prim.len", inputs={"input": output_name}, outputs=[output_name], scope_name=scope_name)
return current_inputs, current_outputs
......@@ -4512,10 +4511,10 @@ def aten_upsample_bilinear2d(mapper, graph, node):
current_outputs, scope_name)
layer_inputs["align_corners"] = inputs_name[2]
current_inputs.append(inputs_name[2])
if "size" in layer_attrs and layer_attrs["size"] is None:
mapper._check_input(graph, inputs_node[3], inputs_name[3],
current_outputs, scope_name)
layer_inputs["scale_factor"] = inputs_name[3]
# if "size" in layer_attrs and layer_attrs["size"] is None:
# mapper._check_input(graph, inputs_node[3], inputs_name[3],
# current_outputs, scope_name)
# layer_inputs["scale_factor"] = inputs_name[3]
layer_attrs["align_mode"] = 0
layer_attrs["mode"] = string("bilinear")
graph.add_layer(
......
......@@ -471,7 +471,7 @@ def prim_str(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diff
def prim_sub(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
if int(get_value(layer, "alpha", different_attrs)) == 1:
if int(float(get_value(layer, "alpha", different_attrs))) == 1:
line = "{} = {} - {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs))
......
......@@ -186,6 +186,7 @@ class DygraphInterpolateBilinearFuser(FuseBase):
inputs={
"input": "interpolate-input-0",
"size": "interpolate-input-3",
# "scale_factor": gen_name(21)
},
outputs=[gen_name(23)])
pattern_block_block.add_layer(
......
......@@ -28,11 +28,11 @@ class GraphOptimizer(object):
"dygraph_constant_fuse_pass",
"dygraph_batchnorm2d_fuse_pass",
"dygraph_interpolate_bilinear_fuse_pass",
"dygraph_fc_fuse_pass",
"dygraph_adaptive_pool2d_fuse_pass",
"dygraph_reshape_fuse_pass",
"dygraph_dropout_fuse_pass",
"dygraph_if_fuse_pass"
# "dygraph_fc_fuse_pass",
# "dygraph_adaptive_pool2d_fuse_pass",
# "dygraph_reshape_fuse_pass",
# "dygraph_dropout_fuse_pass",
# "dygraph_if_fuse_pass"
]
elif source_frame == "caffe":
if paddle_type == "dygraph":
......
......@@ -130,6 +130,8 @@ class PatternMatcher(object):
if is_pop:
subgraph_id2layers.pop(layer_id)
continue
if layer_id not in subgraph_id2layers:
continue
# 当为控制流时的处理
if layer.kernel == "prim.if" or layer.kernel == "prim.loop":
if len(pattern_layer.blocks) != len(layer.blocks):
......@@ -154,6 +156,7 @@ class PatternMatcher(object):
if pattern_index == 0 or is_subblock:
return False
else:
print(subgraph_id2layers.keys())
index = list(subgraph_id2layers.keys()).index(
layer_id)
for key in list(subgraph_id2layers.keys())[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册