提交 16d15aa7 编写于 作者: W wjj19950828

deal with comments

上级 1761e71e
...@@ -462,6 +462,7 @@ class OpSet9(): ...@@ -462,6 +462,7 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
axis = self.graph.get_input_node(node, idx=1, copy=True) axis = self.graph.get_input_node(node, idx=1, copy=True)
axis_values = _const_weight_or_none(axis) axis_values = _const_weight_or_none(axis)
assert axis_values is not None, 'Axis only support constant tensor!'
layer_attrs = {'axis': axis_values} layer_attrs = {'axis': axis_values}
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.cumsum', 'paddle.cumsum',
...@@ -740,6 +741,7 @@ class OpSet9(): ...@@ -740,6 +741,7 @@ class OpSet9():
axes = node.get_attr('axes') axes = node.get_attr('axes')
if axes is None: if axes is None:
axes = self.graph.get_input_node(node, idx=1, copy=True) axes = self.graph.get_input_node(node, idx=1, copy=True)
axes = _const_weight_or_none(axes)
if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[0] == 0: if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[0] == 0:
if node.name: if node.name:
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
...@@ -1709,7 +1711,8 @@ class OpSet9(): ...@@ -1709,7 +1711,8 @@ class OpSet9():
x_shape = val_x.out_shapes[0] x_shape = val_x.out_shapes[0]
y_shape = val_y.out_shapes[0] y_shape = val_y.out_shapes[0]
inputs_dict = {"x": val_x.name, "y": val_y.name} inputs_dict = {"x": val_x.name, "y": val_y.name}
if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1: if len(y_shape) != 0 and y_shape[0] == 1 and len(
x_shape) != 0 and x_shape[-1] != 1 and x_shape[0] != 1:
y_squeeze = val_y.name + '_squeeze' y_squeeze = val_y.name + '_squeeze'
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.squeeze", "paddle.squeeze",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册