提交 cfda1fde 编写于 作者: D dengkaipeng

add attr scale. test=develop

上级 f9061796
...@@ -36,10 +36,19 @@ class InterpolateOp : public framework::OperatorWithKernel { ...@@ -36,10 +36,19 @@ class InterpolateOp : public framework::OperatorWithKernel {
"Interpolation method can only be \"bilinear\" or \"nearest\"."); "Interpolation method can only be \"bilinear\" or \"nearest\".");
auto dim_x = ctx->GetInputDim("X"); // NCHW format auto dim_x = ctx->GetInputDim("X"); // NCHW format
int out_h = ctx->Attrs().Get<int>("out_h");
int out_w = ctx->Attrs().Get<int>("out_w");
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4");
// priority: OutSize > scale > out_h/out_w
int out_h, out_w;
float scale = ctx->Attrs().Get<float>("scale");
if (scale > 0) {
out_h = dim_x[2] * scale;
out_w = dim_x[3] * scale;
} else {
out_h = ctx->Attrs().Get<int>("out_h");
out_w = ctx->Attrs().Get<int>("out_w");
}
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) { if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
auto out_size_dim = ctx->GetInputDim("OutSize"); auto out_size_dim = ctx->GetInputDim("OutSize");
PADDLE_ENFORCE_EQ(out_size_dim.size(), 1, PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
...@@ -76,6 +85,8 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -76,6 +85,8 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("out_h", "output height of interpolate op."); AddAttr<int>("out_h", "output height of interpolate op.");
AddAttr<int>("out_w", "output width of interpolate op."); AddAttr<int>("out_w", "output width of interpolate op.");
AddAttr<float>("scale", "scale factor of interpolate op.")
.SetDefault(0.);
AddAttr<std::string>("interp_method", AddAttr<std::string>("interp_method",
"(string, default \"bilinear\"), interpolation " "(string, default \"bilinear\"), interpolation "
"method, can be \"bilinear\" for " "method, can be \"bilinear\" for "
......
...@@ -192,9 +192,21 @@ class InterpolateOpCUDAKernel : public framework::OpKernel<T> { ...@@ -192,9 +192,21 @@ class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
auto* output = ctx.Output<Tensor>("Out"); auto* output = ctx.Output<Tensor>("Out");
auto* input_data = input->data<T>(); auto* input_data = input->data<T>();
int n = input->dims()[0];
int c = input->dims()[1];
int in_h = input->dims()[2];
int in_w = input->dims()[3];
auto interp_method = ctx.Attr<std::string>("interp_method"); auto interp_method = ctx.Attr<std::string>("interp_method");
int out_h = ctx.Attr<int>("out_h"); int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w"); int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = in_h * scale;
out_w = in_w * scale;
}
auto out_size = ctx.Input<Tensor>("OutSize"); auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) { if (out_size != nullptr) {
Tensor sizes; Tensor sizes;
...@@ -207,11 +219,6 @@ class InterpolateOpCUDAKernel : public framework::OpKernel<T> { ...@@ -207,11 +219,6 @@ class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
bool align_corners = ctx.Attr<bool>("align_corners"); bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode"); int align_mode = ctx.Attr<int>("align_mode");
int n = input->dims()[0];
int c = input->dims()[1];
int in_h = input->dims()[2];
int in_w = input->dims()[3];
auto* output_data = auto* output_data =
output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace()); output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
...@@ -268,14 +275,20 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -268,14 +275,20 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
math::SetConstant<platform::CUDADeviceContext, T> zero; math::SetConstant<platform::CUDADeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0)); zero(device_ctx, input_grad, static_cast<T>(0.0));
int n = input_grad->dims()[0];
int c = input_grad->dims()[1];
int in_h = input_grad->dims()[2];
int in_w = input_grad->dims()[3];
auto interp_method = ctx.Attr<std::string>("interp_method"); auto interp_method = ctx.Attr<std::string>("interp_method");
int out_h = ctx.Attr<int>("out_h"); int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w"); int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = in_h * scale;
out_w - in_w * scale;
}
auto out_size = ctx.Input<Tensor>("OutSize"); auto out_size = ctx.Input<Tensor>("OutSize");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
if (out_size != nullptr) { if (out_size != nullptr) {
Tensor sizes; Tensor sizes;
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes); framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
...@@ -284,10 +297,9 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -284,10 +297,9 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
out_w = size_data[1]; out_w = size_data[1];
} }
int n = input_grad->dims()[0]; bool align_corners = ctx.Attr<bool>("align_corners");
int c = input_grad->dims()[1]; int align_mode = ctx.Attr<int>("align_mode");
int in_h = input_grad->dims()[2];
int in_w = input_grad->dims()[3];
int in_hw = in_h * in_w; int in_hw = in_h * in_w;
int out_hw = out_h * out_w; int out_hw = out_h * out_w;
......
...@@ -163,9 +163,21 @@ class InterpolateKernel : public framework::OpKernel<T> { ...@@ -163,9 +163,21 @@ class InterpolateKernel : public framework::OpKernel<T> {
auto* input = ctx.Input<Tensor>("X"); auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out"); auto* output = ctx.Output<Tensor>("Out");
const int n = input->dims()[0];
const int c = input->dims()[1];
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
std::string interp_method = ctx.Attr<std::string>("interp_method"); std::string interp_method = ctx.Attr<std::string>("interp_method");
int out_h = ctx.Attr<int>("out_h"); int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w"); int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = in_h * scale;
out_w = in_w * scale;
}
auto out_size = ctx.Input<Tensor>("OutSize"); auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) { if (out_size != nullptr) {
auto out_size_data = out_size->data<int>(); auto out_size_data = out_size->data<int>();
...@@ -175,11 +187,6 @@ class InterpolateKernel : public framework::OpKernel<T> { ...@@ -175,11 +187,6 @@ class InterpolateKernel : public framework::OpKernel<T> {
bool align_corners = ctx.Attr<bool>("align_corners"); bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode"); int align_mode = ctx.Attr<int>("align_mode");
const int n = input->dims()[0];
const int c = input->dims()[1];
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace()); output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
auto& device_ctx = auto& device_ctx =
ctx.template device_context<platform::CPUDeviceContext>(); ctx.template device_context<platform::CPUDeviceContext>();
...@@ -221,23 +228,31 @@ class InterpolateGradKernel : public framework::OpKernel<T> { ...@@ -221,23 +228,31 @@ class InterpolateGradKernel : public framework::OpKernel<T> {
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
const int n = input->dims()[0];
const int c = input->dims()[1];
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
std::string interp_method = ctx.Attr<std::string>("interp_method"); std::string interp_method = ctx.Attr<std::string>("interp_method");
int out_h = ctx.Attr<int>("out_h"); int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w"); int out_w = ctx.Attr<int>("out_w");
float scale = ctx.Attr<float>("scale");
if (scale > 0) {
out_h = in_h * scale;
out_w = in_w * scale;
}
auto out_size = ctx.Input<Tensor>("OutSize"); auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) { if (out_size != nullptr) {
auto out_size_data = out_size->data<int>(); auto out_size_data = out_size->data<int>();
out_h = out_size_data[0]; out_h = out_size_data[0];
out_w = out_size_data[1]; out_w = out_size_data[1];
} }
bool align_corners = ctx.Attr<bool>("align_corners"); bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode"); int align_mode = ctx.Attr<int>("align_mode");
const int n = input->dims()[0];
const int c = input->dims()[1];
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
input_grad->mutable_data<T>({n, c, in_h, in_w}, ctx.GetPlace()); input_grad->mutable_data<T>({n, c, in_h, in_w}, ctx.GetPlace());
auto& device_ctx = auto& device_ctx =
ctx.template device_context<platform::CPUDeviceContext>(); ctx.template device_context<platform::CPUDeviceContext>();
......
...@@ -7056,10 +7056,10 @@ def image_resize(input, ...@@ -7056,10 +7056,10 @@ def image_resize(input,
out_shape(list|tuple|Variable|None): Output shape of image resize out_shape(list|tuple|Variable|None): Output shape of image resize
layer, the shape is (out_h, out_w). layer, the shape is (out_h, out_w).
Default: None Default: None
scale(float|None): The multiplier for the input height or width. scale(float|None): The multiplier for the input height or width. At
At least one of out_shape or scale must be set. least one of :attr:`out_shape` or :attr:`scale`` must be set.
And out_shape has a higher priority than scale. And :attr:`scale` has a higher priority than :attr:`out_shape`.
Default: None Default: None.
name(str|None): A name for this layer(optional). If set None, the layer name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
resample(str): The resample method. It supports 'BILINEAR' and 'NEAREST' resample(str): The resample method. It supports 'BILINEAR' and 'NEAREST'
...@@ -7093,10 +7093,12 @@ def image_resize(input, ...@@ -7093,10 +7093,12 @@ def image_resize(input,
Raises: Raises:
TypeError: out_shape should be a list or tuple or Variable. TypeError: out_shape should be a list or tuple or Variable.
TypeError: actual_shape should either be Variable or None. TypeError: actual_shape should either be Variable or None.
TypeError: scale should either be Variable or None.
ValueError: The 'resample' of image_resize can only be 'BILINEAR' ValueError: The 'resample' of image_resize can only be 'BILINEAR'
or 'NEAREST' currently. or 'NEAREST' currently.
ValueError: One of out_shape and scale must not be None. ValueError: One of out_shape and scale must not be None.
ValueError: out_shape length should be 2. ValueError: out_shape length should be 2.
ValueError: scale should be greater than zero.
TypeError: align_corners shoule be a bool value TypeError: align_corners shoule be a bool value
ValueError: align_mode can only be '0' or '1' ValueError: align_mode can only be '0' or '1'
...@@ -7128,9 +7130,15 @@ def image_resize(input, ...@@ -7128,9 +7130,15 @@ def image_resize(input,
def _is_list_or_turple_(data): def _is_list_or_turple_(data):
return (isinstance(data, list) or isinstance(data, tuple)) return (isinstance(data, list) or isinstance(data, tuple))
inputs = {"X": input}
attrs={
"interp_method": resample_type,
"align_corners": align_corners,
"align_mode": align_mode
}
out_h = 0 out_h = 0
out_w = 0 out_w = 0
inputs = {"X": input}
if out_shape is not None: if out_shape is not None:
if isinstance(out_shape, Variable): if isinstance(out_shape, Variable):
warnings.warn("out_shape as Variable type is deprecated, \ warnings.warn("out_shape as Variable type is deprecated, \
...@@ -7143,11 +7151,14 @@ def image_resize(input, ...@@ -7143,11 +7151,14 @@ def image_resize(input,
raise ValueError("out_shape length should be 2.") raise ValueError("out_shape length should be 2.")
out_shape = list(map(int, out_shape)) out_shape = list(map(int, out_shape))
out_h = out_shape[0] attrs['out_h'] = out_shape[0]
out_w = out_shape[1] attrs['out_w'] = out_shape[1]
else: else:
out_h = int(input.shape[2] * scale) if not isinstance(scale, float):
out_w = int(input.shape[3] * scale) raise TypeError("scale should either be Variable or None.")
if scale <= 0:
raise ValueError("scale should be greater than zero.")
attrs['scale'] = scale
if isinstance(actual_shape, Variable): if isinstance(actual_shape, Variable):
inputs["OutSize"] = actual_shape inputs["OutSize"] = actual_shape
...@@ -7159,13 +7170,7 @@ def image_resize(input, ...@@ -7159,13 +7170,7 @@ def image_resize(input,
type='{}_interp'.format(resample_type), type='{}_interp'.format(resample_type),
inputs=inputs, inputs=inputs,
outputs={"Out": out}, outputs={"Out": out},
attrs={ attrs=attrs)
"out_h": out_h,
"out_w": out_w,
"interp_method": resample_type,
"align_corners": align_corners,
"align_mode": align_mode
})
return out return out
...@@ -7236,8 +7241,9 @@ def resize_bilinear(input, ...@@ -7236,8 +7241,9 @@ def resize_bilinear(input,
out_shape(${out_size_type}): ${out_size_comment}. out_shape(${out_size_type}): ${out_size_comment}.
scale(float|None): The multiplier for the input height or width. At scale(float|None): The multiplier for the input height or width. At
least one of out_shape or scale must be set. And out_shape has least one of :attr:`out_shape` or :attr:`scale`` must be set.
a higher priority than scale. Default: None. And :attr:`scale` has a higher priority than :attr:`out_shape`.
Default: None.
name(str|None): The output variable name. name(str|None): The output variable name.
actual_shape(Variable): An optional input to specify output shape actual_shape(Variable): An optional input to specify output shape
...@@ -7327,8 +7333,9 @@ def resize_nearest(input, ...@@ -7327,8 +7333,9 @@ def resize_nearest(input,
out_shape(${out_size_type}): ${out_size_comment}. out_shape(${out_size_type}): ${out_size_comment}.
scale(float|None): The multiplier for the input height or width. At scale(float|None): The multiplier for the input height or width. At
least one of out_shape or scale must be set. And out_shape has least one of :attr:`out_shape` or :attr:`scale`` must be set.
a higher priority than scale. Default: None. And :attr:`scale` has a higher priority than :attr:`out_shape`.
Default: None.
name(str|None): The output variable name. name(str|None): The output variable name.
actual_shape(Variable): An optional input to specify output shape actual_shape(Variable): An optional input to specify output shape
......
...@@ -91,7 +91,14 @@ class TestBilinearInterpOp(OpTest): ...@@ -91,7 +91,14 @@ class TestBilinearInterpOp(OpTest):
self.op_type = "bilinear_interp" self.op_type = "bilinear_interp"
input_np = np.random.random(self.input_shape).astype("float32") input_np = np.random.random(self.input_shape).astype("float32")
output_np = bilinear_interp_np(input_np, self.out_h, self.out_w, if self.scale > 0:
out_h = int(self.input_shape[2] * self.scale)
out_w = int(self.input_shape[3] * self.scale)
else:
out_h = self.out_h
out_w = self.out_w
output_np = bilinear_interp_np(input_np, out_h, out_w,
self.out_size, self.actual_shape, self.out_size, self.actual_shape,
self.align_corners, self.align_mode) self.align_corners, self.align_mode)
self.inputs = {'X': input_np} self.inputs = {'X': input_np}
...@@ -99,9 +106,11 @@ class TestBilinearInterpOp(OpTest): ...@@ -99,9 +106,11 @@ class TestBilinearInterpOp(OpTest):
self.inputs['OutSize'] = self.out_size self.inputs['OutSize'] = self.out_size
if self.actual_shape is not None: if self.actual_shape is not None:
self.inputs['OutSize'] = self.actual_shape self.inputs['OutSize'] = self.actual_shape
self.attrs = { self.attrs = {
'out_h': self.out_h, 'out_h': self.out_h,
'out_w': self.out_w, 'out_w': self.out_w,
'scale': self.scale,
'interp_method': self.interp_method, 'interp_method': self.interp_method,
'align_corners': self.align_corners, 'align_corners': self.align_corners,
'align_mode': self.align_mode 'align_mode': self.align_mode
...@@ -119,6 +128,7 @@ class TestBilinearInterpOp(OpTest): ...@@ -119,6 +128,7 @@ class TestBilinearInterpOp(OpTest):
self.input_shape = [2, 3, 4, 4] self.input_shape = [2, 3, 4, 4]
self.out_h = 2 self.out_h = 2
self.out_w = 2 self.out_w = 2
self.scale = 0.
self.out_size = np.array([3, 3]).astype("int32") self.out_size = np.array([3, 3]).astype("int32")
self.align_corners = True self.align_corners = True
self.align_mode = 1 self.align_mode = 1
...@@ -130,6 +140,7 @@ class TestBilinearInterpCase1(TestBilinearInterpOp): ...@@ -130,6 +140,7 @@ class TestBilinearInterpCase1(TestBilinearInterpOp):
self.input_shape = [4, 1, 7, 8] self.input_shape = [4, 1, 7, 8]
self.out_h = 1 self.out_h = 1
self.out_w = 1 self.out_w = 1
self.scale = 0.
self.align_corners = True self.align_corners = True
self.align_mode = 1 self.align_mode = 1
...@@ -140,6 +151,7 @@ class TestBilinearInterpCase2(TestBilinearInterpOp): ...@@ -140,6 +151,7 @@ class TestBilinearInterpCase2(TestBilinearInterpOp):
self.input_shape = [3, 3, 9, 6] self.input_shape = [3, 3, 9, 6]
self.out_h = 12 self.out_h = 12
self.out_w = 12 self.out_w = 12
self.scale = 0.
self.align_corners = True self.align_corners = True
self.align_mode = 1 self.align_mode = 1
...@@ -150,6 +162,7 @@ class TestBilinearInterpCase3(TestBilinearInterpOp): ...@@ -150,6 +162,7 @@ class TestBilinearInterpCase3(TestBilinearInterpOp):
self.input_shape = [1, 1, 128, 64] self.input_shape = [1, 1, 128, 64]
self.out_h = 64 self.out_h = 64
self.out_w = 128 self.out_w = 128
self.scale = 0.
self.align_corners = True self.align_corners = True
self.align_mode = 1 self.align_mode = 1
...@@ -160,6 +173,7 @@ class TestBilinearInterpCase4(TestBilinearInterpOp): ...@@ -160,6 +173,7 @@ class TestBilinearInterpCase4(TestBilinearInterpOp):
self.input_shape = [4, 1, 7, 8] self.input_shape = [4, 1, 7, 8]
self.out_h = 1 self.out_h = 1
self.out_w = 1 self.out_w = 1
self.scale = 0.
self.out_size = np.array([2, 2]).astype("int32") self.out_size = np.array([2, 2]).astype("int32")
self.align_corners = True self.align_corners = True
self.align_mode = 1 self.align_mode = 1
...@@ -171,6 +185,7 @@ class TestBilinearInterpCase5(TestBilinearInterpOp): ...@@ -171,6 +185,7 @@ class TestBilinearInterpCase5(TestBilinearInterpOp):
self.input_shape = [3, 3, 9, 6] self.input_shape = [3, 3, 9, 6]
self.out_h = 12 self.out_h = 12
self.out_w = 12 self.out_w = 12
self.scale = 0.
self.out_size = np.array([11, 11]).astype("int32") self.out_size = np.array([11, 11]).astype("int32")
self.align_corners = True self.align_corners = True
self.align_mode = 1 self.align_mode = 1
...@@ -182,6 +197,7 @@ class TestBilinearInterpCase6(TestBilinearInterpOp): ...@@ -182,6 +197,7 @@ class TestBilinearInterpCase6(TestBilinearInterpOp):
self.input_shape = [1, 1, 128, 64] self.input_shape = [1, 1, 128, 64]
self.out_h = 64 self.out_h = 64
self.out_w = 128 self.out_w = 128
self.scale = 0.
self.out_size = np.array([65, 129]).astype("int32") self.out_size = np.array([65, 129]).astype("int32")
self.align_corners = True self.align_corners = True
self.align_mode = 1 self.align_mode = 1
...@@ -193,6 +209,7 @@ class TestBilinearInterpActualShape(TestBilinearInterpOp): ...@@ -193,6 +209,7 @@ class TestBilinearInterpActualShape(TestBilinearInterpOp):
self.input_shape = [3, 2, 32, 16] self.input_shape = [3, 2, 32, 16]
self.out_h = 64 self.out_h = 64
self.out_w = 32 self.out_w = 32
self.scale = 0.
self.out_size = np.array([66, 40]).astype("int32") self.out_size = np.array([66, 40]).astype("int32")
self.align_corners = True self.align_corners = True
self.align_mode = 1 self.align_mode = 1
...@@ -206,15 +223,25 @@ class TestBilinearInterpOpUint8(OpTest): ...@@ -206,15 +223,25 @@ class TestBilinearInterpOpUint8(OpTest):
self.op_type = "bilinear_interp" self.op_type = "bilinear_interp"
input_np = np.random.randint( input_np = np.random.randint(
low=0, high=256, size=self.input_shape).astype("uint8") low=0, high=256, size=self.input_shape).astype("uint8")
output_np = bilinear_interp_np(input_np, self.out_h, self.out_w,
if self.scale > 0:
out_h = int(self.input_shape[2] * self.scale)
out_w = int(self.input_shape[3] * self.scale)
else:
out_h = self.out_h
out_w = self.out_w
output_np = bilinear_interp_np(input_np, out_h, out_w,
self.out_size, self.actual_shape, self.out_size, self.actual_shape,
self.align_corners, self.align_mode) self.align_corners, self.align_mode)
self.inputs = {'X': input_np} self.inputs = {'X': input_np}
if self.out_size is not None: if self.out_size is not None:
self.inputs['OutSize'] = self.out_size self.inputs['OutSize'] = self.out_size
self.attrs = { self.attrs = {
'out_h': self.out_h, 'out_h': self.out_h,
'out_w': self.out_w, 'out_w': self.out_w,
'scale': self.scale,
'interp_method': self.interp_method, 'interp_method': self.interp_method,
'align_corners': self.align_corners, 'align_corners': self.align_corners,
'align_mode': self.align_mode 'align_mode': self.align_mode
...@@ -229,6 +256,7 @@ class TestBilinearInterpOpUint8(OpTest): ...@@ -229,6 +256,7 @@ class TestBilinearInterpOpUint8(OpTest):
self.input_shape = [1, 3, 9, 6] self.input_shape = [1, 3, 9, 6]
self.out_h = 10 self.out_h = 10
self.out_w = 9 self.out_w = 9
self.scale = 0.
self.align_corners = True self.align_corners = True
self.align_mode = 1 self.align_mode = 1
...@@ -239,6 +267,7 @@ class TestBilinearInterpCase1Uint8(TestBilinearInterpOpUint8): ...@@ -239,6 +267,7 @@ class TestBilinearInterpCase1Uint8(TestBilinearInterpOpUint8):
self.input_shape = [2, 3, 128, 64] self.input_shape = [2, 3, 128, 64]
self.out_h = 120 self.out_h = 120
self.out_w = 50 self.out_w = 50
self.scale = 0.
self.align_corners = True self.align_corners = True
self.align_mode = 1 self.align_mode = 1
...@@ -249,6 +278,7 @@ class TestBilinearInterpCase2Uint8(TestBilinearInterpOpUint8): ...@@ -249,6 +278,7 @@ class TestBilinearInterpCase2Uint8(TestBilinearInterpOpUint8):
self.input_shape = [4, 1, 7, 8] self.input_shape = [4, 1, 7, 8]
self.out_h = 5 self.out_h = 5
self.out_w = 13 self.out_w = 13
self.scale = 0.
self.out_size = np.array([6, 15]).astype("int32") self.out_size = np.array([6, 15]).astype("int32")
self.align_corners = True self.align_corners = True
self.align_mode = 1 self.align_mode = 1
...@@ -272,5 +302,38 @@ class TestBilinearInterpWithMethod3(TestBilinearInterpOp): ...@@ -272,5 +302,38 @@ class TestBilinearInterpWithMethod3(TestBilinearInterpOp):
self.align_mode = 0 self.align_mode = 0
class TestBilinearInterpScale1(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 16, 32]
self.out_h = 60
self.out_w = 25
self.scale = 2.
self.align_corners = True
self.align_mode = 1
class TestBilinearInterpScale2(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 16, 32]
self.out_h = 60
self.out_w = 25
self.scale = 1.
self.align_corners = True
self.align_mode = 1
class TestBilinearInterpScale3(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 16, 32]
self.out_h = 60
self.out_w = 25
self.scale = 1.5
self.align_corners = True
self.align_mode = 1
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -73,7 +73,14 @@ class TestNearestInterpOp(OpTest): ...@@ -73,7 +73,14 @@ class TestNearestInterpOp(OpTest):
self.op_type = "nearest_interp" self.op_type = "nearest_interp"
input_np = np.random.random(self.input_shape).astype("float32") input_np = np.random.random(self.input_shape).astype("float32")
output_np = nearest_neighbor_interp_np(input_np, self.out_h, self.out_w, if self.scale > 0:
out_h = int(self.input_shape[2] * self.scale)
out_w = int(self.input_shape[3] * self.scale)
else:
out_h = self.out_h
out_w = self.out_w
output_np = nearest_neighbor_interp_np(input_np, out_h, out_w,
self.out_size, self.actual_shape, self.out_size, self.actual_shape,
self.align_corners) self.align_corners)
self.inputs = {'X': input_np} self.inputs = {'X': input_np}
...@@ -84,6 +91,7 @@ class TestNearestInterpOp(OpTest): ...@@ -84,6 +91,7 @@ class TestNearestInterpOp(OpTest):
self.attrs = { self.attrs = {
'out_h': self.out_h, 'out_h': self.out_h,
'out_w': self.out_w, 'out_w': self.out_w,
'scale': self.scale,
'interp_method': self.interp_method, 'interp_method': self.interp_method,
'align_corners': self.align_corners, 'align_corners': self.align_corners,
} }
...@@ -100,6 +108,7 @@ class TestNearestInterpOp(OpTest): ...@@ -100,6 +108,7 @@ class TestNearestInterpOp(OpTest):
self.input_shape = [2, 3, 4, 4] self.input_shape = [2, 3, 4, 4]
self.out_h = 2 self.out_h = 2
self.out_w = 2 self.out_w = 2
self.scale = 0.
self.out_size = np.array([3, 3]).astype("int32") self.out_size = np.array([3, 3]).astype("int32")
self.align_corners = True self.align_corners = True
...@@ -110,6 +119,7 @@ class TestNearestNeighborInterpCase1(TestNearestInterpOp): ...@@ -110,6 +119,7 @@ class TestNearestNeighborInterpCase1(TestNearestInterpOp):
self.input_shape = [4, 1, 7, 8] self.input_shape = [4, 1, 7, 8]
self.out_h = 1 self.out_h = 1
self.out_w = 1 self.out_w = 1
self.scale = 0.
self.align_corners = True self.align_corners = True
...@@ -119,6 +129,7 @@ class TestNearestNeighborInterpCase2(TestNearestInterpOp): ...@@ -119,6 +129,7 @@ class TestNearestNeighborInterpCase2(TestNearestInterpOp):
self.input_shape = [3, 3, 9, 6] self.input_shape = [3, 3, 9, 6]
self.out_h = 12 self.out_h = 12
self.out_w = 12 self.out_w = 12
self.scale = 0.
self.align_corners = True self.align_corners = True
...@@ -128,6 +139,7 @@ class TestNearestNeighborInterpCase3(TestNearestInterpOp): ...@@ -128,6 +139,7 @@ class TestNearestNeighborInterpCase3(TestNearestInterpOp):
self.input_shape = [1, 1, 128, 64] self.input_shape = [1, 1, 128, 64]
self.out_h = 64 self.out_h = 64
self.out_w = 128 self.out_w = 128
self.scale = 0.
self.align_corners = True self.align_corners = True
...@@ -137,6 +149,7 @@ class TestNearestNeighborInterpCase4(TestNearestInterpOp): ...@@ -137,6 +149,7 @@ class TestNearestNeighborInterpCase4(TestNearestInterpOp):
self.input_shape = [4, 1, 7, 8] self.input_shape = [4, 1, 7, 8]
self.out_h = 1 self.out_h = 1
self.out_w = 1 self.out_w = 1
self.scale = 0.
self.out_size = np.array([2, 2]).astype("int32") self.out_size = np.array([2, 2]).astype("int32")
self.align_corners = True self.align_corners = True
...@@ -147,6 +160,7 @@ class TestNearestNeighborInterpCase5(TestNearestInterpOp): ...@@ -147,6 +160,7 @@ class TestNearestNeighborInterpCase5(TestNearestInterpOp):
self.input_shape = [3, 3, 9, 6] self.input_shape = [3, 3, 9, 6]
self.out_h = 12 self.out_h = 12
self.out_w = 12 self.out_w = 12
self.scale = 0.
self.out_size = np.array([11, 11]).astype("int32") self.out_size = np.array([11, 11]).astype("int32")
self.align_corners = True self.align_corners = True
...@@ -157,6 +171,7 @@ class TestNearestNeighborInterpCase6(TestNearestInterpOp): ...@@ -157,6 +171,7 @@ class TestNearestNeighborInterpCase6(TestNearestInterpOp):
self.input_shape = [1, 1, 128, 64] self.input_shape = [1, 1, 128, 64]
self.out_h = 64 self.out_h = 64
self.out_w = 128 self.out_w = 128
self.scale = 0.
self.out_size = np.array([65, 129]).astype("int32") self.out_size = np.array([65, 129]).astype("int32")
self.align_corners = True self.align_corners = True
...@@ -167,6 +182,7 @@ class TestNearestNeighborInterpActualShape(TestNearestInterpOp): ...@@ -167,6 +182,7 @@ class TestNearestNeighborInterpActualShape(TestNearestInterpOp):
self.input_shape = [3, 2, 32, 16] self.input_shape = [3, 2, 32, 16]
self.out_h = 64 self.out_h = 64
self.out_w = 32 self.out_w = 32
self.scale = 0.
self.out_size = np.array([66, 40]).astype("int32") self.out_size = np.array([66, 40]).astype("int32")
self.align_corners = True self.align_corners = True
...@@ -179,7 +195,15 @@ class TestNearestInterpOpUint8(OpTest): ...@@ -179,7 +195,15 @@ class TestNearestInterpOpUint8(OpTest):
self.op_type = "nearest_interp" self.op_type = "nearest_interp"
input_np = np.random.randint( input_np = np.random.randint(
low=0, high=256, size=self.input_shape).astype("uint8") low=0, high=256, size=self.input_shape).astype("uint8")
output_np = nearest_neighbor_interp_np(input_np, self.out_h, self.out_w,
if self.scale > 0:
out_h = int(self.input_shape[2] * self.scale)
out_w = int(self.input_shape[3] * self.scale)
else:
out_h = self.out_h
out_w = self.out_w
output_np = nearest_neighbor_interp_np(input_np, out_h, out_w,
self.out_size, self.actual_shape, self.out_size, self.actual_shape,
self.align_corners) self.align_corners)
self.inputs = {'X': input_np} self.inputs = {'X': input_np}
...@@ -188,6 +212,7 @@ class TestNearestInterpOpUint8(OpTest): ...@@ -188,6 +212,7 @@ class TestNearestInterpOpUint8(OpTest):
self.attrs = { self.attrs = {
'out_h': self.out_h, 'out_h': self.out_h,
'out_w': self.out_w, 'out_w': self.out_w,
'scale': self.scale,
'interp_method': self.interp_method, 'interp_method': self.interp_method,
'align_corners': self.align_corners 'align_corners': self.align_corners
} }
...@@ -201,6 +226,7 @@ class TestNearestInterpOpUint8(OpTest): ...@@ -201,6 +226,7 @@ class TestNearestInterpOpUint8(OpTest):
self.input_shape = [1, 3, 9, 6] self.input_shape = [1, 3, 9, 6]
self.out_h = 10 self.out_h = 10
self.out_w = 9 self.out_w = 9
self.scale = 0.
self.align_corners = True self.align_corners = True
...@@ -210,6 +236,7 @@ class TestNearestNeighborInterpCase1Uint8(TestNearestInterpOpUint8): ...@@ -210,6 +236,7 @@ class TestNearestNeighborInterpCase1Uint8(TestNearestInterpOpUint8):
self.input_shape = [2, 3, 128, 64] self.input_shape = [2, 3, 128, 64]
self.out_h = 120 self.out_h = 120
self.out_w = 50 self.out_w = 50
self.scale = 0.
self.align_corners = True self.align_corners = True
...@@ -219,6 +246,7 @@ class TestNearestNeighborInterpCase2Uint8(TestNearestInterpOpUint8): ...@@ -219,6 +246,7 @@ class TestNearestNeighborInterpCase2Uint8(TestNearestInterpOpUint8):
self.input_shape = [4, 1, 7, 8] self.input_shape = [4, 1, 7, 8]
self.out_h = 5 self.out_h = 5
self.out_w = 13 self.out_w = 13
self.scale = 0.
self.out_size = np.array([6, 15]).astype("int32") self.out_size = np.array([6, 15]).astype("int32")
self.align_corners = True self.align_corners = True
...@@ -228,5 +256,38 @@ class TestNearestInterpWithoutCorners(TestNearestInterpOp): ...@@ -228,5 +256,38 @@ class TestNearestInterpWithoutCorners(TestNearestInterpOp):
self.align_corners = False self.align_corners = False
class TestNearestNeighborInterpScale1(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.scale = 2.
self.out_size = np.array([66, 40]).astype("int32")
self.align_corners = True
class TestNearestNeighborInterpScale2(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.scale = 1.5
self.out_size = np.array([66, 40]).astype("int32")
self.align_corners = True
class TestNearestNeighborInterpScale3(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.scale = 1.
self.out_size = np.array([66, 40]).astype("int32")
self.align_corners = True
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册