提交 422a6b56 编写于 作者: W wjj19950828

deal with comments

上级 a6c820c7
...@@ -740,30 +740,27 @@ class OpSet9(): ...@@ -740,30 +740,27 @@ class OpSet9():
def Unsqueeze(self, node): def Unsqueeze(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes') axes = node.get_attr('axes')
if axes is None: if axes is not None:
axes = self.graph.get_input_node(node, idx=1, copy=True) if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[
axes = _const_weight_or_none(axes) 0] == 0:
if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[0] == 0:
if node.name:
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.reshape', 'paddle.reshape',
inputs={"x": val_x.name}, inputs={"x": val_x.name},
outputs=[node.name], outputs=[node.name],
shape=[1]) shape=[1])
else: else:
if isinstance(axes, list) or isinstance(axes, tuple) or isinstance(
axes, np.ndarray):
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.unsqueeze', 'paddle.unsqueeze',
inputs={"x": val_x.name}, inputs={"x": val_x.name},
axis=axes, axis=axes,
outputs=[node.name]) outputs=[node.name])
else: else:
self.paddle_graph.add_layer( axes = self.graph.get_input_node(node, idx=1, copy=True)
'paddle.unsqueeze', self.paddle_graph.add_layer(
inputs={"x": val_x.name, 'paddle.unsqueeze',
"axis": axes.name}, inputs={"x": val_x.name,
outputs=[node.name]) "axis": axes.name},
outputs=[node.name])
@print_mapping_info @print_mapping_info
def Shrink(self, node): def Shrink(self, node):
...@@ -896,7 +893,7 @@ class OpSet9(): ...@@ -896,7 +893,7 @@ class OpSet9():
def Gather(self, node): def Gather(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
indices = self.graph.get_input_node(node, idx=1, copy=True) indices = self.graph.get_input_node(node, idx=1, copy=True)
indices_values = _const_weight_or_none(indices) indices_values = _const_weight_or_none(indices, necessary=True)
if isinstance(indices_values, np.ndarray): if isinstance(indices_values, np.ndarray):
indices_values = indices_values.tolist() indices_values = indices_values.tolist()
indices_shape = indices.out_shapes[0] indices_shape = indices.out_shapes[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册