未验证 提交 857069f3 编写于 作者: Jeffrey Chen's avatar Jeffrey Chen 提交者: GitHub

[PHI] Migrate where_index op (#40255)

* [PHI] Migrate where_index op

* [PHI] Fix where_index infermate

* [Phi] set where_index out data type
上级 2747de2b
......@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/where_index_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -21,16 +24,6 @@ class WhereIndexOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Condition"), "Input", "Condition", "where");
PADDLE_ENFORCE_GE(
ctx->GetInputDim("Condition").size(), 1UL,
platform::errors::InvalidArgument(
"Input(Condition) should have number of dimension at least 1"));
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "where");
ctx->SetOutputDim("Out", {-1, ctx->GetInputDim("Condition").size()});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -53,11 +46,10 @@ class WhereIndexOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(where_index, ops::WhereIndexOp,
ops::WhereIndexOpMaker);
REGISTER_OP_CPU_KERNEL(where_index, ops::CPUWhereIndexKernel<int64_t>,
ops::CPUWhereIndexKernel<int>,
ops::CPUWhereIndexKernel<int16_t>,
ops::CPUWhereIndexKernel<bool>,
ops::CPUWhereIndexKernel<float>,
ops::CPUWhereIndexKernel<double>);
DECLARE_INFER_SHAPE_FUNCTOR(where_index, WhereIndexInferShapeFunctor,
PD_INFER_META(phi::WhereIndexInferMeta));
REGISTER_OPERATOR(
where_index, ops::WhereIndexOp, ops::WhereIndexOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
WhereIndexInferShapeFunctor);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/where_index_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/ddim.h"
namespace paddle {
namespace operators {
using CUDADeviceContext = paddle::platform::CUDADeviceContext;
template <typename T>
__global__ void GetTrueNum(const T *cond_data, const int64_t numel,
int64_t *true_num_array) {
const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int64_t idx = tid; idx < numel; idx += gridDim.x * blockDim.x) {
true_num_array[idx] =
static_cast<int64_t>(static_cast<bool>(cond_data[idx]));
}
}
template <typename T>
__global__ void SetTrueIndex(int64_t *out_ptr, const T *cond_data,
const int64_t numel, const int64_t *stride_array,
const int64_t rank,
const int64_t *true_num_array) {
const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int64_t idx = tid; idx < numel; idx += gridDim.x * blockDim.x) {
// true_num_array is calculated by cub::InclusiveSum,
// cause the first element of true_num_array is 1,
// so we need substract 1 to get true index.
const int64_t true_index = true_num_array[idx] - 1;
if (static_cast<bool>(cond_data[idx])) {
int64_t rank_index = idx;
for (int j = 0; j < rank; j++) {
const int64_t out_index = rank_index / stride_array[j];
out_ptr[true_index * rank + j] = out_index;
rank_index -= out_index * stride_array[j];
}
}
}
}
template <typename T>
class CUDAWhereIndexKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *condition = context.Input<framework::Tensor>("Condition");
auto *out = context.Output<framework::Tensor>("Out");
auto &dev_ctx = context.template device_context<CUDADeviceContext>();
const T *cond_data = condition->data<T>();
const int64_t numel = condition->numel();
auto dims = condition->dims();
const int rank = dims.size();
auto d_array_mem = memory::Alloc(dev_ctx, (numel + rank) * sizeof(int64_t));
auto h_array_mem =
memory::Alloc(platform::CPUPlace(), (rank + 1) * sizeof(int64_t));
// "stride_array" is an array and len(stride_array)==rank,
// each element is the stride of each dimension -- the length from i to i+1.
int64_t *h_stride_array = reinterpret_cast<int64_t *>(h_array_mem->ptr());
int64_t *d_stride_array = reinterpret_cast<int64_t *>(d_array_mem->ptr());
// "true_num_array" is an array and len(stride_array)==numel,
// at the beginning,
// "true_num_array" will set 1 if condition[i] == true else 0,
// then it will be calculated by cub::InclusiveSum,
// so that we can get the true number before i as the out index
int64_t *d_true_num_array = d_stride_array + rank;
// the total_true_num is the total number of condition[i] == true
int64_t *h_total_true_num = h_stride_array + rank;
// alloce cub memory
size_t cub_size = 0;
cub::DeviceScan::InclusiveSum(nullptr, cub_size, d_true_num_array,
d_true_num_array, numel, dev_ctx.stream());
auto cub_mem = memory::Alloc(dev_ctx, cub_size * sizeof(int64_t));
void *cub_data = cub_mem->ptr();
// set d_true_num_array[i]=1 if cond_data[i]==true else 0
const int threads = std::min(numel, static_cast<int64_t>(128));
const int64_t need_grids = (numel + threads - 1) / threads;
const int grids = std::min(need_grids, static_cast<int64_t>(256));
GetTrueNum<T><<<grids, threads, 0, dev_ctx.stream()>>>(cond_data, numel,
d_true_num_array);
// calculate the inclusive prefix sum of "true_num_array"
// to get the index of "out" tensor,
// and the total number of cond_data[i]==true.
// Example:
// condition: F T T F F F T T
// before: 0 1 1 0 0 0 1 1
// after: 0 1 2 2 2 2 3 4
// out: 1 2 6 7
cub::DeviceScan::InclusiveSum(cub_data, cub_size, d_true_num_array,
d_true_num_array, numel, dev_ctx.stream());
// calculate each dimension's stride
h_stride_array[rank - 1] = 1;
for (int i = rank - 2; i >= 0; i--) {
h_stride_array[i] = h_stride_array[i + 1] * dims[i + 1];
}
memory::Copy(dev_ctx.GetPlace(), d_stride_array, platform::CPUPlace(),
h_stride_array, rank * sizeof(int64_t), dev_ctx.stream());
// get total ture number and set output size
// the last element of cub::InclusiveSum is the total number
memory::Copy(platform::CPUPlace(), h_total_true_num, dev_ctx.GetPlace(),
d_true_num_array + numel - 1, sizeof(int64_t),
dev_ctx.stream());
dev_ctx.Wait();
int64_t true_num = *h_total_true_num;
out->Resize(phi::make_ddim({static_cast<int64_t>(true_num), rank}));
auto out_data = out->mutable_data<int64_t>(context.GetPlace());
if (true_num == 0) {
return;
}
// using true_num_array and stride_array to calculate the output index
SetTrueIndex<T><<<grids, threads, 0, dev_ctx.stream()>>>(
out_data, cond_data, numel, d_stride_array, rank, d_true_num_array);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(where_index, ops::CUDAWhereIndexKernel<int64_t>,
ops::CUDAWhereIndexKernel<int>,
ops::CUDAWhereIndexKernel<int16_t>,
ops::CUDAWhereIndexKernel<bool>,
ops::CUDAWhereIndexKernel<float>,
ops::CUDAWhereIndexKernel<double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
template <typename T>
struct WhereIndexFunctor {
WhereIndexFunctor(const T* true_index, int true_num, const T* stride,
int rank, T* out)
: true_index_(true_index),
true_num_(true_num),
stride_(stride),
rank_(rank),
out_ptr_(out) {}
HOSTDEVICE void operator()(size_t idx) const {
T index = true_index_[idx];
for (int j = 0; j < rank_; j++) {
out_ptr_[idx * rank_ + j] = index / stride_[j];
index -= out_ptr_[idx * rank_ + j] * stride_[j];
}
}
const T* true_index_;
int true_num_;
const T* stride_;
int rank_;
T* out_ptr_;
};
using CPUDeviceContext = paddle::platform::CPUDeviceContext;
template <typename T>
class CPUWhereIndexKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* condition = context.Input<framework::Tensor>("Condition");
auto* out = context.Output<framework::Tensor>("Out");
const T* cond_data = condition->data<T>();
auto numel = condition->numel();
auto dims = condition->dims();
const int rank = dims.size();
std::vector<int64_t> true_index;
for (auto i = 0; i < numel; i++) {
if (static_cast<bool>(cond_data[i])) {
true_index.push_back(i);
}
}
auto true_num = true_index.size();
out->Resize(phi::make_ddim({static_cast<int64_t>(true_num), rank}));
auto out_ptr = out->mutable_data<int64_t>(context.GetPlace());
if (true_num == 0) {
return;
}
std::vector<int64_t> stride(rank);
stride[rank - 1] = 1;
for (int i = rank - 2; i >= 0; i--) {
stride[i] = stride[i + 1] * dims[i + 1];
}
auto& dev_ctx = context.template device_context<CPUDeviceContext>();
WhereIndexFunctor<int64_t> functor(true_index.data(), true_num,
stride.data(), rank, out_ptr);
platform::ForRange<CPUDeviceContext> for_range(dev_ctx, true_num);
for_range(functor);
}
};
} // namespace operators
} // namespace paddle
......@@ -12,8 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/where_index_op.h"
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/where_index_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class WhereIndexXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* condition = context.Input<framework::Tensor>("Condition");
auto* out = context.Output<framework::Tensor>("Out");
const T* cond_data = condition->data<T>();
auto numel = condition->numel();
auto dims = condition->dims();
const int rank = dims.size();
auto& dev_ctx =
context.template device_context<paddle::platform::XPUDeviceContext>();
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int* true_num = RAII_GUARD.alloc_l3_or_gm<int32_t>(1);
int true_num_cpu;
int ret =
xpu::nonzero_count(dev_ctx.x_context(), cond_data, true_num, numel);
PADDLE_ENFORCE_EQ(
ret, XPU_SUCCESS,
platform::errors::External(
"XPU nonzero_count kernel return wrong value[%d %s] in WhereIndex",
ret, XPUAPIErrorMsg[ret]));
memory::Copy(platform::CPUPlace(), static_cast<void*>(&true_num_cpu),
context.GetPlace(), static_cast<void*>(true_num),
sizeof(int32_t));
out->Resize(phi::make_ddim({static_cast<int64_t>(true_num_cpu), rank}));
auto out_data = out->mutable_data<int64_t>(context.GetPlace());
if (true_num_cpu == 0) {
return;
}
auto condition_shape = phi::vectorize<int>(dims);
ret = xpu::where(dev_ctx.x_context(), cond_data, out_data, condition_shape,
true_num_cpu);
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
platform::errors::External(
"XPU masked_select kernel return wrong value[%d %s]",
ret, XPUAPIErrorMsg[ret]));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(where_index, ops::WhereIndexXPUKernel<int>,
ops::WhereIndexXPUKernel<bool>,
ops::WhereIndexXPUKernel<float>);
#endif
......@@ -1260,6 +1260,17 @@ void EighInferMeta(const MetaTensor& x,
out_v->set_dims(input_dim);
}
void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) {
auto rank = condition.dims().size();
PADDLE_ENFORCE_GE(
rank,
1UL,
phi::errors::InvalidArgument(
"Input(Condition) should have number of dimension at least 1"));
out->set_dims(phi::make_ddim({-1, rank}));
out->set_dtype(DataType::INT64);
}
} // namespace phi
PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta);
......
......@@ -182,4 +182,6 @@ void EighInferMeta(const MetaTensor& x,
MetaTensor* out_w,
MetaTensor* out_v);
void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/where_index_kernel.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T>
struct WhereIndexFunctor {
WhereIndexFunctor(
const T* true_index, int true_num, const T* stride, int rank, T* out)
: true_index_(true_index),
true_num_(true_num),
stride_(stride),
rank_(rank),
out_ptr_(out) {}
HOSTDEVICE void operator()(size_t idx) const {
T index = true_index_[idx];
for (int j = 0; j < rank_; j++) {
out_ptr_[idx * rank_ + j] = index / stride_[j];
index -= out_ptr_[idx * rank_ + j] * stride_[j];
}
}
const T* true_index_;
int true_num_;
const T* stride_;
int rank_;
T* out_ptr_;
};
template <typename T, typename Context>
void WhereIndexKernel(const Context& dev_ctx,
const DenseTensor& condition,
DenseTensor* out) {
const T* cond_data = condition.data<T>();
auto numel = condition.numel();
auto dims = condition.dims();
const int rank = dims.size();
std::vector<int64_t> true_index;
for (auto i = 0; i < numel; i++) {
if (static_cast<bool>(cond_data[i])) {
true_index.push_back(i);
}
}
auto true_num = true_index.size();
out->Resize(phi::make_ddim({static_cast<int64_t>(true_num), rank}));
auto* out_ptr = dev_ctx.template Alloc<int64_t>(out);
if (true_num == 0) {
return;
}
std::vector<int64_t> stride(rank);
stride[rank - 1] = 1;
for (int i = rank - 2; i >= 0; i--) {
stride[i] = stride[i + 1] * dims[i + 1];
}
WhereIndexFunctor<int64_t> functor(
true_index.data(), true_num, stride.data(), rank, out_ptr);
phi::funcs::ForRange<phi::CPUContext> for_range(dev_ctx, true_num);
for_range(functor);
}
} // namespace phi
PD_REGISTER_KERNEL(where_index,
CPU,
ALL_LAYOUT,
phi::WhereIndexKernel,
int64_t,
int,
int16_t,
bool,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/phi/kernels/where_index_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
template <typename T>
__global__ void GetTrueNum(const T *cond_data,
const int64_t numel,
int64_t *true_num_array) {
const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int64_t idx = tid; idx < numel; idx += gridDim.x * blockDim.x) {
true_num_array[idx] =
static_cast<int64_t>(static_cast<bool>(cond_data[idx]));
}
}
template <typename T>
__global__ void SetTrueIndex(int64_t *out_ptr,
const T *cond_data,
const int64_t numel,
const int64_t *stride_array,
const int64_t rank,
const int64_t *true_num_array) {
const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int64_t idx = tid; idx < numel; idx += gridDim.x * blockDim.x) {
// true_num_array is calculated by cub::InclusiveSum,
// cause the first element of true_num_array is 1,
// so we need substract 1 to get true index.
const int64_t true_index = true_num_array[idx] - 1;
if (static_cast<bool>(cond_data[idx])) {
int64_t rank_index = idx;
for (int j = 0; j < rank; j++) {
const int64_t out_index = rank_index / stride_array[j];
out_ptr[true_index * rank + j] = out_index;
rank_index -= out_index * stride_array[j];
}
}
}
}
template <typename T, typename Context>
void WhereIndexKernel(const Context &dev_ctx,
const DenseTensor &condition,
DenseTensor *out) {
const T *cond_data = condition.data<T>();
const int64_t numel = condition.numel();
auto dims = condition.dims();
const int rank = dims.size();
auto d_array_mem =
paddle::memory::Alloc(dev_ctx, (numel + rank) * sizeof(int64_t));
auto h_array_mem =
paddle::memory::Alloc(phi::CPUPlace(), (rank + 1) * sizeof(int64_t));
// "stride_array" is an array and len(stride_array)==rank,
// each element is the stride of each dimension -- the length from i to i+1.
int64_t *h_stride_array = reinterpret_cast<int64_t *>(h_array_mem->ptr());
int64_t *d_stride_array = reinterpret_cast<int64_t *>(d_array_mem->ptr());
// "true_num_array" is an array and len(stride_array)==numel,
// at the beginning,
// "true_num_array" will set 1 if condition[i] == true else 0,
// then it will be calculated by cub::InclusiveSum,
// so that we can get the true number before i as the out index
int64_t *d_true_num_array = d_stride_array + rank;
// the total_true_num is the total number of condition[i] == true
int64_t *h_total_true_num = h_stride_array + rank;
// alloce cub memory
size_t cub_size = 0;
cub::DeviceScan::InclusiveSum(nullptr,
cub_size,
d_true_num_array,
d_true_num_array,
numel,
dev_ctx.stream());
auto cub_mem = paddle::memory::Alloc(dev_ctx, cub_size * sizeof(int64_t));
void *cub_data = cub_mem->ptr();
// set d_true_num_array[i]=1 if cond_data[i]==true else 0
const int threads = std::min(numel, static_cast<int64_t>(128));
const int64_t need_grids = (numel + threads - 1) / threads;
const int grids = std::min(need_grids, static_cast<int64_t>(256));
GetTrueNum<T><<<grids, threads, 0, dev_ctx.stream()>>>(
cond_data, numel, d_true_num_array);
// calculate the inclusive prefix sum of "true_num_array"
// to get the index of "out" tensor,
// and the total number of cond_data[i]==true.
// Example:
// condition: F T T F F F T T
// before: 0 1 1 0 0 0 1 1
// after: 0 1 2 2 2 2 3 4
// out: 1 2 6 7
cub::DeviceScan::InclusiveSum(cub_data,
cub_size,
d_true_num_array,
d_true_num_array,
numel,
dev_ctx.stream());
// calculate each dimension's stride
h_stride_array[rank - 1] = 1;
for (int i = rank - 2; i >= 0; i--) {
h_stride_array[i] = h_stride_array[i + 1] * dims[i + 1];
}
paddle::memory::Copy(dev_ctx.GetPlace(),
d_stride_array,
phi::CPUPlace(),
h_stride_array,
rank * sizeof(int64_t),
dev_ctx.stream());
// get total ture number and set output size
// the last element of cub::InclusiveSum is the total number
paddle::memory::Copy(phi::CPUPlace(),
h_total_true_num,
dev_ctx.GetPlace(),
d_true_num_array + numel - 1,
sizeof(int64_t),
dev_ctx.stream());
dev_ctx.Wait();
int64_t true_num = *h_total_true_num;
out->Resize(phi::make_ddim({static_cast<int64_t>(true_num), rank}));
auto *out_data = dev_ctx.template Alloc<int64_t>(out);
if (true_num == 0) {
return;
}
// using true_num_array and stride_array to calculate the output index
SetTrueIndex<T><<<grids, threads, 0, dev_ctx.stream()>>>(
out_data, cond_data, numel, d_stride_array, rank, d_true_num_array);
}
} // namespace phi
PD_REGISTER_KERNEL(where_index,
GPU,
ALL_LAYOUT,
phi::WhereIndexKernel,
int64_t,
int,
int16_t,
bool,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void WhereIndexKernel(const Context& dev_ctx,
const DenseTensor& condition,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/where_index_kernel.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void WhereIndexKernel(const Context& dev_ctx,
const DenseTensor& condition,
DenseTensor* out) {
const T* cond_data = condition.data<T>();
auto numel = condition.numel();
auto dims = condition.dims();
const int rank = dims.size();
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int* true_num = RAII_GUARD.alloc_l3_or_gm<int32_t>(1);
int true_num_cpu;
int ret = xpu::nonzero_count(dev_ctx.x_context(), cond_data, true_num, numel);
PADDLE_ENFORCE_EQ(
ret,
XPU_SUCCESS,
phi::errors::External(
"XPU nonzero_count kernel return wrong value[%d %s] in WhereIndex",
ret,
XPUAPIErrorMsg[ret]));
paddle::memory::Copy(phi::CPUPlace(),
static_cast<void*>(&true_num_cpu),
dev_ctx.GetPlace(),
static_cast<void*>(true_num),
sizeof(int32_t));
out->Resize(phi::make_ddim({static_cast<int64_t>(true_num_cpu), rank}));
auto* out_data = dev_ctx.template Alloc<int64_t>(out);
if (true_num_cpu == 0) {
return;
}
auto condition_shape = phi::vectorize<int>(dims);
ret = xpu::where(
dev_ctx.x_context(), cond_data, out_data, condition_shape, true_num_cpu);
PADDLE_ENFORCE_EQ(ret,
XPU_SUCCESS,
phi::errors::External(
"XPU masked_select kernel return wrong value[%d %s]",
ret,
XPUAPIErrorMsg[ret]));
}
} // namespace phi
PD_REGISTER_KERNEL(
where_index, XPU, ALL_LAYOUT, phi::WhereIndexKernel, int, bool, float) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册