diff --git a/paddle/fluid/operators/activation_op_mlu.cc b/paddle/fluid/operators/activation_op_mlu.cc index d1087965f044e35e8c2f7b79bd7fc082cd47b770..72e0e9ceacf48e90919877c9f2667efef6d92fbf 100644 --- a/paddle/fluid/operators/activation_op_mlu.cc +++ b/paddle/fluid/operators/activation_op_mlu.cc @@ -399,6 +399,25 @@ class HardSigmoidGradMLUKernel : public framework::OpKernel { } }; +template +class FloorMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + output->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc input_desc(*input); + MLUCnnlTensorDesc output_desc(*output); + + MLUCnnl::Floor(ctx, + input_desc.get(), + GetBasePtr(input), + output_desc.get(), + GetBasePtr(output)); + } +}; + template class ReciprocalMLUKernel : public framework::OpKernel { public: @@ -589,3 +608,7 @@ REGISTER_OP_MLU_KERNEL( hard_sigmoid_grad, ops::HardSigmoidGradMLUKernel, ops::HardSigmoidGradMLUKernel); + +REGISTER_OP_MLU_KERNEL(floor, + ops::FloorMLUKernel, + ops::FloorMLUKernel); diff --git a/paddle/fluid/operators/grid_sampler_op_mlu.cc b/paddle/fluid/operators/grid_sampler_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..8327eaad144257b871e0f28d3204d3a1f8934563 --- /dev/null +++ b/paddle/fluid/operators/grid_sampler_op_mlu.cc @@ -0,0 +1,112 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class GridSamplerMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE_EQ( + platform::is_mlu_place(ctx.GetPlace()), + true, + platform::errors::Unavailable("This kernel only runs on MLU.")); + + // input and output data + const Tensor* input = ctx.Input("X"); + const Tensor* grid = ctx.Input("Grid"); + Tensor* output = ctx.Output("Output"); + + int n = input->dims()[0]; + int c = input->dims()[1]; + int out_h = grid->dims()[1]; + int out_w = grid->dims()[2]; + + output->mutable_data({n, c, out_h, out_w}, ctx.GetPlace()); + + // attrs + // paddle.nn.functional.grid_sample(x, grid, mode='bilinear', + // padding_mode='zeros', align_corners=True, name=None) + const std::string mode = ctx.Attr("mode"); + const std::string padding_mode = ctx.Attr("padding_mode"); + bool align_corners = ctx.Attr("align_corners"); + const std::string data_format = + paddle::framework::DataLayoutToString(input->layout()); + + PADDLE_ENFORCE_EQ( + mode == "bilinear", + true, + platform::errors::Unavailable( + "Only support bilinear mode in mlu grid_sample kernel.")); + PADDLE_ENFORCE_EQ( + padding_mode == "zeros", + true, + platform::errors::Unavailable( + "Only support zeros padding_mode in mlu grid_sample kernel.")); + + Tensor trans_input(input->dtype()); + // transpose input from NCHW to NHWC + const std::vector perm_to_nhwc = {0, 2, 3, 1}; + TransposeFromMLUTensor( + ctx, perm_to_nhwc, input, &trans_input, true /*need_reshape_or_alloc*/); + + Tensor tmp_output(output->dtype()); + tmp_output.mutable_data({n, out_h, out_w, c}, ctx.GetPlace()); + + MLUCnnlGridSampleDesc grid_sample_desc(mode, padding_mode, align_corners); + MLUCnnlTensorDesc input_desc( + trans_input, CNNL_LAYOUT_NHWC, ToCnnlDataType()); + MLUCnnlTensorDesc grid_desc(*grid, CNNL_LAYOUT_NHWC, ToCnnlDataType()); + MLUCnnlTensorDesc tmp_output_desc( + tmp_output, CNNL_LAYOUT_NHWC, ToCnnlDataType()); + + MLUCnnl::GridSample(ctx, + grid_sample_desc.get(), + input_desc.get(), + GetBasePtr(&trans_input), + grid_desc.get(), + GetBasePtr(grid), + tmp_output_desc.get(), + GetBasePtr(&tmp_output)); + + // transpose output from NHWC to NCHW + const std::vector perm_to_nchw = { + 0, + 3, + 1, + 2, + }; + TransposeFromMLUTensor(ctx, + perm_to_nchw, + &tmp_output, + output, + false /*need_reshape_or_alloc*/); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(grid_sampler, + ops::GridSamplerMLUKernel, + ops::GridSamplerMLUKernel); diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index 95a365f459f18033b9712ed156efe9ef5e6a9faf..4cd754775d9c0ab3a3ff3f1a5109807de81f3169 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -622,6 +622,29 @@ MLUCnnlDCNDesc::~MLUCnnlDCNDesc() { } } +MLUCnnlGridSampleDesc::MLUCnnlGridSampleDesc( + const std::string& interp_mode_str, + const std::string& padding_mode_str, + bool align_corners) { + cnnlInterpMode_t interp_mode = CNNL_INTERP_BILINEAR; + cnnlGridSamplePaddingMode_t padding_mode = CNNL_GRIDSAMPLE_PADDING_ZEROS; + PADDLE_ENFORCE_MLU_SUCCESS( + cnnlCreateGridSampleDescriptor(&grid_sample_desc_)); + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetGridSampleDescriptor( + grid_sample_desc_, interp_mode, padding_mode, align_corners)); +} + +const cnnlGridSampleDescriptor_t MLUCnnlGridSampleDesc::get() const { + return grid_sample_desc_; +} + +MLUCnnlGridSampleDesc::~MLUCnnlGridSampleDesc() { + if (grid_sample_desc_) { + PADDLE_ENFORCE_MLU_SUCCESS( + cnnlDestroyGridSampleDescriptor(grid_sample_desc_)); + } +} + MLUSeqDataDesc::MLUSeqDataDesc(cnnlSeqDataLayout_t layout, cnnlDataType_t dtype, int dimNb, @@ -4918,6 +4941,38 @@ MLURNNDesc::~MLURNNDesc() { grads_image)); } +/* static */ void MLUCnnl::GridSample( + const ExecutionContext& ctx, + const cnnlGridSampleDescriptor_t grid_sample_desc, + const cnnlTensorDescriptor_t input_desc, + const void* input, + const cnnlTensorDescriptor_t grid_desc, + const void* grid, + const cnnlTensorDescriptor_t output_desc, + void* output) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + size_t workspace_size; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetGridSampleForwardWorkspaceSize( + handle, input_desc, grid_desc, output_desc, &workspace_size)); + + auto& dev_ctx = GetDevCtxFromCTX(ctx); + Tensor workspace = ctx.AllocateTmpTensor( + {static_cast(workspace_size)}, dev_ctx); + void* workspace_ptr = workspace.mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGridSampleForward(handle, + grid_sample_desc, + input_desc, + input, + grid_desc, + grid, + output_desc, + output, + workspace_ptr, + workspace_size)); +} + /* static */ void MLUCnnl::SyncBatchNormStats( const ExecutionContext& ctx, const cnnlTensorDescriptor_t x_desc, diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index 72446f56a18dc89c4d0abdd4c21532431969e4a6..e56331b2728c4353dcc5aec833ba263db66d58e1 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -495,6 +495,20 @@ class MLUCnnlDCNDesc { cnnlDCNDescriptor_t dcn_desc_ = nullptr; }; +class MLUCnnlGridSampleDesc { + public: + MLUCnnlGridSampleDesc(const std::string& interp_mode_str, + const std::string& padding_mode_str, + bool align_corners); + + const cnnlGridSampleDescriptor_t get() const; + + ~MLUCnnlGridSampleDesc(); + + private: + cnnlGridSampleDescriptor_t grid_sample_desc_ = nullptr; +}; + class MLUSeqDataDesc { public: MLUSeqDataDesc(const MLUSeqDataDesc& desc) = delete; @@ -2040,6 +2054,15 @@ class MLUCnnl { const cnnlTensorDescriptor_t grads_image_desc, void* grads_image); + static void GridSample(const ExecutionContext& ctx, + const cnnlGridSampleDescriptor_t grid_sample_desc, + const cnnlTensorDescriptor_t input_desc, + const void* input, + const cnnlTensorDescriptor_t grid_desc, + const void* grid, + const cnnlTensorDescriptor_t output_desc, + void* output); + static void SyncBatchNormStats(const ExecutionContext& ctx, const cnnlTensorDescriptor_t x_desc, const void* x, diff --git a/paddle/fluid/platform/device/mlu/mlu_info.h b/paddle/fluid/platform/device/mlu/mlu_info.h index 14f37879ef070db871605d57194d20f62f73bd25..032606dd1c5b01efd9649b6d9cd894d4842a5073 100644 --- a/paddle/fluid/platform/device/mlu/mlu_info.h +++ b/paddle/fluid/platform/device/mlu/mlu_info.h @@ -16,9 +16,9 @@ limitations under the License. */ #ifdef PADDLE_WITH_MLU #include -#include #include #include +#include #include #ifdef PADDLE_WITH_CNCL #include diff --git a/python/paddle/fluid/tests/unittests/mlu/test_floor_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_floor_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..275e3bd14a7bc201846a8eda7835c6a5a59d6c68 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_floor_op_mlu.py @@ -0,0 +1,59 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest + +import numpy as np +import sys + +sys.path.append('..') +from op_test import OpTest +import paddle + +paddle.enable_static() + + +class TestFloor(OpTest): + + def setUp(self): + self.op_type = "floor" + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.init_dtype() + self.__class__.no_need_check_grad = True + self.python_api = paddle.floor + + np.random.seed(1024) + x = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype) + out = np.floor(x) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output_with_place(self.place, check_eager=False) + + def init_dtype(self): + self.dtype = np.float32 + + +class TestFloorFP16(TestFloor): + + def init_dtype(self): + self.dtype = np.float16 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mlu/test_grid_sampler_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_grid_sampler_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..96dbaab9ee157567a0896dea6ea5217d213407cc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_grid_sampler_op_mlu.py @@ -0,0 +1,223 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import unittest +import numpy as np +import paddle.fluid.core as core +import sys + +sys.path.append('..') +from op_test import OpTest + +paddle.enable_static() + + +def AffineGrid(theta, grid_shape): + n = grid_shape[0] + h = grid_shape[1] + w = grid_shape[2] + h_idx = np.repeat(np.linspace(-1, 1, h)[np.newaxis, :], w, + axis=0).T[:, :, np.newaxis] + w_idx = np.repeat(np.linspace(-1, 1, w)[np.newaxis, :], h, + axis=0)[:, :, np.newaxis] + grid = np.concatenate([w_idx, h_idx, np.ones([h, w, 1])], + axis=2) # h * w * 3 + grid = np.repeat(grid[np.newaxis, :], n, axis=0) # n * h * w *3 + + ret = np.zeros([n, h * w, 2]) + theta = theta.transpose([0, 2, 1]) + for i in range(len(theta)): + ret[i] = np.dot(grid[i].reshape([h * w, 3]), theta[i]) + + return ret.reshape([n, h, w, 2]).astype("float32") + + +def getGridPointValue(data, x, y): + data_shape = data.shape + N = data_shape[0] + C = data_shape[1] + in_H = data_shape[2] + in_W = data_shape[3] + out_H = x.shape[1] + out_W = x.shape[2] + + #out = np.zeros(data_shape, dtype='float32') + out = np.zeros([N, C, out_H, out_W], dtype='float32') + for i in range(N): + for j in range(out_H): + for k in range(out_W): + if y[i, j, k] < 0 or y[i, j, k] > in_H - 1 or x[ + i, j, k] < 0 or x[i, j, k] > in_W - 1: + out[i, :, j, k] = 0 + else: + out[i, :, j, k] = data[i, :, y[i, j, k], x[i, j, k]] + + return out + + +def clip(x, min_n, max_n): + return np.maximum(np.minimum(x, max_n), min_n) + + +def unnormalizeAndClip(grid_slice, max_val, align_corners, padding_mode): + if align_corners: + grid_slice = 0.5 * ((grid_slice.astype('float32') + 1.0) * max_val) + else: + grid_slice = 0.5 * ((grid_slice.astype('float32') + 1.0) * + (max_val + 1)) - 0.5 + + if padding_mode == "border": + grid_slice = clip(grid_slice, 0, max_val) + elif padding_mode == "reflection": + double_range = 2 * max_val if align_corners else (max_val + 1) * 2 + grid_abs = np.abs(grid_slice) if align_corners else np.abs(grid_slice + + 0.5) + extra = grid_abs - np.floor(grid_abs / double_range) * double_range + grid_slice = np.minimum(extra, double_range - extra) + grid_slice = grid_slice if align_corners else clip( + grid_slice - 0.5, 0, max_val) + return grid_slice + + +def GridSampler(data, + grid, + align_corners=True, + mode="bilinear", + padding_mode="zeros"): + dims = data.shape + N = dims[0] + in_C = dims[1] + in_H = dims[2] + in_W = dims[3] + + out_H = grid.shape[1] + out_W = grid.shape[2] + + x = grid[:, :, :, 0] + y = grid[:, :, :, 1] + y_max = in_H - 1 + x_max = in_W - 1 + + x = unnormalizeAndClip(x, x_max, align_corners, padding_mode) + y = unnormalizeAndClip(y, y_max, align_corners, padding_mode) + + if mode == "bilinear": + x0 = np.floor(x).astype('int32') + x1 = x0 + 1 + y0 = np.floor(y).astype('int32') + y1 = y0 + 1 + + wa = np.tile(((x1 - x) * (y1 - y)).reshape((N, 1, out_H, out_W)), + (1, in_C, 1, 1)) + wb = np.tile(((x1 - x) * (y - y0)).reshape((N, 1, out_H, out_W)), + (1, in_C, 1, 1)) + wc = np.tile(((x - x0) * (y1 - y)).reshape((N, 1, out_H, out_W)), + (1, in_C, 1, 1)) + wd = np.tile(((x - x0) * (y - y0)).reshape((N, 1, out_H, out_W)), + (1, in_C, 1, 1)) + + va = getGridPointValue(data, x0, y0) + vb = getGridPointValue(data, x0, y1) + vc = getGridPointValue(data, x1, y0) + vd = getGridPointValue(data, x1, y1) + + out = (wa * va + wb * vb + wc * vc + wd * vd).astype('float32') + elif mode == "nearest": + x = np.round(x).astype('int32') + y = np.round(y).astype('int32') + out = getGridPointValue(data, x, y) + return out + + +class TestGridSamplerOp(OpTest): + + def setUp(self): + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.__class__.no_need_check_grad = True + self.op_type = 'grid_sampler' + self.align_corners = True + self.padding_mode = "zeros" + self.mode = "bilinear" + self.initTestCase() + x = np.random.randint(0, 255, self.x_shape).astype('float32') + + theta = np.zeros(self.theta_shape).astype('float32') + for i in range(self.theta_shape[0]): + for j in range(2): + for k in range(3): + theta[i, j, k] = np.random.rand(1)[0] + grid = AffineGrid(theta, self.grid_shape) + + self.inputs = {'X': x, 'Grid': grid} + self.attrs = { + 'use_cudnn': False, + "align_corners": self.align_corners, + "padding_mode": self.padding_mode, + "mode": self.mode + } + self.outputs = { + 'Output': + GridSampler(x, grid, self.align_corners, self.mode, + self.padding_mode) + } + + def test_check_output(self): + self.check_output_with_place(self.place) + + def initTestCase(self): + self.x_shape = (2, 3, 8, 8) + self.grid_shape = (2, 7, 9, 2) + self.theta_shape = (2, 2, 3) + self.align_corners = False + self.padding_mode = "zeros" + self.mode = "bilinear" + + +class Case1(TestGridSamplerOp): + + def initTestCase(self): + self.x_shape = (2, 3, 5, 6) + self.grid_shape = (2, 8, 9, 2) + self.theta_shape = (2, 2, 3) + self.align_corners = True + self.padding_mode = "zeros" + self.mode = "bilinear" + + +class LargeInputCase(TestGridSamplerOp): + + def initTestCase(self): + self.x_shape = (2, 3, 128, 128) + self.grid_shape = (2, 130, 130, 2) + self.theta_shape = (2, 2, 3) + self.align_corners = False + self.padding_mode = "zeros" + self.mode = "bilinear" + + +class Case2(LargeInputCase): + + def initTestCase(self): + self.x_shape = (2, 3, 128, 128) + self.grid_shape = (2, 130, 130, 2) + self.theta_shape = (2, 2, 3) + self.align_corners = True + self.padding_mode = "zeros" + self.mode = "bilinear" + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/dockerfile/Dockerfile.mlu b/tools/dockerfile/Dockerfile.mlu index 07535a637431e49ddef65a648b601697e05c1162..3fa563ff65e1debd4b3b499d004b73b8ea561c4b 100644 --- a/tools/dockerfile/Dockerfile.mlu +++ b/tools/dockerfile/Dockerfile.mlu @@ -2,14 +2,14 @@ # Update CNTOOLKIT_VERSION, CNNL_VERSION and CNCL_VERSION if using other versions # # Build: -# - CNTOOLKIT_VERSION 2.8.1-1 -# - CNNL_VERSION 1.9.3-1 -# - CNCL_VERSION 1.0.4-1 +# - CNTOOLKIT_VERSION 3.0.0-1 +# - CNNL_VERSION 1.11.0-1 +# - CNCL_VERSION 1.2.0-1 # # Download three packages from FTP (need to connect cambricon AE to get FTP url) -# - cntoolkit_2.6.5-1.ubuntu18.04_amd64.deb -# - cnnl_1.8.3-1.ubuntu18.04_amd64.deb -# - cncl_1.0.2-1.ubuntu18.04_amd64.deb +# - cntoolkit_3.0.0-1.ubuntu18.04_amd64.deb +# - cnnl_1.11.0-1.ubuntu18.04_amd64.deb +# - cncl_1.2.0-1.ubuntu18.04_amd64.deb # copy them to current directory first, then run build commands # # For example: @@ -21,9 +21,9 @@ # (get cncl pkg) # # docker build -f Dockerfile.mlu \ -# --build-arg CNTOOLKIT_VERSION=2.8.1-1 \ -# --build-arg CNNL_VERSION=1.9.3-1 \ -# --build-arg CNCL_VERSION=1.0.4-1 \ +# --build-arg CNTOOLKIT_VERSION=3.0.0-1 \ +# --build-arg CNNL_VERSION=1.11.0-1 \ +# --build-arg CNCL_VERSION=1.2.0-1 \ # -t paddlepaddle/paddle:latest-dev-mlu . # # without mlu device: @@ -40,9 +40,9 @@ MAINTAINER PaddlePaddle Authors ENV WITH_GPU=OFF -ARG CNTOOLKIT_VERSION=2.8.1-1 -ARG CNNL_VERSION=1.9.3-1 -ARG CNCL_VERSION=1.0.4-1 +ARG CNTOOLKIT_VERSION=3.0.0-1 +ARG CNNL_VERSION=1.11.0-1 +ARG CNCL_VERSION=1.2.0-1 ARG CNTOOLKIT_PKG=cntoolkit_$CNTOOLKIT_VERSION.ubuntu18.04_amd64.deb ARG CNNL_PKG=cnnl_$CNNL_VERSION.ubuntu18.04_amd64.deb ARG CNCL_PKG=cncl_$CNCL_VERSION.ubuntu18.04_amd64.deb