diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index ff8a3b9838a46cdb4ca634b3bd180c96d5b93d48..a709616314b174bf204d17becdb60b99d2eab895 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -4,7 +4,7 @@ endif() INCLUDE(ExternalProject) SET(XPU_PROJECT "extern_xpu") -SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2020_11_30.tar.gz" CACHE STRING "" FORCE) +SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2020_12_04.tar.gz" CACHE STRING "" FORCE) SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu") SET(XPU_DOWNLOAD_DIR "${XPU_SOURCE_DIR}/src/${XPU_PROJECT}") SET(XPU_INSTALL_DIR "${THIRD_PARTY_PATH}/install/xpu") diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op_xpu.cc b/paddle/fluid/operators/elementwise/elementwise_div_op_xpu.cc index 4f254a530746b7f91e159fa16b8e59882bb0898d..da676a7244fb3851a6b2b31a684ae07329e7f6f1 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op_xpu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_div_op_xpu.cc @@ -28,9 +28,10 @@ class ElementwiseDivXPUKernel : public framework::OpKernel { }; template -class ElementwiseDivGradXPUKernel : public framework::OpKernel { +class ElementwiseDivGradXPUKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); XPUElementwiseGrad(ctx, xpu::div_grad, true); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op_xpu.cc b/paddle/fluid/operators/elementwise/elementwise_max_op_xpu.cc index 411ddb266032a5f7d51fc8b9e125bd6515c234d2..c87db69c57d78cdf6f5ecd289909267991098b70 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op_xpu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_max_op_xpu.cc @@ -29,9 +29,10 @@ class ElementwiseMaxXPUKernel : public framework::OpKernel { }; template -class ElementwiseMaxGradXPUKernel : public framework::OpKernel { +class ElementwiseMaxGradXPUKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); XPUElementwiseGrad(ctx, xpu::max_grad, true); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op_xpu.cc b/paddle/fluid/operators/elementwise/elementwise_min_op_xpu.cc index 0b1e13122644e90f155af523315bfff4c2dfc49e..f1401369ec69aff873f1f822bb25b1d212596eec 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op_xpu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_min_op_xpu.cc @@ -29,9 +29,10 @@ class ElementwiseMinXPUKernel : public framework::OpKernel { }; template -class ElementwiseMinGradXPUKernel : public framework::OpKernel { +class ElementwiseMinGradXPUKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); XPUElementwiseGrad(ctx, xpu::min_grad, true); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op_xpu.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op_xpu.cc index 02c6900c7c19b8b3968e7891783a47e07ec6e8e0..23bb04f60a842bd8ad4005770f42d357c4f48930 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op_xpu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op_xpu.cc @@ -27,9 +27,10 @@ class ElementwiseMulXPUKernel : public framework::OpKernel { }; // DEFINE_XPU_GRAD_KERNEL(Mul, mul, true); template -class ElementwiseMulGradXPUKernel : public framework::OpKernel { +class ElementwiseMulGradXPUKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); XPUElementwiseGrad(ctx, xpu::mul_grad, true); } }; diff --git a/paddle/fluid/operators/interpolate_op_xpu.cc b/paddle/fluid/operators/interpolate_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..6dc42525469e1cf44f7eb267891d8fd60425f1d9 --- /dev/null +++ b/paddle/fluid/operators/interpolate_op_xpu.cc @@ -0,0 +1,258 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/interpolate_op.h" + +#ifdef PADDLE_WITH_XPU + +namespace paddle { +namespace operators { + +using framework::Tensor; +using DataLayout = framework::DataLayout; + +inline std::vector get_new_shape_xpu( + const std::vector& list_new_shape_tensor) { + // get tensor from + std::vector vec_new_shape; + for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) { + auto tensor = list_new_shape_tensor[i]; + PADDLE_ENFORCE_EQ( + tensor->dims(), framework::make_ddim({1}), + platform::errors::InvalidArgument("shape of dim tensor should be [1]")); + if (platform::is_xpu_place(tensor->place())) { + framework::Tensor temp; + TensorCopySync(*tensor, platform::CPUPlace(), &temp); + vec_new_shape.push_back(static_cast(*temp.data())); + } else { + vec_new_shape.push_back(static_cast(*tensor->data())); + } + } + + return vec_new_shape; +} + +template +inline std::vector get_new_data_from_tensor_xpu( + const Tensor* new_data_tensor) { + std::vector vec_new_data; + auto* new_data = new_data_tensor->data(); + framework::Tensor cpu_starts_tensor; + if (platform::is_xpu_place(new_data_tensor->place())) { + TensorCopySync(*new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor); + new_data = cpu_starts_tensor.data(); + } + vec_new_data = std::vector(new_data, new_data + new_data_tensor->numel()); + return vec_new_data; +} + +template +class InterpolateXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + + auto input_dims = input->dims(); + PADDLE_ENFORCE_EQ( + input_dims.size(), 4, + platform::errors::External("XPU Interpolate kernel only support 2d")); + + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); + int n, c, in_d, in_h, in_w; + ExtractNCDWH(input_dims, data_layout, &n, &c, &in_d, &in_h, &in_w); + + auto interp_method = ctx.Attr("interp_method"); + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + + auto list_new_size_tensor = ctx.MultiInput("SizeTensor"); + if (list_new_size_tensor.size() > 0) { + // have size tensor + auto new_size = get_new_shape_xpu(list_new_size_tensor); + out_h = new_size[0]; + out_w = new_size[1]; + } else { + float scale; + auto scale_tensor = ctx.Input("Scale"); + if (scale_tensor != nullptr) { + auto scale_data = get_new_data_from_tensor_xpu(scale_tensor); + scale = scale_data[0]; + } else { + scale = ctx.Attr("scale"); + } + if (scale > 0) { + out_h = static_cast(in_h * scale); + out_w = static_cast(in_w * scale); + } + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + auto out_size_data = get_new_data_from_tensor_xpu(out_size); + out_h = out_size_data[0]; + out_w = out_size_data[1]; + } + } + PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument( + "out_h in Attr(out_shape) of " + "Op(interpolate) " + "should be greater than 0.")); + PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument( + "out_w in Attr(out_shape) of " + "Op(interpolate) " + "should be greater than 0.")); + framework::DDim dim_out; + if (data_layout == DataLayout::kNCHW) { + dim_out = {n, c, out_h, out_w}; + } else { + dim_out = {n, out_h, out_w, c}; + } + output->mutable_data(dim_out, ctx.GetPlace()); + + if (in_h == out_h && in_w == out_w) { + framework::TensorCopy(*input, ctx.GetPlace(), output); + return; + } + bool nearest = "nearest" == interp_method; + int trans_mode = (align_corners) ? (0) : ((align_mode == 0) ? (1) : (2)); + auto& dev_ctx = ctx.template device_context(); + if (nearest) { + PADDLE_ENFORCE_EQ((data_layout == DataLayout::kNCHW), true, + platform::errors::InvalidArgument( + "XPU nearest is only support NCHW")); + } + int r = xpu::interpolate2d(dev_ctx.x_context(), input->data(), + output->data(), n, c, in_h, in_w, + out_h, out_w, nearest, trans_mode, + (data_layout == DataLayout::kNCHW)); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External("XPU interpolate2d kernel " + "return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + } +}; + +template +class InterpolateGradXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* output_grad = ctx.Input(framework::GradVarName("Out")); + + auto output_grad_dims = output_grad->dims(); + + PADDLE_ENFORCE_EQ(output_grad_dims.size(), 4, + platform::errors::External( + "XPU Interpolategrad kernel only support 2d")); + + auto* input = ctx.Input("X"); + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); + int n, c, in_d, in_h, in_w; + ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); + + auto interp_method = ctx.Attr("interp_method"); + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + float scale; + auto scale_tensor = ctx.Input("Scale"); + if (scale_tensor != nullptr) { + auto scale_data = get_new_data_from_tensor_xpu(scale_tensor); + scale = scale_data[0]; + } else { + scale = ctx.Attr("scale"); + } + if (scale > 0) { + out_h = static_cast(in_h * scale); + out_w = static_cast(in_w * scale); + } + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + auto out_size_data = get_new_data_from_tensor_xpu(out_size); + out_h = out_size_data[0]; + out_w = out_size_data[1]; + } + auto list_new_size_tensor = ctx.MultiInput("SizeTensor"); + if (list_new_size_tensor.size() > 0) { + // have size tensor + auto new_size = get_new_shape_xpu(list_new_size_tensor); + out_h = new_size[0]; + out_w = new_size[1]; + } + + framework::DDim dim_grad; + if (data_layout == DataLayout::kNCHW) { + dim_grad = {n, c, in_h, in_w}; + } else { + dim_grad = {n, in_h, in_w, c}; + } + input_grad->mutable_data(dim_grad, ctx.GetPlace()); + + auto& dev_ctx = ctx.template device_context(); + + int r = XPU_SUCCESS; + r = xpu::constant(dev_ctx.x_context(), input_grad->data(), + input_grad->numel(), static_cast(0.0)); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External( + "XPU constant in interpolate2d_grad kernel return " + "wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + + if (in_h == out_h && in_w == out_w) { + framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad); + return; + } + + bool nearest = "nearest" == interp_method; + int trans_mode = (align_corners) ? (0) : ((align_mode == 0) ? (1) : (2)); + + if (nearest) { + PADDLE_ENFORCE_EQ((data_layout == DataLayout::kNCHW), true, + platform::errors::InvalidArgument( + "XPU nearest is only support NCHW")); + } + + r = xpu::interpolate2d_grad(dev_ctx.x_context(), output_grad->data(), + input_grad->data(), n, c, in_h, in_w, + out_h, out_w, nearest, trans_mode, + (data_layout == DataLayout::kNCHW)); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU interpolate2d_grad kernel return " + "wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_XPU_KERNEL(bilinear_interp, ops::InterpolateXPUKernel); + +REGISTER_OP_XPU_KERNEL(bilinear_interp_grad, + ops::InterpolateGradXPUKernel); +#endif diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc index 368a12057c8993ae5140216c437023e5ff4dd34a..346ed965d06f285f3630f1c49254da9688432333 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc @@ -70,7 +70,8 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel { r)); } else { Tensor labels_int32; - labels_int32.mutable_data(context.GetPlace(), labels->numel()); + labels_int32.mutable_data(context.GetPlace(), + labels->numel() * sizeof(int32_t)); r = xpu::cast_v2( dev_ctx.x_context(), labels->data(), labels_int32.data(), labels->numel()); diff --git a/python/paddle/fluid/tests/unittests/xpu/test_bilinear_interp_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_bilinear_interp_op_xpu.py new file mode 100755 index 0000000000000000000000000000000000000000..f8ae945b6ebe5d0394fa57bb8739901ef23a049b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_bilinear_interp_op_xpu.py @@ -0,0 +1,519 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +import sys +sys.path.append("..") +from op_test_xpu import XPUOpTest +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +import time + +paddle.enable_static() + + +def bilinear_interp_np(input, + out_h, + out_w, + out_size=None, + actual_shape=None, + align_corners=True, + align_mode=0, + data_layout='NCHW'): + """bilinear interpolation implement in shape [N, C, H, W]""" + if data_layout == "NHWC": + input = np.transpose(input, (0, 3, 1, 2)) # NHWC => NCHW + if out_size is not None: + out_h = out_size[0] + out_w = out_size[1] + if actual_shape is not None: + out_h = actual_shape[0] + out_w = actual_shape[1] + batch_size, channel, in_h, in_w = input.shape + + ratio_h = ratio_w = 0.0 + if out_h > 1: + if (align_corners): + ratio_h = (in_h - 1.0) / (out_h - 1.0) + else: + ratio_h = 1.0 * in_h / out_h + if out_w > 1: + if (align_corners): + ratio_w = (in_w - 1.0) / (out_w - 1.0) + else: + ratio_w = 1.0 * in_w / out_w + + out = np.zeros((batch_size, channel, out_h, out_w)) + + for i in range(out_h): + if (align_mode == 0 and not align_corners): + h = int(ratio_h * (i + 0.5) - 0.5) + else: + h = int(ratio_h * i) + + h = max(0, h) + hid = 1 if h < in_h - 1 else 0 + if (align_mode == 0 and not align_corners): + idx_src_h = max(ratio_h * (i + 0.5) - 0.5, 0) + h1lambda = idx_src_h - h + else: + h1lambda = ratio_h * i - h + h2lambda = 1.0 - h1lambda + for j in range(out_w): + if (align_mode == 0 and not align_corners): + w = int(ratio_w * (j + 0.5) - 0.5) + else: + w = int(ratio_w * j) + w = max(0, w) + wid = 1 if w < in_w - 1 else 0 + if (align_mode == 0 and not align_corners): + idx_src_w = max(ratio_w * (j + 0.5) - 0.5, 0) + w1lambda = idx_src_w - w + else: + w1lambda = ratio_w * j - w + w2lambda = 1.0 - w1lambda + + out[:, :, i, j] = h2lambda*(w2lambda*input[:, :, h, w] + + w1lambda*input[:, :, h, w+wid]) + \ + h1lambda*(w2lambda*input[:, :, h+hid, w] + + w1lambda*input[:, :, h+hid, w+wid]) + + if data_layout == "NHWC": + out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC + + return out.astype(input.dtype) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpOp(XPUOpTest): + def setUp(self): + self.use_xpu = True + self.out_size = None + self.actual_shape = None + self.data_layout = 'NCHW' + self.init_test_case() + self.op_type = "bilinear_interp" + input_np = np.random.random(self.input_shape).astype("float32") + + if self.data_layout == "NCHW": + in_h = self.input_shape[2] + in_w = self.input_shape[3] + else: + in_h = self.input_shape[1] + in_w = self.input_shape[2] + + if self.scale > 0: + out_h = int(in_h * self.scale) + out_w = int(in_w * self.scale) + else: + out_h = self.out_h + out_w = self.out_w + + output_np = bilinear_interp_np(input_np, out_h, out_w, self.out_size, + self.actual_shape, self.align_corners, + self.align_mode, self.data_layout) + self.inputs = {'X': input_np} + if self.out_size is not None: + self.inputs['OutSize'] = self.out_size + if self.actual_shape is not None: + self.inputs['OutSize'] = self.actual_shape + + self.attrs = { + 'out_h': self.out_h, + 'out_w': self.out_w, + 'scale': self.scale, + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + 'align_mode': self.align_mode, + 'data_layout': self.data_layout + } + self.outputs = {'Out': output_np} + + def test_check_output(self): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out', in_place=True) + + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 3, 5, 5] + self.out_h = 2 + self.out_w = 2 + self.scale = 0. + self.out_size = np.array([3, 3]).astype("int32") + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpCase1(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + self.scale = 0. + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpCase2(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.scale = 0. + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpCase3(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [1, 1, 32, 64] + self.out_h = 64 + self.out_w = 32 + self.scale = 0. + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpCase4(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + self.scale = 0. + self.out_size = np.array([2, 2]).astype("int32") + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpCase5(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.scale = 0. + self.out_size = np.array([11, 11]).astype("int32") + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpCase6(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [1, 1, 32, 64] + self.out_h = 64 + self.out_w = 32 + self.scale = 0. + self.out_size = np.array([65, 33]).astype("int32") + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpSame(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 3, 32, 64] + self.out_h = 32 + self.out_w = 64 + self.scale = 0. + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpActualShape(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [3, 2, 32, 16] + self.out_h = 64 + self.out_w = 32 + self.scale = 0. + self.out_size = np.array([66, 40]).astype("int32") + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpDataLayout(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 5, 5, 3] + self.out_h = 2 + self.out_w = 2 + self.scale = 0. + self.out_size = np.array([3, 3]).astype("int32") + self.align_corners = True + self.align_mode = 1 + self.data_layout = "NHWC" + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpOtherMethod1(TestBilinearInterpOp): + def set_align_mode(self): + self.align_corners = False + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpWithMethod2(TestBilinearInterpOp): + def set_align_mode(self): + self.align_corners = False + self.align_mode = 0 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpWithMethod3(TestBilinearInterpOp): + def set_align_mode(self): + self.align_corners = True + self.align_mode = 0 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpScale1(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 3, 5, 7] + self.out_h = 60 + self.out_w = 25 + self.scale = 2. + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpScale2(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 3, 5, 7] + self.out_h = 60 + self.out_w = 25 + self.scale = 1. + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpScale3(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 3, 5, 7] + self.out_h = 60 + self.out_w = 25 + self.scale = 1.5 + self.align_corners = True + self.align_mode = 1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpZero(TestBilinearInterpOp): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 3, 5, 7] + self.out_h = 60 + self.out_w = 25 + self.scale = 0.2 + self.align_corners = False + self.align_mode = 0 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpOp_attr_tensor(XPUOpTest): + def setUp(self): + self.out_size = None + self.actual_shape = None + self.init_test_case() + self.op_type = "bilinear_interp" + self.shape_by_1Dtensor = False + self.scale_by_1Dtensor = False + self.attrs = { + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + } + + input_np = np.random.random(self.input_shape).astype("float32") + self.inputs = {'X': input_np} + + if self.scale_by_1Dtensor: + self.inputs['Scale'] = np.array([self.scale]).astype("float32") + elif self.scale > 0: + out_h = int(self.input_shape[2] * self.scale) + out_w = int(self.input_shape[3] * self.scale) + self.attrs['scale'] = self.scale + else: + out_h = self.out_h + out_w = self.out_w + + if self.shape_by_1Dtensor: + self.inputs['OutSize'] = self.out_size + elif self.out_size is not None: + size_tensor = [] + for index, ele in enumerate(self.out_size): + size_tensor.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + self.inputs['SizeTensor'] = size_tensor + + self.attrs['out_h'] = self.out_h + self.attrs['out_w'] = self.out_w + output_np = bilinear_interp_np(input_np, out_h, out_w, self.out_size, + self.actual_shape, self.align_corners) + self.outputs = {'Out': output_np} + + def test_check_output(self): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out', in_place=True) + + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [2, 3, 5, 5] + self.out_h = 3 + self.out_w = 3 + self.scale = 0. + self.out_size = [3, 3] + self.align_corners = True + + +# out_size is a 1-D tensor +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterp_attr_tensor_Case1(TestBilinearInterpOp_attr_tensor): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.scale = 0. + self.out_size = [8, 12] + self.align_corners = True + + +# scale is a 1-D tensor +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterp_attr_tensor_Case2(TestBilinearInterpOp_attr_tensor): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [3, 2, 32, 16] + self.out_h = 64 + self.out_w = 32 + self.scale = 0. + self.out_size = np.array([66, 40]).astype("int32") + self.align_corners = True + self.shape_by_1Dtensor = True + + +# scale is a 1-D tensor +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterp_attr_tensor_Case3(TestBilinearInterpOp_attr_tensor): + def init_test_case(self): + self.interp_method = 'bilinear' + self.input_shape = [3, 2, 32, 16] + self.out_h = 64 + self.out_w = 32 + self.scale = 2.0 + self.out_size = None + self.align_corners = True + self.scale_by_1Dtensor = True + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestBilinearInterpOpAPI(unittest.TestCase): + def test_case(self): + x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32") + + dim = fluid.data(name="dim", shape=[1], dtype="int32") + shape_tensor = fluid.data(name="shape_tensor", shape=[2], dtype="int32") + actual_size = fluid.data(name="actual_size", shape=[2], dtype="int32") + scale_tensor = fluid.data( + name="scale_tensor", shape=[1], dtype="float32") + + out1 = fluid.layers.resize_bilinear(x, out_shape=[12, 12]) + out2 = fluid.layers.resize_bilinear(x, out_shape=[12, dim]) + out3 = fluid.layers.resize_bilinear(x, out_shape=shape_tensor) + out4 = fluid.layers.resize_bilinear( + x, out_shape=[4, 4], actual_shape=actual_size) + out5 = fluid.layers.resize_bilinear(x, scale=scale_tensor) + + x_data = np.random.random((2, 3, 6, 6)).astype("float32") + dim_data = np.array([12]).astype("int32") + shape_data = np.array([12, 12]).astype("int32") + actual_size_data = np.array([12, 12]).astype("int32") + scale_data = np.array([2.0]).astype("float32") + + place = core.XPUPlace(0) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + results = exe.run(fluid.default_main_program(), + feed={ + "x": x_data, + "dim": dim_data, + "shape_tensor": shape_data, + "actual_size": actual_size_data, + "scale_tensor": scale_data + }, + fetch_list=[out1, out2, out3, out4, out5], + return_numpy=True) + + expect_res = bilinear_interp_np( + x_data, out_h=12, out_w=12, align_corners=True) + for res in results: + self.assertTrue(np.allclose(res, expect_res)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_softmax_with_cross_entropy_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_softmax_with_cross_entropy_op_xpu.py index f734d3c25a06993d61e3d706c674c3425718395c..454ef0db0526266461369359b157c47b1b21b675 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_softmax_with_cross_entropy_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_softmax_with_cross_entropy_op_xpu.py @@ -17,7 +17,7 @@ import sys sys.path.append("..") from test_softmax_op import stable_softmax -from op_test import OpTest +from op_test_xpu import XPUOpTest import paddle.fluid.core as core import paddle @@ -45,7 +45,7 @@ def cross_entropy(softmax, label, soft_label, axis, ignore_index=-1): return result.reshape(label.shape) -class TestSoftmaxWithCrossEntropyOp(OpTest): +class TestSoftmaxWithCrossEntropyOp(XPUOpTest): """ Test softmax with cross entropy operator with discreate one-hot labels. """