未验证 提交 ec422ea5 编写于 作者: R ronnywang 提交者: GitHub

[NPU] add masked_select_op_npu (#35649)

上级 5fa9cf7c
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/masked_select_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
namespace operators {
template <typename T>
class MaskedSelectedNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto input = ctx.Input<framework::Tensor>("X");
auto mask = ctx.Input<framework::Tensor>("Mask");
auto out = ctx.Output<framework::Tensor>("Y");
auto input_dim = input->dims();
auto mask_dim = mask->dims();
PADDLE_ENFORCE_EQ(
input_dim, mask_dim,
platform::errors::InvalidArgument(
"The dim size of input and mask in OP(masked_selected) "
"must be equal, but got input dim:(%ld), mask dim: "
"(%ld). Please check input "
"value.",
input_dim, mask_dim));
auto& dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
auto stream = dev_ctx.stream();
Tensor mask_int32, out_size;
std::vector<int32_t> out_size_vec;
mask_int32.mutable_data<int32_t>(mask->dims(), ctx.GetPlace());
out_size.mutable_data<int32_t>({1}, ctx.GetPlace());
{
const auto& cast_runner =
NpuOpRunner("Cast", {*mask}, {mask_int32},
{{"dst_type", static_cast<int32_t>(ConvertToNpuDtype(
framework::proto::VarType::INT32))}});
cast_runner.Run(stream);
mask_int32.Resize({mask_int32.numel()});
NpuOpRunner sum_runner;
sum_runner.SetType("ReduceSum");
sum_runner.AddInput(mask_int32);
sum_runner.AddInput(std::vector<int32_t>({0}));
sum_runner.AddOutput(out_size);
sum_runner.AddAttr("keep_dims", false);
sum_runner.Run(stream);
TensorToVector(out_size, dev_ctx, &out_size_vec);
}
out->Resize({out_size_vec[0]});
out->mutable_data<T>(ctx.GetPlace());
Tensor topkv2_out, indices;
topkv2_out.mutable_data<int32_t>({out_size_vec[0]}, ctx.GetPlace());
indices.mutable_data<int32_t>({out_size_vec[0]}, ctx.GetPlace());
{
NpuOpRunner topkv2_runner;
topkv2_runner.SetType("TopKV2")
.AddInput(mask_int32)
.AddInput(out_size)
.AddOutput(topkv2_out)
.AddOutput(indices)
.AddAttr("sorted", false)
.AddAttr("dim", 0)
.AddAttr("largest", true)
.Run(stream);
// TopKV2 may be unstable
NpuOpRunner topkv2_runner2;
topkv2_runner2.SetType("TopKV2")
.AddInput(indices)
.AddInput(out_size)
.AddOutput(topkv2_out)
.AddOutput(indices)
.AddAttr("sorted", true)
.AddAttr("dim", 0)
.AddAttr("largest", false)
.Run(stream);
Tensor input_tmp;
input_tmp.ShareDataWith(*input);
input_tmp.Resize({input->numel()});
const auto& gather_runner = NpuOpRunner(
"GatherV2D", {input_tmp, topkv2_out}, {*out}, {{"axis", 0}});
gather_runner.Run(stream);
}
}
};
template <typename T>
class MaskedSelectedGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto mask = ctx.Input<framework::Tensor>("Mask");
auto y_grad = ctx.Input<framework::Tensor>(framework::GradVarName("Y"));
auto x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
x_grad->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
auto stream = dev_ctx.stream();
Tensor mask_int32, out_size;
std::vector<int32_t> out_size_vec;
mask_int32.mutable_data<int32_t>(mask->dims(), ctx.GetPlace());
out_size.mutable_data<int32_t>({1}, ctx.GetPlace());
{
const auto& cast_runner =
NpuOpRunner("Cast", {*mask}, {mask_int32},
{{"dst_type", static_cast<int32_t>(ConvertToNpuDtype(
framework::proto::VarType::INT32))}});
cast_runner.Run(stream);
mask_int32.Resize({mask_int32.numel()});
NpuOpRunner sum_runner;
sum_runner.SetType("ReduceSum");
sum_runner.AddInput(mask_int32);
sum_runner.AddInput(std::vector<int32_t>({0}));
sum_runner.AddOutput(out_size);
sum_runner.AddAttr("keep_dims", false);
sum_runner.Run(stream);
TensorToVector(out_size, dev_ctx, &out_size_vec);
}
Tensor topkv2_out, indices;
topkv2_out.mutable_data<int32_t>({out_size_vec[0]}, ctx.GetPlace());
indices.mutable_data<int32_t>({out_size_vec[0]}, ctx.GetPlace());
{
NpuOpRunner topkv2_runner;
topkv2_runner.SetType("TopKV2")
.AddInput(mask_int32)
.AddInput(out_size)
.AddOutput(topkv2_out)
.AddOutput(indices)
.AddAttr("sorted", false)
.AddAttr("dim", 0)
.AddAttr("largest", true)
.Run(stream);
NpuOpRunner topkv2_runner2;
topkv2_runner2.SetType("TopKV2")
.AddInput(indices)
.AddInput(out_size)
.AddOutput(topkv2_out)
.AddOutput(indices)
.AddAttr("sorted", true)
.AddAttr("dim", 0)
.AddAttr("largest", false)
.Run(stream);
topkv2_out.Resize({out_size_vec[0], 1});
x_grad->Resize({x_grad->numel()});
NpuOpRunner scatter_runner;
scatter_runner.SetType("ScatterNd");
scatter_runner.AddInput(topkv2_out);
scatter_runner.AddInput(*y_grad);
scatter_runner.AddInput(
std::vector<int32_t>({static_cast<int32_t>(x_grad->numel())}));
scatter_runner.AddOutput(*x_grad);
scatter_runner.Run(stream);
x_grad->Resize(mask->dims());
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(masked_select,
ops::MaskedSelectedNPUKernel<plat::float16>,
ops::MaskedSelectedNPUKernel<float>,
ops::MaskedSelectedNPUKernel<int>,
ops::MaskedSelectedNPUKernel<int64_t>);
REGISTER_OP_NPU_KERNEL(masked_select_grad,
ops::MaskedSelectedGradNPUKernel<plat::float16>,
ops::MaskedSelectedGradNPUKernel<float>,
ops::MaskedSelectedGradNPUKernel<int>,
ops::MaskedSelectedGradNPUKernel<int64_t>);
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.fluid as fluid
paddle.enable_static()
def np_masked_select(x, mask):
result = np.empty(shape=(0), dtype=x.dtype)
for ele, ma in zip(np.nditer(x), np.nditer(mask)):
if ma:
result = np.append(result, ele)
return result.flatten()
class TestMaskedSelectOp(OpTest):
def set_npu(self):
self.__class__.use_npu = True
def setUp(self):
self.set_npu()
self.init()
self.init_dtype()
self.place = paddle.NPUPlace(0)
self.op_type = "masked_select"
x = np.random.random(self.shape).astype(self.dtype)
mask = np.array(np.random.randint(2, size=self.shape, dtype=bool))
out = np_masked_select(x, mask)
self.inputs = {'X': x, 'Mask': mask}
self.outputs = {'Y': out}
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Y')
def init(self):
self.shape = (50, 3)
def init_dtype(self):
self.dtype = np.float32
class TestMaskedSelectOp1(TestMaskedSelectOp):
def init(self):
self.shape = (6, 8, 9, 18)
class TestMaskedSelectOp2(TestMaskedSelectOp):
def init(self):
self.shape = (168, )
class TestMaskedSelectOpFp16(TestMaskedSelectOp):
def init_dtype(self):
self.dtype = np.float16
def test_check_grad(self):
x_grad = self.inputs['Mask'].astype(self.dtype)
x_grad = x_grad * (1 / x_grad.sum())
self.check_grad_with_place(
self.place, ['X'], 'Y', user_defined_grads=[x_grad])
@skip_check_grad_ci(reason="get_numeric_gradient not support int32")
class TestMaskedSelectOpInt32(TestMaskedSelectOp):
def init_dtype(self):
self.dtype = np.int32
def test_check_grad(self):
pass
@skip_check_grad_ci(reason="get_numeric_gradient not support int64")
class TestMaskedSelectOpInt64(TestMaskedSelectOp):
def init_dtype(self):
self.dtype = np.int64
def test_check_grad(self):
pass
class TestMaskedSelectAPI(unittest.TestCase):
def test_imperative_mode(self):
paddle.disable_static(paddle.NPUPlace(0))
shape = (88, 6, 8)
np_x = np.random.random(shape).astype('float32')
np_mask = np.array(np.random.randint(2, size=shape, dtype=bool))
x = paddle.to_tensor(np_x)
mask = paddle.to_tensor(np_mask)
out = paddle.masked_select(x, mask)
np_out = np_masked_select(np_x, np_mask)
self.assertEqual(np.allclose(out.numpy(), np_out), True)
paddle.enable_static()
def test_static_mode(self):
shape = [8, 9, 6]
x = paddle.fluid.data(shape=shape, dtype='float32', name='x')
mask = paddle.fluid.data(shape=shape, dtype='bool', name='mask')
np_x = np.random.random(shape).astype('float32')
np_mask = np.array(np.random.randint(2, size=shape, dtype=bool))
out = paddle.masked_select(x, mask)
np_out = np_masked_select(np_x, np_mask)
exe = paddle.static.Executor(place=paddle.NPUPlace(0))
res = exe.run(paddle.static.default_main_program(),
feed={"x": np_x,
"mask": np_mask},
fetch_list=[out])
self.assertEqual(np.allclose(res, np_out), True)
class TestMaskedSelectError(unittest.TestCase):
def test_error(self):
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
shape = [8, 9, 6]
x = paddle.fluid.data(shape=shape, dtype='float32', name='x')
mask = paddle.fluid.data(shape=shape, dtype='bool', name='mask')
mask_float = paddle.fluid.data(
shape=shape, dtype='float32', name='mask_float')
np_x = np.random.random(shape).astype('float32')
np_mask = np.array(np.random.randint(2, size=shape, dtype=bool))
def test_x_type():
paddle.masked_select(np_x, mask)
self.assertRaises(TypeError, test_x_type)
def test_mask_type():
paddle.masked_select(x, np_mask)
self.assertRaises(TypeError, test_mask_type)
def test_mask_dtype():
paddle.masked_select(x, mask_float)
self.assertRaises(TypeError, test_mask_dtype)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册