scatter_nd_add 在数据 dtype = int32/int64 时报错
Created by: AoZhang
-
环境 1)PaddlePaddle版本:1.6.2 2)CPU:Intel(R) Xeon(R) CPU E5-2620 v2 3)GPU:no 4)系统环境:CentOS 6.3 ,Python版本 3.6.5
-
复现信息:
x = fluid.layers.data(name='x', shape=[3, 3], dtype='int64', append_batch_size=False)
index = fluid.layers.data(name='index', shape=[2, 2], dtype='int64', append_batch_size=False)
update = fluid.layers.data(name='update', shape=[2], dtype='int64', append_batch_size=False)
output = fluid.layers.scatter_nd_add(x, index, update)
in_data = np.array([[1, 1, 1], [2, 2, 1], [3, 3, 1]]).astype(np.int64)
index_data = np.array([[1, 1], [0, 1]]).astype(np.int64)
update_data = np.array([5, 3]).astype(np.int64)
exe.run(feed={'x':in_data, "index":index_data, "update":update_data}, fetch_list=[output])
- 问题描述:
x
和update
数据类型为 float32 时可以正常运行,改成int32、int64 都会报错。错误信息如下
--------------------------------------------
C++ Call Stacks (More useful to developers):
--------------------------------------------
0 std::string paddle::platform::GetTraceBackString<std::string const&>(std::string const&, char const*, int)
1 paddle::platform::EnforceNotMet::EnforceNotMet(std::string const&, char const*, int)
2 paddle::framework::Tensor::Slice(long, long) const
3 std::enable_if<!std::is_floating_point<long>::value, void>::type paddle::operators::elementwise_inner_add<long, long>(paddle::framework::ExecutionContext const&, long const*, long const*, long*, paddle::framework::Tensor const&, paddle::framework::Tensor*, int const&, long const&, int const&, unsigned long const&)
4 void paddle::operators::ScatterNdAdd<long, long>(paddle::framework::ExecutionContext const&, paddle::framework::Tensor const&, paddle::framework::Tensor const&, paddle::framework::Tensor*)
5 paddle::operators::ScatterNdAddOpKernel<long>::Compute(paddle::framework::ExecutionContext const&) const
6 std::_Function_handler<void (paddle::framework::ExecutionContext const&), paddle::framework::OpKernelRegistrarFunctor<paddle::platform::CPUPlace, false, 2ul, paddle::operators::ScatterNdAddOpKernel<float>, paddle::operators::ScatterNdAddOpKernel<double>, paddle::operators::ScatterNdAddOpKernel<long>, paddle::operators::ScatterNdAddOpKernel<int>, paddle::operators::ScatterNdAddOpKernel<unsigned char> >::operator()(char const*, char const*, int) const::{lambda(paddle::framework::ExecutionContext const&)#1}>::_M_invoke(std::_Any_data const&, paddle::framework::ExecutionContext const&)
7 paddle::framework::OperatorWithKernel::RunImpl(paddle::framework::Scope const&, paddle::platform::Place const&, paddle::framework::RuntimeContext*) const
8 paddle::framework::OperatorWithKernel::RunImpl(paddle::framework::Scope const&, paddle::platform::Place const&) const
9 paddle::framework::OperatorBase::Run(paddle::framework::Scope const&, paddle::platform::Place const&)
10 paddle::framework::Executor::RunPreparedContext(paddle::framework::ExecutorPrepareContext*, paddle::framework::Scope*, bool, bool, bool)
11 paddle::framework::Executor::Run(paddle::framework::ProgramDesc const&, paddle::framework::Scope*, int, bool, bool, std::vector<std::string, std::allocator<std::string> > const&, bool)
------------------------------------------
Python Call Stacks (More useful to users):
------------------------------------------
File "/xxx/python3.6/site-packages/paddle/fluid/framework.py", line 2488, in append_op
attrs=kwargs.get("attrs", None))
File "/xxx/python3.6/site-packages/paddle/fluid/layer_helper.py", line 43, in append_op
return self.main_program.current_block().append_op(*args, **kwargs)
File "/xxx/python3.6/site-packages/paddle/fluid/layers/nn.py", line 11104, in scatter_nd_add
outputs={"Out": output})
File "./script/cookbook.py", line 455, in scatter_nd_add
output = fluid.layers.scatter_nd_add(x, index, update)
File "./script/cookbook.py", line 581, in <module>
fn()
----------------------
Error Message Summary:
----------------------
Error: The end row index is out of bound.
[Hint: Expected end_idx <= dims_[0], but received end_idx:5 > dims_[0]:3.] at (/paddle/paddle/fluid/framework/tensor.cc:79)
[operator < scatter_nd_add > error]