未验证 提交 e2ad433b 编写于 作者: H hong 提交者: GitHub

move Masked select to pten (#39193)

* move masked select cpu kernel

* add masked selected gpu kernel; test=develop

* fix bugs; test=develop

* bug fix; test=develop

* bug fix; test=develop

* add namespace to set mask array; test=develop

* fix bug; test=develop

* fix bugs; test=develop

* fix ddim bug; test=develop

* fix npu op bug; test=develop

* fix xpu dependecy bug; test=develop

* move kernel args to sig.cc; test=develop
上级 8b58862a
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/masked_select_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
......@@ -106,16 +105,3 @@ REGISTER_OPERATOR(masked_select, ops::MaskedSelectOp, ops::MaskedSelectOpMaker,
ops::MaskedSelectGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(masked_select_grad, ops::MaskedSelectOpGrad,
ops::MaskedSelectedGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
masked_select,
ops::MaskedSelectKernel<paddle::platform::CPUDeviceContext, float>,
ops::MaskedSelectKernel<paddle::platform::CPUDeviceContext, double>,
ops::MaskedSelectKernel<paddle::platform::CPUDeviceContext, int>,
ops::MaskedSelectKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
masked_select_grad,
ops::MaskedSelectGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MaskedSelectGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::MaskedSelectGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::MaskedSelectGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/reverse.h>
#include <thrust/scan.h>
#include "paddle/fluid/operators/masked_select_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DDim = framework::DDim;
__global__ void SetMaskArray(const bool* mask, int32_t* mask_array, int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < size; idx += blockDim.x * gridDim.x) {
if (mask[idx])
mask_array[idx] = 1;
else
mask_array[idx] = 0;
}
}
template <typename T>
__global__ void SelectWithPrefixMask(const int32_t* mask_prefix_sum,
const bool* mask, const T* input, T* out,
int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < size; idx += blockDim.x * gridDim.x) {
if (mask[idx]) {
int index = mask_prefix_sum[idx];
out[index] = input[idx];
}
}
}
template <typename T>
__global__ void SelectGradWithPrefixMask(const int32_t* mask_prefix_sum,
const bool* mask, const T* input,
T* out, int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < size; idx += blockDim.x * gridDim.x) {
if (mask[idx]) {
int index = mask_prefix_sum[idx];
out[idx] = input[index];
} else {
out[idx] = 0;
}
}
}
template <typename DeviceContext, typename T>
class MaskedSelectCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto input = ctx.Input<framework::Tensor>("X");
auto mask = ctx.Input<framework::Tensor>("Mask");
auto out = ctx.Output<framework::Tensor>("Y");
auto* mask_data = mask->data<bool>();
auto input_data = input->data<T>();
auto mask_size = mask->numel();
auto input_dim = input->dims();
auto mask_dim = mask->dims();
PADDLE_ENFORCE_EQ(
input_dim, mask_dim,
platform::errors::InvalidArgument(
"The dim size of input and mask in OP(masked_selected) "
"must be equal, but got input dim:(%ld), mask dim: "
"(%ld). Please check input "
"value.",
input_dim, mask_dim));
thrust::device_ptr<const bool> mask_dev_ptr =
thrust::device_pointer_cast(mask_data);
thrust::device_vector<T> mask_vec(mask_dev_ptr, mask_dev_ptr + mask_size);
auto out_size = thrust::count(mask_vec.begin(), mask_vec.end(), true);
framework::DDim out_dim{out_size};
out->Resize(out_dim);
auto out_data = out->mutable_data<T>(ctx.GetPlace());
Tensor mask_array;
Tensor mask_prefix_sum;
mask_array.Resize(mask_dim);
mask_prefix_sum.Resize(mask_dim);
int32_t* mask_array_data = mask_array.mutable_data<int32_t>(ctx.GetPlace());
int32_t* mask_prefix_sum_data =
mask_prefix_sum.mutable_data<int32_t>(ctx.GetPlace());
int threads = 512;
int grid = (mask_size + threads - 1) / threads;
auto stream = ctx.cuda_device_context().stream();
SetMaskArray<<<grid, threads, 0, stream>>>(mask_data, mask_array_data,
mask_size);
thrust::device_ptr<int32_t> mask_array_dev_ptr =
thrust::device_pointer_cast(mask_array_data);
thrust::device_vector<int32_t> mask_array_vec(
mask_array_dev_ptr, mask_array_dev_ptr + mask_size);
thrust::exclusive_scan(thrust::device, mask_array_vec.begin(),
mask_array_vec.end(), mask_prefix_sum_data);
SelectWithPrefixMask<T><<<grid, threads, 0, stream>>>(
mask_prefix_sum_data, mask_data, input_data, out_data, mask_size);
}
};
template <typename DeviceContext, typename T>
class MaskedSelectGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto input = ctx.Input<framework::Tensor>(framework::GradVarName("Y"));
auto mask = ctx.Input<framework::Tensor>("Mask");
auto out = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* mask_data = mask->data<bool>();
auto* input_data = input->data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
auto input_size = input->numel();
auto mask_size = mask->numel();
auto mask_dim = mask->dims();
auto out_size = mask_size;
Tensor mask_array;
Tensor mask_prefix_sum;
mask_array.Resize(mask_dim);
mask_prefix_sum.Resize(mask_dim);
int32_t* mask_array_data = mask_array.mutable_data<int32_t>(ctx.GetPlace());
int32_t* mask_prefix_sum_data =
mask_prefix_sum.mutable_data<int32_t>(ctx.GetPlace());
int threads = 512;
int grid = (mask_size + threads - 1) / threads;
auto stream = ctx.cuda_device_context().stream();
SetMaskArray<<<grid, threads, 0, stream>>>(mask_data, mask_array_data,
mask_size);
thrust::device_ptr<int32_t> mask_array_dev_ptr =
thrust::device_pointer_cast(mask_array_data);
thrust::device_vector<int32_t> mask_array_vec(
mask_array_dev_ptr, mask_array_dev_ptr + mask_size);
thrust::exclusive_scan(thrust::device, mask_array_vec.begin(),
mask_array_vec.end(), mask_prefix_sum_data);
SelectGradWithPrefixMask<T><<<grid, threads, 0, stream>>>(
mask_prefix_sum_data, mask_data, input_data, out_data, mask_size);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
masked_select,
ops::MaskedSelectCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::MaskedSelectCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::MaskedSelectCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::MaskedSelectCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
masked_select_grad,
ops::MaskedSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::MaskedSelectGradCUDAKernel<paddle::platform::CUDADeviceContext,
double>,
ops::MaskedSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::MaskedSelectGradCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DDim = framework::DDim;
template <typename DeviceContext, typename T>
class MaskedSelectKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto input = context.Input<framework::Tensor>("X");
auto mask = context.Input<framework::Tensor>("Mask");
auto out = context.Output<framework::Tensor>("Y");
auto* mask_data = mask->data<bool>();
auto input_data = input->data<T>();
auto mask_size = mask->numel();
auto input_dim = input->dims();
auto mask_dim = mask->dims();
PADDLE_ENFORCE_EQ(
input_dim, mask_dim,
platform::errors::InvalidArgument(
"The dim size of input and mask in OP(masked_selected) "
"must be equal, but got input dim:(%ld), mask dim: "
"(%ld). Please check input "
"value.",
input_dim, mask_dim));
int out_size = 0;
for (int i = 0; i < mask_size; i++) {
if (mask_data[i]) out_size++;
}
framework::DDim out_dim{out_size};
out->Resize(out_dim);
auto out_data = out->mutable_data<T>(context.GetPlace());
int index = 0;
for (int i = 0; i < mask_size; i++) {
if (mask_data[i]) {
out_data[index] = input_data[i];
index++;
}
}
}
};
template <typename DeviceContext, typename T>
class MaskedSelectGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto out = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto mask = context.Input<framework::Tensor>("Mask");
auto input = context.Input<framework::Tensor>(framework::GradVarName("Y"));
auto* mask_data = mask->data<bool>();
auto* input_data = input->data<T>();
auto* out_data = out->mutable_data<T>(context.GetPlace());
int mask_size = mask->numel();
int index = 0;
for (int i = 0; i < mask_size; i++) {
if (mask_data[i]) {
out_data[i] = input_data[index];
index++;
} else {
out_data[i] = 0;
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/masked_select_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
......@@ -11,7 +11,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/masked_select_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/masked_select_grad_kernel.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
namespace pten {
template <typename T, typename Context>
void MaskedSelectGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& mask,
DenseTensor* x_grad) {
auto* mask_data = mask.data<bool>();
auto* input_data = out_grad.data<T>();
auto* out_data = x_grad->mutable_data<T>(dev_ctx.GetPlace());
int mask_size = mask.numel();
int index = 0;
for (int i = 0; i < mask_size; i++) {
if (mask_data[i]) {
out_data[i] = input_data[index];
index++;
} else {
out_data[i] = 0;
}
}
}
} // namespace pten
PT_REGISTER_KERNEL(masked_select_grad,
CPU,
ALL_LAYOUT,
pten::MaskedSelectGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/masked_select_kernel.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
namespace pten {
template <typename T, typename Context>
void MaskedSelectKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& mask,
DenseTensor* out) {
auto* mask_data = mask.data<bool>();
auto input_data = x.data<T>();
auto mask_size = mask.numel();
auto input_dim = x.dims();
auto mask_dim = mask.dims();
PADDLE_ENFORCE_EQ(input_dim,
mask_dim,
pten::errors::InvalidArgument(
"The dim size of input and mask in OP(masked_selected) "
"must be equal, but got input dim:(%ld), mask dim: "
"(%ld). Please check input "
"value.",
input_dim,
mask_dim));
int out_size = 0;
for (int i = 0; i < mask_size; i++) {
if (mask_data[i]) out_size++;
}
framework::DDim out_dim{out_size};
out->Resize(out_dim);
auto out_data = out->mutable_data<T>(paddle::platform::CPUPlace());
int index = 0;
for (int i = 0; i < mask_size; i++) {
if (mask_data[i]) {
out_data[index] = input_data[i];
index++;
}
}
}
} // namespace pten
PT_REGISTER_KERNEL(masked_select,
CPU,
ALL_LAYOUT,
pten::MaskedSelectKernel,
float,
double,
int,
int64_t) {
kernel->InputAt(1).SetDataType(pten::DataType::BOOL);
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/reverse.h>
#include <thrust/scan.h>
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/masked_select_grad_kernel.h"
namespace pten {
__global__ void SetMaskArrayT(const bool* mask, int32_t* mask_array, int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < size; idx += blockDim.x * gridDim.x) {
if (mask[idx])
mask_array[idx] = 1;
else
mask_array[idx] = 0;
}
}
template <typename T>
__global__ void SelectGradWithPrefixMask(const int32_t* mask_prefix_sum,
const bool* mask,
const T* input,
T* out,
int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < size; idx += blockDim.x * gridDim.x) {
if (mask[idx]) {
int index = mask_prefix_sum[idx];
out[idx] = input[index];
} else {
out[idx] = 0;
}
}
}
template <typename T, typename Context>
void MaskedSelectGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& mask,
DenseTensor* x_grad) {
auto* mask_data = mask.data<bool>();
auto* input_data = out_grad.data<T>();
auto* out_data = x_grad->mutable_data<T>(dev_ctx.GetPlace());
auto input_size = out_grad.numel();
auto mask_size = mask.numel();
auto mask_dim = mask.dims();
auto out_size = mask_size;
DenseTensor mask_array;
DenseTensor mask_prefix_sum;
mask_array.Resize(mask_dim);
mask_prefix_sum.Resize(mask_dim);
int32_t* mask_array_data =
mask_array.mutable_data<int32_t>(dev_ctx.GetPlace());
int32_t* mask_prefix_sum_data =
mask_prefix_sum.mutable_data<int32_t>(dev_ctx.GetPlace());
int threads = 512;
int grid = (mask_size + threads - 1) / threads;
auto stream = dev_ctx.stream();
SetMaskArrayT<<<grid, threads, 0, stream>>>(
mask_data, mask_array_data, mask_size);
thrust::device_ptr<int32_t> mask_array_dev_ptr =
thrust::device_pointer_cast(mask_array_data);
thrust::device_vector<int32_t> mask_array_vec(mask_array_dev_ptr,
mask_array_dev_ptr + mask_size);
thrust::exclusive_scan(thrust::device,
mask_array_vec.begin(),
mask_array_vec.end(),
mask_prefix_sum_data);
SelectGradWithPrefixMask<T><<<grid, threads, 0, stream>>>(
mask_prefix_sum_data, mask_data, input_data, out_data, mask_size);
}
} // namespace pten
PT_REGISTER_KERNEL(masked_select_grad,
GPU,
ALL_LAYOUT,
pten::MaskedSelectGradKernel,
float,
double,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/reverse.h>
#include <thrust/scan.h>
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/masked_select_kernel.h"
namespace pten {
__global__ void SetMaskArray(const bool* mask, int32_t* mask_array, int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < size; idx += blockDim.x * gridDim.x) {
if (mask[idx])
mask_array[idx] = 1;
else
mask_array[idx] = 0;
}
}
template <typename T>
__global__ void SelectWithPrefixMask(const int32_t* mask_prefix_sum,
const bool* mask,
const T* input,
T* out,
int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < size; idx += blockDim.x * gridDim.x) {
if (mask[idx]) {
int index = mask_prefix_sum[idx];
out[index] = input[idx];
}
}
}
template <typename T, typename Context>
void MaskedSelectKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& mask,
DenseTensor* out) {
auto* mask_data = mask.data<bool>();
auto input_data = x.data<T>();
auto mask_size = mask.numel();
auto input_dim = x.dims();
auto mask_dim = mask.dims();
PADDLE_ENFORCE_EQ(input_dim,
mask_dim,
pten::errors::InvalidArgument(
"The dim size of input and mask in OP(masked_selected) "
"must be equal, but got input dim:(%ld), mask dim: "
"(%ld). Please check input "
"value.",
input_dim,
mask_dim));
thrust::device_ptr<const bool> mask_dev_ptr =
thrust::device_pointer_cast(mask_data);
thrust::device_vector<T> mask_vec(mask_dev_ptr, mask_dev_ptr + mask_size);
auto out_size = thrust::count(mask_vec.begin(), mask_vec.end(), true);
framework::DDim out_dim{out_size};
out->Resize(out_dim);
auto out_data = out->mutable_data<T>(dev_ctx.GetPlace());
DenseTensor mask_array;
DenseTensor mask_prefix_sum;
mask_array.Resize(mask_dim);
mask_prefix_sum.Resize(mask_dim);
int32_t* mask_array_data =
mask_array.mutable_data<int32_t>(dev_ctx.GetPlace());
int32_t* mask_prefix_sum_data =
mask_prefix_sum.mutable_data<int32_t>(dev_ctx.GetPlace());
int threads = 512;
int grid = (mask_size + threads - 1) / threads;
auto stream = dev_ctx.stream();
SetMaskArray<<<grid, threads, 0, stream>>>(
mask_data, mask_array_data, mask_size);
thrust::device_ptr<int32_t> mask_array_dev_ptr =
thrust::device_pointer_cast(mask_array_data);
thrust::device_vector<int32_t> mask_array_vec(mask_array_dev_ptr,
mask_array_dev_ptr + mask_size);
thrust::exclusive_scan(thrust::device,
mask_array_vec.begin(),
mask_array_vec.end(),
mask_prefix_sum_data);
SelectWithPrefixMask<T><<<grid, threads, 0, stream>>>(
mask_prefix_sum_data, mask_data, input_data, out_data, mask_size);
}
} // namespace pten
PT_REGISTER_KERNEL(masked_select,
GPU,
ALL_LAYOUT,
pten::MaskedSelectKernel,
float,
double,
int,
int64_t) {
kernel->InputAt(1).SetDataType(pten::DataType::BOOL);
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/core/dense_tensor.h"
namespace pten {
template <typename T, typename Context>
void MaskedSelectGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x,
const DenseTensor& mask,
DenseTensor* x_grad);
} // namspace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/core/dense_tensor.h"
namespace pten {
template <typename T, typename Context>
void MaskedSelectKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& mask,
DenseTensor* out);
} // namspace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
KernelSignature MaskedSelectOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("masked_select", {"X", "Mask"}, {}, {"Y"});
}
KernelSignature MaskedSelectGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("masked_select_grad",
{GradVarName("Y"), "X", "Mask"},
{},
{GradVarName("X")});
}
} // namespace pten
PT_REGISTER_ARG_MAPPING_FN(masked_select, pten::MaskedSelectOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(masked_select_grad,
pten::MaskedSelectGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册