未验证 提交 73aa98cf 编写于 作者: R Roc 提交者: GitHub

[0d Tensor] update scatter for zero-dimension tensor (#49279)

* revert concat and change concat to stack

* let stack kernel support int8, uint8 and bool type
上级 1c0afa79
......@@ -255,9 +255,9 @@ void BindDistributed(py::module *m) {
bool sync_op) {
auto out_tensor_list =
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
Tensor stack_out_tensor = paddle::stack(out_tensor_list, 0);
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_out_tensor.impl());
stack_out_tensor.impl());
auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
......@@ -307,16 +307,16 @@ void BindDistributed(py::module *m) {
bool sync_op) {
auto out_tensor_list =
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
Tensor stack_out_tensor = paddle::stack(out_tensor_list, 0);
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_out_tensor.impl());
stack_out_tensor.impl());
auto *out_dense = p_out_tensor.get();
auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
Tensor stack_in_tensor = paddle::stack(in_tensor_list, 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl());
stack_in_tensor.impl());
auto in_dense = *p_in_tensor;
// in_tensor_list should not be empty
......@@ -430,9 +430,9 @@ void BindDistributed(py::module *m) {
auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
Tensor stack_in_tensor = paddle::stack(in_tensor_list, 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl());
stack_in_tensor.impl());
auto in_dense = *p_in_tensor;
distributed::ReduceScatterOptions opts{op};
......@@ -484,9 +484,9 @@ void BindDistributed(py::module *m) {
auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
Tensor stack_in_tensor = paddle::stack(in_tensor_list, 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl());
stack_in_tensor.impl());
auto in_dense = *p_in_tensor;
distributed::ScatterOptions opts{src};
......@@ -746,9 +746,9 @@ void BindDistributed(py::module *m) {
py::handle py_in_tensor) {
auto out_tensor_list =
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
Tensor stack_out_tensor = paddle::stack(out_tensor_list, 0);
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_out_tensor.impl());
stack_out_tensor.impl());
auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
......@@ -854,16 +854,16 @@ void BindDistributed(py::module *m) {
py::handle py_in_tensor_list) {
auto out_tensor_list =
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
Tensor stack_out_tensor = paddle::stack(out_tensor_list, 0);
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_out_tensor.impl());
stack_out_tensor.impl());
auto *out_dense = p_out_tensor.get();
auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
Tensor stack_in_tensor = paddle::stack(in_tensor_list, 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl());
stack_in_tensor.impl());
auto in_dense = *p_in_tensor;
// in_tensor_list should not be empty
......@@ -999,9 +999,9 @@ void BindDistributed(py::module *m) {
auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
Tensor stack_in_tensor = paddle::stack(in_tensor_list, 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl());
stack_in_tensor.impl());
auto in_dense = *p_in_tensor;
distributed::ReduceScatterOptions opts{op};
......@@ -1057,9 +1057,9 @@ void BindDistributed(py::module *m) {
auto in_tensor_list =
CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0);
Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0);
Tensor stack_in_tensor = paddle::stack(in_tensor_list, 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_in_tensor.impl());
stack_in_tensor.impl());
auto in_dense = *p_in_tensor;
distributed::ScatterOptions opts{src};
......
......@@ -911,14 +911,13 @@ void ConcatInferMeta(const std::vector<const MetaTensor*>& x,
// 1. calculate axis
int rank = x.at(0)->dims().size();
PADDLE_ENFORCE_EQ(
!rank || (axis >= -rank && axis < rank),
axis >= -rank && axis < rank,
true,
phi::errors::InvalidArgument(
"The axis is expected to be in range of [%d, %d), but got %d",
-rank,
rank,
axis));
axis = rank ? axis : 0;
if (axis < 0) {
axis = axis + rank;
}
......
......@@ -54,6 +54,10 @@ PD_REGISTER_KERNEL(stack_grad,
phi::StackGradKernel,
float,
double,
bool,
int64_t,
int,
uint8_t,
int8_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -57,6 +57,10 @@ PD_REGISTER_KERNEL(stack,
phi::StackKernel,
float,
double,
int,
bool,
int64_t,
int,
uint8_t,
int8_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -21,14 +21,13 @@ namespace funcs {
static inline int64_t ComputeAxis(int64_t axis, int64_t rank) {
PADDLE_ENFORCE_EQ(
!rank || (axis >= -rank && axis < rank),
axis >= -rank && axis < rank,
true,
phi::errors::InvalidArgument(
"The axis is expected to be in range of [%d, %d), but got %d",
-rank,
rank,
axis));
axis = rank ? axis : 0;
if (axis < 0) {
axis = axis + rank;
}
......
......@@ -34,35 +34,6 @@ void ConcatKernel(const Context& dev_ctx,
DenseTensor* out) {
int64_t axis = axis_scalar.to<int64_t>();
if (UNLIKELY(x[0]->dims().size() == 0)) {
// for dims is 0 specially
phi::DDim tmp_1dim, out_dims;
out_dims[0] = x.size();
tmp_1dim[0] = 1;
out->Resize(out_dims);
dev_ctx.template Alloc<T>(out);
size_t output_offset = 0;
for (auto* in : x) {
if (in->numel() == 0UL) {
continue;
}
auto in_stride = phi::stride_numel(tmp_1dim);
auto out_stride = phi::stride_numel(out->dims());
paddle::operators::StridedNumelCopyWithAxis<T>(
dev_ctx,
axis,
out->data<T>() + output_offset,
out_stride,
in->data<T>(),
in_stride,
in_stride[axis]);
output_offset += in_stride[axis];
}
return;
}
axis = phi::funcs::ComputeAxis(axis, x[0]->dims().size());
std::vector<phi::DDim> x_dims;
......
......@@ -139,7 +139,10 @@ PD_REGISTER_KERNEL(stack_grad,
phi::StackGradKernel,
float,
double,
bool,
int64_t,
int,
uint8_t,
int8_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -175,7 +175,10 @@ PD_REGISTER_KERNEL(stack,
phi::StackKernel,
float,
double,
bool,
int64_t,
int,
uint8_t,
int8_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册