未验证 提交 f3393f49 编写于 作者: 傅剑寒 提交者: GitHub

add index_put api (#52886)

* add index_put api

* fix value broadcast in backward and add test case in static

* add timeout=120s for index_put

* add op_compat for index_put

* add inplace index_put test

* add test case when index tensor in indices is int32 when indices.size less than x.dims

* add index_put api backward in cpu place

* add backward test case

* refactor code to delete some duplicated code

* replace reshape with resize for decrease extra memcpy

* add datatype flag in backward yaml

* fix bug in documentation

* Update python/paddle/tensor/manipulation.py

---------
Co-authored-by: NLigoml <39876205+Ligoml@users.noreply.github.com>
上级 65a3a584
......@@ -796,6 +796,17 @@
data_type : out_grad
inplace : (out_grad -> x_grad)
- backward_op : index_put_grad
forward : index_put (Tensor x, Tensor[] indices, Tensor value, bool accumulate=false) -> Tensor(out)
args : (Tensor x, Tensor[] indices, Tensor value, Tensor out_grad, bool accumulate=false)
output : Tensor(x_grad), Tensor(value_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, value]
kernel :
func : index_put_grad
data_type : out_grad
- backward_op : index_sample_grad
forward : index_sample (Tensor x, Tensor index) -> Tensor(out)
args : (Tensor x, Tensor index, Tensor out_grad)
......
......@@ -940,6 +940,17 @@
inplace : (x -> out)
backward : index_add_grad
- op : index_put
args : (Tensor x, Tensor[] indices, Tensor value, bool accumulate=false)
output : Tensor(out)
infer_meta :
func : IndexPutInferMeta
kernel :
func : index_put
data_type : x
inplace : (x -> out)
backward : index_put_grad
- op : index_sample
args : (Tensor x, Tensor index)
output : Tensor
......
......@@ -1962,6 +1962,21 @@ void InterpolateInferMeta(
}
}
void IndexPutInferMeta(const MetaTensor& x,
const std::vector<const MetaTensor*>& indices,
const MetaTensor& value,
bool accumulate,
MetaTensor* out) {
auto in_dims = x.dims();
PADDLE_ENFORCE_LT(
in_dims.size(),
7,
phi::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.",
in_dims.size()));
out->share_meta(x);
}
void LambInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
......@@ -3295,6 +3310,5 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
out_count->set_dims({-1});
out_count->set_dtype(DataType::INT32);
}
} // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);
......@@ -332,6 +332,12 @@ void InterpolateInferMeta(
MetaTensor* output,
MetaConfig config = MetaConfig());
void IndexPutInferMeta(const MetaTensor& x,
const std::vector<const MetaTensor*>& indices,
const MetaTensor& value,
bool accumulate,
MetaTensor* out);
void LambInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_put_grad_kernel.h"
#include <numeric>
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/index_put_utils.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
namespace phi {
template <typename T>
void set_zero_kernel(const int64_t N,
const int64_t** indices,
const phi::DDim& stride,
const phi::DDim& shape,
T* out) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t idx = 0; idx < N; ++idx) {
int64_t cur_ix = 0;
int64_t offset = 0;
for (int i = 0; i < shape.size(); ++i) {
cur_ix = (static_cast<int64_t>(*(indices[i] + idx)));
if (cur_ix < 0) {
cur_ix += shape[i];
}
offset += stride[i] * cur_ix;
}
*(out + offset) = 0;
}
}
template <typename T>
void index_put_grad_kernel(const int64_t N,
const T* out_grad,
const int64_t** indices,
const phi::DDim& stride,
const phi::DDim& shape,
T* value_grad) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t idx = 0; idx < N; ++idx) {
int64_t cur_ix = 0;
int64_t offset = 0;
for (int i = 0; i < shape.size(); ++i) {
cur_ix = (static_cast<int64_t>(*(indices[i] + idx)));
if (cur_ix < 0) {
cur_ix += shape[i];
}
offset += stride[i] * cur_ix;
}
*(value_grad + idx) = *(out_grad + offset);
}
}
template <typename T, typename Context>
void LaunchIndexPutGradKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& indices,
const DenseTensor& out_grad,
bool accumulate,
DenseTensor* value_grad,
DenseTensor* x_grad) {
const int64_t* pd_indices[7];
for (size_t i = 0; i < indices.size(); ++i) {
pd_indices[i] = indices[i]->data<int64_t>();
}
if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
if (!accumulate) {
T* x_grad_data = x_grad->data<T>();
auto x_grad_dims = x_grad->dims();
const int64_t numel = indices[0]->numel();
auto x_grad_stride = phi::stride(x_grad_dims);
set_zero_kernel<T>(
numel, pd_indices, x_grad_stride, x_grad_dims, x_grad_data);
}
}
auto out_grad_dims = out_grad.dims();
const int64_t numel = indices[0]->numel();
auto out_grad_stride = phi::stride(out_grad_dims);
if (value_grad) {
if (value_grad->numel() == 1) {
DenseTensor tmp_value_grad(value_grad->dtype());
tmp_value_grad.Resize(indices[0]->dims());
T* tmp_value_grad_data = dev_ctx.template Alloc<T>(&tmp_value_grad);
auto out_grad_data = out_grad.data<T>();
index_put_grad_kernel<T>(numel,
out_grad_data,
pd_indices,
out_grad_stride,
out_grad_dims,
tmp_value_grad_data);
std::vector<int> v_dims(tmp_value_grad.dims().size());
std::iota(v_dims.begin(), v_dims.end(), 0);
IntArray v_axis(v_dims);
SumKernel<T>(dev_ctx,
tmp_value_grad,
v_axis,
value_grad->dtype(),
false,
value_grad);
} else if (value_grad->numel() == indices[0]->numel()) {
T* value_grad_data = dev_ctx.template Alloc<T>(value_grad);
auto out_grad_data = out_grad.data<T>();
index_put_grad_kernel<T>(numel,
out_grad_data,
pd_indices,
out_grad_stride,
out_grad_dims,
value_grad_data);
} else {
DenseTensor tmp_value_grad(value_grad->dtype());
tmp_value_grad.Resize(indices[0]->dims());
T* tmp_value_grad_data = dev_ctx.template Alloc<T>(&tmp_value_grad);
auto out_grad_data = out_grad.data<T>();
index_put_grad_kernel<T>(numel,
out_grad_data,
pd_indices,
out_grad_stride,
out_grad_dims,
tmp_value_grad_data);
std::vector<int64_t> after_dims = phi::vectorize(tmp_value_grad.dims());
std::vector<int64_t> before_dims = phi::vectorize(value_grad->dims());
std::vector<int64_t> compress_dims;
std::vector<int64_t> dims_without_1;
funcs::CalCompressedDimsWith1AndWithout1(
&after_dims, &before_dims, &compress_dims, &dims_without_1);
auto pre_dims = value_grad->dims();
value_grad->Resize(phi::make_ddim(dims_without_1));
IntArray v_axis(compress_dims);
SumKernel<T>(dev_ctx,
tmp_value_grad,
v_axis,
value_grad->dtype(),
false,
value_grad);
value_grad->Resize(pre_dims);
}
}
}
template <typename T, typename Context>
void IndexPutGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<const DenseTensor*>& indices,
const DenseTensor& value,
const DenseTensor& out_grad,
bool accumulate,
DenseTensor* x_grad,
DenseTensor* value_grad) {
PADDLE_ENFORCE_EQ(
x.dtype(),
value.dtype(),
phi::errors::InvalidArgument(
"The data type of tensor in indices must be same to the data type "
"of tensor x."));
std::vector<DenseTensor> tmp_args;
std::vector<const phi::DenseTensor*> int_indices_v =
funcs::DealWithBoolIndices<T, Context>(dev_ctx, indices, &tmp_args);
auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v);
std::vector<int64_t> res_dim_v(phi::vectorize(bd_dim));
std::vector<const phi::DenseTensor*> res_indices_v(x.dims().size(), nullptr);
std::vector<DenseTensor> tmp_res_indices_v;
std::vector<DenseTensor> range_tensor_v;
for (int i = indices.size(); i < x.dims().size(); ++i) {
range_tensor_v.emplace_back(funcs::GetRangeTensor<int64_t, Context>(
dev_ctx, x.dims()[i], phi::DataType::INT64));
}
funcs::DealWithIndices<T, Context>(dev_ctx,
x,
int_indices_v,
&res_indices_v,
&tmp_res_indices_v,
range_tensor_v,
bd_dim,
&res_dim_v);
LaunchIndexPutGradKernel<T, Context>(
dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad);
}
} // namespace phi
PD_REGISTER_KERNEL(index_put_grad,
CPU,
ALL_LAYOUT,
phi::IndexPutGradKernel,
float,
double,
int,
int64_t,
bool) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_put_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/index_put_utils.h"
namespace phi {
template <typename T>
void index_put_kernel(const int64_t N,
const T* x,
const T* vals,
const int64_t** indices,
const phi::DDim& stride,
const phi::DDim& shape,
int64_t is_single_val_tensor,
bool accumulate,
T* out) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t idx = 0; idx < N; ++idx) {
int64_t cur_ix = 0;
int64_t offset = 0;
for (int i = 0; i < shape.size(); ++i) {
cur_ix = (static_cast<int64_t>(*(indices[i] + idx)));
if (cur_ix < 0) {
cur_ix += shape[i];
}
offset += stride[i] * cur_ix;
}
if (accumulate) {
*(out + offset) += *(vals + (idx & is_single_val_tensor));
} else {
*(out + offset) = *(vals + (idx & is_single_val_tensor));
}
}
}
template <typename T, typename Context>
void LaunchIndexPutKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<const DenseTensor*>& indices,
const DenseTensor& value,
bool accumulate,
DenseTensor* out) {
auto* x_data = x.data<T>();
auto* val_data = value.data<T>();
bool is_initialized = out->initialized();
T* out_data = dev_ctx.template Alloc<T>(out);
if (!is_initialized) {
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
}
auto x_dims = x.dims();
const int64_t numel = indices[0]->numel();
auto x_stride = phi::stride(x_dims);
int64_t is_single_val_tensor = (value.numel() == 1) ? 0 : INT64_MAX;
const int64_t* pd_indices[7];
for (size_t i = 0; i < indices.size(); ++i) {
pd_indices[i] = indices[i]->data<int64_t>();
}
index_put_kernel<T>(numel,
x_data,
val_data,
pd_indices,
x_stride,
x_dims,
is_single_val_tensor,
accumulate,
out_data);
}
template <typename T, typename Context>
void IndexPutKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<const DenseTensor*>& indices,
const DenseTensor& value,
bool accumulate,
DenseTensor* out) {
PADDLE_ENFORCE_EQ(
x.dtype(),
value.dtype(),
phi::errors::InvalidArgument(
"The data type of tensor in indices must be same to the data type "
"of tensor x."));
PADDLE_ENFORCE_EQ(indices.empty(),
false,
phi::errors::InvalidArgument("Indices cannot be empty."));
const size_t total_dims = x.dims().size();
PADDLE_ENFORCE_LE(total_dims,
6,
phi::errors::InvalidArgument(
"Dims of input tensor should be less than 7."));
std::vector<DenseTensor> tmp_args;
std::vector<const phi::DenseTensor*> int_indices_v =
funcs::DealWithBoolIndices<T, Context>(dev_ctx, indices, &tmp_args);
auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v);
std::vector<int64_t> res_dim_v(phi::vectorize(bd_dim));
std::vector<const phi::DenseTensor*> res_indices_v(x.dims().size(), nullptr);
std::vector<DenseTensor> tmp_res_indices_v;
std::vector<DenseTensor> tmp_value_v;
std::vector<DenseTensor> range_tensor_v;
const DenseTensor* ptr_value = nullptr;
for (int i = indices.size(); i < x.dims().size(); ++i) {
range_tensor_v.emplace_back(funcs::GetRangeTensor<int64_t, Context>(
dev_ctx, x.dims()[i], phi::DataType::INT64));
}
funcs::DealWithIndices<T, Context>(dev_ctx,
x,
int_indices_v,
&res_indices_v,
&tmp_res_indices_v,
range_tensor_v,
bd_dim,
&res_dim_v);
if (value.numel() != 1) {
tmp_value_v.emplace_back(
DenseTensor(value.dtype()).Resize(phi::make_ddim(res_dim_v)));
ExpandKernel<T, Context>(
dev_ctx, value, IntArray(res_dim_v), &tmp_value_v[0]);
ptr_value = &tmp_value_v[0];
} else {
ptr_value = &value;
}
LaunchIndexPutKernel<T, Context>(
dev_ctx, x, res_indices_v, *ptr_value, accumulate, out);
}
} // namespace phi
PD_REGISTER_KERNEL(index_put,
CPU,
ALL_LAYOUT,
phi::IndexPutKernel,
float,
double,
int,
int64_t,
bool) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/utils/array.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/expand_kernel.h"
#include "paddle/phi/kernels/nonzero_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
#include "paddle/phi/kernels/split_kernel.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#ifdef __NVCC__
#include <cuda.h>
#include <cuda_runtime.h>
#elif defined(__HIPCC__)
#include <hip/hip_runtime.h>
#endif
#endif
namespace phi {
namespace funcs {
template <typename T, typename Context>
phi::DenseTensor GetReshapeAndExpandTensor(const Context& dev_ctx,
const phi::DenseTensor& tensor,
const phi::DDim& res_dim,
const phi::DDim& bd_dim,
int index) {
std::vector<int64_t> before_dims = phi::vectorize(tensor.dims());
std::vector<int64_t> mid_dims(res_dim.size(), 1);
if (index == 0) {
for (size_t i = 0; i < before_dims.size(); ++i) {
mid_dims[bd_dim.size() - i - 1] = before_dims[before_dims.size() - i - 1];
}
} else {
mid_dims[index] = before_dims[0];
}
phi::DenseTensor mid_tensor(tensor.dtype());
mid_tensor.Resize(phi::make_ddim(mid_dims));
ReshapeInferKernel<Context>(dev_ctx, tensor, IntArray(mid_dims), &mid_tensor);
phi::DenseTensor res_tensor(tensor.dtype());
res_tensor.Resize(res_dim);
ExpandKernel<T, Context>(
dev_ctx, mid_tensor, IntArray(phi::vectorize(res_dim)), &res_tensor);
return res_tensor;
}
template <typename T, typename Context>
std::vector<const phi::DenseTensor*> DealWithBoolIndices(
const Context& dev_ctx,
const std::vector<const phi::DenseTensor*>& indices_v,
std::vector<phi::DenseTensor>* tmp_indices_v) {
std::vector<const phi::DenseTensor*> res(indices_v.begin(), indices_v.end());
bool contains_bool_tensor = false;
for (size_t i = 0; i < indices_v.size(); ++i) {
if (indices_v[i]->dtype() == phi::DataType::BOOL) {
contains_bool_tensor = true;
} else if ((indices_v[i]->dtype() == phi::DataType::INT64) ||
(indices_v[i]->dtype() == phi::DataType::INT32)) {
PADDLE_ENFORCE_EQ(
contains_bool_tensor,
false,
phi::errors::InvalidArgument(
"indices contains bool tensor and int32/int64 tensor at the same "
"time"));
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"data type of tensor in indices must be int32, int64 or bool"));
}
}
if (contains_bool_tensor) {
if (indices_v.size() != 1) {
PADDLE_THROW(phi::errors::InvalidArgument(
"the size of indices must be 1 when it containts bool tensor"));
}
int rank = indices_v[0]->dims().size();
PADDLE_ENFORCE_GE(
rank,
1UL,
phi::errors::InvalidArgument("the only bool tensor in indices should "
"have number of dimension at least 1"));
phi::DenseTensor nonzero_indices(phi::DataType::INT64);
nonzero_indices.Resize(phi::make_ddim({-1, rank}));
NonZeroKernel<bool, Context>(dev_ctx, *indices_v[0], &nonzero_indices);
std::vector<phi::DenseTensor*> integer_indices(rank, nullptr);
for (int i = 0; i < rank; ++i) {
tmp_indices_v->emplace_back(
DenseTensor(phi::DataType::INT64)
.Resize(phi::make_ddim({nonzero_indices.dims()[0]})));
}
for (int i = 0; i < rank; ++i) {
integer_indices[i] = &((*tmp_indices_v)[i]);
}
SplitWithNumKernel<int64_t, Context>(
dev_ctx, nonzero_indices, rank, 1, integer_indices);
std::vector<const phi::DenseTensor*> res_tmp(integer_indices.size(),
nullptr);
for (int i = 0; i < rank; ++i) {
res_tmp[i] = &((*tmp_indices_v)[i]);
}
res.swap(res_tmp);
}
return res;
}
static phi::DDim BroadCastTensorsDims(
const std::vector<const phi::DenseTensor*>& tensors) {
int target_rank = 0;
for (const auto& tensor : tensors) {
target_rank = std::max(target_rank, tensor->dims().size());
}
PADDLE_ENFORCE_GT(target_rank,
0,
errors::InvalidArgument("BroadCastTensorsDims requires at "
"least one input tensor to have "
"rank greater than zero"));
std::vector<int64_t> target_dims(target_rank, 0);
for (int index = 0; index < target_rank; index++) {
int target_dim_size = 1;
for (const auto& tensor : tensors) {
auto input_ddim = tensor->dims();
int axis = static_cast<int>(input_ddim.size()) - index - 1;
int dim_size = 1;
if (axis >= 0) {
dim_size = input_ddim[axis];
}
if (target_dim_size != 1 && dim_size != 1 &&
target_dim_size != dim_size) {
PADDLE_THROW(errors::InvalidArgument(
"BroadCastTensorsDims inputs does not satisfy bcast semantics, "
"please check axis = %d in reverse order",
index));
}
target_dim_size = dim_size == 1 ? target_dim_size : dim_size;
}
target_dims[target_rank - index - 1] = target_dim_size;
}
return phi::make_ddim(target_dims);
}
template <typename T, typename Context>
T** GetDevicePointerArray(const Context& ctx,
const std::vector<const DenseTensor*>& indices_v) {
std::vector<const T*> h_indices_v(indices_v.size());
for (int i = 0; i < indices_v.size(); ++i) {
h_indices_v[i] = indices_v[i]->data<T>();
}
auto d_indices_data = phi::memory_utils::Alloc(
ctx.GetPlace(),
h_indices_v.size() * sizeof(T*),
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
phi::memory_utils::Copy(ctx.GetPlace(),
d_indices_data->ptr(),
phi::CPUPlace(),
reinterpret_cast<void*>(h_indices_v.data()),
h_indices_v.size() * sizeof(T*),
ctx.stream());
return reinterpret_cast<T**>(d_indices_data->ptr());
}
template <typename T, typename Context>
void DealWithIndices(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<const phi::DenseTensor*>& int_indices_v,
std::vector<const phi::DenseTensor*>* res_indices_v,
std::vector<DenseTensor>* tmp_res_indices_v,
const std::vector<DenseTensor>& range_tensor_v,
const phi::DDim& bd_dim,
std::vector<int64_t>* res_dim_v) {
size_t total_dims = x.dims().size();
if (int_indices_v.size() < total_dims) {
std::vector<int64_t> tmp_x_dims = phi::vectorize(x.dims());
int len_bd_dim = bd_dim.size();
res_dim_v->insert(res_dim_v->end(),
tmp_x_dims.begin() + int_indices_v.size(),
tmp_x_dims.end());
std::vector<DenseTensor> reshaped_indices_v;
for (size_t i = 0; i < int_indices_v.size(); ++i) {
if (int_indices_v[i]->dtype() == phi::DataType::INT32) {
reshaped_indices_v.emplace_back(phi::Cast<int, Context>(
dev_ctx, *int_indices_v[i], phi::DataType::INT64));
} else {
reshaped_indices_v.emplace_back(*int_indices_v[i]);
}
}
reshaped_indices_v.insert(
reshaped_indices_v.end(), range_tensor_v.begin(), range_tensor_v.end());
phi::DDim res_dim = phi::make_ddim(*res_dim_v);
for (size_t i = 0; i < reshaped_indices_v.size(); ++i) {
tmp_res_indices_v->emplace_back(
GetReshapeAndExpandTensor<int64_t, Context>(
dev_ctx,
reshaped_indices_v[i],
res_dim,
bd_dim,
((i < int_indices_v.size())
? 0
: i - int_indices_v.size() + len_bd_dim)));
}
for (size_t i = 0; i < res_indices_v->size(); ++i) {
(*res_indices_v)[i] = &(*tmp_res_indices_v)[i];
}
} else {
std::vector<DenseTensor> int_indices_v_tmp;
for (size_t i = 0; i < int_indices_v.size(); ++i) {
if (int_indices_v[i]->dtype() == phi::DataType::INT32) {
int_indices_v_tmp.emplace_back(phi::Cast<int, Context>(
dev_ctx, *int_indices_v[i], phi::DataType::INT64));
} else {
int_indices_v_tmp.emplace_back(*int_indices_v[i]);
}
}
for (size_t i = 0; i < int_indices_v.size(); ++i) {
if (bd_dim != int_indices_v[i]->dims()) {
tmp_res_indices_v->emplace_back(
DenseTensor(phi::DataType::INT64).Resize(bd_dim));
ExpandKernel<int64_t, Context>(
dev_ctx,
int_indices_v_tmp[i],
IntArray(phi::vectorize<int64_t>(bd_dim)),
&(*tmp_res_indices_v)[i]);
} else {
tmp_res_indices_v->emplace_back(int_indices_v_tmp[i]);
}
}
for (size_t i = 0; i < res_indices_v->size(); ++i) {
(*res_indices_v)[i] = &(*tmp_res_indices_v)[i];
}
}
}
static void CalCompressedDimsWith1AndWithout1(
std::vector<int64_t>* after_dims,
std::vector<int64_t>* before_dims,
std::vector<int64_t>* compress_dims,
std::vector<int64_t>* dims_without_1) {
int i = static_cast<int>(after_dims->size()) - 1;
int j = static_cast<int>(before_dims->size()) - 1;
if (i < j) {
PADDLE_THROW(phi::errors::InvalidArgument(
"shape of value can't not be broadcast to shape of x[indices]"));
}
while ((i >= 0) && (j >= 0)) {
if ((*after_dims)[i] == (*before_dims)[j]) {
dims_without_1->push_back((*before_dims)[j]);
i--;
j--;
continue;
} else if ((*before_dims)[j] == 1) {
compress_dims->push_back(i);
i--;
j--;
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"shape of value can't not be broadcast to shape of x[indices]"));
}
}
while (i >= 0) {
compress_dims->push_back(i);
i--;
}
}
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T>
__global__ void range_cuda_kernel(int64_t N, T* out) {
int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
if (idx >= N) {
return;
}
out[idx] = idx;
}
template <typename T, typename Context>
phi::DenseTensor GetRangeCudaTensor(const Context& dev_ctx,
int64_t N,
phi::DataType dtype) {
phi::DenseTensor res(dtype);
res.Resize(phi::make_ddim({N}));
DenseTensor* p_res = &res;
T* out = dev_ctx.template Alloc<T>(p_res);
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, N);
range_cuda_kernel<T>
<<<config.block_per_grid, config.thread_per_block, 0, dev_ctx.stream()>>>(
N, out);
return res;
}
#endif
template <typename T>
void range_kernel(int64_t N, T* out) {
for (int64_t idx = 0; idx < N; ++idx) {
out[idx] = idx;
}
}
template <typename T, typename Context>
phi::DenseTensor GetRangeTensor(const Context& dev_ctx,
int64_t N,
phi::DataType dtype) {
phi::DenseTensor res(dtype);
res.Resize(phi::make_ddim({N}));
DenseTensor* p_res = &res;
T* out = dev_ctx.template Alloc<T>(p_res);
range_kernel<T>(N, out);
return res;
}
} // namespace funcs
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_put_grad_kernel.h"
#include <numeric>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/index_put_utils.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
namespace phi {
template <typename T, size_t Rank>
__global__ void set_zero_cuda_kernel(const int64_t N,
int64_t** indices,
phi::Array<int64_t, Rank> stride,
phi::Array<int64_t, Rank> shape,
T* out) {
int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
int64_t cur_ix = 0;
if (idx >= N) {
return;
}
int64_t offset = 0;
for (int i = 0; i < Rank; ++i) {
cur_ix = (static_cast<int64_t>(*(indices[i] + idx)));
if (cur_ix < 0) {
cur_ix += shape[i];
}
offset += stride[i] * cur_ix;
}
*(out + offset) = 0;
}
template <typename T, size_t Rank>
__global__ void index_put_grad_cuda_kernel(const int64_t N,
const T* out_grad,
int64_t** indices,
phi::Array<int64_t, Rank> stride,
phi::Array<int64_t, Rank> shape,
T* value_grad) {
int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
int64_t cur_ix = 0;
if (idx >= N) {
return;
}
int64_t offset = 0;
for (int i = 0; i < Rank; ++i) {
cur_ix = (static_cast<int64_t>(*(indices[i] + idx)));
if (cur_ix < 0) {
cur_ix += shape[i];
}
offset += stride[i] * cur_ix;
}
*(value_grad + idx) = *(out_grad + offset);
}
template <typename T, typename Context, size_t Rank>
void LaunchIndexPutGradCudaKernel(
const Context& dev_ctx,
const std::vector<const DenseTensor*>& indices,
const DenseTensor& out_grad,
bool accumulate,
DenseTensor* value_grad,
DenseTensor* x_grad) {
if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
if (!accumulate) {
T* x_grad_data = x_grad->data<T>();
auto x_grad_dims = x_grad->dims();
const int64_t numel = indices[0]->numel();
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel);
auto x_grad_stride = phi::stride(x_grad_dims);
phi::Array<int64_t, Rank> stride_a;
phi::Array<int64_t, Rank> shape_a;
for (size_t idx = 0; idx < Rank; ++idx) {
stride_a[idx] = x_grad_stride[idx];
shape_a[idx] = x_grad_dims[idx];
}
auto pd_indices =
funcs::GetDevicePointerArray<int64_t, Context>(dev_ctx, indices);
set_zero_cuda_kernel<T, Rank><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(
numel, pd_indices, stride_a, shape_a, x_grad_data);
}
}
auto out_grad_dims = out_grad.dims();
const int64_t numel = indices[0]->numel();
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel);
auto out_grad_stride = phi::stride(out_grad_dims);
phi::Array<int64_t, Rank> stride_a;
phi::Array<int64_t, Rank> shape_a;
for (size_t idx = 0; idx < Rank; ++idx) {
stride_a[idx] = out_grad_stride[idx];
shape_a[idx] = out_grad_dims[idx];
}
auto pd_indices =
funcs::GetDevicePointerArray<int64_t, Context>(dev_ctx, indices);
if (value_grad) {
if (value_grad->numel() == 1) {
DenseTensor tmp_value_grad(value_grad->dtype());
tmp_value_grad.Resize(indices[0]->dims());
T* tmp_value_grad_data = dev_ctx.template Alloc<T>(&tmp_value_grad);
auto out_grad_data = out_grad.data<T>();
index_put_grad_cuda_kernel<T, Rank>
<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(numel,
out_grad_data,
pd_indices,
stride_a,
shape_a,
tmp_value_grad_data);
std::vector<int> v_dims(tmp_value_grad.dims().size());
std::iota(v_dims.begin(), v_dims.end(), 0);
IntArray v_axis(v_dims);
SumKernel<T, Context>(dev_ctx,
tmp_value_grad,
v_axis,
value_grad->dtype(),
false,
value_grad);
} else if (value_grad->numel() == indices[0]->numel()) {
T* value_grad_data = dev_ctx.template Alloc<T>(value_grad);
auto out_grad_data = out_grad.data<T>();
index_put_grad_cuda_kernel<T, Rank><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(
numel, out_grad_data, pd_indices, stride_a, shape_a, value_grad_data);
} else {
DenseTensor tmp_value_grad(value_grad->dtype());
tmp_value_grad.Resize(indices[0]->dims());
T* tmp_value_grad_data = dev_ctx.template Alloc<T>(&tmp_value_grad);
auto out_grad_data = out_grad.data<T>();
index_put_grad_cuda_kernel<T, Rank>
<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(numel,
out_grad_data,
pd_indices,
stride_a,
shape_a,
tmp_value_grad_data);
std::vector<int64_t> after_dims = phi::vectorize(tmp_value_grad.dims());
std::vector<int64_t> before_dims = phi::vectorize(value_grad->dims());
std::vector<int64_t> compress_dims;
std::vector<int64_t> dims_without_1;
funcs::CalCompressedDimsWith1AndWithout1(
&after_dims, &before_dims, &compress_dims, &dims_without_1);
auto pre_dims = value_grad->dims();
value_grad->Resize(phi::make_ddim(dims_without_1));
IntArray v_axis(compress_dims);
SumKernel<T, Context>(dev_ctx,
tmp_value_grad,
v_axis,
value_grad->dtype(),
false,
value_grad);
value_grad->Resize(pre_dims);
}
}
}
template <typename T, typename Context>
void IndexPutGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<const DenseTensor*>& indices,
const DenseTensor& value,
const DenseTensor& out_grad,
bool accumulate,
DenseTensor* x_grad,
DenseTensor* value_grad) {
PADDLE_ENFORCE_EQ(
x.dtype(),
value.dtype(),
phi::errors::InvalidArgument(
"The data type of tensor in indices must be same to the data type "
"of tensor x."));
std::vector<DenseTensor> tmp_args;
std::vector<const phi::DenseTensor*> int_indices_v =
funcs::DealWithBoolIndices<T, Context>(dev_ctx, indices, &tmp_args);
const size_t total_dims = x.dims().size();
auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v);
std::vector<int64_t> res_dim_v(phi::vectorize(bd_dim));
std::vector<const phi::DenseTensor*> res_indices_v(x.dims().size(), nullptr);
std::vector<DenseTensor> tmp_res_indices_v;
std::vector<DenseTensor> range_tensor_v;
for (int i = indices.size(); i < x.dims().size(); ++i) {
range_tensor_v.emplace_back(funcs::GetRangeCudaTensor<int64_t, Context>(
dev_ctx, x.dims()[i], phi::DataType::INT64));
}
funcs::DealWithIndices<T, Context>(dev_ctx,
x,
int_indices_v,
&res_indices_v,
&tmp_res_indices_v,
range_tensor_v,
bd_dim,
&res_dim_v);
switch (total_dims) {
case 1:
LaunchIndexPutGradCudaKernel<T, Context, 1>(
dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad);
break;
case 2:
LaunchIndexPutGradCudaKernel<T, Context, 2>(
dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad);
break;
case 3:
LaunchIndexPutGradCudaKernel<T, Context, 3>(
dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad);
break;
case 4:
LaunchIndexPutGradCudaKernel<T, Context, 4>(
dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad);
break;
case 5:
LaunchIndexPutGradCudaKernel<T, Context, 5>(
dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad);
break;
case 6:
LaunchIndexPutGradCudaKernel<T, Context, 6>(
dev_ctx, res_indices_v, out_grad, accumulate, value_grad, x_grad);
break;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"dims of input tensor should be less than 7, But received"
"%d",
x.dims().size()));
}
}
} // namespace phi
PD_REGISTER_KERNEL(index_put_grad,
GPU,
ALL_LAYOUT,
phi::IndexPutGradKernel,
float,
double,
int,
int64_t,
bool,
phi::dtype::float16) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_put_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/index_put_utils.h"
namespace phi {
template <typename T, size_t Rank>
__global__ void index_put_cuda_kernel(const int64_t N,
const T* x,
const T* vals,
int64_t** indices,
phi::Array<int64_t, Rank> stride,
phi::Array<int64_t, Rank> shape,
int64_t is_single_val_tensor,
bool accumulate,
T* out) {
int64_t idx = threadIdx.x + blockDim.x * blockIdx.x;
int64_t cur_ix = 0;
if (idx >= N) {
return;
}
int64_t offset = 0;
for (int i = 0; i < Rank; ++i) {
cur_ix = (static_cast<int64_t>(*(indices[i] + idx)));
if (cur_ix < 0) {
cur_ix += shape[i];
}
offset += stride[i] * cur_ix;
}
if (accumulate) {
*(out + offset) += *(vals + (idx & is_single_val_tensor));
} else {
*(out + offset) = *(vals + (idx & is_single_val_tensor));
}
}
template <typename T, typename Context, size_t Rank>
void LaunchIndexPutCudaKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<const DenseTensor*>& indices,
const DenseTensor& value,
bool accumulate,
DenseTensor* out) {
auto* x_data = x.data<T>();
auto* val_data = value.data<T>();
bool is_initialized = out->initialized();
T* out_data = dev_ctx.template Alloc<T>(out);
if (!is_initialized) {
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
}
auto x_dims = x.dims();
const int64_t numel = indices[0]->numel();
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel);
auto x_stride = phi::stride(x_dims);
phi::Array<int64_t, Rank> stride_a;
phi::Array<int64_t, Rank> shape_a;
for (size_t idx = 0; idx < Rank; ++idx) {
stride_a[idx] = x_stride[idx];
shape_a[idx] = x_dims[idx];
}
int64_t is_single_val_tensor = (value.numel() == 1) ? 0 : INT64_MAX;
auto pd_indices =
funcs::GetDevicePointerArray<int64_t, Context>(dev_ctx, indices);
index_put_cuda_kernel<T, Rank>
<<<config.block_per_grid, config.thread_per_block, 0, dev_ctx.stream()>>>(
numel,
x_data,
val_data,
pd_indices,
stride_a,
shape_a,
is_single_val_tensor,
accumulate,
out_data);
}
template <typename T, typename Context>
void IndexPutKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<const DenseTensor*>& indices,
const DenseTensor& value,
bool accumulate,
DenseTensor* out) {
PADDLE_ENFORCE_EQ(
x.dtype(),
value.dtype(),
phi::errors::InvalidArgument(
"The data type of tensor in indices must be same to the data type "
"of tensor x."));
PADDLE_ENFORCE_EQ(indices.empty(),
false,
phi::errors::InvalidArgument("Indices cannot be empty."));
std::vector<DenseTensor> tmp_args;
std::vector<const phi::DenseTensor*> int_indices_v =
funcs::DealWithBoolIndices<T, Context>(dev_ctx, indices, &tmp_args);
const size_t total_dims = x.dims().size();
auto bd_dim = funcs::BroadCastTensorsDims(int_indices_v);
std::vector<int64_t> res_dim_v(phi::vectorize(bd_dim));
std::vector<const phi::DenseTensor*> res_indices_v(x.dims().size(), nullptr);
std::vector<DenseTensor> tmp_res_indices_v;
std::vector<DenseTensor> tmp_value_v;
std::vector<DenseTensor> range_tensor_v;
const DenseTensor* ptr_value = nullptr;
for (int i = indices.size(); i < x.dims().size(); ++i) {
range_tensor_v.emplace_back(funcs::GetRangeCudaTensor<int64_t, Context>(
dev_ctx, x.dims()[i], phi::DataType::INT64));
}
funcs::DealWithIndices<T, Context>(dev_ctx,
x,
int_indices_v,
&res_indices_v,
&tmp_res_indices_v,
range_tensor_v,
bd_dim,
&res_dim_v);
if (value.numel() != 1) {
tmp_value_v.emplace_back(
DenseTensor(value.dtype()).Resize(phi::make_ddim(res_dim_v)));
ExpandKernel<T, Context>(
dev_ctx, value, IntArray(res_dim_v), &tmp_value_v[0]);
ptr_value = &tmp_value_v[0];
} else {
ptr_value = &value;
}
switch (total_dims) {
case 1:
LaunchIndexPutCudaKernel<T, Context, 1>(
dev_ctx, x, res_indices_v, *ptr_value, accumulate, out);
break;
case 2:
LaunchIndexPutCudaKernel<T, Context, 2>(
dev_ctx, x, res_indices_v, *ptr_value, accumulate, out);
break;
case 3:
LaunchIndexPutCudaKernel<T, Context, 3>(
dev_ctx, x, res_indices_v, *ptr_value, accumulate, out);
break;
case 4:
LaunchIndexPutCudaKernel<T, Context, 4>(
dev_ctx, x, res_indices_v, *ptr_value, accumulate, out);
break;
case 5:
LaunchIndexPutCudaKernel<T, Context, 5>(
dev_ctx, x, res_indices_v, *ptr_value, accumulate, out);
break;
case 6:
LaunchIndexPutCudaKernel<T, Context, 6>(
dev_ctx, x, res_indices_v, *ptr_value, accumulate, out);
break;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"dims of input tensor should be less than 7, But received"
"%d",
x.dims().size()));
}
}
} // namespace phi
PD_REGISTER_KERNEL(index_put,
GPU,
ALL_LAYOUT,
phi::IndexPutKernel,
float,
double,
int,
int64_t,
bool,
phi::dtype::float16) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void IndexPutGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<const DenseTensor*>& indices_v,
const DenseTensor& value,
const DenseTensor& out_grad,
bool accumulate,
DenseTensor* x_grad,
DenseTensor* value_grad);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void IndexPutKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<const DenseTensor*>& indices_v,
const DenseTensor& value,
bool accumulate,
DenseTensor* out);
} // namespace phi
......@@ -198,6 +198,8 @@ from .tensor.manipulation import moveaxis # noqa: F401
from .tensor.manipulation import repeat_interleave # noqa: F401
from .tensor.manipulation import index_add # noqa: F401
from .tensor.manipulation import index_add_ # noqa: F401
from .tensor.manipulation import index_put # noqa: F401
from .tensor.manipulation import index_put_ # noqa: F401
from .tensor.manipulation import unflatten # noqa: F401
from .tensor.math import abs # noqa: F401
from .tensor.math import acos # noqa: F401
......@@ -684,6 +686,8 @@ __all__ = [ # noqa
'tril_indices',
'index_add',
"index_add_",
"index_put",
"index_put_",
'sgn',
'triu_indices',
'take',
......
......@@ -916,6 +916,7 @@ set_tests_properties(test_imperative_selected_rows_to_lod_tensor
PROPERTIES TIMEOUT 200)
set_tests_properties(test_index_select_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_index_add_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_index_put_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_tensordot PROPERTIES TIMEOUT 200)
set_tests_properties(test_partial_eager_deletion_transformer PROPERTIES TIMEOUT
120)
......
......@@ -135,6 +135,8 @@ from .manipulation import moveaxis # noqa: F401
from .manipulation import repeat_interleave # noqa: F401
from .manipulation import index_add # noqa: F401
from .manipulation import index_add_ # noqa: F401
from .manipulation import index_put # noqa: F401
from .manipulation import index_put_ # noqa: F401
from .manipulation import unflatten # noqa: F401
from .math import abs # noqa: F401
from .math import acos # noqa: F401
......@@ -534,6 +536,8 @@ tensor_method_func = [ # noqa
'heaviside',
'index_add',
"index_add_",
'index_put',
'index_put_',
'take',
'bucketize',
'sgn',
......
......@@ -4795,6 +4795,108 @@ def index_add_(x, index, axis, value, name=None):
return _C_ops.index_add_(x, index, value, axis)
@inplace_apis_in_dygraph_only
def index_put_(x, indices, value, accumulate=False, name=None):
"""
Puts values from the tensor values into the tensor x using the indices specified in indices (which is a tuple of Tensors).
The expression paddle.index_put_(x, indices, values) is equivalent to tensor[indices] = values. Returns x.
If accumulate is True, the elements in values are added to x. If accumulate is False, the behavior is undefined if indices contain duplicate elements.
Args:
x (Tensor) : The Source Tensor. Supported data types are int32, int64, float16, float32, float64, bool.
indices (Tuple of Tensor): The tuple of Tensor containing the indices to index.
The data type of ``tensor in indices`` must be int32, int64 or bool
value (Tensor): The tensor used to be assigned to x.
accummulate (Bool, optional): Whether the elements in values are added to x. Default: False.
name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
Tensor, same dimention and dtype with x.
Examples:
.. code-block:: python
import paddle
x = paddle.zeros([3, 3])
value = paddle.ones([3])
ix1 = paddle.to_tensor([0,1,2])
ix2 = paddle.to_tensor([1,2,1])
indices=(ix1,ix2)
out = paddle.index_put_(x,indices,value)
print(x)
# Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[0., 1., 0.],
# [0., 0., 1.],
# [0., 1., 0.]])
print(out)
# Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[0., 1., 0.],
# [0., 0., 1.],
# [0., 1., 0.]])
"""
return _C_ops.index_put_(x, indices, value, accumulate)
def index_put(x, indices, value, accumulate=False, name=None):
"""
Outplace version of ``index_put_`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_index_put`.
Examples:
.. code-block:: python
import paddle
x = paddle.zeros([3, 3])
value = paddle.ones([3])
ix1 = paddle.to_tensor([0,1,2])
ix2 = paddle.to_tensor([1,2,1])
indices=(ix1,ix2)
out = paddle.index_put(x,indices,value)
print(x)
# Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.]])
print(out)
# Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[0., 1., 0.],
# [0., 0., 1.],
# [0., 1., 0.]])
"""
if in_dygraph_mode():
return _C_ops.index_put(x, indices, value, accumulate)
helper = LayerHelper("index_put", **locals())
check_variable_and_dtype(
x,
'x',
['float16', 'float32', 'float64', 'int32', 'int64', 'bool'],
'paddle.tensor.manipulation.index_put',
)
check_variable_and_dtype(
value,
'value',
['float16', 'float32', 'float64', 'int32', 'int64', 'bool'],
'paddle.tensor.manipulation.index_put',
)
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='index_put',
inputs={
'x': x,
'indices': indices,
'value': value,
},
outputs={'out': out},
attrs={'accumulate': accumulate},
)
return out
def unflatten(x, axis, shape, name=None):
"""
Expand a certain dimension of the input x Tensor into a desired shape.
......@@ -4840,7 +4942,6 @@ def unflatten(x, axis, shape, name=None):
# determine whether the input axis is valid.
axis = non_negative_axis(x, axis)
if isinstance(shape, (list, tuple)):
new_shape = (
list(x.shape[:axis]) + list(shape) + list(x.shape[axis + 1 :])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册