From d5323dab41fb03fec26f2c5d9326bbd188707ba8 Mon Sep 17 00:00:00 2001 From: TTerror Date: Sat, 20 Feb 2021 17:05:02 +0800 Subject: [PATCH] add squeeze_op/unsqueeze_op on kunlun;fix conv op and parallel executor;optimize lookup_table op (#31056) * add squeeze_op/unsqueeze_op on kunlun; fix conv op and parallel executor on kunlun; optimize lookup_table op on kunlun * update squeeze/unsqueeze op --- cmake/external/xpu.cmake | 2 +- paddle/fluid/framework/parallel_executor.cc | 7 +- paddle/fluid/operators/conv_op_xpu.cc | 48 +++- .../fluid/operators/lookup_table_v2_op_xpu.cc | 33 +-- paddle/fluid/operators/squeeze_op_xpu.cc | 60 +++++ paddle/fluid/operators/unsqueeze_op_xpu.cc | 60 +++++ .../unittests/xpu/test_squeeze2_op_xpu.py | 87 +++++++ .../unittests/xpu/test_squeeze_op_xpu.py | 110 +++++++++ .../unittests/xpu/test_unsqueeze2_op_xpu.py | 231 ++++++++++++++++++ .../unittests/xpu/test_unsqueeze_op_xpu.py | 93 +++++++ 10 files changed, 693 insertions(+), 38 deletions(-) create mode 100644 paddle/fluid/operators/squeeze_op_xpu.cc create mode 100644 paddle/fluid/operators/unsqueeze_op_xpu.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_squeeze2_op_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_squeeze_op_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_unsqueeze2_op_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_unsqueeze_op_xpu.py diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 3015265d48d..41b2907bbae 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -11,7 +11,7 @@ if(NOT XPU_SDK_ROOT) elseif(WITH_SUNWAY) SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2021_01_13.tar.gz" CACHE STRING "" FORCE) else() - SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_02_03.tar.gz" CACHE STRING "" FORCE) + SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_02_19.tar.gz" CACHE STRING "" FORCE) endif() SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu") diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 3ddd7cc9182..6e1e7146ba6 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -1072,12 +1072,13 @@ void ParallelExecutor::BCastParamsToDevices( platform::errors::Unavailable("bkcl_group_start failed")); for (size_t i = 0; i < member_->places_.size(); ++i) { auto &bkcl_ctx = bkcl_ctxs->at(member_->places_[i]); + auto broadcast_numel = numel; if (main_tensor.type() == framework::proto::VarType::INT64) { - numel *= 2; + broadcast_numel *= 2; } PADDLE_ENFORCE_EQ( - bkcl_broadcast(bkcl_ctx.comm(), buffers[i], buffers[i], numel, - data_type, 0, NULL), + bkcl_broadcast(bkcl_ctx.comm(), buffers[i], buffers[i], + broadcast_numel, data_type, 0, NULL), BKCL_SUCCESS, platform::errors::Unavailable("bkcl_broadcast failed")); } diff --git a/paddle/fluid/operators/conv_op_xpu.cc b/paddle/fluid/operators/conv_op_xpu.cc index 46af4d30500..da8cb68ce54 100644 --- a/paddle/fluid/operators/conv_op_xpu.cc +++ b/paddle/fluid/operators/conv_op_xpu.cc @@ -32,20 +32,31 @@ class GemmConvXPUKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); std::vector dilations = context.Attr>("dilations"); + const std::string data_format = context.Attr("data_format"); + const std::string padding_algorithm = + context.Attr("padding_algorithm"); + + PADDLE_ENFORCE_EQ(data_format == "NHWC" || data_format == "NDHWC", false, + platform::errors::InvalidArgument( + ("XPU do support data_format is NCHW in conv op."))); + + framework::DDim in_data_dims = + framework::slice_ddim(input->dims(), 2, input->dims().size()); + framework::DDim filter_data_dims = + framework::slice_ddim(filter.dims(), 2, filter.dims().size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + const int batch_size = static_cast(input->dims()[0]); const int img_c = static_cast(input->dims()[1]); const int img_h = static_cast(input->dims()[2]); const int img_w = static_cast(input->dims()[3]); const int f = static_cast(filter.dims()[0]); - const int win_h = static_cast(filter.dims()[2]); - const int win_w = static_cast(filter.dims()[3]); auto& dev_ctx = context.template device_context(); - std::vector k_size; - k_size.push_back(win_h); - k_size.push_back(win_w); int r = xpu::conv2d( dev_ctx.x_context(), input->data(), filter.data(), - output->data(), batch_size, img_c, img_h, img_w, f, k_size, + output->data(), batch_size, img_c, img_h, img_w, f, ksize, strides, paddings, dilations, groups, nullptr, nullptr, nullptr, true); PADDLE_ENFORCE_EQ( r, XPU_SUCCESS, @@ -53,6 +64,7 @@ class GemmConvXPUKernel : public framework::OpKernel { r, XPUAPIErrorMsg[r])); } }; + template class GemmConvGradXPUKernel : public framework::OpKernel { public: @@ -73,13 +85,28 @@ class GemmConvGradXPUKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); std::vector dilations = context.Attr>("dilations"); + const std::string data_format = context.Attr("data_format"); + const std::string padding_algorithm = + context.Attr("padding_algorithm"); + + PADDLE_ENFORCE_EQ( + data_format == "NHWC" || data_format == "NDHWC", false, + platform::errors::InvalidArgument( + ("XPU do support data_format is NCHW in conv grad op."))); + + framework::DDim in_data_dims = + framework::slice_ddim(input->dims(), 2, input->dims().size()); + framework::DDim filter_data_dims = + framework::slice_ddim(filter.dims(), 2, filter.dims().size()); + std::vector ksize = framework::vectorize(filter_data_dims); + UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, + in_data_dims, strides, ksize); + const int batch_size = static_cast(input->dims()[0]); const int img_c = static_cast(input->dims()[1]); const int img_h = static_cast(input->dims()[2]); const int img_w = static_cast(input->dims()[3]); const int f = static_cast(filter.dims()[0]); - const int win_h = static_cast(filter.dims()[2]); - const int win_w = static_cast(filter.dims()[3]); if (input_grad) { input_grad->mutable_data(context.GetPlace()); } @@ -87,14 +114,11 @@ class GemmConvGradXPUKernel : public framework::OpKernel { filter_grad->mutable_data(context.GetPlace()); } auto& dev_ctx = context.template device_context(); - std::vector k_size; - k_size.push_back(win_h); - k_size.push_back(win_w); int r = xpu::conv2d_grad( dev_ctx.x_context(), input->data(), filter.data(), output_grad->data(), input_grad ? input_grad->data() : nullptr, filter_grad ? filter_grad->data() : nullptr, batch_size, img_c, - img_h, img_w, f, k_size, strides, paddings, dilations, groups, nullptr, + img_h, img_w, f, ksize, strides, paddings, dilations, groups, nullptr, nullptr, nullptr, nullptr, nullptr, true); PADDLE_ENFORCE_EQ( r, XPU_SUCCESS, diff --git a/paddle/fluid/operators/lookup_table_v2_op_xpu.cc b/paddle/fluid/operators/lookup_table_v2_op_xpu.cc index 2284401ba1b..eec957fb8e5 100644 --- a/paddle/fluid/operators/lookup_table_v2_op_xpu.cc +++ b/paddle/fluid/operators/lookup_table_v2_op_xpu.cc @@ -17,11 +17,10 @@ limitations under the License. */ #include "paddle/fluid/framework/no_need_buffer_vars_inference.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/var_type_inference.h" - +#ifdef PADDLE_WITH_XPU namespace paddle { namespace operators { -#ifdef PADDLE_WITH_XPU template class LookupTableV2XPUKernel : public framework::OpKernel { public: @@ -96,26 +95,19 @@ class LookupTableV2GradXPUKernel : public framework::OpKernel { platform::errors::OutOfRange( "Number of ids greater than int32_t::max , please check " "number of ids in LookupTableV2GradXPUKernel.")); - int ids_numel_int32 = static_cast(ids_numel); - const int64_t *ids_data = ids_t->data(); - int D = d_table_t->dims()[1]; + auto &dev_ctx = context.template device_context(); + const int64_t *ids_data = ids_t->data(); const T *d_output_data = d_output_t->data(); T *d_table_data = d_table_t->mutable_data(context.GetPlace()); - auto &dev_ctx = context.template device_context(); - // set zeros for d_table_data - const int zero = 0; - int r = xpu::memset(dev_ctx.x_context(), d_table_data, zero, - d_table_t->numel() * sizeof(T)); - PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, - platform::errors::External( - "XPU API return wrong value[%d], please check where " - "Baidu Kunlun Card is properly installed.", - r)); - - r = xpu::embedding_backward(dev_ctx.x_context(), - ids_numel_int32, ids_data, D, - d_output_data, d_table_data); + int xm = d_table_t->dims()[0]; + int ym = static_cast(ids_numel); + int n = d_table_t->dims()[1]; + int padding_idx = context.Attr("padding_idx"); + + int r = xpu::embedding_grad(dev_ctx.x_context(), d_output_data, + ids_data, d_table_data, xm, n, ym, + padding_idx); PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, platform::errors::External( "XPU API return wrong value[%d] , please check where " @@ -123,13 +115,10 @@ class LookupTableV2GradXPUKernel : public framework::OpKernel { r)); } }; -#endif - } // namespace operators } // namespace paddle namespace ops = paddle::operators; -#ifdef PADDLE_WITH_XPU REGISTER_OP_XPU_KERNEL( lookup_table_v2, ops::LookupTableV2XPUKernel); diff --git a/paddle/fluid/operators/squeeze_op_xpu.cc b/paddle/fluid/operators/squeeze_op_xpu.cc new file mode 100644 index 00000000000..72be8bdfb43 --- /dev/null +++ b/paddle/fluid/operators/squeeze_op_xpu.cc @@ -0,0 +1,60 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/squeeze_op.h" +#ifdef PADDLE_WITH_XPU + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL( + squeeze, ops::SqueezeKernel, + ops::SqueezeKernel, + ops::SqueezeKernel, + ops::SqueezeKernel, + ops::SqueezeKernel, + ops::SqueezeKernel, + ops::SqueezeKernel, + ops::SqueezeKernel); +REGISTER_OP_XPU_KERNEL( + squeeze_grad, + ops::SqueezeGradKernel, + ops::SqueezeGradKernel, + ops::SqueezeGradKernel, + ops::SqueezeGradKernel, + ops::SqueezeGradKernel, + ops::SqueezeGradKernel, + ops::SqueezeGradKernel, + ops::SqueezeGradKernel); +REGISTER_OP_XPU_KERNEL( + squeeze2, ops::Squeeze2Kernel, + ops::Squeeze2Kernel, + ops::Squeeze2Kernel, + ops::Squeeze2Kernel, + ops::Squeeze2Kernel, + ops::Squeeze2Kernel, + ops::Squeeze2Kernel, + ops::Squeeze2Kernel); +REGISTER_OP_XPU_KERNEL( + squeeze2_grad, + ops::Squeeze2GradKernel, + ops::Squeeze2GradKernel, + ops::Squeeze2GradKernel, + ops::Squeeze2GradKernel, + ops::Squeeze2GradKernel, + ops::Squeeze2GradKernel, + ops::Squeeze2GradKernel, + ops::Squeeze2GradKernel); + +#endif diff --git a/paddle/fluid/operators/unsqueeze_op_xpu.cc b/paddle/fluid/operators/unsqueeze_op_xpu.cc new file mode 100644 index 00000000000..9ce3190f45f --- /dev/null +++ b/paddle/fluid/operators/unsqueeze_op_xpu.cc @@ -0,0 +1,60 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/unsqueeze_op.h" +#ifdef PADDLE_WITH_XPU +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL( + unsqueeze, ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel); +REGISTER_OP_XPU_KERNEL( + unsqueeze_grad, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel); +REGISTER_OP_XPU_KERNEL( + unsqueeze2, ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel); +REGISTER_OP_XPU_KERNEL( + unsqueeze2_grad, + ops::Unsqueeze2GradKernel, + ops::Unsqueeze2GradKernel, + ops::Unsqueeze2GradKernel, + ops::Unsqueeze2GradKernel, + ops::Unsqueeze2GradKernel, + ops::Unsqueeze2GradKernel, + ops::Unsqueeze2GradKernel, + ops::Unsqueeze2GradKernel); + +#endif diff --git a/python/paddle/fluid/tests/unittests/xpu/test_squeeze2_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_squeeze2_op_xpu.py new file mode 100644 index 00000000000..a6269f43daa --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_squeeze2_op_xpu.py @@ -0,0 +1,87 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import sys +sys.path.append("..") + +import numpy as np + +from op_test import OpTest +from op_test_xpu import XPUOpTest +import paddle + +paddle.enable_static() + + +# Correct: General. +class TestSqueezeOp(XPUOpTest): + def setUp(self): + self.op_type = "squeeze2" + self.use_xpu = True + self.use_mkldnn = False + self.init_test_case() + self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} + self.init_attrs() + self.outputs = { + "Out": self.inputs["X"].reshape(self.new_shape), + "XShape": np.random.random(self.ori_shape).astype("float32") + } + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place, no_check_set=['XShape']) + + def test_check_grad(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + + def init_test_case(self): + self.ori_shape = (1, 3, 1, 40) + self.axes = (0, 2) + self.new_shape = (3, 40) + + def init_attrs(self): + self.attrs = {"axes": self.axes} + + +# Correct: There is mins axis. +class TestSqueezeOp1(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (1, 20, 1, 5) + self.axes = (0, -2) + self.new_shape = (20, 5) + + +# Correct: No axes input. +class TestSqueezeOp2(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (1, 20, 1, 5) + self.axes = () + self.new_shape = (20, 5) + + +# Correct: Just part of axes be squeezed. +class TestSqueezeOp3(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (6, 1, 5, 1, 4, 1) + self.axes = (1, -1) + self.new_shape = (6, 5, 1, 4) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_squeeze_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_squeeze_op_xpu.py new file mode 100644 index 00000000000..de701bfc513 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_squeeze_op_xpu.py @@ -0,0 +1,110 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import sys +sys.path.append("..") + +import numpy as np + +import paddle +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard +from op_test import OpTest +from op_test_xpu import XPUOpTest + +paddle.enable_static() + + +# Correct: General. +class TestSqueezeOp(XPUOpTest): + def setUp(self): + self.op_type = "squeeze" + self.use_xpu = True + self.use_mkldnn = False + self.init_test_case() + self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} + self.init_attrs() + self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape), } + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + + def init_test_case(self): + self.ori_shape = (1, 3, 1, 40) + self.axes = (0, 2) + self.new_shape = (3, 40) + + def init_attrs(self): + self.attrs = {"axes": self.axes} + + +# Correct: There is mins axis. +class TestSqueezeOp1(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (1, 3, 1, 40) + self.axes = (0, -2) + self.new_shape = (3, 40) + + +# Correct: No axes input. +class TestSqueezeOp2(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (1, 20, 1, 5) + self.axes = () + self.new_shape = (20, 5) + + +# Correct: Just part of axes be squeezed. +class TestSqueezeOp3(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (6, 1, 5, 1, 4, 1) + self.axes = (1, -1) + self.new_shape = (6, 5, 1, 4) + + +# Correct: The demension of axis is not of size 1 remains unchanged. +class TestSqueezeOp4(TestSqueezeOp): + def init_test_case(self): + self.ori_shape = (6, 1, 5, 1, 4, 1) + self.axes = (1, 2) + self.new_shape = (6, 5, 1, 4, 1) + + +class TestSqueezeOpError(unittest.TestCase): + def test_errors(self): + paddle.enable_static() + with program_guard(Program(), Program()): + # The input type of softmax_op must be Variable. + x1 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], paddle.XPUPlace(0)) + self.assertRaises(TypeError, paddle.squeeze, x1) + # The input axes of squeeze must be list. + x2 = paddle.static.data(name='x2', shape=[4], dtype="int32") + self.assertRaises(TypeError, paddle.squeeze, x2, axes=0) + # The input dtype of squeeze not support float16. + x3 = paddle.static.data(name='x3', shape=[4], dtype="float16") + self.assertRaises(TypeError, paddle.squeeze, x3, axes=0) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_unsqueeze2_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_unsqueeze2_op_xpu.py new file mode 100644 index 00000000000..606053832ea --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_unsqueeze2_op_xpu.py @@ -0,0 +1,231 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import sys +sys.path.append("..") + +import numpy as np + +import paddle +import paddle.fluid as fluid +from op_test import OpTest +from op_test_xpu import XPUOpTest + +paddle.enable_static() + + +# Correct: General. +class TestUnsqueezeOp(XPUOpTest): + def setUp(self): + self.init_test_case() + self.use_xpu = True + self.use_mkldnn = False + self.op_type = "unsqueeze2" + self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} + self.init_attrs() + self.outputs = { + "Out": self.inputs["X"].reshape(self.new_shape), + "XShape": np.random.random(self.ori_shape).astype("float32") + } + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place, no_check_set=['XShape']) + + def test_check_grad(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + + def init_test_case(self): + self.ori_shape = (3, 40) + self.axes = (1, 2) + self.new_shape = (3, 1, 1, 40) + + def init_attrs(self): + self.attrs = {"axes": self.axes} + + +# Correct: Single input index. +class TestUnsqueezeOp1(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (20, 5) + self.axes = (-1, ) + self.new_shape = (20, 5, 1) + + +# Correct: Mixed input axis. +class TestUnsqueezeOp2(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (20, 5) + self.axes = (0, -1) + self.new_shape = (1, 20, 5, 1) + + +# Correct: There is duplicated axis. +class TestUnsqueezeOp3(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (10, 2, 5) + self.axes = (0, 3, 3) + self.new_shape = (1, 10, 2, 1, 1, 5) + + +# Correct: Reversed axes. +class TestUnsqueezeOp4(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (10, 2, 5) + self.axes = (3, 1, 1) + self.new_shape = (10, 1, 1, 2, 5, 1) + + +# axes is a list(with tensor) +class TestUnsqueezeOp_AxesTensorList(XPUOpTest): + def setUp(self): + self.init_test_case() + self.use_xpu = True + self.use_mkldnn = False + self.op_type = "unsqueeze2" + + axes_tensor_list = [] + for index, ele in enumerate(self.axes): + axes_tensor_list.append(("axes" + str(index), np.ones( + (1)).astype('int32') * ele)) + + self.inputs = { + "X": np.random.random(self.ori_shape).astype("float32"), + "AxesTensorList": axes_tensor_list + } + self.init_attrs() + self.outputs = { + "Out": self.inputs["X"].reshape(self.new_shape), + "XShape": np.random.random(self.ori_shape).astype("float32") + } + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place, no_check_set=['XShape']) + + def test_check_grad(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + + def init_test_case(self): + self.ori_shape = (20, 5) + self.axes = (1, 2) + self.new_shape = (20, 1, 1, 5) + + def init_attrs(self): + self.attrs = {} + + +class TestUnsqueezeOp1_AxesTensorList(TestUnsqueezeOp_AxesTensorList): + def init_test_case(self): + self.ori_shape = (20, 5) + self.axes = (-1, ) + self.new_shape = (20, 5, 1) + + +class TestUnsqueezeOp2_AxesTensorList(TestUnsqueezeOp_AxesTensorList): + def init_test_case(self): + self.ori_shape = (20, 5) + self.axes = (0, -1) + self.new_shape = (1, 20, 5, 1) + + +class TestUnsqueezeOp3_AxesTensorList(TestUnsqueezeOp_AxesTensorList): + def init_test_case(self): + self.ori_shape = (10, 2, 5) + self.axes = (0, 3, 3) + self.new_shape = (1, 10, 2, 1, 1, 5) + + +class TestUnsqueezeOp4_AxesTensorList(TestUnsqueezeOp_AxesTensorList): + def init_test_case(self): + self.ori_shape = (10, 2, 5) + self.axes = (3, 1, 1) + self.new_shape = (10, 1, 1, 2, 5, 1) + + +# axes is a Tensor +class TestUnsqueezeOp_AxesTensor(XPUOpTest): + def setUp(self): + self.init_test_case() + self.use_xpu = True + self.use_mkldnn = False + self.op_type = "unsqueeze2" + + self.inputs = { + "X": np.random.random(self.ori_shape).astype("float32"), + "AxesTensor": np.array(self.axes).astype("int32") + } + self.init_attrs() + self.outputs = { + "Out": self.inputs["X"].reshape(self.new_shape), + "XShape": np.random.random(self.ori_shape).astype("float32") + } + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place, no_check_set=['XShape']) + + def test_check_grad(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + + def init_test_case(self): + self.ori_shape = (20, 5) + self.axes = (1, 2) + self.new_shape = (20, 1, 1, 5) + + def init_attrs(self): + self.attrs = {} + + +class TestUnsqueezeOp1_AxesTensor(TestUnsqueezeOp_AxesTensor): + def init_test_case(self): + self.ori_shape = (20, 5) + self.axes = (-1, ) + self.new_shape = (20, 5, 1) + + +class TestUnsqueezeOp2_AxesTensor(TestUnsqueezeOp_AxesTensor): + def init_test_case(self): + self.ori_shape = (20, 5) + self.axes = (0, -1) + self.new_shape = (1, 20, 5, 1) + + +class TestUnsqueezeOp3_AxesTensor(TestUnsqueezeOp_AxesTensor): + def init_test_case(self): + self.ori_shape = (10, 2, 5) + self.axes = (0, 3, 3) + self.new_shape = (1, 10, 2, 1, 1, 5) + + +class TestUnsqueezeOp4_AxesTensor(TestUnsqueezeOp_AxesTensor): + def init_test_case(self): + self.ori_shape = (10, 2, 5) + self.axes = (3, 1, 1) + self.new_shape = (10, 1, 1, 2, 5, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_unsqueeze_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_unsqueeze_op_xpu.py new file mode 100644 index 00000000000..5e40073e731 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_unsqueeze_op_xpu.py @@ -0,0 +1,93 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import sys +sys.path.append("..") + +import numpy as np + +import paddle +import paddle.fluid as fluid +from op_test import OpTest +from op_test_xpu import XPUOpTest + +paddle.enable_static() + + +# Correct: General. +class TestUnsqueezeOp(XPUOpTest): + def setUp(self): + self.init_test_case() + self.op_type = "unsqueeze" + self.use_xpu = True + self.use_mkldnn = False + self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} + self.init_attrs() + self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)} + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + + def init_test_case(self): + self.ori_shape = (3, 40) + self.axes = (1, 2) + self.new_shape = (3, 1, 1, 40) + + def init_attrs(self): + self.attrs = {"axes": self.axes} + + +# Correct: Single input index. +class TestUnsqueezeOp1(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (20, 5) + self.axes = (-1, ) + self.new_shape = (20, 5, 1) + + +# Correct: Mixed input axis. +class TestUnsqueezeOp2(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (20, 5) + self.axes = (0, -1) + self.new_shape = (1, 20, 5, 1) + + +# Correct: There is duplicated axis. +class TestUnsqueezeOp3(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (10, 2, 5) + self.axes = (0, 3, 3) + self.new_shape = (1, 10, 2, 1, 1, 5) + + +# Correct: Reversed axes. +class TestUnsqueezeOp4(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (10, 2, 5) + self.axes = (3, 1, 1) + self.new_shape = (10, 1, 1, 2, 5, 1) + + +if __name__ == "__main__": + unittest.main() -- GitLab