提交 16d15aa7 编写于 作者: W wjj19950828

deal with comments

上级 1761e71e
......@@ -462,6 +462,7 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=0, copy=True)
axis = self.graph.get_input_node(node, idx=1, copy=True)
axis_values = _const_weight_or_none(axis)
assert axis_values is not None, 'Axis only support constant tensor!'
layer_attrs = {'axis': axis_values}
self.paddle_graph.add_layer(
'paddle.cumsum',
......@@ -740,6 +741,7 @@ class OpSet9():
axes = node.get_attr('axes')
if axes is None:
axes = self.graph.get_input_node(node, idx=1, copy=True)
axes = _const_weight_or_none(axes)
if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[0] == 0:
if node.name:
self.paddle_graph.add_layer(
......@@ -1709,7 +1711,8 @@ class OpSet9():
x_shape = val_x.out_shapes[0]
y_shape = val_y.out_shapes[0]
inputs_dict = {"x": val_x.name, "y": val_y.name}
if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1:
if len(y_shape) != 0 and y_shape[0] == 1 and len(
x_shape) != 0 and x_shape[-1] != 1 and x_shape[0] != 1:
y_squeeze = val_y.name + '_squeeze'
self.paddle_graph.add_layer(
"paddle.squeeze",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册