未验证 提交 1bd468e2 编写于 作者: J JYChen 提交者: GitHub

Hack__getitem__ from 0-d to 1-d with FLAGS_set_to_1d (#53358)

上级 2c12abd7
...@@ -136,17 +136,18 @@ static PyObject* tensor_method_numpy(TensorObject* self, ...@@ -136,17 +136,18 @@ static PyObject* tensor_method_numpy(TensorObject* self,
} }
} }
if (set_to_1d) { if (set_to_1d) {
// 0D Tensor hack process to 1D numpy, will remove in future // 0D Tensor hack process to 1D numpy, will remove in release 2.6
VLOG(0) VLOG(0)
<< "Warning:: 0D Tensor cannot be used as 'Tensor.numpy()[0]' . In " << "Warning:: 0D Tensor cannot be used as 'Tensor.numpy()[0]' . In "
"order to avoid this problem, " "order to avoid this problem, "
"0D Tensor will be changed to 1D numpy currently, but it's not " "0D Tensor will be changed to 1D numpy currently, but it's not "
"correct and will be " "correct and will be "
"removed in future. For Tensor contain only one element, Please " "removed in release 2.6. For Tensor contain only one element, "
"Please "
"modify " "modify "
" 'Tensor.numpy()[0]' to 'float(Tensor)' as soon as " " 'Tensor.numpy()[0]' to 'float(Tensor)' as soon as "
"possible, " "possible, "
"otherwise 'Tensor.numpy()[0]' will raise error in future."; "otherwise 'Tensor.numpy()[0]' will raise error in release 2.6.";
py_rank = 1; py_rank = 1;
py_dims[0] = 1; py_dims[0] = 1;
py_strides[0] = sizeof_dtype * numel; py_strides[0] = sizeof_dtype * numel;
...@@ -923,7 +924,16 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self, ...@@ -923,7 +924,16 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
} }
} }
bool set_to_1d = FLAGS_set_to_1d;
if (!none_axes.empty()) { if (!none_axes.empty()) {
if (set_to_1d) {
// NOTE(zoooo0820): When all axes are decreased, the output will be 1-D
// with FLAGS_set_to_1d=True. In this case, one `None` should be pop out,
// otherwise the output shape will be not correct.
if (static_cast<int>(decrease_axis.size()) == tensor->dims().size()) {
none_axes.pop_back();
}
}
if (!none_axes.empty()) { if (!none_axes.empty()) {
paddle::Tensor new_out; paddle::Tensor new_out;
{ {
......
...@@ -63,6 +63,7 @@ limitations under the License. */ ...@@ -63,6 +63,7 @@ limitations under the License. */
#include "paddle/phi/core/compat/arg_map_context.h" #include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/type_defs.h" #include "paddle/phi/core/type_defs.h"
PHI_DECLARE_bool(set_to_1d);
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
...@@ -1064,7 +1065,19 @@ void BindImperative(py::module *m_ptr) { ...@@ -1064,7 +1065,19 @@ void BindImperative(py::module *m_ptr) {
} }
tracer->TraceOp(op_type, ins, outs, std::move(attrs)); tracer->TraceOp(op_type, ins, outs, std::move(attrs));
} }
bool set_to_1d = FLAGS_set_to_1d;
if (!none_axes.empty()) { if (!none_axes.empty()) {
if (set_to_1d) {
// NOTE(zoooo0820): When all axes are decreased, the output
// will be 1-D with FLAGS_set_to_1d=True. In this case, one
// `None` should be pop out, otherwise the output shape will be
// not correct.
if (static_cast<int>(decrease_axis.size()) ==
tensor->dims().size()) {
none_axes.pop_back();
}
}
if (!none_axes.empty()) { if (!none_axes.empty()) {
// Deal with cases that decrease_axes is not empty // Deal with cases that decrease_axes is not empty
// For example: // For example:
......
...@@ -13,10 +13,13 @@ See the License for the specific language governing permissions and ...@@ -13,10 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <glog/logging.h>
#include <paddle/phi/core/ddim.h> #include <paddle/phi/core/ddim.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/phi/core/flags.h"
PHI_DECLARE_bool(set_to_1d);
namespace phi { namespace phi {
...@@ -202,7 +205,23 @@ inline DDim GetDecreasedDims(const DDim slice_dims, ...@@ -202,7 +205,23 @@ inline DDim GetDecreasedDims(const DDim slice_dims,
new_shape.push_back(decreased_dims[i]); new_shape.push_back(decreased_dims[i]);
} }
} }
if (FLAGS_set_to_1d && new_shape.size() == 0) {
// NOTE(zoooo0820): Hack procssing to 1-D, when axes decrease to 0-D in
// slice. This will remove in release 2.6.
VLOG(0)
<< "Warning:: In Tensor '__getitem__', if the number of scalar "
"elements "
"in the index is equal to the rank of the Tensor, the output "
"should "
"be 0-D. In order to be consistent with the behavior of previous "
"versions, it will be processed to 1-D. But it is not correct and "
"will be "
"removed in release 2.6. "
"If 1-D is still wanted, please modify the index element from "
"scalar to slice "
"(e.g. 'x[i]' => 'x[i:i+1]'). ";
new_shape.push_back(1);
}
decreased_dims = phi::make_ddim(new_shape); decreased_dims = phi::make_ddim(new_shape);
} }
return decreased_dims; return decreased_dims;
......
...@@ -1371,7 +1371,7 @@ def fftshift(x, axes=None, name=None): ...@@ -1371,7 +1371,7 @@ def fftshift(x, axes=None, name=None):
elif isinstance(axes, int): elif isinstance(axes, int):
shifts = shape[axes] // 2 shifts = shape[axes] // 2
else: else:
shifts = paddle.stack([shape[ax] // 2 for ax in axes]) shifts = paddle.concat([shape[ax : ax + 1] // 2 for ax in axes])
return paddle.roll(x, shifts, axes, name=name) return paddle.roll(x, shifts, axes, name=name)
...@@ -1416,7 +1416,7 @@ def ifftshift(x, axes=None, name=None): ...@@ -1416,7 +1416,7 @@ def ifftshift(x, axes=None, name=None):
elif isinstance(axes, int): elif isinstance(axes, int):
shifts = -shape[axes] // 2 shifts = -shape[axes] // 2
else: else:
shifts = paddle.stack([-shape[ax] // 2 for ax in axes]) shifts = paddle.concat([-shape[ax : ax + 1] // 2 for ax in axes])
return paddle.roll(x, shifts, axes, name=name) return paddle.roll(x, shifts, axes, name=name)
......
...@@ -574,6 +574,13 @@ def _getitem_impl_(var, item): ...@@ -574,6 +574,13 @@ def _getitem_impl_(var, item):
out = reverse(out, axis=reverse_axes) out = reverse(out, axis=reverse_axes)
# NOTE(zoooo0820): When all axes are decreased, the output will be 1-D
# with FLAGS_set_to_1d=True. In this case, one `None` should be pop out,
# otherwise the output shape will be not correct.
set_to_1d = paddle.get_flags('FLAGS_set_to_1d')['FLAGS_set_to_1d']
if set_to_1d and len(decrease_axes) == len(var.shape):
none_axes = none_axes[1:]
if len(none_axes) > 0: if len(none_axes) > 0:
# Deal with cases that decrease_axes is not empty # Deal with cases that decrease_axes is not empty
# For example: # For example:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册