提交 bee47161 编写于 作者: S SunAhong1993

fix

上级 2186c6d6
...@@ -285,12 +285,12 @@ class PaddleGraph(object): ...@@ -285,12 +285,12 @@ class PaddleGraph(object):
hierarchical_tree.save_source_files(save_dir) hierarchical_tree.save_source_files(save_dir)
self.dump_dygraph_parameter(save_dir) self.dump_dygraph_parameter(save_dir)
else: else:
if self.source_type == "pytorch": # if self.source_type == "pytorch":
from x2paddle.optimizer.pytorch_code_optimizer import ModuleGraph # from x2paddle.optimizer.pytorch_code_optimizer import ModuleGraph
module_graph = ModuleGraph(self) # module_graph = ModuleGraph(self)
module_graph.save_source_files(save_dir) # module_graph.save_source_files(save_dir)
self.dump_dygraph_parameter(save_dir) # self.dump_dygraph_parameter(save_dir)
else: # else:
self.gen_dygraph_code(save_dir) self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir) self.dump_dygraph_parameter(save_dir)
# 动转静 # 动转静
......
...@@ -1310,7 +1310,6 @@ def aten_dim(mapper, graph, node): ...@@ -1310,7 +1310,6 @@ def aten_dim(mapper, graph, node):
""" """
scope_name = mapper.normalize_scope_name(node) scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {} layer_inputs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node) inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list # 获取当前节点输出的list
...@@ -1322,9 +1321,9 @@ def aten_dim(mapper, graph, node): ...@@ -1322,9 +1321,9 @@ def aten_dim(mapper, graph, node):
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.shape", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "prim.shape", inputs=layer_inputs, outputs=[output_name], scope_name=scope_name)
graph.add_layer( graph.add_layer(
"prim.len", inputs={"input": output_name}, outputs=layer_outputs, scope_name=scope_name) "prim.len", inputs={"input": output_name}, outputs=[output_name], scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -4512,10 +4511,10 @@ def aten_upsample_bilinear2d(mapper, graph, node): ...@@ -4512,10 +4511,10 @@ def aten_upsample_bilinear2d(mapper, graph, node):
current_outputs, scope_name) current_outputs, scope_name)
layer_inputs["align_corners"] = inputs_name[2] layer_inputs["align_corners"] = inputs_name[2]
current_inputs.append(inputs_name[2]) current_inputs.append(inputs_name[2])
if "size" in layer_attrs and layer_attrs["size"] is None: # if "size" in layer_attrs and layer_attrs["size"] is None:
mapper._check_input(graph, inputs_node[3], inputs_name[3], # mapper._check_input(graph, inputs_node[3], inputs_name[3],
current_outputs, scope_name) # current_outputs, scope_name)
layer_inputs["scale_factor"] = inputs_name[3] # layer_inputs["scale_factor"] = inputs_name[3]
layer_attrs["align_mode"] = 0 layer_attrs["align_mode"] = 0
layer_attrs["mode"] = string("bilinear") layer_attrs["mode"] = string("bilinear")
graph.add_layer( graph.add_layer(
......
...@@ -471,7 +471,7 @@ def prim_str(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diff ...@@ -471,7 +471,7 @@ def prim_str(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diff
def prim_sub(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_sub(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
if int(get_value(layer, "alpha", different_attrs)) == 1: if int(float(get_value(layer, "alpha", different_attrs))) == 1:
line = "{} = {} - {}".format(layer.outputs[0], line = "{} = {} - {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
......
...@@ -186,6 +186,7 @@ class DygraphInterpolateBilinearFuser(FuseBase): ...@@ -186,6 +186,7 @@ class DygraphInterpolateBilinearFuser(FuseBase):
inputs={ inputs={
"input": "interpolate-input-0", "input": "interpolate-input-0",
"size": "interpolate-input-3", "size": "interpolate-input-3",
# "scale_factor": gen_name(21)
}, },
outputs=[gen_name(23)]) outputs=[gen_name(23)])
pattern_block_block.add_layer( pattern_block_block.add_layer(
......
...@@ -28,11 +28,11 @@ class GraphOptimizer(object): ...@@ -28,11 +28,11 @@ class GraphOptimizer(object):
"dygraph_constant_fuse_pass", "dygraph_constant_fuse_pass",
"dygraph_batchnorm2d_fuse_pass", "dygraph_batchnorm2d_fuse_pass",
"dygraph_interpolate_bilinear_fuse_pass", "dygraph_interpolate_bilinear_fuse_pass",
"dygraph_fc_fuse_pass", # "dygraph_fc_fuse_pass",
"dygraph_adaptive_pool2d_fuse_pass", # "dygraph_adaptive_pool2d_fuse_pass",
"dygraph_reshape_fuse_pass", # "dygraph_reshape_fuse_pass",
"dygraph_dropout_fuse_pass", # "dygraph_dropout_fuse_pass",
"dygraph_if_fuse_pass" # "dygraph_if_fuse_pass"
] ]
elif source_frame == "caffe": elif source_frame == "caffe":
if paddle_type == "dygraph": if paddle_type == "dygraph":
......
...@@ -130,6 +130,8 @@ class PatternMatcher(object): ...@@ -130,6 +130,8 @@ class PatternMatcher(object):
if is_pop: if is_pop:
subgraph_id2layers.pop(layer_id) subgraph_id2layers.pop(layer_id)
continue continue
if layer_id not in subgraph_id2layers:
continue
# 当为控制流时的处理 # 当为控制流时的处理
if layer.kernel == "prim.if" or layer.kernel == "prim.loop": if layer.kernel == "prim.if" or layer.kernel == "prim.loop":
if len(pattern_layer.blocks) != len(layer.blocks): if len(pattern_layer.blocks) != len(layer.blocks):
...@@ -154,6 +156,7 @@ class PatternMatcher(object): ...@@ -154,6 +156,7 @@ class PatternMatcher(object):
if pattern_index == 0 or is_subblock: if pattern_index == 0 or is_subblock:
return False return False
else: else:
print(subgraph_id2layers.keys())
index = list(subgraph_id2layers.keys()).index( index = list(subgraph_id2layers.keys()).index(
layer_id) layer_id)
for key in list(subgraph_id2layers.keys())[ for key in list(subgraph_id2layers.keys())[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册