diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 3228f5a556c2e0b7db2d04624e15d5a49b5b3d60..25d01912f141911ca26610dca232468c9e1b6993 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so") if(NOT DEFINED XPU_BASE_URL) set(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") - set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220707") + set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220708") else() set(XPU_BASE_URL "${XPU_BASE_URL}") endif() @@ -19,7 +19,7 @@ endif() if(NOT DEFINED XPU_XDNN_BASE_URL) set(XPU_XDNN_BASE_URL_WITHOUT_DATE "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev") - set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220707") + set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220708") else() set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}") endif() diff --git a/paddle/fluid/operators/activation_op_xpu.cc b/paddle/fluid/operators/activation_op_xpu.cc index 67a1d70ebad4458033605d08047f012c1f4abf8f..613eea90a6500dab4db55d6d44fb15cd1ed50e39 100644 --- a/paddle/fluid/operators/activation_op_xpu.cc +++ b/paddle/fluid/operators/activation_op_xpu.cc @@ -157,15 +157,6 @@ struct XPUReciprocalGradFunctor : public BaseActivationFunctor { } }; -template -struct XPUReluFunctor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - xpu_activation_forward( - ctx, xpu::relu); - } -}; - template struct XPUReluGradFunctor : public BaseActivationFunctor { using XPUType = typename XPUTypeTrait::Type; @@ -416,6 +407,24 @@ struct XPUPowGradFunctor : public BaseActivationFunctor { } }; +template +struct XPUReluFunctor : public BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + void operator()(const framework::ExecutionContext &ctx) const { + const auto *x = ctx.Input("X"); + auto *y = ctx.Output("Out"); + const XPUType *x_data = reinterpret_cast(x->data()); + XPUType *y_data = + reinterpret_cast(y->mutable_data(ctx.GetPlace())); + + auto xpu_context = + ctx.device_context().x_context(); + int r = + xpu::relu(xpu_context, x_data, y_data, x->numel(), nullptr, nullptr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu"); + } +}; + template struct XPUSoftPlusFunctor : public BaseActivationFunctor { void operator()(const framework::ExecutionContext &ctx) const { diff --git a/paddle/fluid/operators/grid_sampler_op_xpu.cc b/paddle/fluid/operators/grid_sampler_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..2843a90492cec5baeb6848c619107f22e6cfcc39 --- /dev/null +++ b/paddle/fluid/operators/grid_sampler_op_xpu.cc @@ -0,0 +1,138 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_XPU + +#include +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/device/device_wrapper.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class GridSamplerXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE_EQ( + platform::is_xpu_place(context.GetPlace()), + true, + platform::errors::Unavailable("This kernel only runs on XPU.")); + + // input and output data + const Tensor* input = context.Input("X"); + const Tensor* grid = context.Input("Grid"); + Tensor* output = context.Output("Output"); + + int n = input->dims()[0]; + int c = input->dims()[1]; + int h = input->dims()[2]; + int w = input->dims()[3]; + int out_h = grid->dims()[1]; + int out_w = grid->dims()[2]; + + // attrs + // paddle.nn.functional.grid_sample(x, grid, mode='bilinear', + // padding_mode='zeros', align_corners=True, name=None) + const std::string mode = context.Attr("mode"); + const std::string padding_mode = context.Attr("padding_mode"); + bool align_corners_bool = context.Attr("align_corners"); + const std::string data_format = + paddle::framework::DataLayoutToString(input->layout()); + + // attr to real param + bool is_nearest_bool; + if (mode == "bilinear") { + is_nearest_bool = false; + } else if (mode == "nearest") { + is_nearest_bool = true; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "should not reach here: mode should be either 'bilinear' or " + "'nearest', bot got %s.", + mode)); + } + + // attention: 0: zeros, 2: reflection, 1: border according to XDNN api. + int padding_mode_int; + if (padding_mode == "zeros") { + padding_mode_int = 0; + } else if (padding_mode == "reflection") { + padding_mode_int = 2; + } else if (padding_mode == "border") { + padding_mode_int = 1; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "should not reach here: padding_mode should be either 'zeros' or " + "'reflection' or 'border', bot got %s.", + padding_mode)); + } + + bool is_nchw_bool; + if (data_format == "NCHW") { + is_nchw_bool = true; + } else if (data_format == "NHWC") { + is_nchw_bool = false; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "should not reach here: data_format should be either 'NCHW' or " + "'NHWC', bot got %s.", + data_format)); + } + + // data pointers + const T* input_data = input->data(); + const T* grid_data = grid->data(); + T* output_data = + output->mutable_data({n, c, out_h, out_w}, context.GetPlace()); + + auto& dev_ctx = context.template device_context(); + // int grid_sample(Context* ctx, const T* x, const T* grid, T* y, int n, int + // c, int xh, int xw, int yh, int yw, bool is_nearest, bool align_corners, + // int padding_mode, bool is_nchw); + int r = xpu::grid_sample(dev_ctx.x_context(), + input_data, + grid_data, + output_data, + n, + c, + h, + w, + out_h, + out_w, + is_nearest_bool, + align_corners_bool, + padding_mode_int, + is_nchw_bool); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "grid_sampler"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_XPU_KERNEL( + grid_sampler, + ops::GridSamplerXPUKernel); + +#endif diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index c5a70b03cd3c86df311786a68742f6de579b19ad..7e9c61289b67f42e9f0b6d6dc3e537fd421b4cec 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -240,6 +240,8 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, + {"grid_sampler", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"hard_swish_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, diff --git a/python/paddle/fluid/tests/unittests/xpu/test_grid_sampler_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_grid_sampler_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..967815cc559ee40c3214c3d69401e8140f6e13f7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_grid_sampler_op_xpu.py @@ -0,0 +1,284 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys + +sys.path.append("..") + +import paddle + +from op_test import OpTest +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper + +paddle.enable_static() + + +def AffineGrid(theta, grid_shape): + n = grid_shape[0] + h = grid_shape[1] + w = grid_shape[2] + h_idx = np.repeat(np.linspace(-1, 1, h)[np.newaxis, :], w, + axis=0).T[:, :, np.newaxis] + w_idx = np.repeat(np.linspace(-1, 1, w)[np.newaxis, :], h, + axis=0)[:, :, np.newaxis] + grid = np.concatenate([w_idx, h_idx, np.ones([h, w, 1])], + axis=2) # h * w * 3 + grid = np.repeat(grid[np.newaxis, :], n, axis=0) # n * h * w *3 + + ret = np.zeros([n, h * w, 2]) + theta = theta.transpose([0, 2, 1]) + for i in range(len(theta)): + ret[i] = np.dot(grid[i].reshape([h * w, 3]), theta[i]) + + return ret.reshape([n, h, w, 2]).astype("float64") + + +def getGridPointValue(data, x, y): + data_shape = data.shape + N = data_shape[0] + C = data_shape[1] + in_H = data_shape[2] + in_W = data_shape[3] + out_H = x.shape[1] + out_W = x.shape[2] + + #out = np.zeros(data_shape, dtype='float64') + out = np.zeros([N, C, out_H, out_W], dtype='float64') + for i in range(N): + for j in range(out_H): + for k in range(out_W): + if y[i, j, k] < 0 or y[i, j, k] > in_H - 1 or x[ + i, j, k] < 0 or x[i, j, k] > in_W - 1: + out[i, :, j, k] = 0 + else: + out[i, :, j, k] = data[i, :, y[i, j, k], x[i, j, k]] + + return out + + +def clip(x, min_n, max_n): + return np.maximum(np.minimum(x, max_n), min_n) + + +def unnormalizeAndClip(grid_slice, max_val, align_corners, padding_mode): + if align_corners: + grid_slice = 0.5 * ((grid_slice.astype('float64') + 1.0) * max_val) + else: + grid_slice = 0.5 * ((grid_slice.astype('float64') + 1.0) * + (max_val + 1)) - 0.5 + + if padding_mode == "border": + grid_slice = clip(grid_slice, 0, max_val) + elif padding_mode == "reflection": + double_range = 2 * max_val if align_corners else (max_val + 1) * 2 + grid_abs = np.abs(grid_slice) if align_corners else np.abs(grid_slice + + 0.5) + extra = grid_abs - np.floor(grid_abs / double_range) * double_range + grid_slice = np.minimum(extra, double_range - extra) + grid_slice = grid_slice if align_corners else clip( + grid_slice - 0.5, 0, max_val) + return grid_slice + + +def GridSampler(data, + grid, + align_corners=True, + mode="bilinear", + padding_mode="zeros"): + dims = data.shape + N = dims[0] + in_C = dims[1] + in_H = dims[2] + in_W = dims[3] + + out_H = grid.shape[1] + out_W = grid.shape[2] + + x = grid[:, :, :, 0] + y = grid[:, :, :, 1] + y_max = in_H - 1 + x_max = in_W - 1 + + x = unnormalizeAndClip(x, x_max, align_corners, padding_mode) + y = unnormalizeAndClip(y, y_max, align_corners, padding_mode) + + if mode == "bilinear": + x0 = np.floor(x).astype('int32') + x1 = x0 + 1 + y0 = np.floor(y).astype('int32') + y1 = y0 + 1 + + wa = np.tile(((x1 - x) * (y1 - y)).reshape((N, 1, out_H, out_W)), + (1, in_C, 1, 1)) + wb = np.tile(((x1 - x) * (y - y0)).reshape((N, 1, out_H, out_W)), + (1, in_C, 1, 1)) + wc = np.tile(((x - x0) * (y1 - y)).reshape((N, 1, out_H, out_W)), + (1, in_C, 1, 1)) + wd = np.tile(((x - x0) * (y - y0)).reshape((N, 1, out_H, out_W)), + (1, in_C, 1, 1)) + + va = getGridPointValue(data, x0, y0) + vb = getGridPointValue(data, x0, y1) + vc = getGridPointValue(data, x1, y0) + vd = getGridPointValue(data, x1, y1) + + out = (wa * va + wb * vb + wc * vc + wd * vd).astype('float64') + elif mode == "nearest": + x = np.round(x).astype('int32') + y = np.round(y).astype('int32') + out = getGridPointValue(data, x, y) + return out + + +class XPUTestGridSamplerOP(XPUOpTestWrapper): + + def __init__(self): + self.op_name = 'grid_sampler' + self.use_dynamic_create_class = False + + class TestXPUGridSamplerOp(XPUOpTest): + + def setUp(self): + self.place = paddle.XPUPlace(0) + self.init_dtype() + self.op_type = 'grid_sampler' + + self.use_cudnn = False + self.align_corners = True + self.padding_mode = "zeros" + self.mode = "bilinear" + + self.initTestCase() + + x = np.random.uniform(-10, 10, self.x_shape).astype(self.dtype) + + theta = np.zeros(self.theta_shape).astype(self.dtype) + for i in range(self.theta_shape[0]): + for j in range(2): + for k in range(3): + theta[i, j, k] = np.random.rand(1)[0] + grid = AffineGrid(theta, self.grid_shape).astype(self.dtype) + + self.inputs = {'X': x, 'Grid': grid} + self.attrs = { + 'use_cudnn': self.use_cudnn, + "align_corners": self.align_corners, + "padding_mode": self.padding_mode, + "mode": self.mode, + } + self.outputs = { + 'Output': + GridSampler(x, grid, self.align_corners, self.mode, + self.padding_mode) + } + + def initTestCase(self): + self.x_shape = (2, 3, 8, 8) + self.grid_shape = (2, 7, 9, 2) + self.theta_shape = (2, 2, 3) + self.align_corners = True + self.padding_mode = "zeros" + self.mode = "bilinear" + + def init_dtype(self): + self.dtype = self.in_type + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X', 'Grid'], 'Output') + + class TestGridSample1(TestXPUGridSamplerOp): + + def initTestCase(self): + self.x_shape = (2, 3, 5, 6) + self.grid_shape = (2, 8, 9, 2) + self.theta_shape = (2, 2, 3) + self.align_corners = False + self.padding_mode = "zeros" + self.mode = "bilinear" + + class TestGridSample2(TestXPUGridSamplerOp): + + def initTestCase(self): + self.x_shape = (2, 3, 5, 6) + self.grid_shape = (2, 8, 9, 2) + self.theta_shape = (2, 2, 3) + self.align_corners = False + self.padding_mode = "border" + self.mode = "bilinear" + + class TestGridSample3(TestXPUGridSamplerOp): + + def initTestCase(self): + self.x_shape = (2, 3, 5, 6) + self.grid_shape = (2, 8, 9, 2) + self.theta_shape = (2, 2, 3) + self.align_corners = False + self.padding_mode = "reflection" + self.mode = "bilinear" + + class TestGridSample4(TestXPUGridSamplerOp): + + def initTestCase(self): + self.x_shape = (2, 3, 5, 6) + self.grid_shape = (2, 8, 9, 2) + self.theta_shape = (2, 2, 3) + self.align_corners = True + self.padding_mode = "reflection" + self.mode = "bilinear" + + class TestGridSample5(TestXPUGridSamplerOp): + + def initTestCase(self): + self.x_shape = (2, 3, 5, 6) + self.grid_shape = (2, 8, 9, 2) + self.theta_shape = (2, 2, 3) + self.align_corners = False + self.padding_mode = "reflection" + self.mode = "nearest" + + class TestGridSample6(TestXPUGridSamplerOp): + + def initTestCase(self): + self.x_shape = (2, 3, 128, 128) + self.grid_shape = (2, 130, 130, 2) + self.theta_shape = (2, 2, 3) + self.align_corners = False + self.padding_mode = "reflection" + self.mode = "bilinear" + + class TestGridSample7(TestXPUGridSamplerOp): + + def initTestCase(self): + self.x_shape = (2, 3, 128, 128) + self.grid_shape = (2, 130, 130, 2) + self.theta_shape = (2, 2, 3) + self.align_corners = True + self.padding_mode = "zeros" + self.mode = "bilinear" + + +support_types = get_xpu_op_support_types('grid_sampler') +for stype in support_types: + create_test_class(globals(), XPUTestGridSamplerOP, stype) + +if __name__ == '__main__': + unittest.main()