From 82aa01c3737c550addc24cd7f98ea26899d5b496 Mon Sep 17 00:00:00 2001 From: TTerror Date: Tue, 22 Dec 2020 14:31:51 +0800 Subject: [PATCH] add nearest_interp_v2 on kunlun (#29725) * add nearest_interp_v2 on kunlun * add nearest_interp_v2 on kunlun --- .../fluid/operators/interpolate_v2_op_xpu.cc | 294 +++++++++++++ .../xpu/test_nearest_interp_v2_op_xpu.py | 415 ++++++++++++++++++ 2 files changed, 709 insertions(+) create mode 100644 paddle/fluid/operators/interpolate_v2_op_xpu.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_nearest_interp_v2_op_xpu.py diff --git a/paddle/fluid/operators/interpolate_v2_op_xpu.cc b/paddle/fluid/operators/interpolate_v2_op_xpu.cc new file mode 100644 index 00000000000..c960f9a58be --- /dev/null +++ b/paddle/fluid/operators/interpolate_v2_op_xpu.cc @@ -0,0 +1,294 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/interpolate_op.h" + +#ifdef PADDLE_WITH_XPU + +namespace paddle { +namespace operators { + +using framework::Tensor; +using DataLayout = framework::DataLayout; + +inline std::vector get_new_shape_xpu( + const std::vector& list_new_shape_tensor) { + // get tensor from + std::vector vec_new_shape; + for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) { + auto tensor = list_new_shape_tensor[i]; + PADDLE_ENFORCE_EQ( + tensor->dims(), framework::make_ddim({1}), + platform::errors::InvalidArgument("shape of dim tensor should be [1]")); + framework::Tensor temp; + TensorCopySync(*tensor, platform::CPUPlace(), &temp); + vec_new_shape.push_back(static_cast(*temp.data())); + } + + return vec_new_shape; +} + +template +inline std::vector get_new_data_from_tensor_xpu( + const Tensor* new_data_tensor) { + std::vector vec_new_data; + framework::Tensor cpu_starts_tensor; + TensorCopySync(*new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor); + auto* new_data = cpu_starts_tensor.data(); + vec_new_data = std::vector(new_data, new_data + new_data_tensor->numel()); + return vec_new_data; +} + +template +class InterpolateV2XPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + + auto input_dims = input->dims(); + PADDLE_ENFORCE_EQ( + input_dims.size(), 4, + platform::errors::External("XPU Interpolate kernel only support 2d")); + + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); + int n, c, in_d, in_h, in_w; + ExtractNCDWH(input_dims, data_layout, &n, &c, &in_d, &in_h, &in_w); + + auto interp_method = ctx.Attr("interp_method"); + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + float scale_h = -1; + float scale_w = -1; + + auto list_new_size_tensor = ctx.MultiInput("SizeTensor"); + if (list_new_size_tensor.size() > 0) { + // have size tensor + auto new_size = get_new_shape_xpu(list_new_size_tensor); + out_h = new_size[0]; + out_w = new_size[1]; + } else { + auto scale_tensor = ctx.Input("Scale"); + auto scale = ctx.Attr>("scale"); + if (scale_tensor != nullptr) { + auto scale_data = get_new_data_from_tensor_xpu(scale_tensor); + if (scale_data.size() > 1) { + scale_h = scale_data[0]; + scale_w = scale_data[1]; + } else { + scale_h = scale_data[0]; + scale_w = scale_data[0]; + } + PADDLE_ENFORCE_EQ( + scale_w > 0 && scale_h > 0, true, + platform::errors::InvalidArgument("scale of Op(interpolate) " + "should be greater than 0.")); + } else { + if (scale.size() > 1) { + scale_h = scale[0]; + scale_w = scale[1]; + + PADDLE_ENFORCE_EQ( + scale_w > 0 && scale_h > 0, true, + platform::errors::InvalidArgument("scale of Op(interpolate) " + "should be greater than 0.")); + } + } + if (scale_h > 0. && scale_w > 0.) { + out_h = static_cast(in_h * scale_h); + out_w = static_cast(in_w * scale_w); + } + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + auto out_size_data = get_new_data_from_tensor(out_size); + out_h = out_size_data[0]; + out_w = out_size_data[1]; + } + } + PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument( + "out_h in Attr(out_shape) of " + "Op(interpolate) " + "should be greater than 0.")); + PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument( + "out_w in Attr(out_shape) of " + "Op(interpolate) " + "should be greater than 0.")); + framework::DDim dim_out; + if (data_layout == DataLayout::kNCHW) { + dim_out = {n, c, out_h, out_w}; + } else { + dim_out = {n, out_h, out_w, c}; + } + output->mutable_data(dim_out, ctx.GetPlace()); + + if (in_h == out_h && in_w == out_w) { + framework::TensorCopy(*input, ctx.GetPlace(), output); + return; + } + bool nearest = "nearest" == interp_method; + int trans_mode = (align_corners) ? (0) : ((align_mode == 0) ? (1) : (2)); + auto& dev_ctx = ctx.template device_context(); + if (nearest) { + PADDLE_ENFORCE_EQ((data_layout == DataLayout::kNCHW), true, + platform::errors::InvalidArgument( + "XPU nearest is only support NCHW")); + } + int r = xpu::interpolate2d(dev_ctx.x_context(), input->data(), + output->data(), n, c, in_h, in_w, out_h, + out_w, nearest, trans_mode, + (data_layout == DataLayout::kNCHW)); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External("XPU interpolate2d kernel " + "return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + } +}; + +template +class InterpolateV2GradXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* output_grad = ctx.Input(framework::GradVarName("Out")); + + auto output_grad_dims = output_grad->dims(); + + PADDLE_ENFORCE_EQ(output_grad_dims.size(), 4, + platform::errors::External( + "XPU Interpolategrad kernel only support 2d")); + + auto* input = ctx.Input("X"); + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); + int n, c, in_d, in_h, in_w; + ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w); + + auto interp_method = ctx.Attr("interp_method"); + bool align_corners = ctx.Attr("align_corners"); + int align_mode = ctx.Attr("align_mode"); + + int out_h = ctx.Attr("out_h"); + int out_w = ctx.Attr("out_w"); + float scale_h = -1; + float scale_w = -1; + + auto list_new_size_tensor = ctx.MultiInput("SizeTensor"); + if (list_new_size_tensor.size() > 0) { + // have size tensor + auto new_size = get_new_shape_xpu(list_new_size_tensor); + out_h = new_size[0]; + out_w = new_size[1]; + } else { + auto scale_tensor = ctx.Input("Scale"); + auto scale = ctx.Attr>("scale"); + if (scale_tensor != nullptr) { + auto scale_data = get_new_data_from_tensor_xpu(scale_tensor); + if (scale_data.size() > 1) { + scale_h = scale_data[0]; + scale_w = scale_data[1]; + } else { + scale_h = scale_data[0]; + scale_w = scale_data[0]; + } + PADDLE_ENFORCE_EQ( + scale_w > 0 && scale_h > 0, true, + platform::errors::InvalidArgument("scale of Op(interpolate) " + "should be greater than 0.")); + } else { + if (scale.size() > 1) { + scale_h = scale[0]; + scale_w = scale[1]; + + PADDLE_ENFORCE_EQ( + scale_w > 0 && scale_h > 0, true, + platform::errors::InvalidArgument("scale of Op(interpolate) " + "should be greater than 0.")); + } + } + if (scale_h > 0. && scale_w > 0.) { + out_h = static_cast(in_h * scale_h); + out_w = static_cast(in_w * scale_w); + } + auto out_size = ctx.Input("OutSize"); + if (out_size != nullptr) { + auto out_size_data = get_new_data_from_tensor(out_size); + out_h = out_size_data[0]; + out_w = out_size_data[1]; + } + } + + framework::DDim dim_grad; + if (data_layout == DataLayout::kNCHW) { + dim_grad = {n, c, in_h, in_w}; + } else { + dim_grad = {n, in_h, in_w, c}; + } + input_grad->mutable_data(dim_grad, ctx.GetPlace()); + + auto& dev_ctx = ctx.template device_context(); + + int r = XPU_SUCCESS; + r = xpu::constant(dev_ctx.x_context(), input_grad->data(), + input_grad->numel(), static_cast(0.0)); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External( + "XPU constant in interpolate2d_grad kernel return " + "wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + + if (in_h == out_h && in_w == out_w) { + framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad); + return; + } + + bool nearest = "nearest" == interp_method; + int trans_mode = (align_corners) ? (0) : ((align_mode == 0) ? (1) : (2)); + + if (nearest) { + trans_mode = (align_corners) ? (0) : (2); + } + + r = xpu::interpolate2d_grad(dev_ctx.x_context(), output_grad->data(), + input_grad->data(), n, c, in_h, in_w, + out_h, out_w, nearest, trans_mode, + (data_layout == DataLayout::kNCHW)); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU interpolate2d_grad kernel return " + "wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_XPU_KERNEL(bilinear_interp_v2, ops::InterpolateV2XPUKernel); +REGISTER_OP_XPU_KERNEL(nearest_interp_v2, ops::InterpolateV2XPUKernel); + +REGISTER_OP_XPU_KERNEL(bilinear_interp_v2_grad, + ops::InterpolateV2GradXPUKernel); +REGISTER_OP_XPU_KERNEL(nearest_interp_v2_grad, + ops::InterpolateV2GradXPUKernel); +#endif diff --git a/python/paddle/fluid/tests/unittests/xpu/test_nearest_interp_v2_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_nearest_interp_v2_op_xpu.py new file mode 100644 index 00000000000..8de8125166f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_nearest_interp_v2_op_xpu.py @@ -0,0 +1,415 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +import sys +sys.path.append("..") +from op_test_xpu import XPUOpTest +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + +paddle.enable_static() + + +def nearest_neighbor_interp_np(X, + out_h, + out_w, + scale_h=0, + scale_w=0, + out_size=None, + actual_shape=None, + align_corners=True, + data_layout='NCHW'): + """nearest neighbor interpolation implement in shape [N, C, H, W]""" + if data_layout == "NHWC": + X = np.transpose(X, (0, 3, 1, 2)) # NHWC => NCHW + if out_size is not None: + out_h = out_size[0] + out_w = out_size[1] + if actual_shape is not None: + out_h = actual_shape[0] + out_w = actual_shape[1] + n, c, in_h, in_w = X.shape + + ratio_h = ratio_w = 0.0 + if (out_h > 1): + if (align_corners): + ratio_h = (in_h - 1.0) / (out_h - 1.0) + else: + if scale_h > 0: + ratio_h = 1.0 / scale_h + else: + ratio_h = 1.0 * in_h / out_h + if (out_w > 1): + if (align_corners): + ratio_w = (in_w - 1.0) / (out_w - 1.0) + else: + if scale_w > 0: + ratio_w = 1.0 / scale_w + else: + ratio_w = 1.0 * in_w / out_w + out = np.zeros((n, c, out_h, out_w)) + + if align_corners: + for i in range(out_h): + in_i = int(ratio_h * i + 0.5) + for j in range(out_w): + in_j = int(ratio_w * j + 0.5) + out[:, :, i, j] = X[:, :, in_i, in_j] + else: + for i in range(out_h): + in_i = int(ratio_h * i) + for j in range(out_w): + in_j = int(ratio_w * j) + out[:, :, i, j] = X[:, :, in_i, in_j] + + if data_layout == "NHWC": + out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC + + return out.astype(X.dtype) + + +class TestNearestInterpOp(XPUOpTest): + def setUp(self): + self.use_xpu = True + self.out_size = None + self.actual_shape = None + self.init_test_case() + self.op_type = "nearest_interp_v2" + self.shape_by_1Dtensor = False + self.scale_by_1Dtensor = False + self.attrs = { + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + } + + input_np = np.random.random(self.input_shape).astype("float32") + self.inputs = {'X': input_np} + + if self.scale_by_1Dtensor: + self.inputs['Scale'] = np.array([self.scale]).astype("float32") + elif self.scale: + if isinstance(self.scale, float) or isinstance(self.scale, int): + if self.scale > 0: + scale_h = scale_w = float(self.scale) + if isinstance(self.scale, list) and len(self.scale) == 1: + scale_w = scale_h = self.scale[0] + elif isinstance(self.scale, list) and len(self.scale) > 1: + scale_w = self.scale[1] + scale_h = self.scale[0] + out_h = int(self.input_shape[2] * scale_h) + out_w = int(self.input_shape[3] * scale_w) + else: + out_h = self.out_h + out_w = self.out_w + + if self.shape_by_1Dtensor: + self.inputs['OutSize'] = self.out_size + elif self.out_size is not None: + size_tensor = [] + for index, ele in enumerate(self.out_size): + size_tensor.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + self.inputs['SizeTensor'] = size_tensor + + self.attrs['out_h'] = self.out_h + self.attrs['out_w'] = self.out_w + if self.scale: + if isinstance(self.scale, float) or isinstance(self.scale, int): + if self.scale > 0: + self.scale = [self.scale] + if isinstance(self.scale, list) and len(self.scale) == 1: + self.scale = [self.scale[0], self.scale[0]] + self.attrs['scale'] = self.scale + output_np = nearest_neighbor_interp_np(input_np, out_h, out_w, 0, 0, + self.out_size, self.actual_shape, + self.align_corners) + self.outputs = {'Out': output_np} + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out', in_place=True) + + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [2, 5, 4, 4] + self.out_h = 3 + self.out_w = 3 + self.scale = 0. + self.out_size = [3, 3] + self.align_corners = True + + +class TestNearestNeighborInterpCase1(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + self.scale = 0. + self.align_corners = True + + +class TestNearestNeighborInterpCase2(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.scale = 0. + self.align_corners = True + + +class TestNearestNeighborInterpCase3(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [1, 1, 32, 64] + self.out_h = 64 + self.out_w = 32 + self.scale = 0. + self.align_corners = True + + +class TestNearestNeighborInterpCase4(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + self.scale = 0. + self.out_size = np.array([2, 2]).astype("int32") + self.align_corners = True + + +class TestNearestNeighborInterpCase5(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.scale = 0. + self.out_size = np.array([11, 11]).astype("int32") + self.align_corners = True + + +class TestNearestNeighborInterpCase6(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [1, 1, 32, 64] + self.out_h = 64 + self.out_w = 32 + self.scale = 0. + self.out_size = np.array([65, 129]).astype("int32") + self.align_corners = True + + +class TestNearestNeighborInterpSame(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [2, 3, 32, 64] + self.out_h = 32 + self.out_w = 64 + self.scale = 0. + self.align_corners = True + + +class TestNearestNeighborInterpActualShape(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 32, 16] + self.out_h = 64 + self.out_w = 32 + self.scale = 0. + self.out_size = np.array([66, 40]).astype("int32") + self.align_corners = True + + +class TestNearestNeighborInterpDataLayout(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [2, 4, 4, 5] + self.out_h = 2 + self.out_w = 2 + self.scale = 0. + self.out_size = np.array([3, 8]).astype("int32") + self.align_corners = True + self.data_layout = "NHWC" + + +class TestNearestInterpWithoutCorners(TestNearestInterpOp): + def set_align_corners(self): + self.align_corners = False + + +class TestNearestNeighborInterpScale1(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 7, 5] + self.out_h = 64 + self.out_w = 32 + self.scale = 2. + self.out_size = np.array([66, 40]).astype("int32") + self.align_corners = True + + +class TestNearestNeighborInterpScale2(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 5, 7] + self.out_h = 64 + self.out_w = 32 + self.scale = 1.5 + self.out_size = np.array([66, 40]).astype("int32") + self.align_corners = True + + +class TestNearestNeighborInterpScale3(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 7, 5] + self.out_h = 64 + self.out_w = 32 + self.scale = [2.0, 3.0] + self.out_size = np.array([66, 40]).astype("int32") + self.align_corners = True + + +class TestNearestInterpOp_attr_tensor(XPUOpTest): + def setUp(self): + self.use_xpu = True + self.out_size = None + self.actual_shape = None + self.init_test_case() + self.op_type = "nearest_interp_v2" + self.shape_by_1Dtensor = False + self.scale_by_1Dtensor = False + self.attrs = { + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + } + + input_np = np.random.random(self.input_shape).astype("float32") + self.inputs = {'X': input_np} + + if self.scale_by_1Dtensor: + self.inputs['Scale'] = np.array([self.scale]).astype("float32") + elif self.scale: + if isinstance(self.scale, float) or isinstance(self.scale, int): + if self.scale > 0: + scale_h = scale_w = float(self.scale) + if isinstance(self.scale, list) and len(self.scale) == 1: + scale_w = scale_h = self.scale[0] + elif isinstance(self.scale, list) and len(self.scale) > 1: + scale_w = self.scale[1] + scale_h = self.scale[0] + out_h = int(self.input_shape[2] * scale_h) + out_w = int(self.input_shape[3] * scale_w) + else: + out_h = self.out_h + out_w = self.out_w + + if self.shape_by_1Dtensor: + self.inputs['OutSize'] = self.out_size + elif self.out_size is not None: + size_tensor = [] + for index, ele in enumerate(self.out_size): + size_tensor.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + self.inputs['SizeTensor'] = size_tensor + + self.attrs['out_h'] = self.out_h + self.attrs['out_w'] = self.out_w + if self.scale: + if isinstance(self.scale, float) or isinstance(self.scale, int): + if self.scale > 0: + self.scale = [self.scale] + if isinstance(self.scale, list) and len(self.scale) == 1: + self.scale = [self.scale[0], self.scale[0]] + self.attrs['scale'] = self.scale + output_np = nearest_neighbor_interp_np(input_np, out_h, out_w, 0, 0, + self.out_size, self.actual_shape, + self.align_corners) + self.outputs = {'Out': output_np} + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out', in_place=True) + + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [2, 5, 4, 4] + self.out_h = 3 + self.out_w = 3 + self.scale = 0. + self.out_size = [3, 3] + self.align_corners = True + + +# out_size is a tensor list +class TestNearestInterp_attr_tensor_Case1(TestNearestInterpOp_attr_tensor): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.scale = 0. + self.out_size = [8, 12] + self.align_corners = True + + +# out_size is a 1-D tensor +class TestNearestInterp_attr_tensor_Case2(TestNearestInterpOp_attr_tensor): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 32, 16] + self.out_h = 64 + self.out_w = 32 + self.scale = 0. + self.out_size = np.array([66, 40]).astype("int32") + self.align_corners = True + self.shape_by_1Dtensor = True + + +# scale is a 1-D tensor +class TestNearestInterp_attr_tensor_Case3(TestNearestInterpOp_attr_tensor): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 32, 16] + self.out_h = 64 + self.out_w = 32 + self.scale = 2.0 + self.out_size = None + self.align_corners = True + self.scale_by_1Dtensor = True + + +if __name__ == "__main__": + unittest.main() -- GitLab