提交 d117c70f 编写于 作者: S SunAhong1993

fix the slice

上级 ce92c536
......@@ -164,6 +164,10 @@ class PaddleGraph(object):
def build(self, inputs=None, outputs=None):
self.clear_edges()
outputs_from_nodes = dict()
# for layer_id, layer in self.layers.items():
# print(layer.kernel)
# print(layer.inputs)
# print(layer.outputs)
for layer_id, layer in self.layers.items():
for input_key, input_var in layer.inputs.items():
vs = input_var
......
......@@ -954,14 +954,12 @@ class OpSet9():
starts_value = starts_value.copy()
ends_value = ends_value.copy()
for idx in range(len(ends_value)):
if starts_value[idx] >= val_x.out_shapes[0][axes[idx]]:
if starts_value[idx] >= val_x.out_shapes[0][axes[idx]] and val_x.out_shapes[0][axes[idx]] > 0:
starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1
ends_value[idx] = val_x.out_shapes[0][axes[idx]]
starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1
elif ends_value[idx] == 2**31 - 1:
ends_value[idx] = node.out_shapes[0][axes[idx]] + 1
elif ends_value[idx] > 2**31 - 1:
ends_value[idx] = 2**31 - 1
layer_attrs = {
"axes": axes,
"starts": starts_value,
......
......@@ -918,12 +918,9 @@ class OpSet9():
# ends_value[idx] = 2**31 - 1
#print(val_x.out_shapes)
for idx in range(len(ends_value)):
if starts_value[idx] >= val_x.out_shapes[0][axes[idx]]:
if starts_value[idx] >= val_x.out_shapes[0][axes[idx]] and val_x.out_shapes[0][axes[idx]] > 0:
starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1
ends_value[idx] = val_x.out_shapes[0][axes[idx]]
starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1
elif ends_value[idx] == 2**31 - 1:
ends_value[idx] = node.out_shapes[0][axes[idx]] + 1
elif ends_value[idx] > 2**31 - 1:
ends_value[idx] = 2**31 - 1
layer_attrs = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册