未验证 提交 e167e879 编写于 作者: W wangchaochaohu 提交者: GitHub

【API2.0】add masked_select Op for API2.0 (#26374)

上级 c09de13e
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/masked_select_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class MaskedSelectOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "Input", "MaskedSelect");
OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "MaskedSelect");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Out", "MaskedSelect");
framework::DDim output_dims(ctx->GetInputDim("X"));
ctx->SetOutputDim("Y", output_dims);
ctx->ShareLoD("X", /*->*/ "Y");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class MaskedSelectOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor.");
AddInput("Mask",
"The mask of Input Tensor to be selected which is a bool Tensor.");
AddOutput(
"Y",
"The returned tensor, the data type "
"is same as input, will be on the same device with the input Tensor.");
AddComment(R"DOC(
Size Operator.
Return a new 0-D tensor which indexes the indexed tensor according
the mask which is a tensor withe data type bool.
)DOC");
}
};
class MaskedSelectOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Input",
"Input", "MaskedSelect");
OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "MaskedSelect");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Y")),
ctx.device_context());
}
};
template <typename T>
class MaskedSelectGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("masked_select_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Mask", this->Input("Mask"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(MaskedSelectedGradNoNeedBufferVarsInferer,
"X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(masked_select, ops::MaskedSelectOp, ops::MaskedSelectOpMaker,
ops::MaskedSelectGradOpMaker<paddle::framework::OpDesc>,
ops::MaskedSelectGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(masked_select_grad, ops::MaskedSelectOpGrad,
ops::MaskedSelectedGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
masked_select,
ops::MaskedSelectKernel<paddle::platform::CPUDeviceContext, float>,
ops::MaskedSelectKernel<paddle::platform::CPUDeviceContext, double>,
ops::MaskedSelectKernel<paddle::platform::CPUDeviceContext, int>,
ops::MaskedSelectKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
masked_select_grad,
ops::MaskedSelectGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MaskedSelectGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::MaskedSelectGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::MaskedSelectGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/reverse.h>
#include <thrust/scan.h>
#include "paddle/fluid/operators/masked_select_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DDim = framework::DDim;
__global__ void SetMaskArray(const bool* mask, int32_t* mask_array, int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < size; idx += blockDim.x * gridDim.x) {
if (mask[idx])
mask_array[idx] = 1;
else
mask_array[idx] = 0;
}
}
template <typename T>
__global__ void SelectWithPrefixMask(const int32_t* mask_prefix_sum,
const bool* mask, const T* input, T* out,
int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < size; idx += blockDim.x * gridDim.x) {
if (mask[idx]) {
int index = mask_prefix_sum[idx];
out[index] = input[idx];
}
}
}
template <typename T>
__global__ void SelectGradWithPrefixMask(const int32_t* mask_prefix_sum,
const bool* mask, const T* input,
T* out, int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < size; idx += blockDim.x * gridDim.x) {
if (mask[idx]) {
int index = mask_prefix_sum[idx];
out[idx] = input[index];
} else {
out[idx] = 0;
}
}
}
template <typename DeviceContext, typename T>
class MaskedSelectCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto input = ctx.Input<framework::Tensor>("X");
auto mask = ctx.Input<framework::Tensor>("Mask");
auto out = ctx.Output<framework::Tensor>("Y");
auto* mask_data = mask->data<bool>();
auto input_data = input->data<T>();
auto mask_size = mask->numel();
auto input_dim = input->dims();
auto mask_dim = mask->dims();
PADDLE_ENFORCE_EQ(
input_dim, mask_dim,
platform::errors::InvalidArgument(
"The dim size of input and mask in OP(masked_selected) "
"must be equal, but got input dim:(%ld), mask dim: "
"(%ld). Please check input "
"value.",
input_dim, mask_dim));
thrust::device_ptr<const bool> mask_dev_ptr =
thrust::device_pointer_cast(mask_data);
thrust::device_vector<T> mask_vec(mask_dev_ptr, mask_dev_ptr + mask_size);
auto out_size = thrust::count(mask_vec.begin(), mask_vec.end(), true);
framework::DDim out_dim{out_size};
out->Resize(out_dim);
auto out_data = out->mutable_data<T>(ctx.GetPlace());
Tensor mask_array;
Tensor mask_prefix_sum;
mask_array.Resize(mask_dim);
mask_prefix_sum.Resize(mask_dim);
int32_t* mask_array_data = mask_array.mutable_data<int32_t>(ctx.GetPlace());
int32_t* mask_prefix_sum_data =
mask_prefix_sum.mutable_data<int32_t>(ctx.GetPlace());
int threads = 512;
int grid = (mask_size + threads - 1) / threads;
auto stream = ctx.cuda_device_context().stream();
SetMaskArray<<<grid, threads, 0, stream>>>(mask_data, mask_array_data,
mask_size);
thrust::device_ptr<int32_t> mask_array_dev_ptr =
thrust::device_pointer_cast(mask_array_data);
thrust::device_vector<int32_t> mask_array_vec(
mask_array_dev_ptr, mask_array_dev_ptr + mask_size);
thrust::exclusive_scan(thrust::device, mask_array_vec.begin(),
mask_array_vec.end(), mask_prefix_sum_data);
SelectWithPrefixMask<T><<<grid, threads, 0, stream>>>(
mask_prefix_sum_data, mask_data, input_data, out_data, mask_size);
}
};
template <typename DeviceContext, typename T>
class MaskedSelectGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto input = ctx.Input<framework::Tensor>(framework::GradVarName("Y"));
auto mask = ctx.Input<framework::Tensor>("Mask");
auto out = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* mask_data = mask->data<bool>();
auto* input_data = input->data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
auto input_size = input->numel();
auto mask_size = mask->numel();
auto mask_dim = mask->dims();
auto out_size = mask_size;
Tensor mask_array;
Tensor mask_prefix_sum;
mask_array.Resize(mask_dim);
mask_prefix_sum.Resize(mask_dim);
int32_t* mask_array_data = mask_array.mutable_data<int32_t>(ctx.GetPlace());
int32_t* mask_prefix_sum_data =
mask_prefix_sum.mutable_data<int32_t>(ctx.GetPlace());
int threads = 512;
int grid = (mask_size + threads - 1) / threads;
auto stream = ctx.cuda_device_context().stream();
SetMaskArray<<<grid, threads, 0, stream>>>(mask_data, mask_array_data,
mask_size);
thrust::device_ptr<int32_t> mask_array_dev_ptr =
thrust::device_pointer_cast(mask_array_data);
thrust::device_vector<int32_t> mask_array_vec(
mask_array_dev_ptr, mask_array_dev_ptr + mask_size);
thrust::exclusive_scan(thrust::device, mask_array_vec.begin(),
mask_array_vec.end(), mask_prefix_sum_data);
SelectGradWithPrefixMask<T><<<grid, threads, 0, stream>>>(
mask_prefix_sum_data, mask_data, input_data, out_data, mask_size);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
masked_select,
ops::MaskedSelectCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::MaskedSelectCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::MaskedSelectCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::MaskedSelectCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
masked_select_grad,
ops::MaskedSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::MaskedSelectGradCUDAKernel<paddle::platform::CUDADeviceContext,
double>,
ops::MaskedSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::MaskedSelectGradCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DDim = framework::DDim;
template <typename DeviceContext, typename T>
class MaskedSelectKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto input = context.Input<framework::Tensor>("X");
auto mask = context.Input<framework::Tensor>("Mask");
auto out = context.Output<framework::Tensor>("Y");
auto* mask_data = mask->data<bool>();
auto input_data = input->data<T>();
auto mask_size = mask->numel();
auto input_dim = input->dims();
auto mask_dim = mask->dims();
PADDLE_ENFORCE_EQ(
input_dim, mask_dim,
platform::errors::InvalidArgument(
"The dim size of input and mask in OP(masked_selected) "
"must be equal, but got input dim:(%ld), mask dim: "
"(%ld). Please check input "
"value.",
input_dim, mask_dim));
int out_size = 0;
for (int i = 0; i < mask_size; i++) {
if (mask_data[i]) out_size++;
}
framework::DDim out_dim{out_size};
out->Resize(out_dim);
auto out_data = out->mutable_data<T>(context.GetPlace());
int index = 0;
for (int i = 0; i < mask_size; i++) {
if (mask_data[i]) {
out_data[index] = input_data[i];
index++;
}
}
}
};
template <typename DeviceContext, typename T>
class MaskedSelectGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto out = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto mask = context.Input<framework::Tensor>("Mask");
auto input = context.Input<framework::Tensor>(framework::GradVarName("Y"));
auto* mask_data = mask->data<bool>();
auto* input_data = input->data<T>();
auto* out_data = out->mutable_data<T>(context.GetPlace());
int mask_size = mask->numel();
int index = 0;
for (int i = 0; i < mask_size; i++) {
if (mask_data[i]) {
out_data[i] = input_data[index];
index++;
} else {
out_data[i] = 0;
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -201,7 +201,7 @@ from .tensor.search import argmin #DEFINE_ALIAS
from .tensor.search import argsort #DEFINE_ALIAS
from .tensor.search import has_inf #DEFINE_ALIAS
from .tensor.search import has_nan #DEFINE_ALIAS
# from .tensor.search import masked_select #DEFINE_ALIAS
from .tensor.search import masked_select #DEFINE_ALIAS
from .tensor.search import topk #DEFINE_ALIAS
from .tensor.search import where #DEFINE_ALIAS
from .tensor.search import index_select #DEFINE_ALIAS
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
import paddle
def np_masked_select(x, mask):
result = np.empty(shape=(0), dtype=x.dtype)
for ele, ma in zip(np.nditer(x), np.nditer(mask)):
if ma:
result = np.append(result, ele)
return result.flatten()
class TestMaskedSelectOp(OpTest):
def setUp(self):
self.init()
self.op_type = "masked_select"
x = np.random.random(self.shape).astype("float64")
mask = np.array(np.random.randint(2, size=self.shape, dtype=bool))
out = np_masked_select(x, mask)
self.inputs = {'X': x, 'Mask': mask}
self.outputs = {'Y': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y')
def init(self):
self.shape = (50, 3)
class TestMaskedSelectOp1(TestMaskedSelectOp):
def init(self):
self.shape = (6, 8, 9, 18)
class TestMaskedSelectOp2(TestMaskedSelectOp):
def init(self):
self.shape = (168, )
class TestMaskedSelectAPI(unittest.TestCase):
def test_imperative_mode(self):
paddle.disable_static()
shape = (88, 6, 8)
np_x = np.random.random(shape).astype('float32')
np_mask = np.array(np.random.randint(2, size=shape, dtype=bool))
x = paddle.to_tensor(np_x)
mask = paddle.to_tensor(np_mask)
out = paddle.masked_select(x, mask)
np_out = np_masked_select(np_x, np_mask)
self.assertEqual(np.allclose(out.numpy(), np_out), True)
paddle.enable_static()
def test_static_mode(self):
shape = [8, 9, 6]
x = paddle.data(shape=shape, dtype='float32', name='x')
mask = paddle.data(shape=shape, dtype='bool', name='mask')
np_x = np.random.random(shape).astype('float32')
np_mask = np.array(np.random.randint(2, size=shape, dtype=bool))
out = paddle.masked_select(x, mask)
np_out = np_masked_select(np_x, np_mask)
exe = paddle.static.Executor(place=paddle.CPUPlace())
res = exe.run(paddle.static.default_main_program(),
feed={"x": np_x,
"mask": np_mask},
fetch_list=[out])
self.assertEqual(np.allclose(res, np_out), True)
class TestMaskedSelectError(unittest.TestCase):
def test_error(self):
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
shape = [8, 9, 6]
x = paddle.data(shape=shape, dtype='float32', name='x')
mask = paddle.data(shape=shape, dtype='bool', name='mask')
mask_float = paddle.data(
shape=shape, dtype='float32', name='mask_float')
np_x = np.random.random(shape).astype('float32')
np_mask = np.array(np.random.randint(2, size=shape, dtype=bool))
def test_x_type():
paddle.masked_select(np_x, mask)
self.assertRaises(TypeError, test_x_type)
def test_mask_type():
paddle.masked_select(x, np_mask)
self.assertRaises(TypeError, test_mask_type)
def test_mask_dtype():
paddle.masked_select(x, mask_float)
self.assertRaises(TypeError, test_mask_dtype)
if __name__ == '__main__':
unittest.main()
......@@ -177,6 +177,7 @@ from .search import index_select #DEFINE_ALIAS
from .search import nonzero #DEFINE_ALIAS
from .search import sort #DEFINE_ALIAS
from .search import index_sample #DEFINE_ALIAS
from .search import masked_select #DEFINE_ALIAS
from .stat import mean #DEFINE_ALIAS
from .stat import reduce_mean #DEFINE_ALIAS
from .stat import std #DEFINE_ALIAS
......
......@@ -29,13 +29,13 @@ __all__ = [
'argsort',
'has_inf',
'has_nan',
# 'masked_select',
'masked_select',
'topk',
'where',
'index_select',
'nonzero',
'sort',
'index_sample'
'index_sample',
]
from paddle.common_ops_import import *
......@@ -629,3 +629,57 @@ def index_sample(x, index):
'Index': index},
outputs={'Out': out})
return out
def masked_select(x, mask, name=None):
"""
This OP Returns a new 1-D tensor which indexes the input tensor according to the ``mask``
which is a tensor with data type of bool.
Args:
x (Tensor): The input Tensor, the data type can be int32, int64, float32, float64.
mask (Tensor): The Tensor containing the binary mask to index with, it's data type is bool.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns: A 1-D Tensor which is the same data type as ``x``.
Raises:
TypeError: ``x`` must be a Tensor and the data type of ``x`` must be one of float32, float64, int32 and int64.
TypeError: ``mask`` must be a Tensor and the data type of ``mask`` must be bool.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
data = np.array([[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0]]).astype('float32')
mask_data = np.array([[True, False, False, False],
[True, True, False, False],
[True, False, False, False]]).astype('bool')
x = paddle.to_tensor(data)
mask = paddle.to_tensor(mask_data)
out = paddle.masked_select(x, mask)
#[1.0 5.0 6.0 9.0]
"""
if in_dygraph_mode():
return core.ops.masked_select(x, mask)
helper = LayerHelper("masked_select", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
'paddle.tensor.search.mask_select')
check_variable_and_dtype(mask, 'mask', ['bool'],
'paddle.tensor.search.masked_select')
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='masked_select', inputs={'X': x,
'Mask': mask}, outputs={'Y': out})
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册