未验证 提交 99452af7 编写于 作者: C chenenquan 提交者: GitHub

[PHI] Migrate index_select op (#40260)

* [PHI] Migrate index_select op

* [PHI] Fix bug in test_variable

* [PHI] migrate index_select op
上级 57f54d3b
......@@ -13,8 +13,13 @@
// limitations under the License.
#include "paddle/fluid/operators/index_select_op.h"
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -24,52 +29,6 @@ class IndexSelectOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of IndexSelectOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
platform::errors::InvalidArgument(
"Input(Index) of IndexSelectOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of IndexSelectOp should not be null."));
auto input_dim = ctx->GetInputDim("X");
auto index_dim = ctx->GetInputDim("Index");
auto dim = ctx->Attrs().Get<int>("dim");
PADDLE_ENFORCE_EQ(
dim < input_dim.size() && dim >= (0 - input_dim.size()), true,
platform::errors::OutOfRange(
"Attr(dim) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
input_dim.size(), input_dim.size() - 1, dim));
PADDLE_ENFORCE_EQ(
index_dim.size() == 1 || (index_dim.size() == 2 && index_dim[1] == 1),
true, platform::errors::InvalidArgument(
"The 'shape' of Input(Index) must be 1-D tensor. "
"But received: the 'shape' of Input(Index) is [%s], "
"the dimension of Input(Index) is [%d].",
index_dim, index_dim.size()));
PADDLE_ENFORCE_EQ(index_dim[0] != 0, true,
platform::errors::InvalidArgument(
"The length of Input(Index) can't be 0."));
auto output_dim = phi::vectorize(input_dim);
if (dim < 0) {
dim += input_dim.size();
}
output_dim[dim] = index_dim[0];
ctx->SetOutputDim("Out", phi::make_ddim(output_dim));
auto type = ctx->GetInputsVarType("X")[0];
if (type == framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("X", /*->*/ "Out");
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -148,20 +107,11 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexSelectGradNoNeedBufferVarsInferer,
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(index_select, IndexSelectInferShapeFunctor,
PD_INFER_META(phi::IndexSelectInferMeta));
REGISTER_OPERATOR(index_select, ops::IndexSelectOp, ops::IndexSelectOpMaker,
ops::IndexSelectGradMaker<paddle::framework::OpDesc>,
ops::IndexSelectGradMaker<paddle::imperative::OpBase>);
ops::IndexSelectGradMaker<paddle::imperative::OpBase>,
IndexSelectInferShapeFunctor);
REGISTER_OPERATOR(index_select_grad, ops::IndexSelectGradOp,
ops::IndexSelectGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
index_select,
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, float>,
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, double>,
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, int>,
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
index_select_grad,
ops::IndexSelectGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::IndexSelectGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::IndexSelectGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::IndexSelectGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/index_select_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, typename IndexT>
__global__ void index_select_cuda_kernel(const T* input, T* output,
const IndexT* index, int64_t N,
int64_t stride, int64_t size,
int64_t delta) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
output[idx] = input[input_idx];
}
template <typename T, typename IndexT>
__global__ void index_select_grad_cuda_kernel(const T* output_grad,
T* input_grad,
const IndexT* index, int64_t nums,
int64_t N, int64_t stride,
int64_t size, int64_t delta) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]);
}
template <typename T>
__global__ void index_select_grad_init(T* input_grad, int64_t N) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
input_grad[idx] = 0.0;
}
template <typename DeviceContext, typename T>
class IndexSelectCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* index = context.Input<LoDTensor>("Index");
auto* out = context.Output<LoDTensor>("Out");
int dim = context.Attr<int>("dim");
auto input_dim = in->dims();
auto output_dim = out->dims();
dim = dim >= 0 ? dim : dim + input_dim.size();
auto stride_dim = phi::stride(input_dim);
int64_t stride = stride_dim[dim];
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
const auto& index_type = framework::TransToProtoVarType(index->dtype());
bool index_type_match = index_type == framework::proto::VarType::INT64 ||
index_type == framework::proto::VarType::INT32;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
auto* in_data = in->data<T>();
auto* out_data = out->mutable_data<T>(context.GetPlace());
int64_t numel = out->numel();
auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
if (index_type == framework::proto::VarType::INT64) {
const int64_t* index_data = index->data<int64_t>();
index_select_cuda_kernel<T, int64_t><<<
(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_data, out_data, index_data,
numel, stride, size, delta);
platform::GpuStreamSync(stream);
} else {
const int* index_data = index->data<int>();
index_select_cuda_kernel<T, int><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
in_data, out_data, index_data, numel, stride, size, delta);
platform::GpuStreamSync(stream);
}
}
};
template <typename DeviceContext, typename T>
class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* output_grad = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* in_grad = context.Output<LoDTensor>(framework::GradVarName("X"));
auto* index = context.Input<LoDTensor>("Index");
auto* output_grad_data = output_grad->data<T>();
auto* in_grad_data = in_grad->mutable_data<T>(context.GetPlace());
int dim = context.Attr<int>("dim");
auto input_dim = in_grad->dims();
auto output_dim = output_grad->dims();
dim = dim >= 0 ? dim : dim + input_dim.size();
auto stride_dim = phi::stride(input_dim);
int64_t stride = stride_dim[dim];
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
const auto& index_type = framework::TransToProtoVarType(index->dtype());
bool index_type_match = index_type == framework::proto::VarType::INT64 ||
index_type == framework::proto::VarType::INT32;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
int64_t numel = in_grad->numel();
int64_t index_nums = index->numel();
int64_t out_nums = output_grad->numel();
auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
index_select_grad_init<
T><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_grad_data, numel);
if (index_type == framework::proto::VarType::INT64) {
const int64_t* index_data = index->data<int64_t>();
index_select_grad_cuda_kernel<T, int64_t><<<
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data,
index_data, index_nums,
out_nums, stride, size, delta);
platform::GpuStreamSync(stream);
} else {
const int* index_data = index->data<int>();
index_select_grad_cuda_kernel<T, int><<<
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data,
index_data, index_nums,
out_nums, stride, size, delta);
platform::GpuStreamSync(stream);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
index_select,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
index_select_grad,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>);
......@@ -91,41 +91,6 @@ void IndexSelectInner(const framework::ExecutionContext& context,
output->Resize(output_dim);
}
template <typename DeviceContext, typename T>
class IndexSelectKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto inputs = *context.Input<framework::LoDTensor>("X");
auto* index = context.Input<framework::LoDTensor>("Index");
auto* output = context.Output<framework::LoDTensor>("Out");
int dim = context.Attr<int>("dim");
if (dim < 0) {
dim += inputs.dims().size();
}
const auto& index_type = framework::TransToProtoVarType(index->dtype());
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
IndexSelectInner<DeviceContext, T, int>(context, &inputs, *index, output,
dim);
} else if (index_type == framework::proto::VarType::INT64) {
IndexSelectInner<DeviceContext, T, int64_t>(context, &inputs, *index,
output, dim);
}
}
};
template <typename DeviceContext, typename T, class Enable = void>
struct IndexSelectAdd {
void operator()(const framework::ExecutionContext& ctx, int slice_size,
......@@ -197,43 +162,5 @@ void IndexSelectGradInner(const framework::ExecutionContext& context,
x_grad->Resize(output_dim);
}
template <typename DeviceContext, typename T>
class IndexSelectGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x_grad =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto* index = context.Input<framework::LoDTensor>("Index");
auto* out_grad =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
int dim = context.Attr<int>("dim");
if (dim < 0) {
dim += out_grad->dims().size();
}
const auto& index_type = framework::TransToProtoVarType(index->dtype());
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
IndexSelectGradInner<DeviceContext, T, int>(context, *out_grad, *index,
x_grad, dim);
} else if (index_type == framework::proto::VarType::INT64) {
IndexSelectGradInner<DeviceContext, T, int64_t>(context, *out_grad,
*index, x_grad, dim);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,12 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/index_select_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class IndexSelectNPUKernel : public framework::OpKernel<T> {
public:
......
......@@ -643,6 +643,49 @@ void IndexSampleInferMeta(const MetaTensor& x,
out->share_lod(y);
}
void IndexSelectInferMeta(const MetaTensor& x,
const MetaTensor& index,
int dim,
MetaTensor* output) {
auto input_dim = x.dims();
auto index_dim = index.dims();
PADDLE_ENFORCE_EQ(
dim < input_dim.size() && dim >= (0 - input_dim.size()),
true,
phi::errors::OutOfRange(
"Attr(dim) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
input_dim.size(),
input_dim.size() - 1,
dim));
PADDLE_ENFORCE_EQ(
index_dim.size() == 1 || (index_dim.size() == 2 && index_dim[1] == 1),
true,
phi::errors::InvalidArgument(
"The 'shape' of Input(Index) must be 1-D tensor. "
"But received: the 'shape' of Input(Index) is [%s], "
"the dimension of Input(Index) is [%d].",
index_dim,
index_dim.size()));
PADDLE_ENFORCE_EQ(
index_dim[0] != 0,
true,
phi::errors::InvalidArgument("The length of Input(Index) can't be 0."));
auto output_dim = phi::vectorize(input_dim);
if (dim < 0) {
dim += input_dim.size();
}
output_dim[dim] = index_dim[0];
output->set_dims(phi::make_ddim(output_dim));
output->set_dtype(x.dtype());
output->set_layout(x.layout());
output->share_lod(x);
}
void LogLossInferMeta(const MetaTensor& input,
const MetaTensor& label,
float epsilon,
......
......@@ -113,6 +113,11 @@ void IndexSampleInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void IndexSelectInferMeta(const MetaTensor& x,
const MetaTensor& index,
int dim,
MetaTensor* output);
void LogLossInferMeta(const MetaTensor& input,
const MetaTensor& label,
float epsilon,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_select_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/cpu/index_select_impl.h"
namespace phi {
template <typename T, typename Context>
void IndexSelectGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& out_grad,
int dim,
DenseTensor* x_grad) {
if (dim < 0) {
dim += out_grad.dims().size();
}
const auto& index_type = index.dtype();
bool index_type_match =
index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ(index_type_match,
true,
phi::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
index_type,
phi::DataType::INT32,
phi::DataType::INT64));
if (index_type == phi::DataType::INT32) {
IndexSelectGradInner<Context, T, int>(ctx, out_grad, index, x_grad, dim);
} else if (index_type == phi::DataType::INT64) {
IndexSelectGradInner<Context, T, int64_t>(
ctx, out_grad, index, x_grad, dim);
}
}
} // namespace phi
PD_REGISTER_KERNEL(index_select_grad,
CPU,
ALL_LAYOUT,
phi::IndexSelectGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename Context, typename T, class Enable = void>
struct IndexSelectAdd {
void operator()(const Context& ctx,
int slice_size,
const T* src_pointer,
const T* p_pointer,
T* dist_pointer) {
for (int i = 0; i < slice_size; i++) {
dist_pointer[i] = src_pointer[i] + p_pointer[i];
}
}
};
template <typename Context, typename T>
struct IndexSelectAdd<
Context,
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const Context& ctx,
int slice_size,
const T* src_pointer,
const T* p_pointer,
T* dist_pointer) {
auto blas = phi::funcs::GetBlas<Context, T>(ctx);
blas.VADD(slice_size, src_pointer, p_pointer, dist_pointer);
}
};
template <typename Context, typename T, typename IndexT = int>
void IndexSelectInner(const Context& ctx,
DenseTensor* input,
const DenseTensor& index,
DenseTensor* output,
int dim) {
auto input_dim = input->dims();
auto input_dim_size = input_dim.size();
auto output_dim = output->dims();
auto index_size = index.dims()[0];
DenseTensor index_cpu_copy;
if (!paddle::platform::is_cpu_place(index.place())) {
phi::Copy(ctx, index, phi::CPUPlace(), true, &index_cpu_copy);
}
const IndexT* index_data = paddle::platform::is_cpu_place(index.place())
? index.data<IndexT>()
: index_cpu_copy.data<IndexT>();
ctx.template Alloc<T>(output);
auto slice_size = 1;
for (auto i = dim + 1; i < input_dim_size; i++) {
slice_size *= input_dim[i];
}
auto outer_nums = 1;
for (auto i = 0; i < dim; i++) {
outer_nums *= input_dim[i];
}
for (int i = 0; i < index_size; i++) {
PADDLE_ENFORCE_GE(
index_data[i],
0,
phi::errors::InvalidArgument(
"Variable value (index) of OP(index_select) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
input_dim[dim],
index_data[i]));
PADDLE_ENFORCE_LT(
index_data[i],
input_dim[dim],
phi::errors::InvalidArgument(
"Variable value (index) of OP(index_select) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
input_dim[dim],
index_data[i]));
}
VLOG(3) << "Index_Select_Debug; outer_nums: " << outer_nums
<< "; slice_size: " << slice_size << "; index_size: " << index_size;
input->Resize(phi::make_ddim({outer_nums, input_dim[dim], slice_size}));
output->Resize(phi::make_ddim({outer_nums, index_size, slice_size}));
auto input_tensor = EigenTensor<T, 3>::From(*input);
auto output_tensor = EigenTensor<T, 3>::From(*output);
auto& place = *ctx.eigen_device();
for (auto j = 0; j < index_size; j++) {
IndexT index_value = index_data[j];
auto output_t = output_tensor.chip(j, 1);
output_t.device(place) = input_tensor.chip(index_value, 1);
}
input->Resize(input_dim);
output->Resize(output_dim);
}
template <typename Context, typename T, typename IndexT = int>
void IndexSelectGradInner(const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& index,
DenseTensor* x_grad,
int dim) {
const T* input_data = out_grad.data<T>();
const IndexT* index_data = index.data<IndexT>();
const T* p_output = ctx.template Alloc<T>(x_grad);
T* out_data = ctx.template Alloc<T>(x_grad);
auto input_dim = out_grad.dims();
auto input_dim_size = input_dim.size();
auto output_dim = x_grad->dims();
phi::funcs::SetConstant<Context, T> set_constant;
set_constant(ctx, x_grad, static_cast<T>(0.0));
auto slice_size = 1;
for (auto i = dim + 1; i < input_dim_size; i++) {
slice_size *= input_dim[i];
}
auto input_width = slice_size * input_dim[dim];
auto output_width = slice_size * output_dim[dim];
auto outer_nums = 1;
for (auto i = 0; i < dim; i++) {
outer_nums *= input_dim[i];
}
auto index_size = index.dims()[0];
VLOG(3) << "Index_Select_Grad_Debug; outer_nums: " << outer_nums
<< "; slice_size: " << slice_size << "; input_width: " << input_width
<< "; output_width: " << output_width
<< "; index_size: " << index_size;
for (auto i = 0; i < outer_nums; i++) {
auto input_start_offset = i * input_width;
auto output_start_offset = i * output_width;
for (auto j = 0; j < index_size; j++) {
IndexT index_value = index_data[j];
auto src = input_data + input_start_offset + j * slice_size;
auto p_out = p_output + output_start_offset + index_value * slice_size;
auto dst = out_data + output_start_offset + index_value * slice_size;
IndexSelectAdd<Context, T> index_select_add;
index_select_add(ctx, slice_size, src, p_out, dst);
}
}
x_grad->Resize(output_dim);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_select_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/cpu/index_select_impl.h"
namespace phi {
template <typename T, typename Context>
void IndexSelectKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& index,
int dim,
DenseTensor* output) {
auto inputs = x;
if (dim < 0) {
dim += inputs.dims().size();
}
const auto& index_type = index.dtype();
bool index_type_match =
index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ(index_type_match,
true,
phi::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
index_type,
phi::DataType::INT32,
phi::DataType::INT64));
if (index_type == phi::DataType::INT32) {
IndexSelectInner<Context, T, int>(ctx, &inputs, index, output, dim);
} else if (index_type == phi::DataType::INT64) {
IndexSelectInner<Context, T, int64_t>(ctx, &inputs, index, output, dim);
}
}
} // namespace phi
PD_REGISTER_KERNEL(index_select,
CPU,
ALL_LAYOUT,
phi::IndexSelectKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_select_grad_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
template <typename T, typename IndexT>
__global__ void index_select_grad_cuda_kernel(const T* output_grad,
T* input_grad,
const IndexT* index,
int64_t nums,
int64_t N,
int64_t stride,
int64_t size,
int64_t delta) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]);
}
template <typename T>
__global__ void index_select_grad_init(T* input_grad, int64_t N) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
input_grad[idx] = 0.0;
}
template <typename T, typename Context>
void IndexSelectGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& out_grad,
int dim,
DenseTensor* x_grad) {
auto* output_grad_data = out_grad.data<T>();
auto* in_grad_data = ctx.template Alloc<T>(x_grad);
auto input_dim = x_grad->dims();
auto output_dim = out_grad.dims();
dim = dim >= 0 ? dim : dim + input_dim.size();
auto stride_dim = phi::stride(input_dim);
int64_t stride = stride_dim[dim];
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
const auto& index_type = index.dtype();
bool index_type_match =
index_type == phi::DataType::INT64 || index_type == phi::DataType::INT32;
PADDLE_ENFORCE_EQ(index_type_match,
true,
phi::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
index_type,
phi::DataType::INT32,
phi::DataType::INT64));
int64_t numel = x_grad->numel();
int64_t index_nums = index.numel();
int64_t out_nums = out_grad.numel();
auto stream = ctx.stream();
index_select_grad_init<
T><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(in_grad_data, numel);
if (index_type == phi::DataType::INT64) {
const int64_t* index_data = index.data<int64_t>();
index_select_grad_cuda_kernel<T, int64_t><<<
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(output_grad_data,
in_grad_data,
index_data,
index_nums,
out_nums,
stride,
size,
delta);
phi::backends::gpu::GpuStreamSync(stream);
} else {
const int* index_data = index.data<int>();
index_select_grad_cuda_kernel<T, int><<<
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(output_grad_data,
in_grad_data,
index_data,
index_nums,
out_nums,
stride,
size,
delta);
phi::backends::gpu::GpuStreamSync(stream);
}
}
} // namespace phi
PD_REGISTER_KERNEL(index_select_grad,
GPU,
ALL_LAYOUT,
phi::IndexSelectGradKernel,
float,
double,
phi::dtype::float16,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_select_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
template <typename T, typename IndexT>
__global__ void index_select_cuda_kernel(const T* input,
T* output,
const IndexT* index,
int64_t N,
int64_t stride,
int64_t size,
int64_t delta) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
output[idx] = input[input_idx];
}
template <typename T, typename Context>
void IndexSelectKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& index,
int dim,
DenseTensor* output) {
auto input_dim = x.dims();
auto output_dim = output->dims();
dim = dim >= 0 ? dim : dim + input_dim.size();
auto stride_dim = phi::stride(input_dim);
int64_t stride = stride_dim[dim];
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
const auto& index_type = index.dtype();
bool index_type_match =
index_type == phi::DataType::INT64 || index_type == phi::DataType::INT32;
PADDLE_ENFORCE_EQ(index_type_match,
true,
phi::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
index_type,
phi::DataType::INT32,
phi::DataType::INT64));
auto* in_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(output);
int64_t numel = output->numel();
auto stream = ctx.stream();
if (index_type == phi::DataType::INT64) {
const int64_t* index_data = index.data<int64_t>();
index_select_cuda_kernel<T, int64_t><<<
(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(in_data, out_data, index_data, numel, stride, size, delta);
phi::backends::gpu::GpuStreamSync(stream);
} else {
const int* index_data = index.data<int>();
index_select_cuda_kernel<
T,
int><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(
in_data, out_data, index_data, numel, stride, size, delta);
phi::backends::gpu::GpuStreamSync(stream);
}
}
} // namespace phi
PD_REGISTER_KERNEL(index_select,
GPU,
ALL_LAYOUT,
phi::IndexSelectKernel,
float,
double,
phi::dtype::float16,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void IndexSelectGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& out_grad,
int dim,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void IndexSelectKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& index,
int dim,
DenseTensor* output);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature IndexSelectGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("index_select_grad",
{"X", "Index", GradVarName("Out")},
{"dim"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(index_select_grad,
phi::IndexSelectGradOpArgumentMapping);
......@@ -333,6 +333,7 @@ class TestVariable(unittest.TestCase):
with self.assertRaises(IndexError):
res = x[[True, False, False]]
with self.assertRaises(ValueError):
with paddle.static.program_guard(prog):
res = x[[False, False]]
def test_slice(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册