未验证 提交 4a8b97ee 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero Dim] hack process Tensor.numpy() from 0D to 1D to avoid much incompatible (#51586)

上级 145a6cbb
...@@ -120,11 +120,24 @@ static PyObject* tensor_method_numpy(TensorObject* self, ...@@ -120,11 +120,24 @@ static PyObject* tensor_method_numpy(TensorObject* self,
auto sizeof_dtype = phi::SizeOf(self->tensor.type()); auto sizeof_dtype = phi::SizeOf(self->tensor.type());
Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank]; Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];
Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank]; Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];
size_t py_rank = tensor_dims.size();
size_t numel = 1; size_t numel = 1;
for (int i = tensor_dims.size() - 1; i >= 0; --i) { if (py_rank == 0) {
py_dims[i] = static_cast<size_t>(tensor_dims[i]); // 0D Tensor hack process to 1D numpy, will remove in future
py_strides[i] = sizeof_dtype * numel; VLOG(0) << "Warning:: 0D Tensor cannot be used as Tensor.numpy()[0], Now "
numel *= py_dims[i]; "0D will be changed to 1D numpy to avoid this problem, but it's "
"not correct and will be removed in future. Please change "
"'Tensor.numpy()[0]' to 'float(Tensor)' or "
"'Tensor.numpy().item()' as soon as possible.";
py_rank = 1;
py_dims[0] = 1;
py_strides[0] = sizeof_dtype * numel;
} else {
for (int i = tensor_dims.size() - 1; i >= 0; --i) {
py_dims[i] = static_cast<size_t>(tensor_dims[i]);
py_strides[i] = sizeof_dtype * numel;
numel *= py_dims[i];
}
} }
PyObject* array = api.PyArray_NewFromDescr_( PyObject* array = api.PyArray_NewFromDescr_(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册