未验证 提交 99452af7 编写于 作者: C chenenquan 提交者: GitHub

[PHI] Migrate index_select op (#40260)

* [PHI] Migrate index_select op

* [PHI] Fix bug in test_variable

* [PHI] migrate index_select op
上级 57f54d3b
...@@ -13,8 +13,13 @@ ...@@ -13,8 +13,13 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/index_select_op.h" #include "paddle/fluid/operators/index_select_op.h"
#include <memory> #include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -24,52 +29,6 @@ class IndexSelectOp : public framework::OperatorWithKernel { ...@@ -24,52 +29,6 @@ class IndexSelectOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of IndexSelectOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
platform::errors::InvalidArgument(
"Input(Index) of IndexSelectOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of IndexSelectOp should not be null."));
auto input_dim = ctx->GetInputDim("X");
auto index_dim = ctx->GetInputDim("Index");
auto dim = ctx->Attrs().Get<int>("dim");
PADDLE_ENFORCE_EQ(
dim < input_dim.size() && dim >= (0 - input_dim.size()), true,
platform::errors::OutOfRange(
"Attr(dim) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
input_dim.size(), input_dim.size() - 1, dim));
PADDLE_ENFORCE_EQ(
index_dim.size() == 1 || (index_dim.size() == 2 && index_dim[1] == 1),
true, platform::errors::InvalidArgument(
"The 'shape' of Input(Index) must be 1-D tensor. "
"But received: the 'shape' of Input(Index) is [%s], "
"the dimension of Input(Index) is [%d].",
index_dim, index_dim.size()));
PADDLE_ENFORCE_EQ(index_dim[0] != 0, true,
platform::errors::InvalidArgument(
"The length of Input(Index) can't be 0."));
auto output_dim = phi::vectorize(input_dim);
if (dim < 0) {
dim += input_dim.size();
}
output_dim[dim] = index_dim[0];
ctx->SetOutputDim("Out", phi::make_ddim(output_dim));
auto type = ctx->GetInputsVarType("X")[0];
if (type == framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("X", /*->*/ "Out");
}
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -148,20 +107,11 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexSelectGradNoNeedBufferVarsInferer, ...@@ -148,20 +107,11 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexSelectGradNoNeedBufferVarsInferer,
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(index_select, IndexSelectInferShapeFunctor,
PD_INFER_META(phi::IndexSelectInferMeta));
REGISTER_OPERATOR(index_select, ops::IndexSelectOp, ops::IndexSelectOpMaker, REGISTER_OPERATOR(index_select, ops::IndexSelectOp, ops::IndexSelectOpMaker,
ops::IndexSelectGradMaker<paddle::framework::OpDesc>, ops::IndexSelectGradMaker<paddle::framework::OpDesc>,
ops::IndexSelectGradMaker<paddle::imperative::OpBase>); ops::IndexSelectGradMaker<paddle::imperative::OpBase>,
IndexSelectInferShapeFunctor);
REGISTER_OPERATOR(index_select_grad, ops::IndexSelectGradOp, REGISTER_OPERATOR(index_select_grad, ops::IndexSelectGradOp,
ops::IndexSelectGradNoNeedBufferVarsInferer); ops::IndexSelectGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
index_select,
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, float>,
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, double>,
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, int>,
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
index_select_grad,
ops::IndexSelectGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::IndexSelectGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::IndexSelectGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::IndexSelectGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/index_select_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, typename IndexT>
__global__ void index_select_cuda_kernel(const T* input, T* output,
const IndexT* index, int64_t N,
int64_t stride, int64_t size,
int64_t delta) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
output[idx] = input[input_idx];
}
template <typename T, typename IndexT>
__global__ void index_select_grad_cuda_kernel(const T* output_grad,
T* input_grad,
const IndexT* index, int64_t nums,
int64_t N, int64_t stride,
int64_t size, int64_t delta) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]);
}
template <typename T>
__global__ void index_select_grad_init(T* input_grad, int64_t N) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
input_grad[idx] = 0.0;
}
template <typename DeviceContext, typename T>
class IndexSelectCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* index = context.Input<LoDTensor>("Index");
auto* out = context.Output<LoDTensor>("Out");
int dim = context.Attr<int>("dim");
auto input_dim = in->dims();
auto output_dim = out->dims();
dim = dim >= 0 ? dim : dim + input_dim.size();
auto stride_dim = phi::stride(input_dim);
int64_t stride = stride_dim[dim];
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
const auto& index_type = framework::TransToProtoVarType(index->dtype());
bool index_type_match = index_type == framework::proto::VarType::INT64 ||
index_type == framework::proto::VarType::INT32;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
auto* in_data = in->data<T>();
auto* out_data = out->mutable_data<T>(context.GetPlace());
int64_t numel = out->numel();
auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
if (index_type == framework::proto::VarType::INT64) {
const int64_t* index_data = index->data<int64_t>();
index_select_cuda_kernel<T, int64_t><<<
(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_data, out_data, index_data,
numel, stride, size, delta);
platform::GpuStreamSync(stream);
} else {
const int* index_data = index->data<int>();
index_select_cuda_kernel<T, int><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
in_data, out_data, index_data, numel, stride, size, delta);
platform::GpuStreamSync(stream);
}
}
};
template <typename DeviceContext, typename T>
class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* output_grad = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* in_grad = context.Output<LoDTensor>(framework::GradVarName("X"));
auto* index = context.Input<LoDTensor>("Index");
auto* output_grad_data = output_grad->data<T>();
auto* in_grad_data = in_grad->mutable_data<T>(context.GetPlace());
int dim = context.Attr<int>("dim");
auto input_dim = in_grad->dims();
auto output_dim = output_grad->dims();
dim = dim >= 0 ? dim : dim + input_dim.size();
auto stride_dim = phi::stride(input_dim);
int64_t stride = stride_dim[dim];
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
const auto& index_type = framework::TransToProtoVarType(index->dtype());
bool index_type_match = index_type == framework::proto::VarType::INT64 ||
index_type == framework::proto::VarType::INT32;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
int64_t numel = in_grad->numel();
int64_t index_nums = index->numel();
int64_t out_nums = output_grad->numel();
auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
index_select_grad_init<
T><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_grad_data, numel);
if (index_type == framework::proto::VarType::INT64) {
const int64_t* index_data = index->data<int64_t>();
index_select_grad_cuda_kernel<T, int64_t><<<
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data,
index_data, index_nums,
out_nums, stride, size, delta);
platform::GpuStreamSync(stream);
} else {
const int* index_data = index->data<int>();
index_select_grad_cuda_kernel<T, int><<<
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data,
index_data, index_nums,
out_nums, stride, size, delta);
platform::GpuStreamSync(stream);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
index_select,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
index_select_grad,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>);
...@@ -91,41 +91,6 @@ void IndexSelectInner(const framework::ExecutionContext& context, ...@@ -91,41 +91,6 @@ void IndexSelectInner(const framework::ExecutionContext& context,
output->Resize(output_dim); output->Resize(output_dim);
} }
template <typename DeviceContext, typename T>
class IndexSelectKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto inputs = *context.Input<framework::LoDTensor>("X");
auto* index = context.Input<framework::LoDTensor>("Index");
auto* output = context.Output<framework::LoDTensor>("Out");
int dim = context.Attr<int>("dim");
if (dim < 0) {
dim += inputs.dims().size();
}
const auto& index_type = framework::TransToProtoVarType(index->dtype());
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
IndexSelectInner<DeviceContext, T, int>(context, &inputs, *index, output,
dim);
} else if (index_type == framework::proto::VarType::INT64) {
IndexSelectInner<DeviceContext, T, int64_t>(context, &inputs, *index,
output, dim);
}
}
};
template <typename DeviceContext, typename T, class Enable = void> template <typename DeviceContext, typename T, class Enable = void>
struct IndexSelectAdd { struct IndexSelectAdd {
void operator()(const framework::ExecutionContext& ctx, int slice_size, void operator()(const framework::ExecutionContext& ctx, int slice_size,
...@@ -197,43 +162,5 @@ void IndexSelectGradInner(const framework::ExecutionContext& context, ...@@ -197,43 +162,5 @@ void IndexSelectGradInner(const framework::ExecutionContext& context,
x_grad->Resize(output_dim); x_grad->Resize(output_dim);
} }
template <typename DeviceContext, typename T>
class IndexSelectGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x_grad =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto* index = context.Input<framework::LoDTensor>("Index");
auto* out_grad =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
int dim = context.Attr<int>("dim");
if (dim < 0) {
dim += out_grad->dims().size();
}
const auto& index_type = framework::TransToProtoVarType(index->dtype());
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
IndexSelectGradInner<DeviceContext, T, int>(context, *out_grad, *index,
x_grad, dim);
} else if (index_type == framework::proto::VarType::INT64) {
IndexSelectGradInner<DeviceContext, T, int64_t>(context, *out_grad,
*index, x_grad, dim);
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,12 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,12 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/index_select_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class IndexSelectNPUKernel : public framework::OpKernel<T> { class IndexSelectNPUKernel : public framework::OpKernel<T> {
public: public:
......
...@@ -643,6 +643,49 @@ void IndexSampleInferMeta(const MetaTensor& x, ...@@ -643,6 +643,49 @@ void IndexSampleInferMeta(const MetaTensor& x,
out->share_lod(y); out->share_lod(y);
} }
void IndexSelectInferMeta(const MetaTensor& x,
const MetaTensor& index,
int dim,
MetaTensor* output) {
auto input_dim = x.dims();
auto index_dim = index.dims();
PADDLE_ENFORCE_EQ(
dim < input_dim.size() && dim >= (0 - input_dim.size()),
true,
phi::errors::OutOfRange(
"Attr(dim) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
input_dim.size(),
input_dim.size() - 1,
dim));
PADDLE_ENFORCE_EQ(
index_dim.size() == 1 || (index_dim.size() == 2 && index_dim[1] == 1),
true,
phi::errors::InvalidArgument(
"The 'shape' of Input(Index) must be 1-D tensor. "
"But received: the 'shape' of Input(Index) is [%s], "
"the dimension of Input(Index) is [%d].",
index_dim,
index_dim.size()));
PADDLE_ENFORCE_EQ(
index_dim[0] != 0,
true,
phi::errors::InvalidArgument("The length of Input(Index) can't be 0."));
auto output_dim = phi::vectorize(input_dim);
if (dim < 0) {
dim += input_dim.size();
}
output_dim[dim] = index_dim[0];
output->set_dims(phi::make_ddim(output_dim));
output->set_dtype(x.dtype());
output->set_layout(x.layout());
output->share_lod(x);
}
void LogLossInferMeta(const MetaTensor& input, void LogLossInferMeta(const MetaTensor& input,
const MetaTensor& label, const MetaTensor& label,
float epsilon, float epsilon,
......
...@@ -113,6 +113,11 @@ void IndexSampleInferMeta(const MetaTensor& x, ...@@ -113,6 +113,11 @@ void IndexSampleInferMeta(const MetaTensor& x,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void IndexSelectInferMeta(const MetaTensor& x,
const MetaTensor& index,
int dim,
MetaTensor* output);
void LogLossInferMeta(const MetaTensor& input, void LogLossInferMeta(const MetaTensor& input,
const MetaTensor& label, const MetaTensor& label,
float epsilon, float epsilon,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_select_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/cpu/index_select_impl.h"
namespace phi {
template <typename T, typename Context>
void IndexSelectGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& out_grad,
int dim,
DenseTensor* x_grad) {
if (dim < 0) {
dim += out_grad.dims().size();
}
const auto& index_type = index.dtype();
bool index_type_match =
index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ(index_type_match,
true,
phi::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
index_type,
phi::DataType::INT32,
phi::DataType::INT64));
if (index_type == phi::DataType::INT32) {
IndexSelectGradInner<Context, T, int>(ctx, out_grad, index, x_grad, dim);
} else if (index_type == phi::DataType::INT64) {
IndexSelectGradInner<Context, T, int64_t>(
ctx, out_grad, index, x_grad, dim);
}
}
} // namespace phi
PD_REGISTER_KERNEL(index_select_grad,
CPU,
ALL_LAYOUT,
phi::IndexSelectGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename Context, typename T, class Enable = void>
struct IndexSelectAdd {
void operator()(const Context& ctx,
int slice_size,
const T* src_pointer,
const T* p_pointer,
T* dist_pointer) {
for (int i = 0; i < slice_size; i++) {
dist_pointer[i] = src_pointer[i] + p_pointer[i];
}
}
};
template <typename Context, typename T>
struct IndexSelectAdd<
Context,
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const Context& ctx,
int slice_size,
const T* src_pointer,
const T* p_pointer,
T* dist_pointer) {
auto blas = phi::funcs::GetBlas<Context, T>(ctx);
blas.VADD(slice_size, src_pointer, p_pointer, dist_pointer);
}
};
template <typename Context, typename T, typename IndexT = int>
void IndexSelectInner(const Context& ctx,
DenseTensor* input,
const DenseTensor& index,
DenseTensor* output,
int dim) {
auto input_dim = input->dims();
auto input_dim_size = input_dim.size();
auto output_dim = output->dims();
auto index_size = index.dims()[0];
DenseTensor index_cpu_copy;
if (!paddle::platform::is_cpu_place(index.place())) {
phi::Copy(ctx, index, phi::CPUPlace(), true, &index_cpu_copy);
}
const IndexT* index_data = paddle::platform::is_cpu_place(index.place())
? index.data<IndexT>()
: index_cpu_copy.data<IndexT>();
ctx.template Alloc<T>(output);
auto slice_size = 1;
for (auto i = dim + 1; i < input_dim_size; i++) {
slice_size *= input_dim[i];
}
auto outer_nums = 1;
for (auto i = 0; i < dim; i++) {
outer_nums *= input_dim[i];
}
for (int i = 0; i < index_size; i++) {
PADDLE_ENFORCE_GE(
index_data[i],
0,
phi::errors::InvalidArgument(
"Variable value (index) of OP(index_select) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
input_dim[dim],
index_data[i]));
PADDLE_ENFORCE_LT(
index_data[i],
input_dim[dim],
phi::errors::InvalidArgument(
"Variable value (index) of OP(index_select) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
input_dim[dim],
index_data[i]));
}
VLOG(3) << "Index_Select_Debug; outer_nums: " << outer_nums
<< "; slice_size: " << slice_size << "; index_size: " << index_size;
input->Resize(phi::make_ddim({outer_nums, input_dim[dim], slice_size}));
output->Resize(phi::make_ddim({outer_nums, index_size, slice_size}));
auto input_tensor = EigenTensor<T, 3>::From(*input);
auto output_tensor = EigenTensor<T, 3>::From(*output);
auto& place = *ctx.eigen_device();
for (auto j = 0; j < index_size; j++) {
IndexT index_value = index_data[j];
auto output_t = output_tensor.chip(j, 1);
output_t.device(place) = input_tensor.chip(index_value, 1);
}
input->Resize(input_dim);
output->Resize(output_dim);
}
template <typename Context, typename T, typename IndexT = int>
void IndexSelectGradInner(const Context& ctx,
const DenseTensor& out_grad,
const DenseTensor& index,
DenseTensor* x_grad,
int dim) {
const T* input_data = out_grad.data<T>();
const IndexT* index_data = index.data<IndexT>();
const T* p_output = ctx.template Alloc<T>(x_grad);
T* out_data = ctx.template Alloc<T>(x_grad);
auto input_dim = out_grad.dims();
auto input_dim_size = input_dim.size();
auto output_dim = x_grad->dims();
phi::funcs::SetConstant<Context, T> set_constant;
set_constant(ctx, x_grad, static_cast<T>(0.0));
auto slice_size = 1;
for (auto i = dim + 1; i < input_dim_size; i++) {
slice_size *= input_dim[i];
}
auto input_width = slice_size * input_dim[dim];
auto output_width = slice_size * output_dim[dim];
auto outer_nums = 1;
for (auto i = 0; i < dim; i++) {
outer_nums *= input_dim[i];
}
auto index_size = index.dims()[0];
VLOG(3) << "Index_Select_Grad_Debug; outer_nums: " << outer_nums
<< "; slice_size: " << slice_size << "; input_width: " << input_width
<< "; output_width: " << output_width
<< "; index_size: " << index_size;
for (auto i = 0; i < outer_nums; i++) {
auto input_start_offset = i * input_width;
auto output_start_offset = i * output_width;
for (auto j = 0; j < index_size; j++) {
IndexT index_value = index_data[j];
auto src = input_data + input_start_offset + j * slice_size;
auto p_out = p_output + output_start_offset + index_value * slice_size;
auto dst = out_data + output_start_offset + index_value * slice_size;
IndexSelectAdd<Context, T> index_select_add;
index_select_add(ctx, slice_size, src, p_out, dst);
}
}
x_grad->Resize(output_dim);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_select_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/cpu/index_select_impl.h"
namespace phi {
template <typename T, typename Context>
void IndexSelectKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& index,
int dim,
DenseTensor* output) {
auto inputs = x;
if (dim < 0) {
dim += inputs.dims().size();
}
const auto& index_type = index.dtype();
bool index_type_match =
index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64;
PADDLE_ENFORCE_EQ(index_type_match,
true,
phi::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
index_type,
phi::DataType::INT32,
phi::DataType::INT64));
if (index_type == phi::DataType::INT32) {
IndexSelectInner<Context, T, int>(ctx, &inputs, index, output, dim);
} else if (index_type == phi::DataType::INT64) {
IndexSelectInner<Context, T, int64_t>(ctx, &inputs, index, output, dim);
}
}
} // namespace phi
PD_REGISTER_KERNEL(index_select,
CPU,
ALL_LAYOUT,
phi::IndexSelectKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_select_grad_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
template <typename T, typename IndexT>
__global__ void index_select_grad_cuda_kernel(const T* output_grad,
T* input_grad,
const IndexT* index,
int64_t nums,
int64_t N,
int64_t stride,
int64_t size,
int64_t delta) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]);
}
template <typename T>
__global__ void index_select_grad_init(T* input_grad, int64_t N) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
input_grad[idx] = 0.0;
}
template <typename T, typename Context>
void IndexSelectGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& out_grad,
int dim,
DenseTensor* x_grad) {
auto* output_grad_data = out_grad.data<T>();
auto* in_grad_data = ctx.template Alloc<T>(x_grad);
auto input_dim = x_grad->dims();
auto output_dim = out_grad.dims();
dim = dim >= 0 ? dim : dim + input_dim.size();
auto stride_dim = phi::stride(input_dim);
int64_t stride = stride_dim[dim];
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
const auto& index_type = index.dtype();
bool index_type_match =
index_type == phi::DataType::INT64 || index_type == phi::DataType::INT32;
PADDLE_ENFORCE_EQ(index_type_match,
true,
phi::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
index_type,
phi::DataType::INT32,
phi::DataType::INT64));
int64_t numel = x_grad->numel();
int64_t index_nums = index.numel();
int64_t out_nums = out_grad.numel();
auto stream = ctx.stream();
index_select_grad_init<
T><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(in_grad_data, numel);
if (index_type == phi::DataType::INT64) {
const int64_t* index_data = index.data<int64_t>();
index_select_grad_cuda_kernel<T, int64_t><<<
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(output_grad_data,
in_grad_data,
index_data,
index_nums,
out_nums,
stride,
size,
delta);
phi::backends::gpu::GpuStreamSync(stream);
} else {
const int* index_data = index.data<int>();
index_select_grad_cuda_kernel<T, int><<<
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(output_grad_data,
in_grad_data,
index_data,
index_nums,
out_nums,
stride,
size,
delta);
phi::backends::gpu::GpuStreamSync(stream);
}
}
} // namespace phi
PD_REGISTER_KERNEL(index_select_grad,
GPU,
ALL_LAYOUT,
phi::IndexSelectGradKernel,
float,
double,
phi::dtype::float16,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/index_select_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS;
template <typename T, typename IndexT>
__global__ void index_select_cuda_kernel(const T* input,
T* output,
const IndexT* index,
int64_t N,
int64_t stride,
int64_t size,
int64_t delta) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
output[idx] = input[input_idx];
}
template <typename T, typename Context>
void IndexSelectKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& index,
int dim,
DenseTensor* output) {
auto input_dim = x.dims();
auto output_dim = output->dims();
dim = dim >= 0 ? dim : dim + input_dim.size();
auto stride_dim = phi::stride(input_dim);
int64_t stride = stride_dim[dim];
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
const auto& index_type = index.dtype();
bool index_type_match =
index_type == phi::DataType::INT64 || index_type == phi::DataType::INT32;
PADDLE_ENFORCE_EQ(index_type_match,
true,
phi::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
index_type,
phi::DataType::INT32,
phi::DataType::INT64));
auto* in_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(output);
int64_t numel = output->numel();
auto stream = ctx.stream();
if (index_type == phi::DataType::INT64) {
const int64_t* index_data = index.data<int64_t>();
index_select_cuda_kernel<T, int64_t><<<
(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(in_data, out_data, index_data, numel, stride, size, delta);
phi::backends::gpu::GpuStreamSync(stream);
} else {
const int* index_data = index.data<int>();
index_select_cuda_kernel<
T,
int><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(
in_data, out_data, index_data, numel, stride, size, delta);
phi::backends::gpu::GpuStreamSync(stream);
}
}
} // namespace phi
PD_REGISTER_KERNEL(index_select,
GPU,
ALL_LAYOUT,
phi::IndexSelectKernel,
float,
double,
phi::dtype::float16,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void IndexSelectGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& out_grad,
int dim,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void IndexSelectKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& index,
int dim,
DenseTensor* output);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature IndexSelectGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("index_select_grad",
{"X", "Index", GradVarName("Out")},
{"dim"},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(index_select_grad,
phi::IndexSelectGradOpArgumentMapping);
...@@ -333,7 +333,8 @@ class TestVariable(unittest.TestCase): ...@@ -333,7 +333,8 @@ class TestVariable(unittest.TestCase):
with self.assertRaises(IndexError): with self.assertRaises(IndexError):
res = x[[True, False, False]] res = x[[True, False, False]]
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
res = x[[False, False]] with paddle.static.program_guard(prog):
res = x[[False, False]]
def test_slice(self): def test_slice(self):
places = [fluid.CPUPlace()] places = [fluid.CPUPlace()]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册