From 338f9e05c9a3524e868bf4cc9e42e79d06bffd56 Mon Sep 17 00:00:00 2001 From: ronnywang <524019753@qq.com> Date: Sat, 7 Aug 2021 22:21:12 -0500 Subject: [PATCH] add sequence_mask_op_npu and tests (#34455) --- .../sequence_ops/sequence_mask_op_npu.cc | 138 +++++++++++++ .../unittests/npu/test_sequence_mask_npu.py | 182 ++++++++++++++++++ 2 files changed, 320 insertions(+) create mode 100644 paddle/fluid/operators/sequence_ops/sequence_mask_op_npu.cc create mode 100644 python/paddle/fluid/tests/unittests/npu/test_sequence_mask_npu.py diff --git a/paddle/fluid/operators/sequence_ops/sequence_mask_op_npu.cc b/paddle/fluid/operators/sequence_ops/sequence_mask_op_npu.cc new file mode 100644 index 00000000000..aa84da10ad6 --- /dev/null +++ b/paddle/fluid/operators/sequence_ops/sequence_mask_op_npu.cc @@ -0,0 +1,138 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/sequence_ops/sequence_mask_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class SequenceMaskNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); + auto* x = ctx.Input("X"); + auto* y = ctx.Output("Y"); + int maxlen = ctx.Attr("maxlen"); + + if (ctx.HasInput("MaxLenTensor")) { + auto max_len_tensor = ctx.Input("MaxLenTensor"); + PADDLE_ENFORCE_NOT_NULL(max_len_tensor, + platform::errors::InvalidArgument( + "Input(MaxLenTensor) should not be NULL." + "But received Input(MaxLenTensor) is NULL")); + framework::Tensor temp; + TensorCopySync(*max_len_tensor, platform::CPUPlace(), &temp); + maxlen = *temp.data(); + PADDLE_ENFORCE_GT( + maxlen, 0, + platform::errors::InvalidArgument( + "Input(MaxLenTensor) value should be greater than 0. But " + "received Input(MaxLenTensor) value = %d.", + maxlen)); + } + + if (maxlen < 0) { + auto x_numel = x->numel(); + std::vector x_vec; + framework::TensorToVector(*x, dev_ctx, &x_vec); + auto x_data = x_vec.data(); + maxlen = static_cast(*std::max_element(x_data, x_data + x_numel)); + } + auto y_dim = framework::vectorize(x->dims()); + y_dim.push_back(maxlen); + + Tensor cast_x; + cast_x.mutable_data(x->dims(), ctx.GetPlace()); + const auto& cast1_runner = + NpuOpRunner("Cast", {*x}, {cast_x}, + {{"dst_type", ConvertToNpuDtype(cast_x.type())}}); + cast1_runner.Run(dev_ctx.stream()); + + Tensor tmp; + tmp.mutable_data(framework::make_ddim({maxlen}), ctx.GetPlace()); + NpuOpRunner range_runner; + range_runner.SetType("Range"); + range_runner.AddInput(std::vector({0})); + range_runner.AddInput(std::vector({maxlen})); + range_runner.AddInput(std::vector({1})); + range_runner.AddOutput(tmp); + range_runner.Run(dev_ctx.stream()); + + Tensor expand_tmp; + expand_tmp.mutable_data(framework::make_ddim(y_dim), + ctx.GetPlace()); + const auto& expand_runner = + NpuOpRunner("ExpandD", {tmp}, {expand_tmp}, {{"shape", y_dim}}); + expand_runner.Run(dev_ctx.stream()); + + auto x_dims = framework::vectorize(x->dims()); + x_dims.push_back(1); + cast_x.Resize(framework::make_ddim({x_dims})); + Tensor x_tmp; + x_tmp.mutable_data(framework::make_ddim(y_dim), ctx.GetPlace()); + const auto& tile_runner = + NpuOpRunner("TileWithAxis", {cast_x}, {x_tmp}, + {{"axis", x->dims().size()}, {"tiles", maxlen}}); + tile_runner.Run(dev_ctx.stream()); + + Tensor y_tmp; + y_tmp.mutable_data(framework::make_ddim(y_dim), ctx.GetPlace()); + const auto& less_runner = + NpuOpRunner("Less", {expand_tmp, x_tmp}, {y_tmp}, {}); + less_runner.Run(dev_ctx.stream()); + + y->Resize(framework::make_ddim(y_dim)); + auto out_dtype = static_cast( + ctx.Attr("out_dtype")); + if (out_dtype == framework::proto::VarType::INT32) { + y->mutable_data(ctx.GetPlace()); + } else if (out_dtype == framework::proto::VarType::INT64) { + y->mutable_data(ctx.GetPlace()); + } else if (out_dtype == framework::proto::VarType::FP32) { + y->mutable_data(ctx.GetPlace()); + } else if (out_dtype == framework::proto::VarType::FP64) { + y->mutable_data(ctx.GetPlace()); + } else if (out_dtype == framework::proto::VarType::BOOL) { + y->mutable_data(ctx.GetPlace()); + } else if (out_dtype == framework::proto::VarType::UINT8) { + y->mutable_data(ctx.GetPlace()); + } else { + PADDLE_ENFORCE(false, + platform::errors::InvalidArgument( + "out_dtype only supporing int32, int64, fp32, fp64, " + "bool, uint8, but receive out_dtype is %d", + out_dtype)); + } + + const auto& cast2_runner = NpuOpRunner( + "Cast", {y_tmp}, {*y}, {{"dst_type", ConvertToNpuDtype(out_dtype)}}); + cast2_runner.Run(dev_ctx.stream()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_NPU_KERNEL( + sequence_mask, ops::SequenceMaskNPUKernel, + ops::SequenceMaskNPUKernel, + ops::SequenceMaskNPUKernel, + ops::SequenceMaskNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_sequence_mask_npu.py b/python/paddle/fluid/tests/unittests/npu/test_sequence_mask_npu.py new file mode 100644 index 00000000000..21440de9fdd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_sequence_mask_npu.py @@ -0,0 +1,182 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.framework import convert_np_dtype_to_dtype_, Program, program_guard + +paddle.enable_static() + + +class SequenceMaskTestBase(OpTest): + def set_npu(self): + self.__class__.use_npu = True + + def initDefaultParameters(self): + self.op_type = 'sequence_mask' + self.maxlen = 10 + self.mask_dtype = 'int64' + self.x = [[0, 3, 4], [5, 7, 9]] + + def initParameters(self): + pass + + def setUp(self): + self.set_npu() + self.initDefaultParameters() + self.initParameters() + if not isinstance(self.x, np.ndarray): + self.x = np.array(self.x) + + self.inputs = {'X': self.x} + self.outputs = {'Y': self.calc_ground_truth_mask()} + self.attrs = { + 'maxlen': self.maxlen, + 'out_dtype': convert_np_dtype_to_dtype_(self.mask_dtype) + } + + def calc_ground_truth_mask(self): + maxlen = np.max(self.x) if self.maxlen < 0 else self.maxlen + shape = self.x.shape + (maxlen, ) + index_broadcast = np.broadcast_to( + np.reshape( + range(maxlen), newshape=[1] * self.x.ndim + [-1]), + shape=shape) + x_broadcast = np.broadcast_to( + np.reshape( + self.x, newshape=self.x.shape + (-1, )), shape=shape) + return (index_broadcast < x_broadcast).astype(self.mask_dtype) + + def test_check_output(self): + self.check_output_with_place(paddle.NPUPlace(0)) + + +class SequenceMaskTest1(SequenceMaskTestBase): + def initParameters(self): + self.mask_dtype = 'bool' + + +class SequenceMaskTest2(SequenceMaskTestBase): + def initParameters(self): + self.mask_dtype = 'uint8' + + +class SequenceMaskTest3(SequenceMaskTestBase): + def initParameters(self): + self.mask_dtype = 'int32' + + +class SequenceMaskTest4(SequenceMaskTestBase): + def initParameters(self): + self.mask_dtype = 'float32' + + +class SequenceMaskTest5(SequenceMaskTestBase): + def initParameters(self): + self.mask_dtype = 'float64' + + +class SequenceMaskTest6(SequenceMaskTestBase): + def initParameters(self): + self.maxlen = -1 + + +class SequenceMaskTestBase_tensor_attr(OpTest): + def set_npu(self): + self.__class__.use_npu = True + + def initDefaultParameters(self): + self.op_type = 'sequence_mask' + self.maxlen = 10 + self.maxlen_tensor = np.ones((1), 'int32') * 10 + self.mask_dtype = 'int64' + self.x = [[0, 3, 4], [5, 7, 9]] + + def initParameters(self): + pass + + def setUp(self): + self.set_npu() + self.initDefaultParameters() + self.initParameters() + if not isinstance(self.x, np.ndarray): + self.x = np.array(self.x) + + self.inputs = {'X': self.x, 'MaxLenTensor': self.maxlen_tensor} + self.outputs = {'Y': self.calc_ground_truth_mask()} + self.attrs = {'out_dtype': convert_np_dtype_to_dtype_(self.mask_dtype)} + + def calc_ground_truth_mask(self): + maxlen = np.max(self.x) if self.maxlen < 0 else self.maxlen + shape = self.x.shape + (maxlen, ) + index_broadcast = np.broadcast_to( + np.reshape( + range(maxlen), newshape=[1] * self.x.ndim + [-1]), + shape=shape) + x_broadcast = np.broadcast_to( + np.reshape( + self.x, newshape=self.x.shape + (-1, )), shape=shape) + return (index_broadcast < x_broadcast).astype(self.mask_dtype) + + def test_check_output(self): + self.check_output() + + +class SequenceMaskTest1_tensor_attr(SequenceMaskTestBase_tensor_attr): + def initParameters(self): + self.mask_dtype = 'bool' + + +class SequenceMaskTest2_tensor_attr(SequenceMaskTestBase_tensor_attr): + def initParameters(self): + self.mask_dtype = 'uint8' + + +class SequenceMaskTest3_tensor_attr(SequenceMaskTestBase_tensor_attr): + def initParameters(self): + self.mask_dtype = 'int32' + + +class SequenceMaskTest4_tensor_attr(SequenceMaskTestBase_tensor_attr): + def initParameters(self): + self.mask_dtype = 'float32' + + +class SequenceMaskTest5_tensor_attr(SequenceMaskTestBase_tensor_attr): + def initParameters(self): + self.mask_dtype = 'float64' + + +class TestSequenceMaskOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + input_data = np.random.uniform(1, 5, [4]).astype("float32") + + def test_Variable(): + # the input must be Variable + fluid.layers.sequence_mask(input_data, maxlen=4) + + self.assertRaises(TypeError, test_Variable) + + +if __name__ == '__main__': + unittest.main() -- GitLab