提交 fe54adf7 编写于 作者: Q Qiao Longfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add-ctr-reader

......@@ -76,11 +76,12 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("out_h", "output height of interpolate op.");
AddAttr<int>("out_w", "output width of interpolate op.");
AddAttr<std::string>(
"interp_method",
"(string), interpolation method, can be \"bilinear\" for "
"bilinear interpolation and \"nearest\" for nearest "
"neighbor interpolation.");
AddAttr<std::string>("interp_method",
"(string, default \"bilinear\"), interpolation "
"method, can be \"bilinear\" for "
"bilinear interpolation and \"nearest\" for nearest "
"neighbor interpolation.")
.SetDefault("bilinear");
AddComment(R"DOC(
This operator samples input X to given output shape by using specified
interpolation method, the interpolation methods can be \"nearest\"
......@@ -132,11 +133,19 @@ class InterpolateOpGrad : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(interpolate, ops::InterpolateOp, ops::InterpolateOpMaker,
REGISTER_OPERATOR(bilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(interpolate_grad, ops::InterpolateOpGrad);
REGISTER_OP_CPU_KERNEL(interpolate, ops::InterpolateKernel<float>,
REGISTER_OPERATOR(bilinear_interp_grad, ops::InterpolateOpGrad);
REGISTER_OPERATOR(nearest_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad);
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>,
ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(bilinear_interp_grad, ops::InterpolateGradKernel<float>,
ops::InterpolateGradKernel<double>);
REGISTER_OP_CPU_KERNEL(nearest_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>,
ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(interpolate_grad, ops::InterpolateGradKernel<float>,
REGISTER_OP_CPU_KERNEL(nearest_interp_grad, ops::InterpolateGradKernel<float>,
ops::InterpolateGradKernel<double>);
......@@ -284,9 +284,15 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(interpolate, ops::InterpolateOpCUDAKernel<float>,
REGISTER_OP_CUDA_KERNEL(bilinear_interp, ops::InterpolateOpCUDAKernel<float>,
ops::InterpolateOpCUDAKernel<double>,
ops::InterpolateOpCUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(interpolate_grad,
REGISTER_OP_CUDA_KERNEL(bilinear_interp_grad,
ops::InterpolateGradOpCUDAKernel<float>,
ops::InterpolateGradOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(nearest_interp, ops::InterpolateOpCUDAKernel<float>,
ops::InterpolateOpCUDAKernel<double>,
ops::InterpolateOpCUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(nearest_interp_grad,
ops::InterpolateGradOpCUDAKernel<float>,
ops::InterpolateGradOpCUDAKernel<double>);
......@@ -5870,9 +5870,10 @@ def image_resize(input,
raise ValueError(
"The 'resample' of image_resize can only be 'BILINEAR' or 'NEAREST' currently."
)
resample_type = resample_methods[resample]
if out_shape is None and scale is None:
raise ValueError("One of out_shape and scale must not be None.")
helper = LayerHelper('interpolate', **locals())
helper = LayerHelper('{}_interp'.format(resample_type), **locals())
dtype = helper.input_dtype()
def _is_list_or_turple_(data):
......@@ -5906,18 +5907,16 @@ def image_resize(input,
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='interpolate',
type='{}_interp'.format(resample_type),
inputs=inputs,
outputs={"Out": out},
attrs={
"out_h": out_h,
"out_w": out_w,
"interp_method": resample_methods[resample]
})
attrs={"out_h": out_h,
"out_w": out_w,
"interp_method": resample_type})
return out
@templatedoc(op_type="interpolate")
@templatedoc(op_type="bilinear_interp")
def resize_bilinear(input,
out_shape=None,
scale=None,
......@@ -5973,7 +5972,7 @@ def resize_bilinear(input,
return image_resize(input, out_shape, scale, name, 'BILINEAR', actual_shape)
@templatedoc(op_type="interpolate")
@templatedoc(op_type="nearest_interp")
def resize_nearest(input,
out_shape=None,
scale=None,
......
......@@ -81,12 +81,14 @@ list(REMOVE_ITEM TEST_OPS test_dist_se_resnext)
list(REMOVE_ITEM TEST_OPS test_dist_transformer)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_transformer)
list(REMOVE_ITEM TEST_OPS test_image_classification_resnet)
list(REMOVE_ITEM TEST_OPS test_interpolate_op)
list(REMOVE_ITEM TEST_OPS test_bilinear_interp_op)
list(REMOVE_ITEM TEST_OPS test_nearest_interp_op)
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP)
py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR} SERIAL)
py_test_modules(test_interpolate_op MODULES test_interpolate_op SERIAL)
py_test_modules(test_bilinear_interp_op MODULES test_bilinear_interp_op SERIAL)
py_test_modules(test_nearest_interp_op MODULES test_nearest_interp_op SERIAL)
if(WITH_DISTRIBUTE)
py_test_modules(test_dist_train MODULES test_dist_train SERIAL)
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20)
......
......@@ -20,36 +20,6 @@ from op_test import OpTest
import paddle.fluid.core as core
def nearest_neighbor_interp_np(X,
out_h,
out_w,
out_size=None,
actual_shape=None):
"""nearest neighbor interpolation implement in shape [N, C, H, W]"""
if out_size is not None:
out_h = out_size[0]
out_w = out_size[1]
if actual_shape is not None:
out_h = actual_shape[0]
out_w = actual_shape[1]
n, c, in_h, in_w = X.shape
ratio_h = ratio_w = 0.0
if out_h > 1:
ratio_h = (in_h - 1.0) / (out_h - 1.0)
if out_w > 1:
ratio_w = (in_w - 1.0) / (out_w - 1.0)
out = np.zeros((n, c, out_h, out_w))
for i in range(out_h):
in_i = int(ratio_h * i + 0.5)
for j in range(out_w):
in_j = int(ratio_w * j + 0.5)
out[:, :, i, j] = X[:, :, in_i, in_j]
return out.astype(X.dtype)
def bilinear_interp_np(input, out_h, out_w, out_size=None, actual_shape=None):
"""bilinear interpolation implement in shape [N, C, H, W]"""
if out_size is not None:
......@@ -87,22 +57,16 @@ def bilinear_interp_np(input, out_h, out_w, out_size=None, actual_shape=None):
return out.astype(input.dtype)
INTERPOLATE_FUNCS = {
'bilinear': bilinear_interp_np,
'nearest': nearest_neighbor_interp_np,
}
class TestInterpolateOp(OpTest):
class TestBilinearInterpOp(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "interpolate"
self.op_type = "bilinear_interp"
input_np = np.random.random(self.input_shape).astype("float32")
output_np = INTERPOLATE_FUNCS[self.interp_method](
input_np, self.out_h, self.out_w, self.out_size, self.actual_shape)
output_np = bilinear_interp_np(input_np, self.out_h, self.out_w,
self.out_size, self.actual_shape)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
......@@ -129,7 +93,7 @@ class TestInterpolateOp(OpTest):
self.out_size = np.array([3, 3]).astype("int32")
class TestBilinearInterpCase1(TestInterpolateOp):
class TestBilinearInterpCase1(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [4, 1, 7, 8]
......@@ -137,7 +101,7 @@ class TestBilinearInterpCase1(TestInterpolateOp):
self.out_w = 1
class TestBilinearInterpCase2(TestInterpolateOp):
class TestBilinearInterpCase2(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 3, 9, 6]
......@@ -145,7 +109,7 @@ class TestBilinearInterpCase2(TestInterpolateOp):
self.out_w = 12
class TestBilinearInterpCase3(TestInterpolateOp):
class TestBilinearInterpCase3(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [1, 1, 128, 64]
......@@ -153,7 +117,7 @@ class TestBilinearInterpCase3(TestInterpolateOp):
self.out_w = 128
class TestBilinearInterpCase4(TestInterpolateOp):
class TestBilinearInterpCase4(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [4, 1, 7, 8]
......@@ -162,7 +126,7 @@ class TestBilinearInterpCase4(TestInterpolateOp):
self.out_size = np.array([2, 2]).astype("int32")
class TestBilinearInterpCase5(TestInterpolateOp):
class TestBilinearInterpCase5(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 3, 9, 6]
......@@ -171,7 +135,7 @@ class TestBilinearInterpCase5(TestInterpolateOp):
self.out_size = np.array([11, 11]).astype("int32")
class TestBilinearInterpCase6(TestInterpolateOp):
class TestBilinearInterpCase6(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [1, 1, 128, 64]
......@@ -180,7 +144,7 @@ class TestBilinearInterpCase6(TestInterpolateOp):
self.out_size = np.array([65, 129]).astype("int32")
class TestBilinearInterpActualShape(TestInterpolateOp):
class TestBilinearInterpActualShape(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 2, 32, 16]
......@@ -189,25 +153,16 @@ class TestBilinearInterpActualShape(TestInterpolateOp):
self.out_size = np.array([66, 40]).astype("int32")
class TestBilinearInterpBigScale(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [4, 4, 64, 32]
self.out_h = 100
self.out_w = 50
self.out_size = np.array([101, 51]).astype('int32')
class TestInterpolateOpUint8(OpTest):
class TestBilinearInterpOpUint8(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "interpolate"
self.op_type = "bilinear_interp"
input_np = np.random.randint(
low=0, high=256, size=self.input_shape).astype("uint8")
output_np = INTERPOLATE_FUNCS[self.interp_method](
input_np, self.out_h, self.out_w, self.out_size, self.actual_shape)
output_np = bilinear_interp_np(input_np, self.out_h, self.out_w,
self.out_size, self.actual_shape)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
......@@ -228,7 +183,7 @@ class TestInterpolateOpUint8(OpTest):
self.out_w = 9
class TestBilinearInterpCase1Uint8(TestInterpolateOpUint8):
class TestBilinearInterpCase1Uint8(TestBilinearInterpOpUint8):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 128, 64]
......@@ -236,7 +191,7 @@ class TestBilinearInterpCase1Uint8(TestInterpolateOpUint8):
self.out_w = 50
class TestBilinearInterpCase2Uint8(TestInterpolateOpUint8):
class TestBilinearInterpCase2Uint8(TestBilinearInterpOpUint8):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [4, 1, 7, 8]
......@@ -245,91 +200,5 @@ class TestBilinearInterpCase2Uint8(TestInterpolateOpUint8):
self.out_size = np.array([6, 15]).astype("int32")
class TestNearestNeighborInterpCase1(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
class TestNearestNeighborInterpCase2(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
class TestNearestNeighborInterpCase3(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [1, 1, 128, 64]
self.out_h = 64
self.out_w = 128
class TestNearestNeighborInterpCase4(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.out_size = np.array([2, 2]).astype("int32")
class TestNearestNeighborInterpCase5(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.out_size = np.array([11, 11]).astype("int32")
class TestNearestNeighborInterpCase6(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [1, 1, 128, 64]
self.out_h = 64
self.out_w = 128
self.out_size = np.array([65, 129]).astype("int32")
class TestNearestNeighborInterpActualShape(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.out_size = np.array([66, 40]).astype("int32")
class TestNearestNeighborInterpBigScale(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [4, 4, 64, 32]
self.out_h = 100
self.out_w = 50
self.out_size = np.array([101, 51]).astype('int32')
class TestNearestNeighborInterpCase1Uint8(TestInterpolateOpUint8):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 3, 128, 64]
self.out_h = 120
self.out_w = 50
class TestNearestNeighborInterpCase2Uint8(TestInterpolateOpUint8):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [4, 1, 7, 8]
self.out_h = 5
self.out_w = 13
self.out_size = np.array([6, 15]).astype("int32")
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
def nearest_neighbor_interp_np(X,
out_h,
out_w,
out_size=None,
actual_shape=None):
"""nearest neighbor interpolation implement in shape [N, C, H, W]"""
if out_size is not None:
out_h = out_size[0]
out_w = out_size[1]
if actual_shape is not None:
out_h = actual_shape[0]
out_w = actual_shape[1]
n, c, in_h, in_w = X.shape
ratio_h = ratio_w = 0.0
if out_h > 1:
ratio_h = (in_h - 1.0) / (out_h - 1.0)
if out_w > 1:
ratio_w = (in_w - 1.0) / (out_w - 1.0)
out = np.zeros((n, c, out_h, out_w))
for i in range(out_h):
in_i = int(ratio_h * i + 0.5)
for j in range(out_w):
in_j = int(ratio_w * j + 0.5)
out[:, :, i, j] = X[:, :, in_i, in_j]
return out.astype(X.dtype)
class TestNearestInterpOp(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "nearest_interp"
input_np = np.random.random(self.input_shape).astype("float32")
output_np = nearest_neighbor_interp_np(input_np, self.out_h, self.out_w,
self.out_size, self.actual_shape)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
if self.actual_shape is not None:
self.inputs['OutSize'] = self.actual_shape
self.attrs = {
'out_h': self.out_h,
'out_w': self.out_w,
'interp_method': self.interp_method
}
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True)
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 3, 4, 4]
self.out_h = 2
self.out_w = 2
self.out_size = np.array([3, 3]).astype("int32")
class TestNearestNeighborInterpCase1(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
class TestNearestNeighborInterpCase2(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
class TestNearestNeighborInterpCase3(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [1, 1, 128, 64]
self.out_h = 64
self.out_w = 128
class TestNearestNeighborInterpCase4(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.out_size = np.array([2, 2]).astype("int32")
class TestNearestNeighborInterpCase5(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.out_size = np.array([11, 11]).astype("int32")
class TestNearestNeighborInterpCase6(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [1, 1, 128, 64]
self.out_h = 64
self.out_w = 128
self.out_size = np.array([65, 129]).astype("int32")
class TestNearestNeighborInterpActualShape(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.out_size = np.array([66, 40]).astype("int32")
class TestNearestInterpOpUint8(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "nearest_interp"
input_np = np.random.randint(
low=0, high=256, size=self.input_shape).astype("uint8")
output_np = nearest_neighbor_interp_np(input_np, self.out_h, self.out_w,
self.out_size, self.actual_shape)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
self.attrs = {
'out_h': self.out_h,
'out_w': self.out_w,
'interp_method': self.interp_method
}
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output_with_place(place=core.CPUPlace(), atol=1)
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [1, 3, 9, 6]
self.out_h = 10
self.out_w = 9
class TestNearestNeighborInterpCase1Uint8(TestNearestInterpOpUint8):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 3, 128, 64]
self.out_h = 120
self.out_w = 50
class TestNearestNeighborInterpCase2Uint8(TestNearestInterpOpUint8):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [4, 1, 7, 8]
self.out_h = 5
self.out_w = 13
self.out_size = np.array([6, 15]).astype("int32")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册