From ed0e95a804115f81452cfa5d1e281b92ecf29ca8 Mon Sep 17 00:00:00 2001 From: helen88 Date: Wed, 3 Aug 2022 16:45:07 +0800 Subject: [PATCH] add sequence_unpad for xpu (#44808) * add sequence_unpad for xpu,*test=kunlun * add sequence_unpad, *test=kunlun * fix bug in testcase,should not be sequence_pad,*test=kunlun --- .../fluid/operators/math/sequence_padding.cc | 49 ++++++ .../sequence_ops/sequence_unpad_op.h | 3 +- .../sequence_ops/sequence_unpad_op_xpu.cc | 24 +++ .../fluid/platform/device/xpu/xpu2_op_list.h | 3 +- .../xpu/test_sequence_unpad_op_xpu.py | 161 ++++++++++++++++++ 5 files changed, 238 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/sequence_ops/sequence_unpad_op_xpu.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_sequence_unpad_op_xpu.py diff --git a/paddle/fluid/operators/math/sequence_padding.cc b/paddle/fluid/operators/math/sequence_padding.cc index 826eda5559a..273f99a5f96 100644 --- a/paddle/fluid/operators/math/sequence_padding.cc +++ b/paddle/fluid/operators/math/sequence_padding.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/sequence_padding.h" +#include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/phi/backends/cpu/cpu_context.h" @@ -190,6 +191,50 @@ class UnpaddingLoDTensorFunctor { } }; +#ifdef PADDLE_WITH_XPU +template +class UnpaddingLoDTensorFunctor { + public: + void operator()(const platform::XPUDeviceContext& context, + const framework::LoDTensor& pad_tensor, + framework::LoDTensor* seq_tensor, + int pad_seq_len = -1, + int lod_level = 0, + bool norm_by_times = false, + const PadLayout layout = kBatchLengthWidth) { + auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level]; + const auto& seq_tensor_dims = seq_tensor->dims(); + const auto& pad_tensor_dims = pad_tensor.dims(); + if (pad_seq_len == -1) { + pad_seq_len = MaximumSequenceLength(seq_offsets); + } + int step_width = seq_tensor->numel() / seq_tensor_dims[0]; + + CheckDims(seq_tensor_dims, + pad_tensor_dims, + seq_offsets, + pad_seq_len, + step_width, + layout); + + const T* pad_data = pad_tensor.data(); // padding tensor x + T* seq_data = seq_tensor->data(); // unpadding tensor y + + xpu::VectorParam seq_offsets_param{ + reinterpret_cast(seq_offsets.data()), + static_cast(seq_offsets.size()), + nullptr}; + int r = xpu::sequence_unpad(context.x_context(), + pad_data, + seq_data, + seq_offsets_param, + pad_seq_len /*max_seqlen*/, + step_width /*dim*/); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "sequence_unpad"); + } +}; +#endif + template class PaddingLoDTensorFunctor; template class PaddingLoDTensorFunctor; template class PaddingLoDTensorFunctor; @@ -200,6 +245,10 @@ template class UnpaddingLoDTensorFunctor; template class UnpaddingLoDTensorFunctor; template class UnpaddingLoDTensorFunctor; +#ifdef PADDLE_WITH_XPU +template class UnpaddingLoDTensorFunctor; +#endif + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/sequence_ops/sequence_unpad_op.h b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.h index f6a8eebf7ce..747549eed51 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_unpad_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.h @@ -38,7 +38,8 @@ class SequenceUnpadOpKernel : public framework::OpKernel { auto& dev_ctx = ctx.template device_context(); framework::Tensor seq_len_cpu = ctx.AllocateTmpTensor(len_t->dims(), dev_ctx); - if (platform::is_gpu_place(ctx.GetPlace())) { + if (platform::is_gpu_place(ctx.GetPlace()) || + platform::is_xpu_place(ctx.GetPlace())) { seq_len_cpu.mutable_data(platform::CPUPlace()); framework::TensorCopySync(*len_t, platform::CPUPlace(), &seq_len_cpu); } else { diff --git a/paddle/fluid/operators/sequence_ops/sequence_unpad_op_xpu.cc b/paddle/fluid/operators/sequence_ops/sequence_unpad_op_xpu.cc new file mode 100644 index 00000000000..caa2ec24570 --- /dev/null +++ b/paddle/fluid/operators/sequence_ops/sequence_unpad_op_xpu.cc @@ -0,0 +1,24 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU + +#include "paddle/fluid/operators/sequence_ops/sequence_unpad_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL( + sequence_unpad, + ops::SequenceUnpadOpKernel); + +#endif diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 22f2d9a60a5..c36dd6425c8 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -545,7 +545,8 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"sequence_conv_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - + {"sequence_unpad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, // Fused op {"resnet_basic_block_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, diff --git a/python/paddle/fluid/tests/unittests/xpu/test_sequence_unpad_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_sequence_unpad_op_xpu.py new file mode 100644 index 00000000000..50da2be25bd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_sequence_unpad_op_xpu.py @@ -0,0 +1,161 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +import unittest +import six +import numpy as np +import sys + +sys.path.append("..") + +import paddle +import paddle.fluid as fluid +import unittest +import numpy as np +from op_test import OpTest +from op_test_xpu import XPUOpTest +from paddle.fluid.framework import Program, program_guard +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper + +paddle.enable_static() + + +class XPUTestSequenceUnpadOp(XPUOpTestWrapper): + + def __init__(self): + self.op_name = 'sequence_unpad' + self.use_dynamic_create_class = False + + class TestSequenceUnpadOp(XPUOpTest): + + def setUp(self): + self.init_dtype() + self.initTestCase() + self.set_xpu() + self.op_type = 'sequence_unpad' + self.place = paddle.XPUPlace(0) + self.compute() + + def init_dtype(self): + self.dtype = self.in_type + + def set_xpu(self): + self.__class__.use_xpu = True + self.__class__.no_need_check_grad = True + + def test_check_output(self): + self.check_output_with_place(self.place) + + def initTestCase(self): + self.length = [2, 3, 4] + self.x_shape = (3, 40) + + def compute(self): + assert len(self.length) == self.x_shape[0] + x = np.random.random(self.x_shape).astype(self.dtype) + out_lod = [self.length] + + out = x[0, 0:self.length[0]] + for i in six.moves.xrange(1, x.shape[0]): + out = np.append(out, x[i, 0:self.length[i]], axis=0) + + out_shape = (sum(self.length), ) + if len(self.x_shape) == 2: + out_shape = out_shape + (1, ) + else: + out_shape = out_shape + self.x_shape[2:] + + self.inputs = { + 'X': x, + 'Length': np.array(self.length).astype('int64') + } + self.outputs = {'Out': (out.reshape(out_shape), out_lod)} + + class TestSequenceUnpadOp2(TestSequenceUnpadOp): + + def initTestCase(self): + self.length = [2, 3, 4] + self.x_shape = (3, 5, 4, 3) + + class TestSequenceUnpadOp3(TestSequenceUnpadOp): + + def initTestCase(self): + self.length = [5, 2, 3, 4] + self.x_shape = (4, 5, 3, 3, 6) + + class TestSequenceUnpadOp4(TestSequenceUnpadOp): + + def initTestCase(self): + self.length = [5, 5, 5, 5] + self.x_shape = (4, 5, 3, 3, 6) + + class TestSequenceUnpadOp5(TestSequenceUnpadOp): + + def initTestCase(self): + self.length = [1, 4, 3, 1] + self.x_shape = (4, 5, 3, 3, 6) + + +class TestSequenceUnpadOpError(unittest.TestCase): + + def test_error(self): + """ + The type of 'x' in fluid.layers.sequence_unpad must be , but received . + """ + + def test_x_variable(): + x = np.random.random((10, 5)).astype("float64") + len = fluid.data(name='length2', shape=[10], dtype='int64') + fluid.layers.sequence_unpad(x=x, length=len) + + self.assertRaises(TypeError, test_x_variable) + """ + The type of 'length' in fluid.layers.sequence_unpad must be , but received . + """ + + def test_length_variable(): + x1 = fluid.data(name='x1', shape=[10, 5], dtype='float32') + len1 = np.random.random((10)).astype("int64") + fluid.layers.sequence_unpad(x=x1, length=len1) + + self.assertRaises(TypeError, test_length_variable) + """ + The data type of 'x' in fluid.layers.sequence_unpad must be ['float32', 'float64', 'int32', 'int64'], but received float16 + """ + + def test_x_dtype(): + x2 = fluid.data(name='x2', shape=[10, 5], dtype='float16') + len2 = fluid.data(name='length2', shape=[10], dtype='int64') + fluid.layers.sequence_unpad(x=x2, length=len2) + + self.assertRaises(TypeError, test_x_dtype) + """ + The data type of 'length' in fluid.layers.sequence_unpad must be ['int64'], but received int32 + """ + + def test_length_dtype(): + x3 = fluid.data(name='x3', shape=[10, 5], dtype='float64') + len3 = fluid.data(name='length3', shape=[10], dtype='int32') + fluid.layers.sequence_unpad(x=x3, length=len3) + + self.assertRaises(TypeError, test_length_dtype) + + +support_types = get_xpu_op_support_types('sequence_unpad') +for stype in support_types: + create_test_class(globals(), XPUTestSequenceUnpadOp, stype) + +if __name__ == '__main__': + unittest.main() -- GitLab