diff --git a/paddle/fluid/operators/interpolate_op_npu.cc b/paddle/fluid/operators/interpolate_op_npu.cc new file mode 100755 index 0000000000000000000000000000000000000000..8d4b1e00c5d89a49b5dc55c994c1419ecd686d93 --- /dev/null +++ b/paddle/fluid/operators/interpolate_op_npu.cc @@ -0,0 +1,214 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/interpolate_op.h" +#include +#include +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; +using DataLayout = framework::DataLayout; + +inline static void CheckArgument(const framework::ExecutionContext& ctx) { + const std::string interp_method = ctx.Attr("interp_method"); + bool align_corners = ctx.Attr("align_corners"); + PADDLE_ENFORCE_EQ( + align_corners, false, + platform::errors::InvalidArgument( + "NPU Interpolate Kernel has diff when align_corners is true.")); + PADDLE_ENFORCE_EQ( + interp_method, "nearest", + platform::errors::InvalidArgument( + "NPU Interpolate Kernel only support nearest interpolotion.")); +} + +inline static void ExtractNCHW(const framework::DDim& dims, + const DataLayout& data_layout, int32_t* n, + int32_t* c, int32_t* h, int32_t* w) { + *n = dims[0]; + if (data_layout == DataLayout::kNCHW) { + *c = dims[1]; + *h = dims[2]; + *w = dims[3]; + } else { // kNHWC + *h = dims[1]; + *w = dims[2]; + *c = dims[3]; + } +} + +static void CalcOutSize(const framework::ExecutionContext& ctx, int32_t in_h, + int32_t in_w, int32_t* out_h, int32_t* out_w) { + // Priority: SizeTensor > OutSize > Scale > scale > out_h & out_w + *out_h = ctx.Attr("out_h"); + *out_w = ctx.Attr("out_w"); + + auto dev_ctx = platform::DeviceContextPool::Instance().Get(ctx.GetPlace()); + auto list_new_size_tensor = ctx.MultiInput("SizeTensor"); + + if (list_new_size_tensor.size() > 0) { + std::vector new_size_h(1); + std::vector new_size_w(1); + framework::TensorToVector(*list_new_size_tensor[0], *dev_ctx, &new_size_h); + framework::TensorToVector(*list_new_size_tensor[1], *dev_ctx, &new_size_w); + *out_h = new_size_h[0]; + *out_w = new_size_w[0]; + } else { + float scale; + auto scale_tensor = ctx.Input("Scale"); + if (scale_tensor != nullptr) { + std::vector scale_data; + framework::TensorToVector(*scale_tensor, *dev_ctx, &scale_data); + scale = scale_data[0]; + } else { + scale = ctx.Attr("scale"); + } + + if (scale > 0) { + *out_h = static_cast(in_h * scale); + *out_w = static_cast(in_w * scale); + } + + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + std::vector out_size_data; + framework::TensorToVector(*out_size, *dev_ctx, &out_size_data); + *out_h = out_size_data[0]; + *out_w = out_size_data[1]; + } + } + + PADDLE_ENFORCE_GT(*out_h, 0, + platform::errors::InvalidArgument( + "out_h in Attr(out_shape) of Op(interpolate) " + "should be greater than 0.")); + PADDLE_ENFORCE_GT(*out_w, 0, + platform::errors::InvalidArgument( + "out_w in Attr(out_shape) of Op(interpolate) " + "should be greater than 0.")); +} + +template +class InterpolateNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // NOTE(Ruibiao): + // this kernel only support nearest interpolotion for 2D images + // the Ascend 'ResizeNearestNeighborV2' used in this kernle has diff + // when 'align_corners' is 'true' or data type is 'double' + CheckArgument(ctx); + + auto* input = ctx.Input("X"); + framework::DDim input_dims = input->dims(); + + const std::string data_layout_str = + ctx.Attr("data_layout"); // kNCHW or kNHWC + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); + + int32_t n, c, h, w, out_h, out_w; + ExtractNCHW(input_dims, data_layout, &n, &c, &h, &w); + CalcOutSize(ctx, h, w, &out_h, &out_w); + + // the 'input' tensor may has no set (or wrong set) of the layout + Tensor input_x(input->type()); + input_x.ShareDataWith(*input); + input_x.set_layout(data_layout); + + auto* output = ctx.Output("Out"); + framework::DDim output_dims; + if (data_layout == DataLayout::kNCHW) { + output_dims = {n, c, out_h, out_w}; + } else { + output_dims = {n, out_h, out_w, c}; + } + output->set_layout(data_layout); + output->mutable_data(output_dims, ctx.GetPlace()); + + NpuOpRunner npu_op_runner; + auto npu_stream = + ctx.template device_context() + .stream(); + npu_op_runner.SetType("ResizeNearestNeighborV2") + .AddInput(input_x) + .AddInput(std::vector{out_h, out_w}) + .AddOutput(*output) + .AddAttr("align_corners", false) + .AddAttr("half_pixel_centers", false) + .Run(npu_stream); + } +}; + +template +class InterpolateGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // NOTE(Ruibiao): + // this kernel only support nearest interpolotion for 2D images + // the Ascend 'ResizeNearestNeighborV2' used in this kernle has diff + // when 'align_corners' is 'true' or data type is 'double' + CheckArgument(ctx); + + auto* input = ctx.Input("X"); + framework::DDim input_dims = input->dims(); + + const std::string data_layout_str = + ctx.Attr("data_layout"); // kNCHW or kNHWC + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); + + int32_t n, c, h, w, out_h, out_w; + ExtractNCHW(input_dims, data_layout, &n, &c, &h, &w); + CalcOutSize(ctx, h, w, &out_h, &out_w); + + // the 'output_grad' tensor may has no set (or wrong set) of the layout + auto* output_grad = ctx.Input(framework::GradVarName("Out")); + Tensor output_grad_tmp(output_grad->type()); + output_grad_tmp.ShareDataWith(*output_grad); + output_grad_tmp.set_layout(data_layout); + + auto* input_grad = ctx.Output(framework::GradVarName("X")); + input_grad->set_layout(data_layout); + framework::DDim input_grad_dims; + if (data_layout == DataLayout::kNCHW) { + input_grad_dims = {n, c, h, w}; + } else { + input_grad_dims = {n, h, w, c}; + } + input_grad->mutable_data(input_grad_dims, ctx.GetPlace()); + + NpuOpRunner npu_op_runner; + auto npu_stream = + ctx.template device_context() + .stream(); + npu_op_runner.SetType("ResizeNearestNeighborV2Grad") + .AddInput(output_grad_tmp) + .AddInput(std::vector{h, w}) + .AddOutput(*input_grad) + .AddAttr("align_corners", false) + .AddAttr("half_pixel_centers", false) + .Run(npu_stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_NPU_KERNEL(nearest_interp, ops::InterpolateNPUKernel, + ops::InterpolateNPUKernel); +REGISTER_OP_NPU_KERNEL(nearest_interp_grad, + ops::InterpolateGradNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_nearest_interp_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_nearest_interp_op_npu.py new file mode 100755 index 0000000000000000000000000000000000000000..c6f85c8dee40ce57b1cca0008ed45cfa4b80c889 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_nearest_interp_op_npu.py @@ -0,0 +1,461 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +from test_nearest_interp_op import nearest_neighbor_interp_np + +paddle.enable_static() + + +class TestNearestInterpOp(OpTest): + def setUp(self): + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) + + self.out_size = None + self.actual_shape = None + self.data_layout = 'NCHW' + self.init_test_case() + self.op_type = "nearest_interp" + input_np = np.random.random(self.input_shape).astype("float32") + if self.data_layout == "NCHW": + in_h = self.input_shape[2] + in_w = self.input_shape[3] + else: + in_h = self.input_shape[1] + in_w = self.input_shape[2] + + if self.scale > 0: + out_h = int(in_h * self.scale) + out_w = int(in_w * self.scale) + else: + out_h = self.out_h + out_w = self.out_w + + output_np = nearest_neighbor_interp_np( + input_np, out_h, out_w, self.out_size, self.actual_shape, + self.align_corners, self.data_layout) + self.inputs = {'X': input_np} + if self.out_size is not None: + self.inputs['OutSize'] = self.out_size + if self.actual_shape is not None: + self.inputs['OutSize'] = self.actual_shape + self.attrs = { + 'out_h': self.out_h, + 'out_w': self.out_w, + 'scale': self.scale, + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + 'data_layout': self.data_layout + } + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.__class__.no_need_check_grad = True + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out', in_place=True) + + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [2, 3, 4, 5] + self.out_h = 2 + self.out_w = 2 + self.scale = 0. + self.out_size = np.array([3, 3]).astype("int32") + self.align_corners = False + + +class TestNearestNeighborInterpCase1(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + self.scale = 0. + self.align_corners = False + + +class TestNearestNeighborInterpCase2(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.scale = 0. + self.align_corners = False + + def test_check_grad(self): + self.check_grad_with_place( + self.place, ['X'], 'Out', in_place=True, max_relative_error=0.006) + + +class TestNearestNeighborInterpCase3(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [1, 1, 32, 64] + self.out_h = 64 + self.out_w = 32 + self.scale = 0. + self.align_corners = False + + +class TestNearestNeighborInterpCase4(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + self.scale = 0. + self.out_size = np.array([2, 2]).astype("int32") + self.align_corners = False + + +class TestNearestNeighborInterpCase5(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.scale = 0. + self.out_size = np.array([11, 11]).astype("int32") + self.align_corners = False + + +class TestNearestNeighborInterpCase6(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [1, 1, 32, 64] + self.out_h = 64 + self.out_w = 32 + self.scale = 0. + self.out_size = np.array([65, 129]).astype("int32") + self.align_corners = False + + +class TestNearestNeighborInterpSame(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [2, 3, 32, 64] + self.out_h = 32 + self.out_w = 64 + self.scale = 0. + self.align_corners = False + + +class TestNearestNeighborInterpActualShape(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 32, 16] + self.out_h = 64 + self.out_w = 32 + self.scale = 0. + self.out_size = np.array([66, 40]).astype("int32") + self.align_corners = False + + +class TestNearestNeighborInterpDataLayout(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [2, 4, 4, 5] + self.out_h = 2 + self.out_w = 2 + self.scale = 0. + self.out_size = np.array([3, 8]).astype("int32") + self.align_corners = False + self.data_layout = "NHWC" + + +class TestNearestInterpOpUint8(OpTest): + def setUp(self): + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) + + self.out_size = None + self.actual_shape = None + self.init_test_case() + self.op_type = "nearest_interp" + input_np = np.random.randint( + low=0, high=256, size=self.input_shape).astype("uint8") + + if self.scale > 0: + out_h = int(self.input_shape[2] * self.scale) + out_w = int(self.input_shape[3] * self.scale) + else: + out_h = self.out_h + out_w = self.out_w + + output_np = nearest_neighbor_interp_np(input_np, out_h, out_w, + self.out_size, self.actual_shape, + self.align_corners) + self.inputs = {'X': input_np} + if self.out_size is not None: + self.inputs['OutSize'] = self.out_size + self.attrs = { + 'out_h': self.out_h, + 'out_w': self.out_w, + 'scale': self.scale, + 'interp_method': self.interp_method, + 'align_corners': self.align_corners + } + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output_with_place(self.place) + + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [1, 3, 9, 6] + self.out_h = 10 + self.out_w = 9 + self.scale = 0. + self.align_corners = False + + +class TestNearestNeighborInterpCase1Uint8(TestNearestInterpOpUint8): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [2, 3, 32, 64] + self.out_h = 80 + self.out_w = 40 + self.scale = 0. + self.align_corners = False + + +class TestNearestNeighborInterpCase2Uint8(TestNearestInterpOpUint8): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [4, 1, 7, 8] + self.out_h = 5 + self.out_w = 13 + self.scale = 0. + self.out_size = np.array([6, 15]).astype("int32") + self.align_corners = False + + +class TestNearestNeighborInterpScale1(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 7, 5] + self.out_h = 64 + self.out_w = 32 + self.scale = 2. + self.out_size = np.array([66, 40]).astype("int32") + self.align_corners = False + + +class TestNearestNeighborInterpScale2(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 5, 7] + self.out_h = 64 + self.out_w = 32 + self.scale = 1.5 + self.out_size = np.array([66, 40]).astype("int32") + self.align_corners = False + + +class TestNearestNeighborInterpScale3(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 7, 5] + self.out_h = 64 + self.out_w = 32 + self.scale = 1. + self.out_size = np.array([66, 40]).astype("int32") + self.align_corners = False + + +class TestNearestInterpOp_attr_tensor(OpTest): + def setUp(self): + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) + + self.out_size = None + self.actual_shape = None + self.init_test_case() + self.op_type = "nearest_interp" + self.shape_by_1Dtensor = False + self.scale_by_1Dtensor = False + self.attrs = { + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + } + + input_np = np.random.random(self.input_shape).astype("float32") + self.inputs = {'X': input_np} + + if self.scale_by_1Dtensor: + self.inputs['Scale'] = np.array([self.scale]).astype("float64") + elif self.scale > 0: + out_h = int(self.input_shape[2] * self.scale) + out_w = int(self.input_shape[3] * self.scale) + self.attrs['scale'] = self.scale + else: + out_h = self.out_h + out_w = self.out_w + + if self.shape_by_1Dtensor: + self.inputs['OutSize'] = self.out_size + elif self.out_size is not None: + size_tensor = [] + for index, ele in enumerate(self.out_size): + size_tensor.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + self.inputs['SizeTensor'] = size_tensor + + self.attrs['out_h'] = self.out_h + self.attrs['out_w'] = self.out_w + output_np = nearest_neighbor_interp_np(input_np, out_h, out_w, + self.out_size, self.actual_shape, + self.align_corners) + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out', in_place=True) + + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [2, 5, 4, 4] + self.out_h = 3 + self.out_w = 3 + self.scale = 0. + self.out_size = [3, 3] + self.align_corners = False + + +# out_size is a tensor list +class TestNearestInterp_attr_tensor_Case1(TestNearestInterpOp_attr_tensor): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.scale = 0. + self.out_size = [8, 12] + self.align_corners = False + + +# out_size is a 1-D tensor +class TestNearestInterp_attr_tensor_Case2(TestNearestInterpOp_attr_tensor): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 32, 16] + self.out_h = 64 + self.out_w = 32 + self.scale = 0. + self.out_size = np.array([66, 40]).astype("int32") + self.align_corners = False + self.shape_by_1Dtensor = True + + +# scale is a 1-D tensor +class TestNearestInterp_attr_tensor_Case3(TestNearestInterpOp_attr_tensor): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 32, 16] + self.out_h = 64 + self.out_w = 32 + self.scale = 2.0 + self.out_size = None + self.align_corners = False + self.scale_by_1Dtensor = True + + +class TestNearestAPI(unittest.TestCase): + def test_case(self): + x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32") + y = fluid.data(name="y", shape=[2, 6, 6, 3], dtype="float32") + + dim = fluid.data(name="dim", shape=[1], dtype="int32") + shape_tensor = fluid.data(name="shape_tensor", shape=[2], dtype="int32") + actual_size = fluid.data(name="actual_size", shape=[2], dtype="int32") + scale_tensor = fluid.data( + name="scale_tensor", shape=[1], dtype="float32") + + out1 = fluid.layers.resize_nearest( + y, out_shape=[12, 12], data_format='NHWC', align_corners=False) + out2 = fluid.layers.resize_nearest( + x, out_shape=[12, dim], align_corners=False) + out3 = fluid.layers.resize_nearest( + x, out_shape=shape_tensor, align_corners=False) + out4 = fluid.layers.resize_nearest( + x, out_shape=[4, 4], actual_shape=actual_size, align_corners=False) + out5 = fluid.layers.resize_nearest( + x, scale=scale_tensor, align_corners=False) + + x_data = np.random.random((2, 3, 6, 6)).astype("float32") + dim_data = np.array([12]).astype("int32") + shape_data = np.array([12, 12]).astype("int32") + actual_size_data = np.array([12, 12]).astype("int32") + scale_data = np.array([2.0]).astype("float32") + + place = paddle.NPUPlace(0) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + results = exe.run(fluid.default_main_program(), + feed={ + "x": x_data, + "y": np.transpose(x_data, (0, 2, 3, 1)), + "dim": dim_data, + "shape_tensor": shape_data, + "actual_size": actual_size_data, + "scale_tensor": scale_data + }, + fetch_list=[out1, out2, out3, out4, out5], + return_numpy=True) + + expect_res = nearest_neighbor_interp_np( + x_data, out_h=12, out_w=12, align_corners=False) + self.assertTrue( + np.allclose(results[0], np.transpose(expect_res, (0, 2, 3, 1)))) + for i in range(len(results) - 1): + self.assertTrue(np.allclose(results[i + 1], expect_res)) + + +class TestNearestInterpException(unittest.TestCase): + def test_exception(self): + input = fluid.data(name="input", shape=[1, 3, 6, 6], dtype="float32") + + def attr_data_format(): + # for 4-D input, data_format can only be NCHW or NHWC + out = fluid.layers.resize_nearest( + input, out_shape=[4, 8], data_format='NDHWC') + + def attr_scale_type(): + out = fluid.layers.resize_nearest(input, scale='scale') + + def attr_scale_value(): + out = fluid.layers.resize_nearest(input, scale=-0.3) + + self.assertRaises(ValueError, attr_data_format) + self.assertRaises(TypeError, attr_scale_type) + self.assertRaises(ValueError, attr_scale_value) + + +if __name__ == "__main__": + unittest.main()