未验证 提交 d29a1214 编写于 作者: 光明和真理's avatar 光明和真理 提交者: GitHub

[MLU] add mlu kernel for masked_select (#43816)

上级 59d50468
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
template <typename T>
class MaskedSelectedMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto input = ctx.Input<framework::Tensor>("X");
auto mask = ctx.Input<framework::Tensor>("Mask");
auto out = ctx.Output<framework::Tensor>("Y");
auto input_dim = input->dims();
auto mask_dim = mask->dims();
PADDLE_ENFORCE_EQ(
input_dim,
mask_dim,
platform::errors::InvalidArgument(
"The dim size of input and mask in OP(masked_selected) "
"must be equal, but got input dim:(%ld), mask dim: "
"(%ld). Please check input "
"value.",
input_dim,
mask_dim));
Tensor number(framework::TransToPhiDataType(VT::INT32));
void* number_ptr = number.mutable_data<int32_t>({1}, ctx.GetPlace());
out->Resize(mask->dims());
out->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc input_desc(*input);
MLUCnnlTensorDesc mask_desc(*mask);
MLUCnnlTensorDesc out_desc(*out);
MLUCnnl::Mask(ctx,
CNNL_MASKED_SELECT,
input_desc.get(),
GetBasePtr(input),
mask_desc.get(),
GetBasePtr(mask),
nullptr,
nullptr,
out_desc.get(),
GetBasePtr(out),
static_cast<uint32_t*>(number_ptr));
}
};
template <typename T>
class MaskedSelectedGradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto mask = ctx.Input<framework::Tensor>("Mask");
auto y_grad = ctx.Input<framework::Tensor>(framework::GradVarName("Y"));
auto x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto& dev_ctx =
ctx.template device_context<paddle::platform::MLUDeviceContext>();
Tensor mask_int32, out_size;
std::vector<int32_t> out_size_vec;
mask_int32.mutable_data<int32_t>(mask->dims(), ctx.GetPlace());
out_size.mutable_data<int32_t>({1}, ctx.GetPlace());
MLUCnnlTensorDesc mask_desc(*mask);
MLUCnnlTensorDesc mask_int32_desc(mask_int32);
MLUCnnlTensorDesc out_size_desc(out_size);
auto cast_type = GetCastDataType(mask->dtype(), DataType::INT32);
MLUCnnl::Cast(ctx,
cast_type,
mask_desc.get(),
GetBasePtr(mask),
mask_int32_desc.get(),
GetBasePtr(&mask_int32));
auto mask_int32_dim = phi::vectorize(mask_int32.dims());
std::vector<int32_t> reduce_dims;
for (size_t i = 0; i < mask_int32_dim.size(); i++) {
reduce_dims.push_back(static_cast<int>(i));
}
std::string reduce_name = "reduce_sum";
cnnlReduceOp_t reduce_op = GetMLUCnnlReduceOp(reduce_name);
MLUCnnlReduceDesc reduce_desc(reduce_dims,
reduce_op,
ToCnnlDataType<int32_t>(),
CNNL_NOT_PROPAGATE_NAN,
CNNL_REDUCE_NO_INDICES,
CNNL_32BIT_INDICES);
MLUCnnl::Reduce(ctx,
true,
reduce_desc.get(),
nullptr,
mask_int32_desc.get(),
GetBasePtr(&mask_int32),
0,
nullptr,
nullptr,
out_size_desc.get(),
GetBasePtr(&out_size));
paddle::framework::TensorToVector(out_size, dev_ctx, &out_size_vec);
dev_ctx.Wait();
Tensor mask_int32_tmp;
mask_int32_tmp.ShareDataWith(mask_int32);
mask_int32_tmp.Resize({mask_int32.numel()});
Tensor topk_v2_out(framework::TransToPhiDataType(VT::INT32)),
indices_int32(framework::TransToPhiDataType(VT::INT32));
topk_v2_out.mutable_data<int32_t>({mask_int32.numel()}, ctx.GetPlace());
indices_int32.mutable_data<int32_t>({mask_int32.numel()}, ctx.GetPlace());
MLUCnnlTensorDesc topk_v2_out_desc(topk_v2_out);
MLUCnnlTensorDesc indices_int32_desc(indices_int32);
MLUCnnlTensorDesc mask_int32_tmp_desc(mask_int32_tmp);
const int dim = 0;
MLUCnnl::TopK(ctx,
mask_int32.numel(),
dim,
true,
false,
mask_int32_tmp_desc.get(),
GetBasePtr(&mask_int32_tmp),
topk_v2_out_desc.get(),
GetBasePtr(&topk_v2_out),
indices_int32_desc.get(),
GetBasePtr(&indices_int32));
auto stream = ctx.template device_context<MLUDeviceContext>().stream();
Tensor indices_int32_out;
indices_int32_out.mutable_data<int32_t>({out_size_vec[0]}, ctx.GetPlace());
memory::Copy(ctx.GetPlace(),
GetBasePtr(&indices_int32_out),
ctx.GetPlace(),
GetBasePtr(&indices_int32),
out_size_vec[0] * sizeof(int32_t),
stream);
Tensor y_grad_tmp_out;
y_grad_tmp_out.mutable_data<T>({out_size_vec[0]}, ctx.GetPlace());
MLUCnnlTensorDesc y_grad_tmp_out_desc(y_grad_tmp_out);
memory::Copy(ctx.GetPlace(),
GetBasePtr(&y_grad_tmp_out),
ctx.GetPlace(),
GetBasePtr(y_grad),
out_size_vec[0] * sizeof(T),
stream);
Tensor indices_int32_tmp;
indices_int32_tmp.ShareDataWith(indices_int32_out);
indices_int32_tmp.Resize({out_size_vec[0], 1});
MLUCnnlTensorDesc indices_int32_tmp_desc(indices_int32_tmp);
const cnnlScatterNdMode_t mode = CNNL_SCATTERND_UPDATE;
x_grad->Resize({x_grad->numel()});
x_grad->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc x_grad_desc(*x_grad);
MLUCnnl::ScatterNd(ctx,
mode,
indices_int32_tmp_desc.get(),
GetBasePtr(&indices_int32_tmp),
y_grad_tmp_out_desc.get(),
GetBasePtr(&y_grad_tmp_out),
nullptr,
nullptr,
x_grad_desc.get(),
GetBasePtr(x_grad));
x_grad->Resize(mask->dims());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(masked_select,
ops::MaskedSelectedMLUKernel<float>,
ops::MaskedSelectedMLUKernel<int>,
ops::MaskedSelectedMLUKernel<plat::float16>);
REGISTER_OP_MLU_KERNEL(masked_select_grad,
ops::MaskedSelectedGradMLUKernel<float>,
ops::MaskedSelectedGradMLUKernel<int>,
ops::MaskedSelectedGradMLUKernel<plat::float16>);
......@@ -2597,6 +2597,19 @@ MLURNNDesc::~MLURNNDesc() {
cnnlSign(handle, input_desc, input, output_desc, output));
}
/* static */ void MLUCnnl::IndexSelect(const ExecutionContext& ctx,
const int dim,
cnnlTensorDescriptor_t input_desc,
const void* input,
const cnnlTensorDescriptor_t index_desc,
const void* index,
const cnnlTensorDescriptor_t output_desc,
void* output) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(cnnlIndexSelect(
handle, dim, input_desc, input, index_desc, index, output_desc, output));
}
/* static */ void MLUCnnl::IsFinite(const ExecutionContext& ctx,
const cnnlTensorDescriptor_t input_desc,
const void* input,
......
......@@ -1391,6 +1391,15 @@ class MLUCnnl {
const cnnlTensorDescriptor_t output_desc,
void* output);
static void IndexSelect(const ExecutionContext& ctx,
const int dim,
cnnlTensorDescriptor_t input_desc,
const void* input,
const cnnlTensorDescriptor_t index_desc,
const void* index,
const cnnlTensorDescriptor_t output_desc,
void* output);
static void IsFinite(const ExecutionContext& ctx,
const cnnlTensorDescriptor_t input_desc,
const void* input,
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import sys
sys.path.append("..")
import numpy as np
from op_test import OpTest, skip_check_grad_ci
import paddle.fluid as fluid
import paddle
paddle.enable_static()
def np_masked_select(shape, x, mask):
result = np.empty(shape=(0), dtype=x.dtype)
sum = 0
for index, (ele, ma) in enumerate(zip(np.nditer(x), np.nditer(mask))):
if ma:
sum = sum + 1
result = np.append(result, ele)
for index, (ele, ma) in enumerate(zip(np.nditer(x), np.nditer(mask))):
if index >= sum:
result = np.append(result, 0)
result = np.reshape(result, shape)
return result
class TestMaskedSelectOp(OpTest):
def setUp(self):
self.init()
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
self.op_type = "masked_select"
self.python_api = paddle.masked_select
x = np.random.random(self.shape).astype('float32')
mask = np.array(np.random.randint(2, size=self.shape, dtype=bool))
out = np_masked_select(self.shape, x, mask)
self.inputs = {'X': x, 'Mask': mask}
self.outputs = {'Y': out}
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Y')
def init(self):
self.shape = (50, 3)
class TestMaskedSelectOp1(TestMaskedSelectOp):
def init(self):
self.shape = (6, 8, 9, 18)
class TestMaskedSelectOp2(TestMaskedSelectOp):
def init(self):
self.shape = (168, )
@skip_check_grad_ci(reason="get_numeric_gradient not support int32")
class TestMaskedSelectOpInt32(TestMaskedSelectOp):
def init_dtype(self):
self.dtype = np.int32
def test_check_grad(self):
pass
class TestMaskedSelectOpFp16(TestMaskedSelectOp):
def init_dtype(self):
self.dtype = np.float16
def test_check_grad(self):
x_grad = self.inputs['Mask'].astype(self.dtype)
x_grad = x_grad * (1 / x_grad.size)
self.check_grad_with_place(self.place, ['X'],
'Y',
user_defined_grads=[x_grad])
class TestMaskedSelectAPI(unittest.TestCase):
def test_imperative_mode(self):
paddle.disable_static()
shape = (88, 6, 8)
np_x = np.random.random(shape).astype('float32')
np_mask = np.array(np.random.randint(2, size=shape, dtype=bool))
x = paddle.to_tensor(np_x)
mask = paddle.to_tensor(np_mask)
out = paddle.masked_select(x, mask)
np_out = np_masked_select(shape, np_x, np_mask)
self.assertEqual(np.allclose(out.numpy(), np_out), True)
paddle.enable_static()
def test_static_mode(self):
shape = [8, 9, 6]
x = paddle.fluid.data(shape=shape, dtype='float32', name='x')
mask = paddle.fluid.data(shape=shape, dtype='bool', name='mask')
np_x = np.random.random(shape).astype('float32')
np_mask = np.array(np.random.randint(2, size=shape, dtype=bool))
out = paddle.masked_select(x, mask)
np_out = np_masked_select(shape, np_x, np_mask)
exe = paddle.static.Executor(place=paddle.device.MLUPlace(0))
res = exe.run(paddle.static.default_main_program(),
feed={
"x": np_x,
"mask": np_mask
},
fetch_list=[out])
self.assertEqual(np.allclose(res, np_out), True)
class TestMaskedSelectError(unittest.TestCase):
def test_error(self):
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
shape = [8, 9, 6]
x = paddle.fluid.data(shape=shape, dtype='float32', name='x')
mask = paddle.fluid.data(shape=shape, dtype='bool', name='mask')
mask_float = paddle.fluid.data(shape=shape,
dtype='float32',
name='mask_float')
np_x = np.random.random(shape).astype('float32')
np_mask = np.array(np.random.randint(2, size=shape, dtype=bool))
def test_x_type():
paddle.masked_select(np_x, mask)
self.assertRaises(TypeError, test_x_type)
def test_mask_type():
paddle.masked_select(x, np_mask)
self.assertRaises(TypeError, test_mask_type)
def test_mask_dtype():
paddle.masked_select(x, mask_float)
self.assertRaises(TypeError, test_mask_dtype)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册