提交 8ec4883c 编写于 作者: S SunAhong1993

fix the unpack

上级 3cf6f557
......@@ -76,6 +76,7 @@ class PaddleGraph(object):
self.source_type = source_type
self.custom_code = None
self.inputs_info = None
self.has_unpack = False
def set_name(self, name):
self.name = name.replace("-", "_").replace("/", "_")
......@@ -112,6 +113,8 @@ class PaddleGraph(object):
layer_id)
layer = PaddleLayer(layer_id, kernel, inputs, outputs, scope_name=scope_name, **kwargs)
self.layers[layer_id] = layer
if layer.kernel in ["prim.list_unpack" or "prim.tuple_unpack"]:
self.has_unpack = True
return layer_id
def del_layer(self, layer_id):
......@@ -272,12 +275,16 @@ class PaddleGraph(object):
def gen_dygraph_model(self, save_dir, jit_type=None):
if jit_type == "trace":
from x2paddle.optimizer.pytorch_code_optimizer import HierarchicalTree
hierarchical_tree = HierarchicalTree(self)
for layer_id, layer in self.layers.items():
hierarchical_tree.insert(layer)
hierarchical_tree.save_source_files(save_dir)
self.dump_dygraph_parameter(save_dir)
if not self.has_unpack:
from x2paddle.optimizer.pytorch_code_optimizer import HierarchicalTree
hierarchical_tree = HierarchicalTree(self)
for layer_id, layer in self.layers.items():
hierarchical_tree.insert(layer)
hierarchical_tree.save_source_files(save_dir)
self.dump_dygraph_parameter(save_dir)
else:
self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir)
else:
if self.source_type == "pytorch":
from x2paddle.optimizer.pytorch_code_optimizer import ModuleGraph
......
......@@ -313,6 +313,9 @@ class OpSet9():
inputs['size'] = var_hw
attrs = {"align_corners": False,
"mode": string(node.get_attr('mode', 'nearest'))}
val_x_shape = val_x.out_shapes[0]
if len(val_x_shape) == 4:
attrs["mode"] = string("bilinear")
self.paddle_graph.add_layer(
kernel="paddle.nn.functional.interpolate",
inputs=inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册