Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4d16cd63
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4d16cd63
编写于
5月 10, 2023
作者:
傅
傅剑寒
提交者:
GitHub
5月 10, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[cherry pick] add index_put api (#53652)
This PR add index_put api for paddle
上级
1ab562ca
变更
16
展开全部
隐藏空白更改
内联
并排
Showing
16 changed file
with
2262 addition
and
0 deletion
+2262
-0
paddle/phi/api/yaml/backward.yaml
paddle/phi/api/yaml/backward.yaml
+11
-0
paddle/phi/api/yaml/ops.yaml
paddle/phi/api/yaml/ops.yaml
+10
-0
paddle/phi/infermeta/multiary.cc
paddle/phi/infermeta/multiary.cc
+15
-0
paddle/phi/infermeta/multiary.h
paddle/phi/infermeta/multiary.h
+6
-0
paddle/phi/kernels/cpu/index_put_grad_kernel.cc
paddle/phi/kernels/cpu/index_put_grad_kernel.cc
+225
-0
paddle/phi/kernels/cpu/index_put_kernel.cc
paddle/phi/kernels/cpu/index_put_kernel.cc
+166
-0
paddle/phi/kernels/funcs/index_put_utils.h
paddle/phi/kernels/funcs/index_put_utils.h
+348
-0
paddle/phi/kernels/gpu/index_put_grad_kernel.cu
paddle/phi/kernels/gpu/index_put_grad_kernel.cu
+287
-0
paddle/phi/kernels/gpu/index_put_kernel.cu
paddle/phi/kernels/gpu/index_put_kernel.cu
+198
-0
paddle/phi/kernels/index_put_grad_kernel.h
paddle/phi/kernels/index_put_grad_kernel.h
+30
-0
paddle/phi/kernels/index_put_kernel.h
paddle/phi/kernels/index_put_kernel.h
+29
-0
python/paddle/__init__.py
python/paddle/__init__.py
+4
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/test_index_put_op.py
python/paddle/fluid/tests/unittests/test_index_put_op.py
+826
-0
python/paddle/tensor/__init__.py
python/paddle/tensor/__init__.py
+4
-0
python/paddle/tensor/manipulation.py
python/paddle/tensor/manipulation.py
+102
-0
未找到文件。
paddle/phi/api/yaml/backward.yaml
浏览文件 @
4d16cd63
...
@@ -796,6 +796,17 @@
...
@@ -796,6 +796,17 @@
data_type
:
out_grad
data_type
:
out_grad
inplace
:
(out_grad -> x_grad)
inplace
:
(out_grad -> x_grad)
-
backward_op
:
index_put_grad
forward
:
index_put (Tensor x, Tensor[] indices, Tensor value, bool accumulate=false) -> Tensor(out)
args
:
(Tensor x, Tensor[] indices, Tensor value, Tensor out_grad, bool accumulate=false)
output
:
Tensor(x_grad), Tensor(value_grad)
infer_meta
:
func
:
GeneralBinaryGradInferMeta
param
:
[
x
,
value
]
kernel
:
func
:
index_put_grad
data_type
:
out_grad
-
backward_op
:
index_sample_grad
-
backward_op
:
index_sample_grad
forward
:
index_sample (Tensor x, Tensor index) -> Tensor(out)
forward
:
index_sample (Tensor x, Tensor index) -> Tensor(out)
args
:
(Tensor x, Tensor index, Tensor out_grad)
args
:
(Tensor x, Tensor index, Tensor out_grad)
...
...
paddle/phi/api/yaml/ops.yaml
浏览文件 @
4d16cd63
...
@@ -870,6 +870,16 @@
...
@@ -870,6 +870,16 @@
inplace
:
(x -> out)
inplace
:
(x -> out)
backward
:
index_add_grad
backward
:
index_add_grad
-
op
:
index_put
args
:
(Tensor x, Tensor[] indices, Tensor value, bool accumulate=false)
output
:
Tensor(out)
infer_meta
:
func
:
IndexPutInferMeta
kernel
:
func
:
index_put
inplace
:
(x -> out)
backward
:
index_put_grad
-
op
:
index_sample
-
op
:
index_sample
args
:
(Tensor x, Tensor index)
args
:
(Tensor x, Tensor index)
output
:
Tensor
output
:
Tensor
...
...
paddle/phi/infermeta/multiary.cc
浏览文件 @
4d16cd63
...
@@ -1961,6 +1961,21 @@ void InterpolateInferMeta(
...
@@ -1961,6 +1961,21 @@ void InterpolateInferMeta(
}
}
}
}
void
IndexPutInferMeta
(
const
MetaTensor
&
x
,
const
std
::
vector
<
const
MetaTensor
*>&
indices
,
const
MetaTensor
&
value
,
bool
accumulate
,
MetaTensor
*
out
)
{
auto
in_dims
=
x
.
dims
();
PADDLE_ENFORCE_LT
(
in_dims
.
size
(),
7
,
phi
::
errors
::
InvalidArgument
(
"The rank of input should be less than 7, but received %d."
,
in_dims
.
size
()));
out
->
share_meta
(
x
);
}
void
LambInferMeta
(
const
MetaTensor
&
param
,
void
LambInferMeta
(
const
MetaTensor
&
param
,
const
MetaTensor
&
grad
,
const
MetaTensor
&
grad
,
const
MetaTensor
&
learning_rate
,
const
MetaTensor
&
learning_rate
,
...
...
paddle/phi/infermeta/multiary.h
浏览文件 @
4d16cd63
...
@@ -333,6 +333,12 @@ void InterpolateInferMeta(
...
@@ -333,6 +333,12 @@ void InterpolateInferMeta(
MetaTensor
*
output
,
MetaTensor
*
output
,
MetaConfig
config
=
MetaConfig
());
MetaConfig
config
=
MetaConfig
());
void
IndexPutInferMeta
(
const
MetaTensor
&
x
,
const
std
::
vector
<
const
MetaTensor
*>&
indices
,
const
MetaTensor
&
value
,
bool
accumulate
,
MetaTensor
*
out
);
void
LambInferMeta
(
const
MetaTensor
&
param
,
void
LambInferMeta
(
const
MetaTensor
&
param
,
const
MetaTensor
&
grad
,
const
MetaTensor
&
grad
,
const
MetaTensor
&
learning_rate
,
const
MetaTensor
&
learning_rate
,
...
...
paddle/phi/kernels/cpu/index_put_grad_kernel.cc
0 → 100644
浏览文件 @
4d16cd63
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_put_grad_kernel.h"
#include <numeric>
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/index_put_utils.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
namespace
phi
{
template
<
typename
T
>
void
set_zero_kernel
(
const
int64_t
N
,
const
int64_t
**
indices
,
const
phi
::
DDim
&
stride
,
const
phi
::
DDim
&
shape
,
T
*
out
)
{
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for
(
int64_t
idx
=
0
;
idx
<
N
;
++
idx
)
{
int64_t
cur_ix
=
0
;
int64_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
cur_ix
=
(
static_cast
<
int64_t
>
(
*
(
indices
[
i
]
+
idx
)));
if
(
cur_ix
<
0
)
{
cur_ix
+=
shape
[
i
];
}
offset
+=
stride
[
i
]
*
cur_ix
;
}
*
(
out
+
offset
)
=
0
;
}
}
template
<
typename
T
>
void
index_put_grad_kernel
(
const
int64_t
N
,
const
T
*
out_grad
,
const
int64_t
**
indices
,
const
phi
::
DDim
&
stride
,
const
phi
::
DDim
&
shape
,
T
*
value_grad
)
{
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for
(
int64_t
idx
=
0
;
idx
<
N
;
++
idx
)
{
int64_t
cur_ix
=
0
;
int64_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
cur_ix
=
(
static_cast
<
int64_t
>
(
*
(
indices
[
i
]
+
idx
)));
if
(
cur_ix
<
0
)
{
cur_ix
+=
shape
[
i
];
}
offset
+=
stride
[
i
]
*
cur_ix
;
}
*
(
value_grad
+
idx
)
=
*
(
out_grad
+
offset
);
}
}
template
<
typename
T
,
typename
Context
>
void
LaunchIndexPutGradKernel
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
const
DenseTensor
*>&
indices
,
const
DenseTensor
&
out_grad
,
bool
accumulate
,
DenseTensor
*
value_grad
,
DenseTensor
*
x_grad
)
{
const
int64_t
*
pd_indices
[
7
];
for
(
size_t
i
=
0
;
i
<
indices
.
size
();
++
i
)
{
pd_indices
[
i
]
=
indices
[
i
]
->
data
<
int64_t
>
();
}
if
(
x_grad
)
{
phi
::
Copy
(
dev_ctx
,
out_grad
,
dev_ctx
.
GetPlace
(),
false
,
x_grad
);
if
(
!
accumulate
)
{
T
*
x_grad_data
=
x_grad
->
data
<
T
>
();
auto
x_grad_dims
=
x_grad
->
dims
();
const
int64_t
numel
=
indices
[
0
]
->
numel
();
auto
x_grad_stride
=
phi
::
stride
(
x_grad_dims
);
set_zero_kernel
<
T
>
(
numel
,
pd_indices
,
x_grad_stride
,
x_grad_dims
,
x_grad_data
);
}
}
auto
out_grad_dims
=
out_grad
.
dims
();
const
int64_t
numel
=
indices
[
0
]
->
numel
();
auto
out_grad_stride
=
phi
::
stride
(
out_grad_dims
);
if
(
value_grad
)
{
if
(
value_grad
->
numel
()
==
1
)
{
DenseTensor
tmp_value_grad
(
value_grad
->
dtype
());
tmp_value_grad
.
Resize
(
indices
[
0
]
->
dims
());
T
*
tmp_value_grad_data
=
dev_ctx
.
template
Alloc
<
T
>(
&
tmp_value_grad
);
auto
out_grad_data
=
out_grad
.
data
<
T
>
();
index_put_grad_kernel
<
T
>
(
numel
,
out_grad_data
,
pd_indices
,
out_grad_stride
,
out_grad_dims
,
tmp_value_grad_data
);
std
::
vector
<
int
>
v_dims
(
tmp_value_grad
.
dims
().
size
());
std
::
iota
(
v_dims
.
begin
(),
v_dims
.
end
(),
0
);
IntArray
v_axis
(
v_dims
);
SumKernel
<
T
>
(
dev_ctx
,
tmp_value_grad
,
v_axis
,
value_grad
->
dtype
(),
false
,
value_grad
);
}
else
if
(
value_grad
->
numel
()
==
indices
[
0
]
->
numel
())
{
T
*
value_grad_data
=
dev_ctx
.
template
Alloc
<
T
>(
value_grad
);
auto
out_grad_data
=
out_grad
.
data
<
T
>
();
index_put_grad_kernel
<
T
>
(
numel
,
out_grad_data
,
pd_indices
,
out_grad_stride
,
out_grad_dims
,
value_grad_data
);
}
else
{
DenseTensor
tmp_value_grad
(
value_grad
->
dtype
());
tmp_value_grad
.
Resize
(
indices
[
0
]
->
dims
());
T
*
tmp_value_grad_data
=
dev_ctx
.
template
Alloc
<
T
>(
&
tmp_value_grad
);
auto
out_grad_data
=
out_grad
.
data
<
T
>
();
index_put_grad_kernel
<
T
>
(
numel
,
out_grad_data
,
pd_indices
,
out_grad_stride
,
out_grad_dims
,
tmp_value_grad_data
);
std
::
vector
<
int64_t
>
after_dims
=
phi
::
vectorize
(
tmp_value_grad
.
dims
());
std
::
vector
<
int64_t
>
before_dims
=
phi
::
vectorize
(
value_grad
->
dims
());
std
::
vector
<
int64_t
>
compress_dims
;
std
::
vector
<
int64_t
>
dims_without_1
;
funcs
::
CalCompressedDimsWith1AndWithout1
(
&
after_dims
,
&
before_dims
,
&
compress_dims
,
&
dims_without_1
);
auto
pre_dims
=
value_grad
->
dims
();
value_grad
->
Resize
(
phi
::
make_ddim
(
dims_without_1
));
IntArray
v_axis
(
compress_dims
);
SumKernel
<
T
>
(
dev_ctx
,
tmp_value_grad
,
v_axis
,
value_grad
->
dtype
(),
false
,
value_grad
);
value_grad
->
Resize
(
pre_dims
);
}
}
}
template
<
typename
T
,
typename
Context
>
void
IndexPutGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
std
::
vector
<
const
DenseTensor
*>&
indices
,
const
DenseTensor
&
value
,
const
DenseTensor
&
out_grad
,
bool
accumulate
,
DenseTensor
*
x_grad
,
DenseTensor
*
value_grad
)
{
PADDLE_ENFORCE_EQ
(
x
.
dtype
(),
value
.
dtype
(),
phi
::
errors
::
InvalidArgument
(
"The data type of tensor in indices must be same to the data type "
"of tensor x."
));
std
::
vector
<
DenseTensor
>
tmp_args
;
std
::
vector
<
const
phi
::
DenseTensor
*>
int_indices_v
=
funcs
::
DealWithBoolIndices
<
T
,
Context
>
(
dev_ctx
,
indices
,
&
tmp_args
);
auto
bd_dim
=
funcs
::
BroadCastTensorsDims
(
int_indices_v
);
std
::
vector
<
int64_t
>
res_dim_v
(
phi
::
vectorize
(
bd_dim
));
std
::
vector
<
const
phi
::
DenseTensor
*>
res_indices_v
(
x
.
dims
().
size
(),
nullptr
);
std
::
vector
<
DenseTensor
>
tmp_res_indices_v
;
std
::
vector
<
DenseTensor
>
range_tensor_v
;
for
(
int
i
=
indices
.
size
();
i
<
x
.
dims
().
size
();
++
i
)
{
range_tensor_v
.
emplace_back
(
funcs
::
GetRangeTensor
<
int64_t
,
Context
>
(
dev_ctx
,
x
.
dims
()[
i
],
phi
::
DataType
::
INT64
));
}
funcs
::
DealWithIndices
<
T
,
Context
>
(
dev_ctx
,
x
,
int_indices_v
,
&
res_indices_v
,
&
tmp_res_indices_v
,
range_tensor_v
,
bd_dim
,
&
res_dim_v
);
LaunchIndexPutGradKernel
<
T
,
Context
>
(
dev_ctx
,
res_indices_v
,
out_grad
,
accumulate
,
value_grad
,
x_grad
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
index_put_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
IndexPutGradKernel
,
float
,
double
,
int
,
int64_t
,
bool
)
{}
paddle/phi/kernels/cpu/index_put_kernel.cc
0 → 100644
浏览文件 @
4d16cd63
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_put_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/index_put_utils.h"
namespace
phi
{
template
<
typename
T
>
void
index_put_kernel
(
const
int64_t
N
,
const
T
*
x
,
const
T
*
vals
,
const
int64_t
**
indices
,
const
phi
::
DDim
&
stride
,
const
phi
::
DDim
&
shape
,
int64_t
is_single_val_tensor
,
bool
accumulate
,
T
*
out
)
{
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for
(
int64_t
idx
=
0
;
idx
<
N
;
++
idx
)
{
int64_t
cur_ix
=
0
;
int64_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
cur_ix
=
(
static_cast
<
int64_t
>
(
*
(
indices
[
i
]
+
idx
)));
if
(
cur_ix
<
0
)
{
cur_ix
+=
shape
[
i
];
}
offset
+=
stride
[
i
]
*
cur_ix
;
}
if
(
accumulate
)
{
*
(
out
+
offset
)
+=
*
(
vals
+
(
idx
&
is_single_val_tensor
));
}
else
{
*
(
out
+
offset
)
=
*
(
vals
+
(
idx
&
is_single_val_tensor
));
}
}
}
template
<
typename
T
,
typename
Context
>
void
LaunchIndexPutKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
std
::
vector
<
const
DenseTensor
*>&
indices
,
const
DenseTensor
&
value
,
bool
accumulate
,
DenseTensor
*
out
)
{
auto
*
x_data
=
x
.
data
<
T
>
();
auto
*
val_data
=
value
.
data
<
T
>
();
bool
is_initialized
=
out
->
initialized
();
T
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
if
(
!
is_initialized
)
{
phi
::
Copy
(
dev_ctx
,
x
,
dev_ctx
.
GetPlace
(),
false
,
out
);
}
auto
x_dims
=
x
.
dims
();
const
int64_t
numel
=
indices
[
0
]
->
numel
();
auto
x_stride
=
phi
::
stride
(
x_dims
);
int64_t
is_single_val_tensor
=
(
value
.
numel
()
==
1
)
?
0
:
INT64_MAX
;
const
int64_t
*
pd_indices
[
7
];
for
(
size_t
i
=
0
;
i
<
indices
.
size
();
++
i
)
{
pd_indices
[
i
]
=
indices
[
i
]
->
data
<
int64_t
>
();
}
index_put_kernel
<
T
>
(
numel
,
x_data
,
val_data
,
pd_indices
,
x_stride
,
x_dims
,
is_single_val_tensor
,
accumulate
,
out_data
);
}
template
<
typename
T
,
typename
Context
>
void
IndexPutKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
std
::
vector
<
const
DenseTensor
*>&
indices
,
const
DenseTensor
&
value
,
bool
accumulate
,
DenseTensor
*
out
)
{
PADDLE_ENFORCE_EQ
(
x
.
dtype
(),
value
.
dtype
(),
phi
::
errors
::
InvalidArgument
(
"The data type of tensor in indices must be same to the data type "
"of tensor x."
));
PADDLE_ENFORCE_EQ
(
indices
.
empty
(),
false
,
phi
::
errors
::
InvalidArgument
(
"Indices cannot be empty."
));
const
size_t
total_dims
=
x
.
dims
().
size
();
PADDLE_ENFORCE_LE
(
total_dims
,
6
,
phi
::
errors
::
InvalidArgument
(
"Dims of input tensor should be less than 7."
));
std
::
vector
<
DenseTensor
>
tmp_args
;
std
::
vector
<
const
phi
::
DenseTensor
*>
int_indices_v
=
funcs
::
DealWithBoolIndices
<
T
,
Context
>
(
dev_ctx
,
indices
,
&
tmp_args
);
auto
bd_dim
=
funcs
::
BroadCastTensorsDims
(
int_indices_v
);
std
::
vector
<
int64_t
>
res_dim_v
(
phi
::
vectorize
(
bd_dim
));
std
::
vector
<
const
phi
::
DenseTensor
*>
res_indices_v
(
x
.
dims
().
size
(),
nullptr
);
std
::
vector
<
DenseTensor
>
tmp_res_indices_v
;
std
::
vector
<
DenseTensor
>
tmp_value_v
;
std
::
vector
<
DenseTensor
>
range_tensor_v
;
const
DenseTensor
*
ptr_value
=
nullptr
;
for
(
int
i
=
indices
.
size
();
i
<
x
.
dims
().
size
();
++
i
)
{
range_tensor_v
.
emplace_back
(
funcs
::
GetRangeTensor
<
int64_t
,
Context
>
(
dev_ctx
,
x
.
dims
()[
i
],
phi
::
DataType
::
INT64
));
}
funcs
::
DealWithIndices
<
T
,
Context
>
(
dev_ctx
,
x
,
int_indices_v
,
&
res_indices_v
,
&
tmp_res_indices_v
,
range_tensor_v
,
bd_dim
,
&
res_dim_v
);
if
(
value
.
numel
()
!=
1
)
{
tmp_value_v
.
emplace_back
(
DenseTensor
(
value
.
dtype
()).
Resize
(
phi
::
make_ddim
(
res_dim_v
)));
ExpandKernel
<
T
,
Context
>
(
dev_ctx
,
value
,
IntArray
(
res_dim_v
),
&
tmp_value_v
[
0
]);
ptr_value
=
&
tmp_value_v
[
0
];
}
else
{
ptr_value
=
&
value
;
}
LaunchIndexPutKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
res_indices_v
,
*
ptr_value
,
accumulate
,
out
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
index_put
,
CPU
,
ALL_LAYOUT
,
phi
::
IndexPutKernel
,
float
,
double
,
int
,
int64_t
,
bool
)
{}
paddle/phi/kernels/funcs/index_put_utils.h
0 → 100644
浏览文件 @
4d16cd63
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/utils/array.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/expand_kernel.h"
#include "paddle/phi/kernels/nonzero_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
#include "paddle/phi/kernels/split_kernel.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#ifdef __NVCC__
#include <cuda.h>
#include <cuda_runtime.h>
#elif defined(__HIPCC__)
#include <hip/hip_runtime.h>
#endif
#endif
namespace
phi
{
namespace
funcs
{
template
<
typename
T
,
typename
Context
>
phi
::
DenseTensor
GetReshapeAndExpandTensor
(
const
Context
&
dev_ctx
,
const
phi
::
DenseTensor
&
tensor
,
const
phi
::
DDim
&
res_dim
,
const
phi
::
DDim
&
bd_dim
,
int
index
)
{
std
::
vector
<
int64_t
>
before_dims
=
phi
::
vectorize
(
tensor
.
dims
());
std
::
vector
<
int64_t
>
mid_dims
(
res_dim
.
size
(),
1
);
if
(
index
==
0
)
{
for
(
size_t
i
=
0
;
i
<
before_dims
.
size
();
++
i
)
{
mid_dims
[
bd_dim
.
size
()
-
i
-
1
]
=
before_dims
[
before_dims
.
size
()
-
i
-
1
];
}
}
else
{
mid_dims
[
index
]
=
before_dims
[
0
];
}
phi
::
DenseTensor
mid_tensor
(
tensor
.
dtype
());
mid_tensor
.
Resize
(
phi
::
make_ddim
(
mid_dims
));
ReshapeInferKernel
<
Context
>
(
dev_ctx
,
tensor
,
IntArray
(
mid_dims
),
&
mid_tensor
);
phi
::
DenseTensor
res_tensor
(
tensor
.
dtype
());
res_tensor
.
Resize
(
res_dim
);
ExpandKernel
<
T
,
Context
>
(
dev_ctx
,
mid_tensor
,
IntArray
(
phi
::
vectorize
(
res_dim
)),
&
res_tensor
);
return
res_tensor
;
}
template
<
typename
T
,
typename
Context
>
std
::
vector
<
const
phi
::
DenseTensor
*>
DealWithBoolIndices
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
const
phi
::
DenseTensor
*>&
indices_v
,
std
::
vector
<
phi
::
DenseTensor
>*
tmp_indices_v
)
{
std
::
vector
<
const
phi
::
DenseTensor
*>
res
(
indices_v
.
begin
(),
indices_v
.
end
());
bool
contains_bool_tensor
=
false
;
for
(
size_t
i
=
0
;
i
<
indices_v
.
size
();
++
i
)
{
if
(
indices_v
[
i
]
->
dtype
()
==
phi
::
DataType
::
BOOL
)
{
contains_bool_tensor
=
true
;
}
else
if
((
indices_v
[
i
]
->
dtype
()
==
phi
::
DataType
::
INT64
)
||
(
indices_v
[
i
]
->
dtype
()
==
phi
::
DataType
::
INT32
))
{
PADDLE_ENFORCE_EQ
(
contains_bool_tensor
,
false
,
phi
::
errors
::
InvalidArgument
(
"indices contains bool tensor and int32/int64 tensor at the same "
"time"
));
}
else
{
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"data type of tensor in indices must be int32, int64 or bool"
));
}
}
if
(
contains_bool_tensor
)
{
if
(
indices_v
.
size
()
!=
1
)
{
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"the size of indices must be 1 when it containts bool tensor"
));
}
int
rank
=
indices_v
[
0
]
->
dims
().
size
();
PADDLE_ENFORCE_GE
(
rank
,
1UL
,
phi
::
errors
::
InvalidArgument
(
"the only bool tensor in indices should "
"have number of dimension at least 1"
));
phi
::
DenseTensor
nonzero_indices
(
phi
::
DataType
::
INT64
);
nonzero_indices
.
Resize
(
phi
::
make_ddim
({
-
1
,
rank
}));
NonZeroKernel
<
bool
,
Context
>
(
dev_ctx
,
*
indices_v
[
0
],
&
nonzero_indices
);
std
::
vector
<
phi
::
DenseTensor
*>
integer_indices
(
rank
,
nullptr
);
for
(
int
i
=
0
;
i
<
rank
;
++
i
)
{
tmp_indices_v
->
emplace_back
(
DenseTensor
(
phi
::
DataType
::
INT64
)
.
Resize
(
phi
::
make_ddim
({
nonzero_indices
.
dims
()[
0
]})));
}
for
(
int
i
=
0
;
i
<
rank
;
++
i
)
{
integer_indices
[
i
]
=
&
((
*
tmp_indices_v
)[
i
]);
}
SplitWithNumKernel
<
int64_t
,
Context
>
(
dev_ctx
,
nonzero_indices
,
rank
,
1
,
integer_indices
);
std
::
vector
<
const
phi
::
DenseTensor
*>
res_tmp
(
integer_indices
.
size
(),
nullptr
);
for
(
int
i
=
0
;
i
<
rank
;
++
i
)
{
res_tmp
[
i
]
=
&
((
*
tmp_indices_v
)[
i
]);
}
res
.
swap
(
res_tmp
);
}
return
res
;
}
static
phi
::
DDim
BroadCastTensorsDims
(
const
std
::
vector
<
const
phi
::
DenseTensor
*>&
tensors
)
{
int
target_rank
=
0
;
for
(
const
auto
&
tensor
:
tensors
)
{
target_rank
=
std
::
max
(
target_rank
,
tensor
->
dims
().
size
());
}
PADDLE_ENFORCE_GT
(
target_rank
,
0
,
errors
::
InvalidArgument
(
"BroadCastTensorsDims requires at "
"least one input tensor to have "
"rank greater than zero"
));
std
::
vector
<
int64_t
>
target_dims
(
target_rank
,
0
);
for
(
int
index
=
0
;
index
<
target_rank
;
index
++
)
{
int
target_dim_size
=
1
;
for
(
const
auto
&
tensor
:
tensors
)
{
auto
input_ddim
=
tensor
->
dims
();
int
axis
=
static_cast
<
int
>
(
input_ddim
.
size
())
-
index
-
1
;
int
dim_size
=
1
;
if
(
axis
>=
0
)
{
dim_size
=
input_ddim
[
axis
];
}
if
(
target_dim_size
!=
1
&&
dim_size
!=
1
&&
target_dim_size
!=
dim_size
)
{
PADDLE_THROW
(
errors
::
InvalidArgument
(
"BroadCastTensorsDims inputs does not satisfy bcast semantics, "
"please check axis = %d in reverse order"
,
index
));
}
target_dim_size
=
dim_size
==
1
?
target_dim_size
:
dim_size
;
}
target_dims
[
target_rank
-
index
-
1
]
=
target_dim_size
;
}
return
phi
::
make_ddim
(
target_dims
);
}
template
<
typename
T
,
typename
Context
>
T
**
GetDevicePointerArray
(
const
Context
&
ctx
,
const
std
::
vector
<
const
DenseTensor
*>&
indices_v
)
{
std
::
vector
<
const
T
*>
h_indices_v
(
indices_v
.
size
());
for
(
int
i
=
0
;
i
<
indices_v
.
size
();
++
i
)
{
h_indices_v
[
i
]
=
indices_v
[
i
]
->
data
<
T
>
();
}
auto
d_indices_data
=
phi
::
memory_utils
::
Alloc
(
ctx
.
GetPlace
(),
h_indices_v
.
size
()
*
sizeof
(
T
*
),
phi
::
Stream
(
reinterpret_cast
<
phi
::
StreamId
>
(
ctx
.
stream
())));
phi
::
memory_utils
::
Copy
(
ctx
.
GetPlace
(),
d_indices_data
->
ptr
(),
phi
::
CPUPlace
(),
reinterpret_cast
<
void
*>
(
h_indices_v
.
data
()),
h_indices_v
.
size
()
*
sizeof
(
T
*
),
ctx
.
stream
());
return
reinterpret_cast
<
T
**>
(
d_indices_data
->
ptr
());
}
template
<
typename
T
,
typename
Context
>
void
DealWithIndices
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
std
::
vector
<
const
phi
::
DenseTensor
*>&
int_indices_v
,
std
::
vector
<
const
phi
::
DenseTensor
*>*
res_indices_v
,
std
::
vector
<
DenseTensor
>*
tmp_res_indices_v
,
const
std
::
vector
<
DenseTensor
>&
range_tensor_v
,
const
phi
::
DDim
&
bd_dim
,
std
::
vector
<
int64_t
>*
res_dim_v
)
{
size_t
total_dims
=
x
.
dims
().
size
();
if
(
int_indices_v
.
size
()
<
total_dims
)
{
std
::
vector
<
int64_t
>
tmp_x_dims
=
phi
::
vectorize
(
x
.
dims
());
int
len_bd_dim
=
bd_dim
.
size
();
res_dim_v
->
insert
(
res_dim_v
->
end
(),
tmp_x_dims
.
begin
()
+
int_indices_v
.
size
(),
tmp_x_dims
.
end
());
std
::
vector
<
DenseTensor
>
reshaped_indices_v
;
for
(
size_t
i
=
0
;
i
<
int_indices_v
.
size
();
++
i
)
{
if
(
int_indices_v
[
i
]
->
dtype
()
==
phi
::
DataType
::
INT32
)
{
reshaped_indices_v
.
emplace_back
(
phi
::
Cast
<
int
,
Context
>
(
dev_ctx
,
*
int_indices_v
[
i
],
phi
::
DataType
::
INT64
));
}
else
{
reshaped_indices_v
.
emplace_back
(
*
int_indices_v
[
i
]);
}
}
reshaped_indices_v
.
insert
(
reshaped_indices_v
.
end
(),
range_tensor_v
.
begin
(),
range_tensor_v
.
end
());
phi
::
DDim
res_dim
=
phi
::
make_ddim
(
*
res_dim_v
);
for
(
size_t
i
=
0
;
i
<
reshaped_indices_v
.
size
();
++
i
)
{
tmp_res_indices_v
->
emplace_back
(
GetReshapeAndExpandTensor
<
int64_t
,
Context
>
(
dev_ctx
,
reshaped_indices_v
[
i
],
res_dim
,
bd_dim
,
((
i
<
int_indices_v
.
size
())
?
0
:
i
-
int_indices_v
.
size
()
+
len_bd_dim
)));
}
for
(
size_t
i
=
0
;
i
<
res_indices_v
->
size
();
++
i
)
{
(
*
res_indices_v
)[
i
]
=
&
(
*
tmp_res_indices_v
)[
i
];
}
}
else
{
std
::
vector
<
DenseTensor
>
int_indices_v_tmp
;
for
(
size_t
i
=
0
;
i
<
int_indices_v
.
size
();
++
i
)
{
if
(
int_indices_v
[
i
]
->
dtype
()
==
phi
::
DataType
::
INT32
)
{
int_indices_v_tmp
.
emplace_back
(
phi
::
Cast
<
int
,
Context
>
(
dev_ctx
,
*
int_indices_v
[
i
],
phi
::
DataType
::
INT64
));
}
else
{
int_indices_v_tmp
.
emplace_back
(
*
int_indices_v
[
i
]);
}
}
for
(
size_t
i
=
0
;
i
<
int_indices_v
.
size
();
++
i
)
{
if
(
bd_dim
!=
int_indices_v
[
i
]
->
dims
())
{
tmp_res_indices_v
->
emplace_back
(
DenseTensor
(
phi
::
DataType
::
INT64
).
Resize
(
bd_dim
));
ExpandKernel
<
int64_t
,
Context
>
(
dev_ctx
,
int_indices_v_tmp
[
i
],
IntArray
(
phi
::
vectorize
<
int64_t
>
(
bd_dim
)),
&
(
*
tmp_res_indices_v
)[
i
]);
}
else
{
tmp_res_indices_v
->
emplace_back
(
int_indices_v_tmp
[
i
]);
}
}
for
(
size_t
i
=
0
;
i
<
res_indices_v
->
size
();
++
i
)
{
(
*
res_indices_v
)[
i
]
=
&
(
*
tmp_res_indices_v
)[
i
];
}
}
}
static
void
CalCompressedDimsWith1AndWithout1
(
std
::
vector
<
int64_t
>*
after_dims
,
std
::
vector
<
int64_t
>*
before_dims
,
std
::
vector
<
int64_t
>*
compress_dims
,
std
::
vector
<
int64_t
>*
dims_without_1
)
{
int
i
=
static_cast
<
int
>
(
after_dims
->
size
())
-
1
;
int
j
=
static_cast
<
int
>
(
before_dims
->
size
())
-
1
;
if
(
i
<
j
)
{
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"shape of value can't not be broadcast to shape of x[indices]"
));
}
while
((
i
>=
0
)
&&
(
j
>=
0
))
{
if
((
*
after_dims
)[
i
]
==
(
*
before_dims
)[
j
])
{
dims_without_1
->
push_back
((
*
before_dims
)[
j
]);
i
--
;
j
--
;
continue
;
}
else
if
((
*
before_dims
)[
j
]
==
1
)
{
compress_dims
->
push_back
(
i
);
i
--
;
j
--
;
}
else
{
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"shape of value can't not be broadcast to shape of x[indices]"
));
}
}
while
(
i
>=
0
)
{
compress_dims
->
push_back
(
i
);
i
--
;
}
}
#if defined(__NVCC__) || defined(__HIPCC__)
template
<
typename
T
>
__global__
void
range_cuda_kernel
(
int64_t
N
,
T
*
out
)
{
int64_t
idx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
if
(
idx
>=
N
)
{
return
;
}
out
[
idx
]
=
idx
;
}
template
<
typename
T
,
typename
Context
>
phi
::
DenseTensor
GetRangeCudaTensor
(
const
Context
&
dev_ctx
,
int64_t
N
,
phi
::
DataType
dtype
)
{
phi
::
DenseTensor
res
(
dtype
);
res
.
Resize
(
phi
::
make_ddim
({
N
}));
DenseTensor
*
p_res
=
&
res
;
T
*
out
=
dev_ctx
.
template
Alloc
<
T
>(
p_res
);
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
N
);
range_cuda_kernel
<
T
>
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
N
,
out
);
return
res
;
}
#endif
template
<
typename
T
>
void
range_kernel
(
int64_t
N
,
T
*
out
)
{
for
(
int64_t
idx
=
0
;
idx
<
N
;
++
idx
)
{
out
[
idx
]
=
idx
;
}
}
template
<
typename
T
,
typename
Context
>
phi
::
DenseTensor
GetRangeTensor
(
const
Context
&
dev_ctx
,
int64_t
N
,
phi
::
DataType
dtype
)
{
phi
::
DenseTensor
res
(
dtype
);
res
.
Resize
(
phi
::
make_ddim
({
N
}));
DenseTensor
*
p_res
=
&
res
;
T
*
out
=
dev_ctx
.
template
Alloc
<
T
>(
p_res
);
range_kernel
<
T
>
(
N
,
out
);
return
res
;
}
}
// namespace funcs
}
// namespace phi
paddle/phi/kernels/gpu/index_put_grad_kernel.cu
0 → 100644
浏览文件 @
4d16cd63
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_put_grad_kernel.h"
#include <numeric>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/index_put_utils.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
namespace
phi
{
template
<
typename
T
,
size_t
Rank
>
__global__
void
set_zero_cuda_kernel
(
const
int64_t
N
,
int64_t
**
indices
,
phi
::
Array
<
int64_t
,
Rank
>
stride
,
phi
::
Array
<
int64_t
,
Rank
>
shape
,
T
*
out
)
{
int64_t
idx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
int64_t
cur_ix
=
0
;
if
(
idx
>=
N
)
{
return
;
}
int64_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
Rank
;
++
i
)
{
cur_ix
=
(
static_cast
<
int64_t
>
(
*
(
indices
[
i
]
+
idx
)));
if
(
cur_ix
<
0
)
{
cur_ix
+=
shape
[
i
];
}
offset
+=
stride
[
i
]
*
cur_ix
;
}
*
(
out
+
offset
)
=
0
;
}
template
<
typename
T
,
size_t
Rank
>
__global__
void
index_put_grad_cuda_kernel
(
const
int64_t
N
,
const
T
*
out_grad
,
int64_t
**
indices
,
phi
::
Array
<
int64_t
,
Rank
>
stride
,
phi
::
Array
<
int64_t
,
Rank
>
shape
,
T
*
value_grad
)
{
int64_t
idx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
int64_t
cur_ix
=
0
;
if
(
idx
>=
N
)
{
return
;
}
int64_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
Rank
;
++
i
)
{
cur_ix
=
(
static_cast
<
int64_t
>
(
*
(
indices
[
i
]
+
idx
)));
if
(
cur_ix
<
0
)
{
cur_ix
+=
shape
[
i
];
}
offset
+=
stride
[
i
]
*
cur_ix
;
}
*
(
value_grad
+
idx
)
=
*
(
out_grad
+
offset
);
}
template
<
typename
T
,
typename
Context
,
size_t
Rank
>
void
LaunchIndexPutGradCudaKernel
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
const
DenseTensor
*>&
indices
,
const
DenseTensor
&
out_grad
,
bool
accumulate
,
DenseTensor
*
value_grad
,
DenseTensor
*
x_grad
)
{
if
(
x_grad
)
{
phi
::
Copy
(
dev_ctx
,
out_grad
,
dev_ctx
.
GetPlace
(),
false
,
x_grad
);
if
(
!
accumulate
)
{
T
*
x_grad_data
=
x_grad
->
data
<
T
>
();
auto
x_grad_dims
=
x_grad
->
dims
();
const
int64_t
numel
=
indices
[
0
]
->
numel
();
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
numel
);
auto
x_grad_stride
=
phi
::
stride
(
x_grad_dims
);
phi
::
Array
<
int64_t
,
Rank
>
stride_a
;
phi
::
Array
<
int64_t
,
Rank
>
shape_a
;
for
(
size_t
idx
=
0
;
idx
<
Rank
;
++
idx
)
{
stride_a
[
idx
]
=
x_grad_stride
[
idx
];
shape_a
[
idx
]
=
x_grad_dims
[
idx
];
}
auto
pd_indices
=
funcs
::
GetDevicePointerArray
<
int64_t
,
Context
>
(
dev_ctx
,
indices
);
set_zero_cuda_kernel
<
T
,
Rank
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
numel
,
pd_indices
,
stride_a
,
shape_a
,
x_grad_data
);
}
}
auto
out_grad_dims
=
out_grad
.
dims
();
const
int64_t
numel
=
indices
[
0
]
->
numel
();
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
numel
);
auto
out_grad_stride
=
phi
::
stride
(
out_grad_dims
);
phi
::
Array
<
int64_t
,
Rank
>
stride_a
;
phi
::
Array
<
int64_t
,
Rank
>
shape_a
;
for
(
size_t
idx
=
0
;
idx
<
Rank
;
++
idx
)
{
stride_a
[
idx
]
=
out_grad_stride
[
idx
];
shape_a
[
idx
]
=
out_grad_dims
[
idx
];
}
auto
pd_indices
=
funcs
::
GetDevicePointerArray
<
int64_t
,
Context
>
(
dev_ctx
,
indices
);
if
(
value_grad
)
{
if
(
value_grad
->
numel
()
==
1
)
{
DenseTensor
tmp_value_grad
(
value_grad
->
dtype
());
tmp_value_grad
.
Resize
(
indices
[
0
]
->
dims
());
T
*
tmp_value_grad_data
=
dev_ctx
.
template
Alloc
<
T
>(
&
tmp_value_grad
);
auto
out_grad_data
=
out_grad
.
data
<
T
>
();
index_put_grad_cuda_kernel
<
T
,
Rank
>
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
numel
,
out_grad_data
,
pd_indices
,
stride_a
,
shape_a
,
tmp_value_grad_data
);
std
::
vector
<
int
>
v_dims
(
tmp_value_grad
.
dims
().
size
());
std
::
iota
(
v_dims
.
begin
(),
v_dims
.
end
(),
0
);
IntArray
v_axis
(
v_dims
);
SumKernel
<
T
,
Context
>
(
dev_ctx
,
tmp_value_grad
,
v_axis
,
value_grad
->
dtype
(),
false
,
value_grad
);
}
else
if
(
value_grad
->
numel
()
==
indices
[
0
]
->
numel
())
{
T
*
value_grad_data
=
dev_ctx
.
template
Alloc
<
T
>(
value_grad
);
auto
out_grad_data
=
out_grad
.
data
<
T
>
();
index_put_grad_cuda_kernel
<
T
,
Rank
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
numel
,
out_grad_data
,
pd_indices
,
stride_a
,
shape_a
,
value_grad_data
);
}
else
{
DenseTensor
tmp_value_grad
(
value_grad
->
dtype
());
tmp_value_grad
.
Resize
(
indices
[
0
]
->
dims
());
T
*
tmp_value_grad_data
=
dev_ctx
.
template
Alloc
<
T
>(
&
tmp_value_grad
);
auto
out_grad_data
=
out_grad
.
data
<
T
>
();
index_put_grad_cuda_kernel
<
T
,
Rank
>
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
numel
,
out_grad_data
,
pd_indices
,
stride_a
,
shape_a
,
tmp_value_grad_data
);
std
::
vector
<
int64_t
>
after_dims
=
phi
::
vectorize
(
tmp_value_grad
.
dims
());
std
::
vector
<
int64_t
>
before_dims
=
phi
::
vectorize
(
value_grad
->
dims
());
std
::
vector
<
int64_t
>
compress_dims
;
std
::
vector
<
int64_t
>
dims_without_1
;
funcs
::
CalCompressedDimsWith1AndWithout1
(
&
after_dims
,
&
before_dims
,
&
compress_dims
,
&
dims_without_1
);
auto
pre_dims
=
value_grad
->
dims
();
value_grad
->
Resize
(
phi
::
make_ddim
(
dims_without_1
));
IntArray
v_axis
(
compress_dims
);
SumKernel
<
T
,
Context
>
(
dev_ctx
,
tmp_value_grad
,
v_axis
,
value_grad
->
dtype
(),
false
,
value_grad
);
value_grad
->
Resize
(
pre_dims
);
}
}
}
template
<
typename
T
,
typename
Context
>
void
IndexPutGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
std
::
vector
<
const
DenseTensor
*>&
indices
,
const
DenseTensor
&
value
,
const
DenseTensor
&
out_grad
,
bool
accumulate
,
DenseTensor
*
x_grad
,
DenseTensor
*
value_grad
)
{
PADDLE_ENFORCE_EQ
(
x
.
dtype
(),
value
.
dtype
(),
phi
::
errors
::
InvalidArgument
(
"The data type of tensor in indices must be same to the data type "
"of tensor x."
));
std
::
vector
<
DenseTensor
>
tmp_args
;
std
::
vector
<
const
phi
::
DenseTensor
*>
int_indices_v
=
funcs
::
DealWithBoolIndices
<
T
,
Context
>
(
dev_ctx
,
indices
,
&
tmp_args
);
const
size_t
total_dims
=
x
.
dims
().
size
();
auto
bd_dim
=
funcs
::
BroadCastTensorsDims
(
int_indices_v
);
std
::
vector
<
int64_t
>
res_dim_v
(
phi
::
vectorize
(
bd_dim
));
std
::
vector
<
const
phi
::
DenseTensor
*>
res_indices_v
(
x
.
dims
().
size
(),
nullptr
);
std
::
vector
<
DenseTensor
>
tmp_res_indices_v
;
std
::
vector
<
DenseTensor
>
range_tensor_v
;
for
(
int
i
=
indices
.
size
();
i
<
x
.
dims
().
size
();
++
i
)
{
range_tensor_v
.
emplace_back
(
funcs
::
GetRangeCudaTensor
<
int64_t
,
Context
>
(
dev_ctx
,
x
.
dims
()[
i
],
phi
::
DataType
::
INT64
));
}
funcs
::
DealWithIndices
<
T
,
Context
>
(
dev_ctx
,
x
,
int_indices_v
,
&
res_indices_v
,
&
tmp_res_indices_v
,
range_tensor_v
,
bd_dim
,
&
res_dim_v
);
switch
(
total_dims
)
{
case
1
:
LaunchIndexPutGradCudaKernel
<
T
,
Context
,
1
>
(
dev_ctx
,
res_indices_v
,
out_grad
,
accumulate
,
value_grad
,
x_grad
);
break
;
case
2
:
LaunchIndexPutGradCudaKernel
<
T
,
Context
,
2
>
(
dev_ctx
,
res_indices_v
,
out_grad
,
accumulate
,
value_grad
,
x_grad
);
break
;
case
3
:
LaunchIndexPutGradCudaKernel
<
T
,
Context
,
3
>
(
dev_ctx
,
res_indices_v
,
out_grad
,
accumulate
,
value_grad
,
x_grad
);
break
;
case
4
:
LaunchIndexPutGradCudaKernel
<
T
,
Context
,
4
>
(
dev_ctx
,
res_indices_v
,
out_grad
,
accumulate
,
value_grad
,
x_grad
);
break
;
case
5
:
LaunchIndexPutGradCudaKernel
<
T
,
Context
,
5
>
(
dev_ctx
,
res_indices_v
,
out_grad
,
accumulate
,
value_grad
,
x_grad
);
break
;
case
6
:
LaunchIndexPutGradCudaKernel
<
T
,
Context
,
6
>
(
dev_ctx
,
res_indices_v
,
out_grad
,
accumulate
,
value_grad
,
x_grad
);
break
;
default:
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"dims of input tensor should be less than 7, But received"
"%d"
,
x
.
dims
().
size
()));
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
index_put_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
IndexPutGradKernel
,
float
,
double
,
int
,
int64_t
,
bool
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/gpu/index_put_kernel.cu
0 → 100644
浏览文件 @
4d16cd63
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_put_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/index_put_utils.h"
namespace
phi
{
template
<
typename
T
,
size_t
Rank
>
__global__
void
index_put_cuda_kernel
(
const
int64_t
N
,
const
T
*
x
,
const
T
*
vals
,
int64_t
**
indices
,
phi
::
Array
<
int64_t
,
Rank
>
stride
,
phi
::
Array
<
int64_t
,
Rank
>
shape
,
int64_t
is_single_val_tensor
,
bool
accumulate
,
T
*
out
)
{
int64_t
idx
=
threadIdx
.
x
+
blockDim
.
x
*
blockIdx
.
x
;
int64_t
cur_ix
=
0
;
if
(
idx
>=
N
)
{
return
;
}
int64_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
Rank
;
++
i
)
{
cur_ix
=
(
static_cast
<
int64_t
>
(
*
(
indices
[
i
]
+
idx
)));
if
(
cur_ix
<
0
)
{
cur_ix
+=
shape
[
i
];
}
offset
+=
stride
[
i
]
*
cur_ix
;
}
if
(
accumulate
)
{
*
(
out
+
offset
)
+=
*
(
vals
+
(
idx
&
is_single_val_tensor
));
}
else
{
*
(
out
+
offset
)
=
*
(
vals
+
(
idx
&
is_single_val_tensor
));
}
}
template
<
typename
T
,
typename
Context
,
size_t
Rank
>
void
LaunchIndexPutCudaKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
std
::
vector
<
const
DenseTensor
*>&
indices
,
const
DenseTensor
&
value
,
bool
accumulate
,
DenseTensor
*
out
)
{
auto
*
x_data
=
x
.
data
<
T
>
();
auto
*
val_data
=
value
.
data
<
T
>
();
bool
is_initialized
=
out
->
initialized
();
T
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
if
(
!
is_initialized
)
{
phi
::
Copy
(
dev_ctx
,
x
,
dev_ctx
.
GetPlace
(),
false
,
out
);
}
auto
x_dims
=
x
.
dims
();
const
int64_t
numel
=
indices
[
0
]
->
numel
();
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
numel
);
auto
x_stride
=
phi
::
stride
(
x_dims
);
phi
::
Array
<
int64_t
,
Rank
>
stride_a
;
phi
::
Array
<
int64_t
,
Rank
>
shape_a
;
for
(
size_t
idx
=
0
;
idx
<
Rank
;
++
idx
)
{
stride_a
[
idx
]
=
x_stride
[
idx
];
shape_a
[
idx
]
=
x_dims
[
idx
];
}
int64_t
is_single_val_tensor
=
(
value
.
numel
()
==
1
)
?
0
:
INT64_MAX
;
auto
pd_indices
=
funcs
::
GetDevicePointerArray
<
int64_t
,
Context
>
(
dev_ctx
,
indices
);
index_put_cuda_kernel
<
T
,
Rank
>
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
numel
,
x_data
,
val_data
,
pd_indices
,
stride_a
,
shape_a
,
is_single_val_tensor
,
accumulate
,
out_data
);
}
template
<
typename
T
,
typename
Context
>
void
IndexPutKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
std
::
vector
<
const
DenseTensor
*>&
indices
,
const
DenseTensor
&
value
,
bool
accumulate
,
DenseTensor
*
out
)
{
PADDLE_ENFORCE_EQ
(
x
.
dtype
(),
value
.
dtype
(),
phi
::
errors
::
InvalidArgument
(
"The data type of tensor in indices must be same to the data type "
"of tensor x."
));
PADDLE_ENFORCE_EQ
(
indices
.
empty
(),
false
,
phi
::
errors
::
InvalidArgument
(
"Indices cannot be empty."
));
std
::
vector
<
DenseTensor
>
tmp_args
;
std
::
vector
<
const
phi
::
DenseTensor
*>
int_indices_v
=
funcs
::
DealWithBoolIndices
<
T
,
Context
>
(
dev_ctx
,
indices
,
&
tmp_args
);
const
size_t
total_dims
=
x
.
dims
().
size
();
auto
bd_dim
=
funcs
::
BroadCastTensorsDims
(
int_indices_v
);
std
::
vector
<
int64_t
>
res_dim_v
(
phi
::
vectorize
(
bd_dim
));
std
::
vector
<
const
phi
::
DenseTensor
*>
res_indices_v
(
x
.
dims
().
size
(),
nullptr
);
std
::
vector
<
DenseTensor
>
tmp_res_indices_v
;
std
::
vector
<
DenseTensor
>
tmp_value_v
;
std
::
vector
<
DenseTensor
>
range_tensor_v
;
const
DenseTensor
*
ptr_value
=
nullptr
;
for
(
int
i
=
indices
.
size
();
i
<
x
.
dims
().
size
();
++
i
)
{
range_tensor_v
.
emplace_back
(
funcs
::
GetRangeCudaTensor
<
int64_t
,
Context
>
(
dev_ctx
,
x
.
dims
()[
i
],
phi
::
DataType
::
INT64
));
}
funcs
::
DealWithIndices
<
T
,
Context
>
(
dev_ctx
,
x
,
int_indices_v
,
&
res_indices_v
,
&
tmp_res_indices_v
,
range_tensor_v
,
bd_dim
,
&
res_dim_v
);
if
(
value
.
numel
()
!=
1
)
{
tmp_value_v
.
emplace_back
(
DenseTensor
(
value
.
dtype
()).
Resize
(
phi
::
make_ddim
(
res_dim_v
)));
ExpandKernel
<
T
,
Context
>
(
dev_ctx
,
value
,
IntArray
(
res_dim_v
),
&
tmp_value_v
[
0
]);
ptr_value
=
&
tmp_value_v
[
0
];
}
else
{
ptr_value
=
&
value
;
}
switch
(
total_dims
)
{
case
1
:
LaunchIndexPutCudaKernel
<
T
,
Context
,
1
>
(
dev_ctx
,
x
,
res_indices_v
,
*
ptr_value
,
accumulate
,
out
);
break
;
case
2
:
LaunchIndexPutCudaKernel
<
T
,
Context
,
2
>
(
dev_ctx
,
x
,
res_indices_v
,
*
ptr_value
,
accumulate
,
out
);
break
;
case
3
:
LaunchIndexPutCudaKernel
<
T
,
Context
,
3
>
(
dev_ctx
,
x
,
res_indices_v
,
*
ptr_value
,
accumulate
,
out
);
break
;
case
4
:
LaunchIndexPutCudaKernel
<
T
,
Context
,
4
>
(
dev_ctx
,
x
,
res_indices_v
,
*
ptr_value
,
accumulate
,
out
);
break
;
case
5
:
LaunchIndexPutCudaKernel
<
T
,
Context
,
5
>
(
dev_ctx
,
x
,
res_indices_v
,
*
ptr_value
,
accumulate
,
out
);
break
;
case
6
:
LaunchIndexPutCudaKernel
<
T
,
Context
,
6
>
(
dev_ctx
,
x
,
res_indices_v
,
*
ptr_value
,
accumulate
,
out
);
break
;
default:
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"dims of input tensor should be less than 7, But received"
"%d"
,
x
.
dims
().
size
()));
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
index_put
,
GPU
,
ALL_LAYOUT
,
phi
::
IndexPutKernel
,
float
,
double
,
int
,
int64_t
,
bool
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/index_put_grad_kernel.h
0 → 100644
浏览文件 @
4d16cd63
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
IndexPutGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
std
::
vector
<
const
DenseTensor
*>&
indices_v
,
const
DenseTensor
&
value
,
const
DenseTensor
&
out_grad
,
bool
accumulate
,
DenseTensor
*
x_grad
,
DenseTensor
*
value_grad
);
}
// namespace phi
paddle/phi/kernels/index_put_kernel.h
0 → 100644
浏览文件 @
4d16cd63
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
IndexPutKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
std
::
vector
<
const
DenseTensor
*>&
indices_v
,
const
DenseTensor
&
value
,
bool
accumulate
,
DenseTensor
*
out
);
}
// namespace phi
python/paddle/__init__.py
浏览文件 @
4d16cd63
...
@@ -194,6 +194,8 @@ from .tensor.manipulation import moveaxis # noqa: F401
...
@@ -194,6 +194,8 @@ from .tensor.manipulation import moveaxis # noqa: F401
from
.tensor.manipulation
import
repeat_interleave
# noqa: F401
from
.tensor.manipulation
import
repeat_interleave
# noqa: F401
from
.tensor.manipulation
import
index_add
# noqa: F401
from
.tensor.manipulation
import
index_add
# noqa: F401
from
.tensor.manipulation
import
index_add_
# noqa: F401
from
.tensor.manipulation
import
index_add_
# noqa: F401
from
.tensor.manipulation
import
index_put
# noqa: F401
from
.tensor.manipulation
import
index_put_
# noqa: F401
from
.tensor.math
import
abs
# noqa: F401
from
.tensor.math
import
abs
# noqa: F401
from
.tensor.math
import
acos
# noqa: F401
from
.tensor.math
import
acos
# noqa: F401
from
.tensor.math
import
asin
# noqa: F401
from
.tensor.math
import
asin
# noqa: F401
...
@@ -676,6 +678,8 @@ __all__ = [ # noqa
...
@@ -676,6 +678,8 @@ __all__ = [ # noqa
'tril_indices'
,
'tril_indices'
,
'index_add'
,
'index_add'
,
"index_add_"
,
"index_add_"
,
'index_put'
,
'index_put_'
,
'sgn'
,
'sgn'
,
'triu_indices'
,
'triu_indices'
,
'take'
,
'take'
,
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
4d16cd63
...
@@ -924,6 +924,7 @@ set_tests_properties(test_imperative_selected_rows_to_lod_tensor
...
@@ -924,6 +924,7 @@ set_tests_properties(test_imperative_selected_rows_to_lod_tensor
PROPERTIES TIMEOUT 200
)
PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_index_select_op PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_index_select_op PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_index_add_op PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_index_add_op PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_index_put_op PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_tensordot PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_tensordot PROPERTIES TIMEOUT 200
)
set_tests_properties
(
test_partial_eager_deletion_transformer PROPERTIES TIMEOUT
set_tests_properties
(
test_partial_eager_deletion_transformer PROPERTIES TIMEOUT
120
)
120
)
...
...
python/paddle/fluid/tests/unittests/test_index_put_op.py
0 → 100644
浏览文件 @
4d16cd63
此差异已折叠。
点击以展开。
python/paddle/tensor/__init__.py
浏览文件 @
4d16cd63
...
@@ -135,6 +135,8 @@ from .manipulation import moveaxis # noqa: F401
...
@@ -135,6 +135,8 @@ from .manipulation import moveaxis # noqa: F401
from
.manipulation
import
repeat_interleave
# noqa: F401
from
.manipulation
import
repeat_interleave
# noqa: F401
from
.manipulation
import
index_add
# noqa: F401
from
.manipulation
import
index_add
# noqa: F401
from
.manipulation
import
index_add_
# noqa: F401
from
.manipulation
import
index_add_
# noqa: F401
from
.manipulation
import
index_put
# noqa: F401
from
.manipulation
import
index_put_
# noqa: F401
from
.math
import
abs
# noqa: F401
from
.math
import
abs
# noqa: F401
from
.math
import
acos
# noqa: F401
from
.math
import
acos
# noqa: F401
from
.math
import
asin
# noqa: F401
from
.math
import
asin
# noqa: F401
...
@@ -530,6 +532,8 @@ tensor_method_func = [ # noqa
...
@@ -530,6 +532,8 @@ tensor_method_func = [ # noqa
'heaviside'
,
'heaviside'
,
'index_add'
,
'index_add'
,
"index_add_"
,
"index_add_"
,
'index_put'
,
'index_put_'
,
'take'
,
'take'
,
'bucketize'
,
'bucketize'
,
'sgn'
,
'sgn'
,
...
...
python/paddle/tensor/manipulation.py
浏览文件 @
4d16cd63
...
@@ -4794,6 +4794,108 @@ def index_add_(x, index, axis, value, name=None):
...
@@ -4794,6 +4794,108 @@ def index_add_(x, index, axis, value, name=None):
return
_C_ops
.
index_add_
(
x
,
index
,
value
,
axis
)
return
_C_ops
.
index_add_
(
x
,
index
,
value
,
axis
)
@
inplace_apis_in_dygraph_only
def
index_put_
(
x
,
indices
,
value
,
accumulate
=
False
,
name
=
None
):
"""
Puts values from the tensor values into the tensor x using the indices specified in indices (which is a tuple of Tensors).
The expression paddle.index_put_(x, indices, values) is equivalent to tensor[indices] = values. Returns x.
If accumulate is True, the elements in values are added to x. If accumulate is False, the behavior is undefined if indices contain duplicate elements.
Args:
x (Tensor) : The Source Tensor. Supported data types are int32, int64, float16, float32, float64, bool.
indices (Tuple of Tensor): The tuple of Tensor containing the indices to index.
The data type of ``tensor in indices`` must be int32, int64 or bool
value (Tensor): The tensor used to be assigned to x.
accummulate (Bool, optional): Whether the elements in values are added to x. Default: False.
name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
Tensor, same dimention and dtype with x.
Examples:
.. code-block:: python
import paddle
x = paddle.zeros([3, 3])
value = paddle.ones([3])
ix1 = paddle.to_tensor([0,1,2])
ix2 = paddle.to_tensor([1,2,1])
indices=(ix1,ix2)
out = paddle.index_put_(x,indices,value)
print(x)
# Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[0., 1., 0.],
# [0., 0., 1.],
# [0., 1., 0.]])
print(out)
# Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[0., 1., 0.],
# [0., 0., 1.],
# [0., 1., 0.]])
"""
return
_C_ops
.
index_put_
(
x
,
indices
,
value
,
accumulate
)
def
index_put
(
x
,
indices
,
value
,
accumulate
=
False
,
name
=
None
):
"""
Outplace version of ``index_put_`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_index_put`.
Examples:
.. code-block:: python
import paddle
x = paddle.zeros([3, 3])
value = paddle.ones([3])
ix1 = paddle.to_tensor([0,1,2])
ix2 = paddle.to_tensor([1,2,1])
indices=(ix1,ix2)
out = paddle.index_put(x,indices,value)
print(x)
# Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.]])
print(out)
# Tensor(shape=[3, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[0., 1., 0.],
# [0., 0., 1.],
# [0., 1., 0.]])
"""
if
in_dygraph_mode
():
return
_C_ops
.
index_put
(
x
,
indices
,
value
,
accumulate
)
helper
=
LayerHelper
(
"index_put"
,
**
locals
())
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'bool'
],
'paddle.tensor.manipulation.index_put'
,
)
check_variable_and_dtype
(
value
,
'value'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'bool'
],
'paddle.tensor.manipulation.index_put'
,
)
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
helper
.
append_op
(
type
=
'index_put'
,
inputs
=
{
'x'
:
x
,
'indices'
:
indices
,
'value'
:
value
,
},
outputs
=
{
'out'
:
out
},
attrs
=
{
'accumulate'
:
accumulate
},
)
return
out
# TODO(dev): We need avoid implementing it by this way.
# TODO(dev): We need avoid implementing it by this way.
__METHODS
=
{
__METHODS
=
{
'fill_'
:
fill_
,
'fill_'
:
fill_
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录