未验证 提交 6452ab3b 编写于 作者: H HongyuJia 提交者: GitHub

transfer nearest_interp op to phi, change name from nearest_interp_v2 to nearest_interp (#45148)

上级 2fb65e44
......@@ -464,9 +464,9 @@
args : (Tensor label, int num_classes, int num_samples, int ring_id, int rank, int nranks, bool fix_seed, int seed)
output : Tensor(remapped_label), Tensor(sampled_local_class_center)
infer_meta :
func : ClassCenterSampleInferMeta
func : ClassCenterSampleInferMeta
kernel :
func : class_center_sample
func : class_center_sample
- api : clip
args : (Tensor x, Scalar(float) min, Scalar(float) max)
......@@ -1852,6 +1852,17 @@
func : multiply
backward : multiply_grad
- api : nearest_interp
args : (Tensor x, Tensor out_size, Tensor[] size_tensor, Tensor scale_tensor, str data_layout, int out_d, int out_h, int out_w, float[] scale, str interp_method, bool align_corners, int align_mode)
output : Tensor(output)
infer_meta :
func : InterpolateInferMeta
optional: out_size, size_tensor, scale_tensor
kernel :
func : nearest_interp
data_type : x
backward : nearest_interp_grad
- api : nll_loss
args : (Tensor input, Tensor label, Tensor weight, int64_t ignore_index, str reduction)
output : Tensor(out), Tensor(total_weight)
......
......@@ -1619,6 +1619,18 @@
func : multiply_triple_grad
optional : fwd_grad_grad_x, fwd_grad_grad_y, grad_grad_out_grad
- backward_api : nearest_interp_grad
forward : nearest_interp (Tensor x, Tensor out_size, Tensor[] size_tensor, Tensor scale_tensor, str data_layout, int out_d, int out_h, int out_w, float[] scale, str interp_method, bool align_corners, int align_mode) -> Tensor(output)
args : (Tensor x, Tensor out_size, Tensor[] size_tensor, Tensor scale_tensor, Tensor output_grad, str data_layout, int out_d, int out_h, int out_w, float[] scale, str interp_method, bool align_corners, int align_mode)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
optional: out_size, size_tensor, scale_tensor
kernel :
func : nearest_interp_grad
data_type : output_grad
- backward_api : nll_loss_grad
forward : nll_loss (Tensor input, Tensor label, Tensor weight, int64_t ignore_index, str reduction) -> Tensor(out), Tensor(total_weight)
args : (Tensor input, Tensor label, Tensor weight, Tensor total_weight, Tensor out_grad, int64_t ignore_index, str reduction)
......
......@@ -73,7 +73,9 @@ const std::unordered_set<std::string> deprecated_op_names(
"top_k",
"top_k_grad",
"linear_interp",
"linear_interp_grad"});
"linear_interp_grad",
"nearest_interp",
"nearest_interp_grad"});
class DefaultKernelSignatureMap {
public:
......
......@@ -1045,7 +1045,7 @@ PD_REGISTER_KERNEL(bilinear_interp_v2_grad,
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(nearest_interp_v2_grad,
PD_REGISTER_KERNEL(nearest_interp_grad,
CPU,
ALL_LAYOUT,
phi::NearestInterpGradKernel,
......
......@@ -1197,7 +1197,7 @@ PD_REGISTER_KERNEL(bilinear_interp_v2,
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(nearest_interp_v2,
PD_REGISTER_KERNEL(nearest_interp,
CPU,
ALL_LAYOUT,
phi::NearestInterpKernel,
......
......@@ -1578,7 +1578,7 @@ PD_REGISTER_KERNEL(bilinear_interp_v2_grad,
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(nearest_interp_v2_grad,
PD_REGISTER_KERNEL(nearest_interp_grad,
GPU,
ALL_LAYOUT,
phi::NearestInterpGradKernel,
......
......@@ -1450,7 +1450,7 @@ PD_REGISTER_KERNEL(bilinear_interp_v2,
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(nearest_interp_v2,
PD_REGISTER_KERNEL(nearest_interp,
GPU,
ALL_LAYOUT,
phi::NearestInterpKernel,
......
......@@ -33,7 +33,7 @@ KernelSignature BilinearInterpOpArgumentMapping(
KernelSignature NearestInterpOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("nearest_interp_v2",
return KernelSignature("nearest_interp",
{"X", "OutSize", "SizeTensor", "Scale"},
{"data_layout",
"out_d",
......@@ -107,7 +107,7 @@ KernelSignature BilinearInterpGradOpArgumentMapping(
KernelSignature NearestInterpGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("nearest_interp_v2_grad",
return KernelSignature("nearest_interp_grad",
{"X", "OutSize", "SizeTensor", "Scale", "Out@GRAD"},
{"data_layout",
"out_d",
......@@ -167,7 +167,10 @@ KernelSignature BicubicInterpGradOpArgumentMapping(
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(linear_interp_v2, linear_interp);
PD_REGISTER_BASE_KERNEL_NAME(nearest_interp_v2, nearest_interp);
PD_REGISTER_BASE_KERNEL_NAME(linear_interp_v2_grad, linear_interp_grad);
PD_REGISTER_BASE_KERNEL_NAME(nearest_interp_v2_grad, nearest_interp_grad);
PD_REGISTER_ARG_MAPPING_FN(bilinear_interp_v2,
phi::BilinearInterpOpArgumentMapping);
......
......@@ -22,10 +22,39 @@ import paddle.fluid as fluid
import paddle.nn as nn
import paddle
from paddle.nn.functional import interpolate
from paddle._C_ops import final_state_nearest_interp
paddle.enable_static()
def nearest_interp_test(x,
OutSize=None,
SizeTensor=None,
Scale=None,
data_layout='NCHW',
out_d=-1,
out_h=-1,
out_w=-1,
scale=[],
interp_method='linear',
align_corners=False,
align_mode=1):
if isinstance(scale, float) or isinstance(scale, int):
scale_list = []
for _ in range(len(x.shape) - 2):
scale_list.append(scale)
scale = list(map(float, scale_list))
elif isinstance(scale, list) or isinstance(scale, tuple):
scale = list(map(float, scale))
if SizeTensor is not None:
if not isinstance(SizeTensor, list) and not isinstance(
SizeTensor, tuple):
SizeTensor = [SizeTensor]
return final_state_nearest_interp(x, OutSize, SizeTensor, Scale,
data_layout, out_d, out_h, out_w, scale,
interp_method, align_corners, align_mode)
def nearest_neighbor_interp_np(X,
out_h,
out_w,
......@@ -160,6 +189,7 @@ def nearest_neighbor_interp3d_np(X,
class TestNearestInterpOp(OpTest):
def setUp(self):
self.python_api = nearest_interp_test
self.out_size = None
self.actual_shape = None
self.data_layout = 'NCHW'
......@@ -254,17 +284,17 @@ class TestNearestInterpOp(OpTest):
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True)
self.check_grad(['X'], 'Out', in_place=True, check_eager=True)
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 3, 4, 5]
self.out_h = 2
self.out_w = 2
self.scale = 0.
self.scale = []
self.out_size = np.array([3, 3]).astype("int32")
self.align_corners = True
......@@ -277,7 +307,7 @@ class TestNearestNeighborInterpCase1(TestNearestInterpOp):
self.out_d = 1
self.out_h = 1
self.out_w = 1
self.scale = 0.
self.scale = []
self.align_corners = True
......@@ -288,7 +318,7 @@ class TestNearestNeighborInterpCase2(TestNearestInterpOp):
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.scale = []
self.align_corners = True
......@@ -299,7 +329,7 @@ class TestNearestNeighborInterpCase3(TestNearestInterpOp):
self.input_shape = [1, 1, 32, 64]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.scale = []
self.align_corners = True
......@@ -310,7 +340,7 @@ class TestNearestNeighborInterpCase4(TestNearestInterpOp):
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.scale = 0.
self.scale = []
self.out_size = np.array([2, 2]).astype("int32")
self.align_corners = True
......@@ -322,7 +352,7 @@ class TestNearestNeighborInterpCase5(TestNearestInterpOp):
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.scale = []
self.out_size = np.array([11, 11]).astype("int32")
self.align_corners = True
......@@ -334,7 +364,7 @@ class TestNearestNeighborInterpCase6(TestNearestInterpOp):
self.input_shape = [1, 1, 32, 64]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.scale = []
self.out_size = np.array([65, 129]).astype("int32")
self.align_corners = True
......@@ -346,7 +376,7 @@ class TestNearestNeighborInterpSame(TestNearestInterpOp):
self.input_shape = [2, 3, 32, 64]
self.out_h = 32
self.out_w = 64
self.scale = 0.
self.scale = []
self.align_corners = True
......@@ -357,7 +387,7 @@ class TestNearestNeighborInterpActualShape(TestNearestInterpOp):
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.scale = []
self.out_size = np.array([66, 40]).astype("int32")
self.align_corners = True
......@@ -369,7 +399,7 @@ class TestNearestNeighborInterpDataLayout(TestNearestInterpOp):
self.input_shape = [2, 4, 4, 5]
self.out_h = 2
self.out_w = 2
self.scale = 0.
self.scale = []
self.out_size = np.array([3, 8]).astype("int32")
self.align_corners = True
self.data_layout = "NHWC"
......@@ -378,6 +408,7 @@ class TestNearestNeighborInterpDataLayout(TestNearestInterpOp):
class TestNearestInterpOpUint8(OpTest):
def setUp(self):
self.python_api = nearest_interp_test
self.out_size = None
self.actual_shape = None
self.init_test_case()
......@@ -422,14 +453,16 @@ class TestNearestInterpOpUint8(OpTest):
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output_with_place(place=core.CPUPlace(), atol=1)
self.check_output_with_place(place=core.CPUPlace(),
atol=1,
check_eager=True)
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [1, 3, 9, 6]
self.out_h = 10
self.out_w = 9
self.scale = 0.
self.scale = []
self.align_corners = True
......@@ -440,7 +473,7 @@ class TestNearestNeighborInterpCase1Uint8(TestNearestInterpOpUint8):
self.input_shape = [2, 3, 32, 64]
self.out_h = 80
self.out_w = 40
self.scale = 0.
self.scale = []
self.align_corners = True
......@@ -451,7 +484,7 @@ class TestNearestNeighborInterpCase2Uint8(TestNearestInterpOpUint8):
self.input_shape = [4, 1, 7, 8]
self.out_h = 5
self.out_w = 13
self.scale = 0.
self.scale = []
self.out_size = np.array([6, 15]).astype("int32")
self.align_corners = True
......@@ -514,6 +547,7 @@ class TestNearestNeighbor3DInterp(TestNearestInterpOp):
class TestNearestInterpOp_attr_tensor(OpTest):
def setUp(self):
self.python_api = nearest_interp_test
self.out_size = None
self.actual_shape = None
self.init_test_case()
......@@ -569,17 +603,17 @@ class TestNearestInterpOp_attr_tensor(OpTest):
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True)
self.check_grad(['X'], 'Out', in_place=True, check_eager=True)
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 5, 4, 4]
self.out_h = 3
self.out_w = 3
self.scale = 0.
self.scale = []
self.out_size = [3, 3]
self.align_corners = True
......@@ -592,7 +626,7 @@ class TestNearestInterp_attr_tensor_Case1(TestNearestInterpOp_attr_tensor):
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.scale = []
self.out_size = [8, 12]
self.align_corners = True
......@@ -605,7 +639,7 @@ class TestNearestInterp_attr_tensor_Case2(TestNearestInterpOp_attr_tensor):
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.scale = []
self.out_size = np.array([66, 40]).astype("int32")
self.align_corners = True
self.shape_by_1Dtensor = True
......
......@@ -607,7 +607,17 @@ def interpolate(x,
elif resample_type == "trilinear":
out = _C_ops.trilinear_interp_v2(x, *dy_attr)
elif resample_type == "nearest":
out = _C_ops.nearest_interp_v2(x, *dy_attr)
if in_dygraph_mode():
out = _C_ops.final_state_nearest_interp(
x, inputs['OutSize'] if 'OutSize' in inputs else None,
inputs['SizeTensor'] if 'SizeTensor' in inputs else None,
inputs['Scale'] if 'Scale' in inputs else None,
attrs['data_layout'], attrs['out_d'], attrs['out_h'],
attrs['out_w'], attrs['scale'] if 'scale' in attrs else [],
attrs['interp_method'], attrs['align_corners'],
attrs['align_mode'])
else:
out = _C_ops.nearest_interp_v2(x, *dy_attr)
elif resample_type == "bicubic":
out = _C_ops.bicubic_interp_v2(x, *dy_attr)
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册