From 78f563917ca95d718ce3fe70fe147d67cf8d4f20 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 26 Nov 2018 16:28:21 +0800 Subject: [PATCH] revert interpolate_op to bilinear_interp_op & nearest_interp_op. test=develop --- paddle/fluid/operators/interpolate_op.cc | 16 +- paddle/fluid/operators/interpolate_op.cu | 10 +- python/paddle/fluid/layers/nn.py | 17 +- ...olate_op.py => test_bilinear_interp_op.py} | 165 ++------------- .../tests/unittests/test_nearest_interp_op.py | 197 ++++++++++++++++++ 5 files changed, 242 insertions(+), 163 deletions(-) rename python/paddle/fluid/tests/unittests/{test_interpolate_op.py => test_bilinear_interp_op.py} (53%) create mode 100644 python/paddle/fluid/tests/unittests/test_nearest_interp_op.py diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 8f979e05d31..203291ed61f 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -132,11 +132,19 @@ class InterpolateOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(interpolate, ops::InterpolateOp, ops::InterpolateOpMaker, +REGISTER_OPERATOR(bilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(interpolate_grad, ops::InterpolateOpGrad); -REGISTER_OP_CPU_KERNEL(interpolate, ops::InterpolateKernel, +REGISTER_OPERATOR(bilinear_interp_grad, ops::InterpolateOpGrad); +REGISTER_OPERATOR(nearest_interp, ops::InterpolateOp, ops::InterpolateOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad); +REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel, + ops::InterpolateKernel, + ops::InterpolateKernel); +REGISTER_OP_CPU_KERNEL(bilinear_interp_grad, ops::InterpolateGradKernel, + ops::InterpolateGradKernel); +REGISTER_OP_CPU_KERNEL(nearest_interp, ops::InterpolateKernel, ops::InterpolateKernel, ops::InterpolateKernel); -REGISTER_OP_CPU_KERNEL(interpolate_grad, ops::InterpolateGradKernel, +REGISTER_OP_CPU_KERNEL(nearest_interp_grad, ops::InterpolateGradKernel, ops::InterpolateGradKernel); diff --git a/paddle/fluid/operators/interpolate_op.cu b/paddle/fluid/operators/interpolate_op.cu index 190afbdac43..99ac725f73b 100644 --- a/paddle/fluid/operators/interpolate_op.cu +++ b/paddle/fluid/operators/interpolate_op.cu @@ -284,9 +284,15 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(interpolate, ops::InterpolateOpCUDAKernel, +REGISTER_OP_CUDA_KERNEL(bilinear_interp, ops::InterpolateOpCUDAKernel, ops::InterpolateOpCUDAKernel, ops::InterpolateOpCUDAKernel); -REGISTER_OP_CUDA_KERNEL(interpolate_grad, +REGISTER_OP_CUDA_KERNEL(bilinear_interp_grad, + ops::InterpolateGradOpCUDAKernel, + ops::InterpolateGradOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(nearest_interp, ops::InterpolateOpCUDAKernel, + ops::InterpolateOpCUDAKernel, + ops::InterpolateOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(nearest_interp_grad, ops::InterpolateGradOpCUDAKernel, ops::InterpolateGradOpCUDAKernel); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 7af1f380e70..f0c5e67ccad 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5870,9 +5870,10 @@ def image_resize(input, raise ValueError( "The 'resample' of image_resize can only be 'BILINEAR' or 'NEAREST' currently." ) + resample_type = resample_methods[resample] if out_shape is None and scale is None: raise ValueError("One of out_shape and scale must not be None.") - helper = LayerHelper('interpolate', **locals()) + helper = LayerHelper('{}_interp'.format(resample_type), **locals()) dtype = helper.input_dtype() def _is_list_or_turple_(data): @@ -5906,18 +5907,16 @@ def image_resize(input, out = helper.create_variable_for_type_inference(dtype) helper.append_op( - type='interpolate', + type='{}_interp'.format(resample_type), inputs=inputs, outputs={"Out": out}, - attrs={ - "out_h": out_h, - "out_w": out_w, - "interp_method": resample_methods[resample] - }) + attrs={"out_h": out_h, + "out_w": out_w, + "interp_method": resample_type}) return out -@templatedoc(op_type="interpolate") +@templatedoc(op_type="bilinear_interp") def resize_bilinear(input, out_shape=None, scale=None, @@ -5973,7 +5972,7 @@ def resize_bilinear(input, return image_resize(input, out_shape, scale, name, 'BILINEAR', actual_shape) -@templatedoc(op_type="interpolate") +@templatedoc(op_type="nearest_interp") def resize_nearest(input, out_shape=None, scale=None, diff --git a/python/paddle/fluid/tests/unittests/test_interpolate_op.py b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py similarity index 53% rename from python/paddle/fluid/tests/unittests/test_interpolate_op.py rename to python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py index 9748d094cda..c8a7063dc1c 100644 --- a/python/paddle/fluid/tests/unittests/test_interpolate_op.py +++ b/python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py @@ -20,36 +20,6 @@ from op_test import OpTest import paddle.fluid.core as core -def nearest_neighbor_interp_np(X, - out_h, - out_w, - out_size=None, - actual_shape=None): - """nearest neighbor interpolation implement in shape [N, C, H, W]""" - if out_size is not None: - out_h = out_size[0] - out_w = out_size[1] - if actual_shape is not None: - out_h = actual_shape[0] - out_w = actual_shape[1] - n, c, in_h, in_w = X.shape - - ratio_h = ratio_w = 0.0 - if out_h > 1: - ratio_h = (in_h - 1.0) / (out_h - 1.0) - if out_w > 1: - ratio_w = (in_w - 1.0) / (out_w - 1.0) - - out = np.zeros((n, c, out_h, out_w)) - for i in range(out_h): - in_i = int(ratio_h * i + 0.5) - for j in range(out_w): - in_j = int(ratio_w * j + 0.5) - out[:, :, i, j] = X[:, :, in_i, in_j] - - return out.astype(X.dtype) - - def bilinear_interp_np(input, out_h, out_w, out_size=None, actual_shape=None): """bilinear interpolation implement in shape [N, C, H, W]""" if out_size is not None: @@ -87,22 +57,16 @@ def bilinear_interp_np(input, out_h, out_w, out_size=None, actual_shape=None): return out.astype(input.dtype) -INTERPOLATE_FUNCS = { - 'bilinear': bilinear_interp_np, - 'nearest': nearest_neighbor_interp_np, -} - - -class TestInterpolateOp(OpTest): +class TestBilinearInterpOp(OpTest): def setUp(self): self.out_size = None self.actual_shape = None self.init_test_case() - self.op_type = "interpolate" + self.op_type = "bilinear_interp" input_np = np.random.random(self.input_shape).astype("float32") - output_np = INTERPOLATE_FUNCS[self.interp_method]( - input_np, self.out_h, self.out_w, self.out_size, self.actual_shape) + output_np = bilinear_interp_np(input_np, self.out_h, self.out_w, + self.out_size, self.actual_shape) self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size @@ -129,7 +93,7 @@ class TestInterpolateOp(OpTest): self.out_size = np.array([3, 3]).astype("int32") -class TestBilinearInterpCase1(TestInterpolateOp): +class TestBilinearInterpCase1(TestBilinearInterpOp): def init_test_case(self): self.interp_method = 'bilinear' self.input_shape = [4, 1, 7, 8] @@ -137,7 +101,7 @@ class TestBilinearInterpCase1(TestInterpolateOp): self.out_w = 1 -class TestBilinearInterpCase2(TestInterpolateOp): +class TestBilinearInterpCase2(TestBilinearInterpOp): def init_test_case(self): self.interp_method = 'bilinear' self.input_shape = [3, 3, 9, 6] @@ -145,7 +109,7 @@ class TestBilinearInterpCase2(TestInterpolateOp): self.out_w = 12 -class TestBilinearInterpCase3(TestInterpolateOp): +class TestBilinearInterpCase3(TestBilinearInterpOp): def init_test_case(self): self.interp_method = 'bilinear' self.input_shape = [1, 1, 128, 64] @@ -153,7 +117,7 @@ class TestBilinearInterpCase3(TestInterpolateOp): self.out_w = 128 -class TestBilinearInterpCase4(TestInterpolateOp): +class TestBilinearInterpCase4(TestBilinearInterpOp): def init_test_case(self): self.interp_method = 'bilinear' self.input_shape = [4, 1, 7, 8] @@ -162,7 +126,7 @@ class TestBilinearInterpCase4(TestInterpolateOp): self.out_size = np.array([2, 2]).astype("int32") -class TestBilinearInterpCase5(TestInterpolateOp): +class TestBilinearInterpCase5(TestBilinearInterpOp): def init_test_case(self): self.interp_method = 'bilinear' self.input_shape = [3, 3, 9, 6] @@ -171,7 +135,7 @@ class TestBilinearInterpCase5(TestInterpolateOp): self.out_size = np.array([11, 11]).astype("int32") -class TestBilinearInterpCase6(TestInterpolateOp): +class TestBilinearInterpCase6(TestBilinearInterpOp): def init_test_case(self): self.interp_method = 'bilinear' self.input_shape = [1, 1, 128, 64] @@ -180,7 +144,7 @@ class TestBilinearInterpCase6(TestInterpolateOp): self.out_size = np.array([65, 129]).astype("int32") -class TestBilinearInterpActualShape(TestInterpolateOp): +class TestBilinearInterpActualShape(TestBilinearInterpOp): def init_test_case(self): self.interp_method = 'bilinear' self.input_shape = [3, 2, 32, 16] @@ -189,25 +153,16 @@ class TestBilinearInterpActualShape(TestInterpolateOp): self.out_size = np.array([66, 40]).astype("int32") -class TestBilinearInterpBigScale(TestInterpolateOp): - def init_test_case(self): - self.interp_method = 'bilinear' - self.input_shape = [4, 4, 64, 32] - self.out_h = 100 - self.out_w = 50 - self.out_size = np.array([101, 51]).astype('int32') - - -class TestInterpolateOpUint8(OpTest): +class TestBilinearInterpOpUint8(OpTest): def setUp(self): self.out_size = None self.actual_shape = None self.init_test_case() - self.op_type = "interpolate" + self.op_type = "bilinear_interp" input_np = np.random.randint( low=0, high=256, size=self.input_shape).astype("uint8") - output_np = INTERPOLATE_FUNCS[self.interp_method]( - input_np, self.out_h, self.out_w, self.out_size, self.actual_shape) + output_np = bilinear_interp_np(input_np, self.out_h, self.out_w, + self.out_size, self.actual_shape) self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size @@ -228,7 +183,7 @@ class TestInterpolateOpUint8(OpTest): self.out_w = 9 -class TestBilinearInterpCase1Uint8(TestInterpolateOpUint8): +class TestBilinearInterpCase1Uint8(TestBilinearInterpOpUint8): def init_test_case(self): self.interp_method = 'bilinear' self.input_shape = [2, 3, 128, 64] @@ -236,7 +191,7 @@ class TestBilinearInterpCase1Uint8(TestInterpolateOpUint8): self.out_w = 50 -class TestBilinearInterpCase2Uint8(TestInterpolateOpUint8): +class TestBilinearInterpCase2Uint8(TestBilinearInterpOpUint8): def init_test_case(self): self.interp_method = 'bilinear' self.input_shape = [4, 1, 7, 8] @@ -245,91 +200,5 @@ class TestBilinearInterpCase2Uint8(TestInterpolateOpUint8): self.out_size = np.array([6, 15]).astype("int32") -class TestNearestNeighborInterpCase1(TestInterpolateOp): - def init_test_case(self): - self.interp_method = 'nearest' - self.input_shape = [4, 1, 7, 8] - self.out_h = 1 - self.out_w = 1 - - -class TestNearestNeighborInterpCase2(TestInterpolateOp): - def init_test_case(self): - self.interp_method = 'nearest' - self.input_shape = [3, 3, 9, 6] - self.out_h = 12 - self.out_w = 12 - - -class TestNearestNeighborInterpCase3(TestInterpolateOp): - def init_test_case(self): - self.interp_method = 'nearest' - self.input_shape = [1, 1, 128, 64] - self.out_h = 64 - self.out_w = 128 - - -class TestNearestNeighborInterpCase4(TestInterpolateOp): - def init_test_case(self): - self.interp_method = 'nearest' - self.input_shape = [4, 1, 7, 8] - self.out_h = 1 - self.out_w = 1 - self.out_size = np.array([2, 2]).astype("int32") - - -class TestNearestNeighborInterpCase5(TestInterpolateOp): - def init_test_case(self): - self.interp_method = 'nearest' - self.input_shape = [3, 3, 9, 6] - self.out_h = 12 - self.out_w = 12 - self.out_size = np.array([11, 11]).astype("int32") - - -class TestNearestNeighborInterpCase6(TestInterpolateOp): - def init_test_case(self): - self.interp_method = 'nearest' - self.input_shape = [1, 1, 128, 64] - self.out_h = 64 - self.out_w = 128 - self.out_size = np.array([65, 129]).astype("int32") - - -class TestNearestNeighborInterpActualShape(TestInterpolateOp): - def init_test_case(self): - self.interp_method = 'nearest' - self.input_shape = [3, 2, 32, 16] - self.out_h = 64 - self.out_w = 32 - self.out_size = np.array([66, 40]).astype("int32") - - -class TestNearestNeighborInterpBigScale(TestInterpolateOp): - def init_test_case(self): - self.interp_method = 'nearest' - self.input_shape = [4, 4, 64, 32] - self.out_h = 100 - self.out_w = 50 - self.out_size = np.array([101, 51]).astype('int32') - - -class TestNearestNeighborInterpCase1Uint8(TestInterpolateOpUint8): - def init_test_case(self): - self.interp_method = 'nearest' - self.input_shape = [2, 3, 128, 64] - self.out_h = 120 - self.out_w = 50 - - -class TestNearestNeighborInterpCase2Uint8(TestInterpolateOpUint8): - def init_test_case(self): - self.interp_method = 'nearest' - self.input_shape = [4, 1, 7, 8] - self.out_h = 5 - self.out_w = 13 - self.out_size = np.array([6, 15]).astype("int32") - - if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py b/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py new file mode 100644 index 00000000000..242709425f2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_nearest_interp_op.py @@ -0,0 +1,197 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core + + +def nearest_neighbor_interp_np(X, + out_h, + out_w, + out_size=None, + actual_shape=None): + """nearest neighbor interpolation implement in shape [N, C, H, W]""" + if out_size is not None: + out_h = out_size[0] + out_w = out_size[1] + if actual_shape is not None: + out_h = actual_shape[0] + out_w = actual_shape[1] + n, c, in_h, in_w = X.shape + + ratio_h = ratio_w = 0.0 + if out_h > 1: + ratio_h = (in_h - 1.0) / (out_h - 1.0) + if out_w > 1: + ratio_w = (in_w - 1.0) / (out_w - 1.0) + + out = np.zeros((n, c, out_h, out_w)) + for i in range(out_h): + in_i = int(ratio_h * i + 0.5) + for j in range(out_w): + in_j = int(ratio_w * j + 0.5) + out[:, :, i, j] = X[:, :, in_i, in_j] + + return out.astype(X.dtype) + + +class TestNearestInterpOp(OpTest): + def setUp(self): + self.out_size = None + self.actual_shape = None + self.init_test_case() + self.op_type = "nearest_interp" + input_np = np.random.random(self.input_shape).astype("float32") + + output_np = nearest_neighbor_interp_np(input_np, self.out_h, self.out_w, + self.out_size, self.actual_shape) + self.inputs = {'X': input_np} + if self.out_size is not None: + self.inputs['OutSize'] = self.out_size + if self.actual_shape is not None: + self.inputs['OutSize'] = self.actual_shape + self.attrs = { + 'out_h': self.out_h, + 'out_w': self.out_w, + 'interp_method': self.interp_method + } + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', in_place=True) + + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [2, 3, 4, 4] + self.out_h = 2 + self.out_w = 2 + self.out_size = np.array([3, 3]).astype("int32") + + +class TestNearestNeighborInterpCase1(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + + +class TestNearestNeighborInterpCase2(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + + +class TestNearestNeighborInterpCase3(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [1, 1, 128, 64] + self.out_h = 64 + self.out_w = 128 + + +class TestNearestNeighborInterpCase4(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + self.out_size = np.array([2, 2]).astype("int32") + + +class TestNearestNeighborInterpCase5(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.out_size = np.array([11, 11]).astype("int32") + + +class TestNearestNeighborInterpCase6(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [1, 1, 128, 64] + self.out_h = 64 + self.out_w = 128 + self.out_size = np.array([65, 129]).astype("int32") + + +class TestNearestNeighborInterpActualShape(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 32, 16] + self.out_h = 64 + self.out_w = 32 + self.out_size = np.array([66, 40]).astype("int32") + + +class TestNearestInterpOpUint8(OpTest): + def setUp(self): + self.out_size = None + self.actual_shape = None + self.init_test_case() + self.op_type = "nearest_interp" + input_np = np.random.randint( + low=0, high=256, size=self.input_shape).astype("uint8") + output_np = nearest_neighbor_interp_np(input_np, self.out_h, self.out_w, + self.out_size, self.actual_shape) + self.inputs = {'X': input_np} + if self.out_size is not None: + self.inputs['OutSize'] = self.out_size + self.attrs = { + 'out_h': self.out_h, + 'out_w': self.out_w, + 'interp_method': self.interp_method + } + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output_with_place(place=core.CPUPlace(), atol=1) + + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [1, 3, 9, 6] + self.out_h = 10 + self.out_w = 9 + + +class TestNearestNeighborInterpCase1Uint8(TestNearestInterpOpUint8): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [2, 3, 128, 64] + self.out_h = 120 + self.out_w = 50 + + +class TestNearestNeighborInterpCase2Uint8(TestNearestInterpOpUint8): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [4, 1, 7, 8] + self.out_h = 5 + self.out_w = 13 + self.out_size = np.array([6, 15]).astype("int32") + + +if __name__ == "__main__": + unittest.main() -- GitLab