slice op使用报错
Created by: dingsiyu
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#import tensorflow as tf
import paddle.fluid as fluid
import paddle.fluid.layers as layers
def rel_shift(x):
x_size = tf.shape(x)
x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]])
x = tf.reshape(x, [x_size[1] + 1, x_size[0], x_size[2], x_size[3]])
x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
x = tf.reshape(x, x_size)
return x
def rel_shift2(x):
x_size = x.shape
INT_MAX = 100000000
x = layers.pad(x, [0, 0, 1, 0, 0, 0, 0, 0])
print(x.shape)
x = layers.reshape(x, [x_size[1] + 1, x_size[0], x_size[2], x_size[3]])
print(x.shape)
x = layers.slice(x, [0,1,2,3], [1, 0, 0, 0], [INT_MAX, INT_MAX, INT_MAX, INT_MAX])
x = layers.reshape(x, x_size)
return x
#a = tf.Variable(tf.ones([4,5,6,7]))
#b = rel_shift(a)
#with tf.Session() as sess:
# init = tf.initialize_all_variables()
# sess.run(init)
# print(sess.run(b))
a = layers.ones([4,5,6,7], dtype='float32')
b = rel_shift2(a)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
re = exe.run(fetch_list=b)
print(re[0])
报错:
Traceback (most recent call last):
File "test.py", line 44, in <module>
re = exe.run(fetch_list=b)
File "/home/dingsiyu/bin/anaconda3/lib/python3.6/site-packages/paddle/fluid/executor.py", line 651, in run
use_program_cache=use_program_cache)
File "/home/dingsiyu/bin/anaconda3/lib/python3.6/site-packages/paddle/fluid/executor.py", line 745, in _run
fetch_var_name=fetch_var_name)
File "/home/dingsiyu/bin/anaconda3/lib/python3.6/site-packages/paddle/fluid/executor.py", line 430, in _add_feed_fetch_ops
for i, var in enumerate(fetch_list):
File "/home/dingsiyu/bin/anaconda3/lib/python3.6/site-packages/paddle/fluid/framework.py", line 906, in __getitem__
'decrease_axis': decrease_axis
File "/home/dingsiyu/bin/anaconda3/lib/python3.6/site-packages/paddle/fluid/framework.py", line 1771, in append_op
attrs=kwargs.get("attrs", None))
File "/home/dingsiyu/bin/anaconda3/lib/python3.6/site-packages/paddle/fluid/framework.py", line 1164, in __init__
self.desc.infer_shape(self.block.desc)
paddle.fluid.core_avx.EnforceNotMet: Enforce failed. Expected end > start, but received end:4 <= start:4.
end should greater than start at [/paddle/paddle/fluid/operators/slice_op.cc:57]
PaddlePaddle Call Stacks:
0 0x7f266386edb8p void paddle::platform::EnforceNotMet::Init<std::string>(std::string, char const*, int) + 360
1 0x7f266386f107p paddle::platform::EnforceNotMet::EnforceNotMet(std::string const&, char const*, int) + 87
2 0x7f26645dbec9p paddle::operators::SliceOp::InferShape(paddle::framework::InferShapeContext*) const + 2185
3 0x7f26639e0aaep paddle::framework::OpDesc::InferShape(paddle::framework::BlockDesc const&) const + 862
4 0x7f266391a53cp
5 0x7f26638a11a6p
6 0x7f269ed3e744p _PyCFunction_FastCallDict + 340
7 0x7f269edcca8ep
8 0x7f269edf0a7ap _PyEval_EvalFrameDefault + 762
9 0x7f269edc5ac4p
10 0x7f269edc706ap _PyFunction_FastCallDict + 986
11 0x7f269ed3eb0fp _PyObject_FastCallDict + 623
12 0x7f269ed43723p _PyObject_Call_Prepend + 99
13 0x7f269ed3e54ep PyObject_Call + 62
14 0x7f269ed9a09bp
15 0x7f269edccc77p
16 0x7f269ed3e92bp _PyObject_FastCallDict + 139
17 0x7f269edc6c5ap _PyObject_FastCallKeywords + 170
18 0x7f269edcca8ep
19 0x7f269edf184ep _PyEval_EvalFrameDefault + 4302
20 0x7f269edc5ac4p
21 0x7f269edc6971p
22 0x7f269edcca15p
23 0x7f269edf184ep _PyEval_EvalFrameDefault + 4302
24 0x7f269edc6dabp _PyFunction_FastCallDict + 283
25 0x7f269ed3eb0fp _PyObject_FastCallDict + 623
26 0x7f269ed43723p _PyObject_Call_Prepend + 99
27 0x7f269ed3e92bp _PyObject_FastCallDict + 139
28 0x7f269ed9b3e0p
29 0x7f269ed89cd0p
30 0x7f269ed7499fp
31 0x7f269edf0f66p _PyEval_EvalFrameDefault + 2022
32 0x7f269edc5ac4p
33 0x7f269edc6971p
34 0x7f269edcca15p
35 0x7f269edf184ep _PyEval_EvalFrameDefault + 4302
36 0x7f269edc5ac4p
37 0x7f269edc6971p
38 0x7f269edcca15p
39 0x7f269edf184ep _PyEval_EvalFrameDefault + 4302
40 0x7f269edc5ac4p
41 0x7f269edc6971p
42 0x7f269edcca15p
43 0x7f269edf184ep _PyEval_EvalFrameDefault + 4302
44 0x7f269edc7489p PyEval_EvalCodeEx + 809
45 0x7f269edc822cp PyEval_EvalCode + 28
46 0x7f269ee44c94p
47 0x7f269ee45091p PyRun_FileExFlags + 161
48 0x7f269ee45294p PyRun_SimpleFileExFlags + 452
49 0x7f269ee48d6fp Py_Main + 1535
50 0x7f269ed0ff7ep main + 238
51 0x7f269e45325dp __libc_start_main + 253
52 0x7f269edf75c8p