diff --git a/paddle/fluid/operators/masked_select_op.cc b/paddle/fluid/operators/masked_select_op.cc index 17bf5df18adc543ea487160a31d05d3c802b95a7..7a8b5bf3a194ac159cd6a3089568951da7437a09 100644 --- a/paddle/fluid/operators/masked_select_op.cc +++ b/paddle/fluid/operators/masked_select_op.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/masked_select_op.h" #include "paddle/fluid/framework/op_registry.h" namespace paddle { @@ -106,16 +105,3 @@ REGISTER_OPERATOR(masked_select, ops::MaskedSelectOp, ops::MaskedSelectOpMaker, ops::MaskedSelectGradOpMaker); REGISTER_OPERATOR(masked_select_grad, ops::MaskedSelectOpGrad, ops::MaskedSelectedGradNoNeedBufferVarsInferer); - -REGISTER_OP_CPU_KERNEL( - masked_select, - ops::MaskedSelectKernel, - ops::MaskedSelectKernel, - ops::MaskedSelectKernel, - ops::MaskedSelectKernel); -REGISTER_OP_CPU_KERNEL( - masked_select_grad, - ops::MaskedSelectGradKernel, - ops::MaskedSelectGradKernel, - ops::MaskedSelectGradKernel, - ops::MaskedSelectGradKernel); diff --git a/paddle/fluid/operators/masked_select_op.cu b/paddle/fluid/operators/masked_select_op.cu deleted file mode 100644 index 7dc0516800c483d1d82a2390a64130e77b1efb01..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/masked_select_op.cu +++ /dev/null @@ -1,179 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#include -#include -#include -#include -#include "paddle/fluid/operators/masked_select_op.h" -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; -using DDim = framework::DDim; - -__global__ void SetMaskArray(const bool* mask, int32_t* mask_array, int size) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - for (; idx < size; idx += blockDim.x * gridDim.x) { - if (mask[idx]) - mask_array[idx] = 1; - else - mask_array[idx] = 0; - } -} - -template -__global__ void SelectWithPrefixMask(const int32_t* mask_prefix_sum, - const bool* mask, const T* input, T* out, - int size) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - for (; idx < size; idx += blockDim.x * gridDim.x) { - if (mask[idx]) { - int index = mask_prefix_sum[idx]; - out[index] = input[idx]; - } - } -} - -template -__global__ void SelectGradWithPrefixMask(const int32_t* mask_prefix_sum, - const bool* mask, const T* input, - T* out, int size) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - for (; idx < size; idx += blockDim.x * gridDim.x) { - if (mask[idx]) { - int index = mask_prefix_sum[idx]; - out[idx] = input[index]; - } else { - out[idx] = 0; - } - } -} - -template -class MaskedSelectCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - auto input = ctx.Input("X"); - auto mask = ctx.Input("Mask"); - auto out = ctx.Output("Y"); - auto* mask_data = mask->data(); - auto input_data = input->data(); - - auto mask_size = mask->numel(); - auto input_dim = input->dims(); - auto mask_dim = mask->dims(); - PADDLE_ENFORCE_EQ( - input_dim, mask_dim, - platform::errors::InvalidArgument( - "The dim size of input and mask in OP(masked_selected) " - "must be equal, but got input dim:(%ld), mask dim: " - "(%ld). Please check input " - "value.", - input_dim, mask_dim)); - - thrust::device_ptr mask_dev_ptr = - thrust::device_pointer_cast(mask_data); - thrust::device_vector mask_vec(mask_dev_ptr, mask_dev_ptr + mask_size); - auto out_size = thrust::count(mask_vec.begin(), mask_vec.end(), true); - - framework::DDim out_dim{out_size}; - out->Resize(out_dim); - auto out_data = out->mutable_data(ctx.GetPlace()); - - Tensor mask_array; - Tensor mask_prefix_sum; - mask_array.Resize(mask_dim); - mask_prefix_sum.Resize(mask_dim); - - int32_t* mask_array_data = mask_array.mutable_data(ctx.GetPlace()); - int32_t* mask_prefix_sum_data = - mask_prefix_sum.mutable_data(ctx.GetPlace()); - int threads = 512; - int grid = (mask_size + threads - 1) / threads; - auto stream = ctx.cuda_device_context().stream(); - SetMaskArray<<>>(mask_data, mask_array_data, - mask_size); - - thrust::device_ptr mask_array_dev_ptr = - thrust::device_pointer_cast(mask_array_data); - thrust::device_vector mask_array_vec( - mask_array_dev_ptr, mask_array_dev_ptr + mask_size); - thrust::exclusive_scan(thrust::device, mask_array_vec.begin(), - mask_array_vec.end(), mask_prefix_sum_data); - - SelectWithPrefixMask<<>>( - mask_prefix_sum_data, mask_data, input_data, out_data, mask_size); - } -}; - -template -class MaskedSelectGradCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - auto input = ctx.Input(framework::GradVarName("Y")); - auto mask = ctx.Input("Mask"); - auto out = ctx.Output(framework::GradVarName("X")); - auto* mask_data = mask->data(); - auto* input_data = input->data(); - auto* out_data = out->mutable_data(ctx.GetPlace()); - - auto input_size = input->numel(); - auto mask_size = mask->numel(); - auto mask_dim = mask->dims(); - - auto out_size = mask_size; - - Tensor mask_array; - Tensor mask_prefix_sum; - mask_array.Resize(mask_dim); - mask_prefix_sum.Resize(mask_dim); - - int32_t* mask_array_data = mask_array.mutable_data(ctx.GetPlace()); - int32_t* mask_prefix_sum_data = - mask_prefix_sum.mutable_data(ctx.GetPlace()); - int threads = 512; - int grid = (mask_size + threads - 1) / threads; - auto stream = ctx.cuda_device_context().stream(); - SetMaskArray<<>>(mask_data, mask_array_data, - mask_size); - - thrust::device_ptr mask_array_dev_ptr = - thrust::device_pointer_cast(mask_array_data); - thrust::device_vector mask_array_vec( - mask_array_dev_ptr, mask_array_dev_ptr + mask_size); - thrust::exclusive_scan(thrust::device, mask_array_vec.begin(), - mask_array_vec.end(), mask_prefix_sum_data); - - SelectGradWithPrefixMask<<>>( - mask_prefix_sum_data, mask_data, input_data, out_data, mask_size); - } -}; -} // namespace operators -} // namespace paddle -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - masked_select, - ops::MaskedSelectCUDAKernel, - ops::MaskedSelectCUDAKernel, - ops::MaskedSelectCUDAKernel, - ops::MaskedSelectCUDAKernel); -REGISTER_OP_CUDA_KERNEL( - masked_select_grad, - ops::MaskedSelectGradCUDAKernel, - ops::MaskedSelectGradCUDAKernel, - ops::MaskedSelectGradCUDAKernel, - ops::MaskedSelectGradCUDAKernel); diff --git a/paddle/fluid/operators/masked_select_op.h b/paddle/fluid/operators/masked_select_op.h deleted file mode 100644 index ce8371556c82fe105b6719e845d4fd6232f3a95e..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/masked_select_op.h +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; -using DDim = framework::DDim; - -template -class MaskedSelectKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto input = context.Input("X"); - auto mask = context.Input("Mask"); - auto out = context.Output("Y"); - auto* mask_data = mask->data(); - auto input_data = input->data(); - - auto mask_size = mask->numel(); - - auto input_dim = input->dims(); - auto mask_dim = mask->dims(); - PADDLE_ENFORCE_EQ( - input_dim, mask_dim, - platform::errors::InvalidArgument( - "The dim size of input and mask in OP(masked_selected) " - "must be equal, but got input dim:(%ld), mask dim: " - "(%ld). Please check input " - "value.", - input_dim, mask_dim)); - - int out_size = 0; - for (int i = 0; i < mask_size; i++) { - if (mask_data[i]) out_size++; - } - - framework::DDim out_dim{out_size}; - out->Resize(out_dim); - auto out_data = out->mutable_data(context.GetPlace()); - - int index = 0; - for (int i = 0; i < mask_size; i++) { - if (mask_data[i]) { - out_data[index] = input_data[i]; - index++; - } - } - } -}; - -template -class MaskedSelectGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto out = context.Output(framework::GradVarName("X")); - auto mask = context.Input("Mask"); - auto input = context.Input(framework::GradVarName("Y")); - - auto* mask_data = mask->data(); - auto* input_data = input->data(); - auto* out_data = out->mutable_data(context.GetPlace()); - int mask_size = mask->numel(); - - int index = 0; - for (int i = 0; i < mask_size; i++) { - if (mask_data[i]) { - out_data[i] = input_data[index]; - index++; - } else { - out_data[i] = 0; - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/masked_select_op_npu.cc b/paddle/fluid/operators/masked_select_op_npu.cc index 828a3b002c20d1f7e4d9ec0c43f93a403ca262db..5b2f93c9752290f1a0f6d610170907e63b726c91 100644 --- a/paddle/fluid/operators/masked_select_op_npu.cc +++ b/paddle/fluid/operators/masked_select_op_npu.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/masked_select_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/fluid/operators/masked_select_op_xpu.cc b/paddle/fluid/operators/masked_select_op_xpu.cc index 8dbc5bcfc347abd1401d577b66e7a2da9600a30b..f0b61d91f78404dd0e082ca52a70b613e62d2628 100644 --- a/paddle/fluid/operators/masked_select_op_xpu.cc +++ b/paddle/fluid/operators/masked_select_op_xpu.cc @@ -11,7 +11,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/operators/masked_select_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/device_wrapper.h" namespace paddle { diff --git a/paddle/pten/kernels/cpu/masked_select_grad_kernel.cc b/paddle/pten/kernels/cpu/masked_select_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..06bd17cac9bb314682f7171d2a1a595023ad7b02 --- /dev/null +++ b/paddle/pten/kernels/cpu/masked_select_grad_kernel.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/masked_select_grad_kernel.h" + +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" + +namespace pten { + +template +void MaskedSelectGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + const DenseTensor& x, + const DenseTensor& mask, + DenseTensor* x_grad) { + auto* mask_data = mask.data(); + auto* input_data = out_grad.data(); + auto* out_data = x_grad->mutable_data(dev_ctx.GetPlace()); + int mask_size = mask.numel(); + + int index = 0; + for (int i = 0; i < mask_size; i++) { + if (mask_data[i]) { + out_data[i] = input_data[index]; + index++; + } else { + out_data[i] = 0; + } + } +} + +} // namespace pten + +PT_REGISTER_KERNEL(masked_select_grad, + CPU, + ALL_LAYOUT, + pten::MaskedSelectGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/pten/kernels/cpu/masked_select_kernel.cc b/paddle/pten/kernels/cpu/masked_select_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..a604ea36fbbe7b1dd790584b165eeaa3ed19b345 --- /dev/null +++ b/paddle/pten/kernels/cpu/masked_select_kernel.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/masked_select_kernel.h" + +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" + +namespace pten { + +template +void MaskedSelectKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& mask, + DenseTensor* out) { + auto* mask_data = mask.data(); + auto input_data = x.data(); + + auto mask_size = mask.numel(); + + auto input_dim = x.dims(); + auto mask_dim = mask.dims(); + PADDLE_ENFORCE_EQ(input_dim, + mask_dim, + pten::errors::InvalidArgument( + "The dim size of input and mask in OP(masked_selected) " + "must be equal, but got input dim:(%ld), mask dim: " + "(%ld). Please check input " + "value.", + input_dim, + mask_dim)); + + int out_size = 0; + for (int i = 0; i < mask_size; i++) { + if (mask_data[i]) out_size++; + } + + framework::DDim out_dim{out_size}; + out->Resize(out_dim); + auto out_data = out->mutable_data(paddle::platform::CPUPlace()); + + int index = 0; + for (int i = 0; i < mask_size; i++) { + if (mask_data[i]) { + out_data[index] = input_data[i]; + index++; + } + } +} + +} // namespace pten + +PT_REGISTER_KERNEL(masked_select, + CPU, + ALL_LAYOUT, + pten::MaskedSelectKernel, + float, + double, + int, + int64_t) { + kernel->InputAt(1).SetDataType(pten::DataType::BOOL); +} diff --git a/paddle/pten/kernels/gpu/masked_select_grad_kernel.cu b/paddle/pten/kernels/gpu/masked_select_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..aa638653a8e1f52b44114349e1aa02d6d0753323 --- /dev/null +++ b/paddle/pten/kernels/gpu/masked_select_grad_kernel.cu @@ -0,0 +1,106 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/masked_select_grad_kernel.h" + +namespace pten { + +__global__ void SetMaskArrayT(const bool* mask, int32_t* mask_array, int size) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < size; idx += blockDim.x * gridDim.x) { + if (mask[idx]) + mask_array[idx] = 1; + else + mask_array[idx] = 0; + } +} + +template +__global__ void SelectGradWithPrefixMask(const int32_t* mask_prefix_sum, + const bool* mask, + const T* input, + T* out, + int size) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < size; idx += blockDim.x * gridDim.x) { + if (mask[idx]) { + int index = mask_prefix_sum[idx]; + out[idx] = input[index]; + } else { + out[idx] = 0; + } + } +} + +template +void MaskedSelectGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + const DenseTensor& x, + const DenseTensor& mask, + DenseTensor* x_grad) { + auto* mask_data = mask.data(); + auto* input_data = out_grad.data(); + auto* out_data = x_grad->mutable_data(dev_ctx.GetPlace()); + + auto input_size = out_grad.numel(); + auto mask_size = mask.numel(); + auto mask_dim = mask.dims(); + + auto out_size = mask_size; + + DenseTensor mask_array; + DenseTensor mask_prefix_sum; + mask_array.Resize(mask_dim); + mask_prefix_sum.Resize(mask_dim); + + int32_t* mask_array_data = + mask_array.mutable_data(dev_ctx.GetPlace()); + int32_t* mask_prefix_sum_data = + mask_prefix_sum.mutable_data(dev_ctx.GetPlace()); + int threads = 512; + int grid = (mask_size + threads - 1) / threads; + auto stream = dev_ctx.stream(); + SetMaskArrayT<<>>( + mask_data, mask_array_data, mask_size); + + thrust::device_ptr mask_array_dev_ptr = + thrust::device_pointer_cast(mask_array_data); + thrust::device_vector mask_array_vec(mask_array_dev_ptr, + mask_array_dev_ptr + mask_size); + thrust::exclusive_scan(thrust::device, + mask_array_vec.begin(), + mask_array_vec.end(), + mask_prefix_sum_data); + + SelectGradWithPrefixMask<<>>( + mask_prefix_sum_data, mask_data, input_data, out_data, mask_size); +} + +} // namespace pten + +PT_REGISTER_KERNEL(masked_select_grad, + GPU, + ALL_LAYOUT, + pten::MaskedSelectGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/pten/kernels/gpu/masked_select_kernel.cu b/paddle/pten/kernels/gpu/masked_select_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..502522e2e8c99840f1fa1ef74f65ca300b50fde4 --- /dev/null +++ b/paddle/pten/kernels/gpu/masked_select_kernel.cu @@ -0,0 +1,120 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/masked_select_kernel.h" + +namespace pten { + +__global__ void SetMaskArray(const bool* mask, int32_t* mask_array, int size) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < size; idx += blockDim.x * gridDim.x) { + if (mask[idx]) + mask_array[idx] = 1; + else + mask_array[idx] = 0; + } +} + +template +__global__ void SelectWithPrefixMask(const int32_t* mask_prefix_sum, + const bool* mask, + const T* input, + T* out, + int size) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < size; idx += blockDim.x * gridDim.x) { + if (mask[idx]) { + int index = mask_prefix_sum[idx]; + out[index] = input[idx]; + } + } +} + +template +void MaskedSelectKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& mask, + DenseTensor* out) { + auto* mask_data = mask.data(); + auto input_data = x.data(); + + auto mask_size = mask.numel(); + auto input_dim = x.dims(); + auto mask_dim = mask.dims(); + PADDLE_ENFORCE_EQ(input_dim, + mask_dim, + pten::errors::InvalidArgument( + "The dim size of input and mask in OP(masked_selected) " + "must be equal, but got input dim:(%ld), mask dim: " + "(%ld). Please check input " + "value.", + input_dim, + mask_dim)); + + thrust::device_ptr mask_dev_ptr = + thrust::device_pointer_cast(mask_data); + thrust::device_vector mask_vec(mask_dev_ptr, mask_dev_ptr + mask_size); + auto out_size = thrust::count(mask_vec.begin(), mask_vec.end(), true); + + framework::DDim out_dim{out_size}; + out->Resize(out_dim); + auto out_data = out->mutable_data(dev_ctx.GetPlace()); + + DenseTensor mask_array; + DenseTensor mask_prefix_sum; + mask_array.Resize(mask_dim); + mask_prefix_sum.Resize(mask_dim); + + int32_t* mask_array_data = + mask_array.mutable_data(dev_ctx.GetPlace()); + int32_t* mask_prefix_sum_data = + mask_prefix_sum.mutable_data(dev_ctx.GetPlace()); + int threads = 512; + int grid = (mask_size + threads - 1) / threads; + auto stream = dev_ctx.stream(); + SetMaskArray<<>>( + mask_data, mask_array_data, mask_size); + + thrust::device_ptr mask_array_dev_ptr = + thrust::device_pointer_cast(mask_array_data); + thrust::device_vector mask_array_vec(mask_array_dev_ptr, + mask_array_dev_ptr + mask_size); + thrust::exclusive_scan(thrust::device, + mask_array_vec.begin(), + mask_array_vec.end(), + mask_prefix_sum_data); + + SelectWithPrefixMask<<>>( + mask_prefix_sum_data, mask_data, input_data, out_data, mask_size); +} + +} // namespace pten + +PT_REGISTER_KERNEL(masked_select, + GPU, + ALL_LAYOUT, + pten::MaskedSelectKernel, + float, + double, + int, + int64_t) { + kernel->InputAt(1).SetDataType(pten::DataType::BOOL); +} diff --git a/paddle/pten/kernels/masked_select_grad_kernel.h b/paddle/pten/kernels/masked_select_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..05175b58ef4552e5e0850cdb575045b1fa737d96 --- /dev/null +++ b/paddle/pten/kernels/masked_select_grad_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" +namespace pten { + +template +void MaskedSelectGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + const DenseTensor& x, + const DenseTensor& mask, + DenseTensor* x_grad); + +} // namspace pten diff --git a/paddle/pten/kernels/masked_select_kernel.h b/paddle/pten/kernels/masked_select_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..ecf4e557e6d1afdd44378b50096b6cb9320c4ad0 --- /dev/null +++ b/paddle/pten/kernels/masked_select_kernel.h @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" +namespace pten { + +template +void MaskedSelectKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& mask, + DenseTensor* out); + +} // namspace pten diff --git a/paddle/pten/ops/compat/masked_select_sig.cc b/paddle/pten/ops/compat/masked_select_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..a4b7260a1aa29da77dbd0c001b0e6cfb7e1e93b6 --- /dev/null +++ b/paddle/pten/ops/compat/masked_select_sig.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature MaskedSelectOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("masked_select", {"X", "Mask"}, {}, {"Y"}); +} + +KernelSignature MaskedSelectGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("masked_select_grad", + {GradVarName("Y"), "X", "Mask"}, + {}, + {GradVarName("X")}); +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(masked_select, pten::MaskedSelectOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(masked_select_grad, + pten::MaskedSelectGradOpArgumentMapping);