Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
34bfae24
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
34bfae24
编写于
11月 02, 2018
作者:
D
dengkaipeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add Interpolate operation. test=develop
上级
df4a3544
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
874 addition
and
678 deletion
+874
-678
paddle/fluid/operators/bilinear_interp_op.cc
paddle/fluid/operators/bilinear_interp_op.cc
+0
-116
paddle/fluid/operators/bilinear_interp_op.cu
paddle/fluid/operators/bilinear_interp_op.cu
+0
-207
paddle/fluid/operators/bilinear_interp_op.h
paddle/fluid/operators/bilinear_interp_op.h
+0
-163
paddle/fluid/operators/interpolate_op.cc
paddle/fluid/operators/interpolate_op.cc
+47
-23
paddle/fluid/operators/interpolate_op.cu
paddle/fluid/operators/interpolate_op.cu
+286
-0
paddle/fluid/operators/interpolate_op.h
paddle/fluid/operators/interpolate_op.h
+236
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+10
-10
python/paddle/fluid/tests/unittests/test_interpolate_op.py
python/paddle/fluid/tests/unittests/test_interpolate_op.py
+294
-0
python/paddle/fluid/tests/unittests/test_layers.py
python/paddle/fluid/tests/unittests/test_layers.py
+1
-1
python/paddle/fluid/tests/unittests/test_nearest_neighbor_interp_op.py
.../fluid/tests/unittests/test_nearest_neighbor_interp_op.py
+0
-158
未找到文件。
paddle/fluid/operators/bilinear_interp_op.cc
已删除
100644 → 0
浏览文件 @
df4a3544
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/bilinear_interp_op.h"
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
framework
::
Tensor
;
class
BilinearInterpOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of BilinearInterOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of BilinearInterOp should not be null."
);
auto
dim_x
=
ctx
->
GetInputDim
(
"X"
);
// NCHW format
int
out_h
=
ctx
->
Attrs
().
Get
<
int
>
(
"out_h"
);
int
out_w
=
ctx
->
Attrs
().
Get
<
int
>
(
"out_w"
);
PADDLE_ENFORCE_EQ
(
dim_x
.
size
(),
4
,
"X's dimension must be 4"
);
if
(
ctx
->
HasInput
(
"OutSize"
))
{
auto
out_size_dim
=
ctx
->
GetInputDim
(
"OutSize"
);
PADDLE_ENFORCE_EQ
(
out_size_dim
.
size
(),
1
,
"OutSize's dimension size must be 1"
);
PADDLE_ENFORCE_EQ
(
out_size_dim
[
0
],
2
,
"OutSize's dim[0] must be 2"
);
}
std
::
vector
<
int64_t
>
dim_out
({
dim_x
[
0
],
dim_x
[
1
],
out_h
,
out_w
});
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
dim_out
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
());
}
};
class
BilinearInterpOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"The input tensor of bilinear interpolation, "
"This is a 4-D tensor with shape of (N x C x h x w)"
);
AddInput
(
"OutSize"
,
"This is a 1-D tensor with two number. "
"The first number is height and the second number is width."
)
.
AsDispensable
();
AddOutput
(
"Out"
,
"The dimension of output is (N x C x out_h x out_w)"
);
AddAttr
<
int
>
(
"out_h"
,
"output height of bilinear interpolation op."
);
AddAttr
<
int
>
(
"out_w"
,
"output width of bilinear interpolation op."
);
AddComment
(
R"DOC(
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
W-direction in this op) on a rectilinear 2D grid.
The key idea is to perform linear interpolation first in one
direction, and then again in the other direction.
For details, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation
)DOC"
);
}
};
class
BilinearInterpOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
auto
dim_x
=
ctx
->
GetInputDim
(
"X"
);
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
dim_x
);
}
}
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
());
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
bilinear_interp
,
ops
::
BilinearInterpOp
,
ops
::
BilinearInterpOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
bilinear_interp_grad
,
ops
::
BilinearInterpOpGrad
);
REGISTER_OP_CPU_KERNEL
(
bilinear_interp
,
ops
::
BilinearInterpKernel
<
float
>
,
ops
::
BilinearInterpKernel
<
uint8_t
>
);
REGISTER_OP_CPU_KERNEL
(
bilinear_interp_grad
,
ops
::
BilinearInterpGradKernel
<
float
>
);
paddle/fluid/operators/bilinear_interp_op.cu
已删除
100644 → 0
浏览文件 @
df4a3544
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/bilinear_interp_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace
paddle
{
namespace
operators
{
using
framework
::
Tensor
;
template
<
typename
T
>
__global__
void
KeBilinearInterpFw
(
const
T
*
in
,
const
size_t
in_img_h
,
const
size_t
in_img_w
,
const
size_t
input_h
,
const
size_t
input_w
,
T
*
out
,
const
size_t
out_img_h
,
const
size_t
out_img_w
,
const
size_t
output_h
,
const
size_t
output_w
,
const
size_t
num_channels
,
const
T
ratio_h
,
const
T
ratioW
)
{
int
nthreads
=
output_h
*
output_w
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
nthreads
)
{
int
out_id_h
=
tid
/
output_w
;
int
out_id_w
=
tid
%
output_w
;
int
in_img_size
=
input_w
/
num_channels
;
int
out_img_size
=
output_w
/
num_channels
;
int
channel_id
=
out_id_w
/
out_img_size
;
int
out_img_idy
=
(
out_id_w
%
out_img_size
)
/
out_img_w
;
int
in_img_idy
=
ratio_h
*
out_img_idy
;
int
h_id
=
(
in_img_idy
<
in_img_h
-
1
)
?
1
:
0
;
T
h1lambda
=
ratio_h
*
out_img_idy
-
in_img_idy
;
T
h2lambda
=
1.
f
-
h1lambda
;
int
out_img_idx
=
tid
%
out_img_w
;
int
in_img_idx
=
ratioW
*
out_img_idx
;
int
w_id
=
(
in_img_idx
<
in_img_w
-
1
)
?
1
:
0
;
T
w1lambda
=
ratioW
*
out_img_idx
-
in_img_idx
;
T
w2lambda
=
1.
f
-
w1lambda
;
const
T
*
in_pos
=
&
in
[
out_id_h
*
input_w
+
channel_id
*
in_img_size
+
in_img_idy
*
in_img_w
+
in_img_idx
];
// bilinear interpolation
out
[
out_id_h
*
output_w
+
out_id_w
]
=
h2lambda
*
(
w2lambda
*
in_pos
[
0
]
+
w1lambda
*
in_pos
[
w_id
])
+
h1lambda
*
(
w2lambda
*
in_pos
[
h_id
*
in_img_w
]
+
w1lambda
*
in_pos
[
h_id
*
in_img_w
+
w_id
]);
}
}
template
<
typename
T
>
__global__
void
KeBilinearInterpBw
(
T
*
in
,
const
size_t
in_img_h
,
const
size_t
in_img_w
,
const
size_t
input_h
,
const
size_t
input_w
,
const
T
*
out
,
const
size_t
out_img_h
,
const
size_t
out_img_w
,
const
size_t
output_h
,
const
size_t
output_w
,
const
size_t
num_channels
,
const
T
ratio_h
,
const
T
ratioW
)
{
int
nthreads
=
output_h
*
output_w
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
nthreads
)
{
int
out_id_h
=
tid
/
output_w
;
int
out_id_w
=
tid
%
output_w
;
int
in_img_size
=
input_w
/
num_channels
;
int
out_img_size
=
output_w
/
num_channels
;
int
channel_id
=
out_id_w
/
out_img_size
;
int
out_img_idy
=
(
out_id_w
%
out_img_size
)
/
out_img_w
;
int
in_img_idy
=
ratio_h
*
out_img_idy
;
int
h_id
=
(
in_img_idy
<
in_img_h
-
1
)
?
1
:
0
;
T
h1lambda
=
ratio_h
*
out_img_idy
-
in_img_idy
;
T
h2lambda
=
1.
f
-
h1lambda
;
int
out_img_idx
=
tid
%
out_img_w
;
int
in_img_idx
=
ratioW
*
out_img_idx
;
int
w_id
=
(
in_img_idx
<
in_img_w
-
1
)
?
1
:
0
;
T
w1lambda
=
ratioW
*
out_img_idx
-
in_img_idx
;
T
w2lambda
=
1.
f
-
w1lambda
;
T
*
in_pos
=
&
in
[
out_id_h
*
input_w
+
channel_id
*
in_img_size
+
in_img_idy
*
in_img_w
+
in_img_idx
];
const
T
*
out_pos
=
&
out
[
out_id_h
*
output_w
+
out_id_w
];
atomicAdd
(
&
in_pos
[
0
],
h2lambda
*
w2lambda
*
out_pos
[
0
]);
atomicAdd
(
&
in_pos
[
w_id
],
h2lambda
*
w1lambda
*
out_pos
[
0
]);
atomicAdd
(
&
in_pos
[
h_id
*
in_img_w
],
h1lambda
*
w2lambda
*
out_pos
[
0
]);
atomicAdd
(
&
in_pos
[
h_id
*
in_img_w
+
w_id
],
h1lambda
*
w1lambda
*
out_pos
[
0
]);
}
}
template
<
typename
T
>
class
BilinearInterpOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"This kernel only runs on GPU device."
);
auto
*
input_t
=
ctx
.
Input
<
Tensor
>
(
"X"
);
// float tensor
auto
*
output_t
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
// float tensor
auto
*
input
=
input_t
->
data
<
T
>
();
int
out_h
=
ctx
.
Attr
<
int
>
(
"out_h"
);
int
out_w
=
ctx
.
Attr
<
int
>
(
"out_w"
);
auto
out_dims
=
output_t
->
dims
();
auto
out_size_t
=
ctx
.
Input
<
Tensor
>
(
"OutSize"
);
if
(
out_size_t
!=
nullptr
)
{
Tensor
sizes
;
framework
::
TensorCopy
(
*
out_size_t
,
platform
::
CPUPlace
(),
&
sizes
);
auto
size_data
=
sizes
.
data
<
int
>
();
out_h
=
size_data
[
0
];
out_w
=
size_data
[
1
];
}
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
{
out_dims
[
0
],
out_dims
[
1
],
out_h
,
out_w
},
ctx
.
GetPlace
());
int
batch_size
=
input_t
->
dims
()[
0
];
int
channels
=
input_t
->
dims
()[
1
];
int
in_h
=
input_t
->
dims
()[
2
];
int
in_w
=
input_t
->
dims
()[
3
];
int
in_hw
=
in_h
*
in_w
;
int
out_hw
=
out_h
*
out_w
;
int
in_chw
=
channels
*
in_hw
;
int
out_chw
=
channels
*
out_hw
;
T
ratio_h
=
(
out_h
>
1
)
?
static_cast
<
T
>
(
in_h
-
1
)
/
(
out_h
-
1
)
:
0.
f
;
T
ratio_w
=
(
out_w
>
1
)
?
static_cast
<
T
>
(
in_w
-
1
)
/
(
out_w
-
1
)
:
0.
f
;
if
(
in_h
==
out_h
&&
in_w
==
out_w
)
{
memcpy
(
output
,
input
,
input_t
->
numel
()
*
sizeof
(
T
));
}
else
{
int
threadNum
=
batch_size
*
out_chw
;
int
blocks
=
(
threadNum
+
1024
-
1
)
/
1024
;
KeBilinearInterpFw
<
T
><<<
blocks
,
1024
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input
,
in_h
,
in_w
,
batch_size
,
in_chw
,
output
,
out_h
,
out_w
,
batch_size
,
out_chw
,
channels
,
ratio_h
,
ratio_w
);
}
}
};
template
<
typename
T
>
class
BilinearInterpGradOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
d_input_t
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_output_t
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_output
=
d_output_t
->
data
<
T
>
();
auto
*
d_input
=
d_input_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
device_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
zero
;
zero
(
device_ctx
,
d_input_t
,
static_cast
<
T
>
(
0.0
));
int
out_h
=
ctx
.
Attr
<
int
>
(
"out_h"
);
int
out_w
=
ctx
.
Attr
<
int
>
(
"out_w"
);
auto
out_size_t
=
ctx
.
Input
<
Tensor
>
(
"OutSize"
);
if
(
out_size_t
!=
nullptr
)
{
Tensor
sizes
;
framework
::
TensorCopy
(
*
out_size_t
,
platform
::
CPUPlace
(),
&
sizes
);
auto
size_data
=
sizes
.
data
<
int
>
();
out_h
=
size_data
[
0
];
out_w
=
size_data
[
1
];
}
int
batch_size
=
d_input_t
->
dims
()[
0
];
int
channels
=
d_input_t
->
dims
()[
1
];
int
in_h
=
d_input_t
->
dims
()[
2
];
int
in_w
=
d_input_t
->
dims
()[
3
];
int
in_hw
=
in_h
*
in_w
;
int
out_hw
=
out_h
*
out_w
;
int
in_chw
=
channels
*
in_hw
;
int
out_chw
=
channels
*
out_hw
;
T
ratio_h
=
(
out_h
>
1
)
?
static_cast
<
T
>
(
in_h
-
1
)
/
(
out_h
-
1
)
:
0.
f
;
T
ratio_w
=
(
out_w
>
1
)
?
static_cast
<
T
>
(
in_w
-
1
)
/
(
out_w
-
1
)
:
0.
f
;
if
(
in_h
==
out_h
&&
in_w
==
out_w
)
{
memcpy
(
d_input
,
d_output
,
d_input_t
->
numel
()
*
sizeof
(
T
));
}
else
{
int
threadNum
=
batch_size
*
out_chw
;
int
blocks
=
(
threadNum
+
1024
-
1
)
/
1024
;
KeBilinearInterpBw
<
T
><<<
blocks
,
1024
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
d_input
,
in_h
,
in_w
,
batch_size
,
in_chw
,
d_output
,
out_h
,
out_w
,
batch_size
,
out_chw
,
channels
,
ratio_h
,
ratio_w
);
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
bilinear_interp
,
ops
::
BilinearInterpOpCUDAKernel
<
float
>
);
REGISTER_OP_CUDA_KERNEL
(
bilinear_interp_grad
,
ops
::
BilinearInterpGradOpCUDAKernel
<
float
>
);
paddle/fluid/operators/bilinear_interp_op.h
已删除
100644 → 0
浏览文件 @
df4a3544
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
class
BilinearInterpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input_t
=
ctx
.
Input
<
Tensor
>
(
"X"
);
// float tensor
auto
*
output_t
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
// float tensor
auto
out_dims
=
output_t
->
dims
();
auto
*
input
=
input_t
->
data
<
T
>
();
int
out_h
=
ctx
.
Attr
<
int
>
(
"out_h"
);
int
out_w
=
ctx
.
Attr
<
int
>
(
"out_w"
);
auto
out_size_t
=
ctx
.
Input
<
Tensor
>
(
"OutSize"
);
if
(
out_size_t
!=
nullptr
)
{
auto
out_size_data
=
out_size_t
->
data
<
int
>
();
out_h
=
out_size_data
[
0
];
out_w
=
out_size_data
[
1
];
}
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
{
out_dims
[
0
],
out_dims
[
1
],
out_h
,
out_w
},
ctx
.
GetPlace
());
int
batch_size
=
input_t
->
dims
()[
0
];
int
channels
=
input_t
->
dims
()[
1
];
int
in_h
=
input_t
->
dims
()[
2
];
int
in_w
=
input_t
->
dims
()[
3
];
int
in_hw
=
in_h
*
in_w
;
int
out_hw
=
out_h
*
out_w
;
int
in_chw
=
channels
*
in_hw
;
int
out_chw
=
channels
*
out_hw
;
float
ratio_h
=
(
out_h
>
1
)
?
static_cast
<
float
>
(
in_h
-
1
)
/
(
out_h
-
1
)
:
0.
f
;
float
ratio_w
=
(
out_w
>
1
)
?
static_cast
<
float
>
(
in_w
-
1
)
/
(
out_w
-
1
)
:
0.
f
;
if
(
in_h
==
out_h
&&
in_w
==
out_w
)
{
memcpy
(
output
,
input
,
input_t
->
numel
()
*
sizeof
(
T
));
}
else
{
for
(
int
k
=
0
;
k
<
batch_size
;
++
k
)
{
// loop for batches
for
(
int
i
=
0
;
i
<
out_h
;
++
i
)
{
// loop for images
int
h
=
ratio_h
*
i
;
int
hid
=
(
h
<
in_h
-
1
)
?
1
:
0
;
float
h1lambda
=
ratio_h
*
i
-
h
;
float
h2lambda
=
1.
f
-
h1lambda
;
for
(
int
j
=
0
;
j
<
out_w
;
++
j
)
{
int
w
=
ratio_w
*
j
;
int
wid
=
(
w
<
in_w
-
1
)
?
1
:
0
;
float
w1lambda
=
ratio_w
*
j
-
w
;
float
w2lambda
=
1.
f
-
w1lambda
;
// calculate four position for bilinear interpolation
const
T
*
in_pos
=
&
input
[
k
*
in_chw
+
h
*
in_w
+
w
];
T
*
out_pos
=
&
output
[
k
*
out_chw
+
i
*
out_w
+
j
];
for
(
int
c
=
0
;
c
<
channels
;
++
c
)
{
// loop for channels
// bilinear interpolation
out_pos
[
0
]
=
static_cast
<
T
>
(
h2lambda
*
(
w2lambda
*
in_pos
[
0
]
+
w1lambda
*
in_pos
[
wid
])
+
h1lambda
*
(
w2lambda
*
in_pos
[
hid
*
in_w
]
+
w1lambda
*
in_pos
[
hid
*
in_w
+
wid
]));
in_pos
+=
in_hw
;
out_pos
+=
out_hw
;
}
}
}
}
}
}
};
template
<
typename
T
>
class
BilinearInterpGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
d_input_t
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_output_t
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_output
=
d_output_t
->
data
<
T
>
();
auto
*
d_input
=
d_input_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
device_ctx
=
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>();
math
::
SetConstant
<
platform
::
CPUDeviceContext
,
T
>
zero
;
zero
(
device_ctx
,
d_input_t
,
static_cast
<
T
>
(
0.0
));
int
out_h
=
ctx
.
Attr
<
int
>
(
"out_h"
);
int
out_w
=
ctx
.
Attr
<
int
>
(
"out_w"
);
auto
out_size_t
=
ctx
.
Input
<
Tensor
>
(
"OutSize"
);
if
(
out_size_t
!=
nullptr
)
{
auto
out_size_data
=
out_size_t
->
data
<
int
>
();
out_h
=
out_size_data
[
0
];
out_w
=
out_size_data
[
1
];
}
int
batch_size
=
d_input_t
->
dims
()[
0
];
int
channels
=
d_input_t
->
dims
()[
1
];
int
in_h
=
d_input_t
->
dims
()[
2
];
int
in_w
=
d_input_t
->
dims
()[
3
];
int
in_hw
=
in_h
*
in_w
;
int
out_hw
=
out_h
*
out_w
;
int
in_chw
=
channels
*
in_hw
;
int
out_chw
=
channels
*
out_hw
;
float
ratio_h
=
(
out_h
>
1
)
?
static_cast
<
float
>
(
in_h
-
1
)
/
(
out_h
-
1
)
:
0.
f
;
float
ratio_w
=
(
out_w
>
1
)
?
static_cast
<
float
>
(
in_w
-
1
)
/
(
out_w
-
1
)
:
0.
f
;
if
(
in_h
==
out_h
&&
in_w
==
out_w
)
{
memcpy
(
d_input
,
d_output
,
d_input_t
->
numel
()
*
sizeof
(
T
));
}
else
{
for
(
int
k
=
0
;
k
<
batch_size
;
++
k
)
{
// loop for batches
for
(
int
i
=
0
;
i
<
out_h
;
++
i
)
{
// loop for images
int
h
=
ratio_h
*
i
;
int
hid
=
(
h
<
in_h
-
1
)
?
1
:
0
;
float
h1lambda
=
ratio_h
*
i
-
h
;
float
h2lambda
=
1
-
h1lambda
;
for
(
int
j
=
0
;
j
<
out_w
;
++
j
)
{
int
w
=
ratio_w
*
j
;
int
wid
=
(
w
<
in_w
-
1
)
?
1
:
0
;
float
w1lambda
=
ratio_w
*
j
-
w
;
float
w2lambda
=
1
-
w1lambda
;
T
*
in_pos
=
&
d_input
[
k
*
in_chw
+
h
*
in_w
+
w
];
const
T
*
out_pos
=
&
d_output
[
k
*
out_chw
+
i
*
out_w
+
j
];
for
(
int
c
=
0
;
c
<
channels
;
++
c
)
{
// loop for channels
in_pos
[
0
]
+=
static_cast
<
T
>
(
h2lambda
*
w2lambda
*
out_pos
[
0
]);
in_pos
[
wid
]
+=
static_cast
<
T
>
(
h2lambda
*
w1lambda
*
out_pos
[
0
]);
in_pos
[
hid
*
in_w
]
+=
static_cast
<
T
>
(
h1lambda
*
w2lambda
*
out_pos
[
0
]);
in_pos
[
hid
*
in_w
+
wid
]
+=
static_cast
<
T
>
(
h1lambda
*
w1lambda
*
out_pos
[
0
]);
in_pos
+=
in_hw
;
out_pos
+=
out_hw
;
}
}
}
}
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/
nearest_neighbor_interp
_op.cc
→
paddle/fluid/operators/
interpolate
_op.cc
浏览文件 @
34bfae24
...
...
@@ -9,7 +9,8 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/nearest_neighbor_interp_op.h"
#include "paddle/fluid/operators/interpolate_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
...
...
@@ -18,16 +19,21 @@ namespace operators {
using
framework
::
Tensor
;
class
NearestNeighborInterp
Op
:
public
framework
::
OperatorWithKernel
{
class
Interpolate
Op
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of
NearestNeighborInter
Op should not be null."
);
"Input(X) of
Interpolate
Op should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of NearestNeighborInterOp should not be null."
);
"Output(Out) of InterpolationOp should not be null."
);
auto
interp_method
=
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"interp_method"
);
PADDLE_ENFORCE
(
"bilinear"
==
interp_method
||
"nearest"
==
interp_method
,
"Interpolation method can only be
\"
bilinear
\"
or
\"
nearest
\"
."
);
auto
dim_x
=
ctx
->
GetInputDim
(
"X"
);
// NCHW format
int
out_h
=
ctx
->
Attrs
().
Get
<
int
>
(
"out_h"
);
...
...
@@ -52,33 +58,53 @@ class NearestNeighborInterpOp : public framework::OperatorWithKernel {
}
};
class
NearestNeighborInterp
OpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
Interpolate
OpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"The input tensor of
nearest neighbor interpolation
, "
"This is a 4-D tensor with shape of
(N x C x h x w)
"
);
"The input tensor of
interpolate operator
, "
"This is a 4-D tensor with shape of
[N, C, H, w].
"
);
AddInput
(
"OutSize"
,
"This is a 1-D tensor with two number. "
"This is a 1-D tensor with two number
s to specify output size
. "
"The first number is height and the second number is width."
)
.
AsDispensable
();
AddOutput
(
"Out"
,
"The dimension of output is (N x C x out_h x out_w)"
);
AddAttr
<
int
>
(
"out_h"
,
"output height of nearest neighbor interpolation op."
);
AddAttr
<
int
>
(
"out_w"
,
"output width of nearest neighbor interpolation op."
);
AddOutput
(
"Out"
,
"The output tensor of interpolate operator, "
"This is a 4-D tensor with shape of [N, C, H, W]."
);
AddAttr
<
int
>
(
"out_h"
,
"output height of interpolate op."
);
AddAttr
<
int
>
(
"out_w"
,
"output width of interpolate op."
);
AddAttr
<
std
::
string
>
(
"interp_method"
,
"(string), interpolation method, can be
\"
bilinear
\"
for "
"bilinear interpolation and
\"
nearest
\"
for nearest "
"neighbor interpolation."
);
AddComment
(
R"DOC(
This operator samples input X to given output shape by using specified
interpolation method, the interpolation methods can be \"nearest\"
for nearest neighbor interpolation and \"bilinear\" for bilinear
interpolation.
Nearest neighbor interpolation is to perform nearest neighbor interpolation
in bot the 3rd dimention(in height direction) and the 4th dimention(in width
direction) on input tensor.
For details, please refer to Wikipedia:
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
W-direction in this op) on a rectilinear 2D grid. The key idea is
to perform linear interpolation first in one direction, and then
again in the other direction.
For details of nearest neighbor interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
For details of bilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation
)DOC"
);
}
};
class
NearestNeighborInterp
OpGrad
:
public
framework
::
OperatorWithKernel
{
class
Interpolate
OpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
...
...
@@ -104,13 +130,11 @@ class NearestNeighborInterpOpGrad : public framework::OperatorWithKernel {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
nearest_neighbor_interp
,
ops
::
NearestNeighborInterpOp
,
ops
::
NearestNeighborInterpOpMaker
,
REGISTER_OPERATOR
(
interpolate
,
ops
::
InterpolateOp
,
ops
::
InterpolateOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
nearest_neighbor_interp_grad
,
ops
::
NearestNeighborInterpOpGrad
);
REGISTER_OP_CPU_KERNEL
(
nearest_neighbor_interp
,
ops
::
NearestNeighborInterpKernel
<
float
>
,
ops
::
NearestNeighborInterpKernel
<
uint8_t
>
);
REGISTER_OP_CPU_KERNEL
(
nearest_neighbor_interp_grad
,
ops
::
NearestNeighborInterpGradKernel
<
float
>
);
REGISTER_OPERATOR
(
interpolate_grad
,
ops
::
InterpolateOpGrad
);
REGISTER_OP_CPU_KERNEL
(
interpolate
,
ops
::
InterpolateKernel
<
float
>
,
ops
::
InterpolateKernel
<
double
>
,
ops
::
InterpolateKernel
<
uint8_t
>
);
REGISTER_OP_CPU_KERNEL
(
interpolate_grad
,
ops
::
InterpolateGradKernel
<
float
>
,
ops
::
InterpolateGradKernel
<
double
>
);
paddle/fluid/operators/
nearest_neighbor_interp
_op.cu
→
paddle/fluid/operators/
interpolate
_op.cu
浏览文件 @
34bfae24
...
...
@@ -9,7 +9,8 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/nearest_neighbor_interp_op.h"
#include <string>
#include "paddle/fluid/operators/interpolate_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace
paddle
{
...
...
@@ -22,7 +23,7 @@ __global__ void KeNearestNeighborInterpFw(
const
T
*
in
,
const
size_t
in_img_h
,
const
size_t
in_img_w
,
const
size_t
input_h
,
const
size_t
input_w
,
T
*
out
,
const
size_t
out_img_h
,
const
size_t
out_img_w
,
const
size_t
output_h
,
const
size_t
output_w
,
const
size_t
num_channels
,
const
T
ratio_h
,
const
T
ratio_w
)
{
const
size_t
num_channels
,
const
float
ratio_h
,
const
float
ratio_w
)
{
int
nthreads
=
output_h
*
output_w
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
nthreads
)
{
...
...
@@ -33,10 +34,10 @@ __global__ void KeNearestNeighborInterpFw(
int
channel_id
=
out_id_w
/
out_img_size
;
int
out_img_idy
=
(
out_id_w
%
out_img_size
)
/
out_img_w
;
int
in_img_idy
=
static_cast
<
int
>
(
r
ound
(
ratio_h
*
out_img_idy
)
);
int
in_img_idy
=
static_cast
<
int
>
(
r
atio_h
*
out_img_idy
+
0.5
);
int
out_img_idx
=
tid
%
out_img_w
;
int
in_img_idx
=
static_cast
<
int
>
(
r
ound
(
ratio_w
*
out_img_idx
)
);
int
in_img_idx
=
static_cast
<
int
>
(
r
atio_w
*
out_img_idx
+
0.5
);
out
[
tid
]
=
in
[
out_id_h
*
input_w
+
channel_id
*
in_img_size
+
in_img_idy
*
in_img_w
+
in_img_idx
];
...
...
@@ -48,7 +49,7 @@ __global__ void KeNearestNeighborInterpBw(
T
*
in
,
const
size_t
in_img_h
,
const
size_t
in_img_w
,
const
size_t
input_h
,
const
size_t
input_w
,
const
T
*
out
,
const
size_t
out_img_h
,
const
size_t
out_img_w
,
const
size_t
output_h
,
const
size_t
output_w
,
const
size_t
num_channels
,
const
T
ratio_h
,
const
T
ratio_w
)
{
const
size_t
num_channels
,
const
float
ratio_h
,
const
float
ratio_w
)
{
int
nthreads
=
output_h
*
output_w
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
nthreads
)
{
...
...
@@ -59,28 +60,106 @@ __global__ void KeNearestNeighborInterpBw(
int
channel_id
=
out_id_w
/
out_img_size
;
int
out_img_idy
=
(
out_id_w
%
out_img_size
)
/
out_img_w
;
int
in_img_idy
=
static_cast
<
int
>
(
r
ound
(
ratio_h
*
out_img_idy
)
);
int
in_img_idy
=
static_cast
<
int
>
(
r
atio_h
*
out_img_idy
+
0.5
);
int
out_img_idx
=
tid
%
out_img_w
;
int
in_img_idx
=
static_cast
<
int
>
(
r
ound
(
ratio_w
*
out_img_idx
)
);
int
in_img_idx
=
static_cast
<
int
>
(
r
atio_w
*
out_img_idx
+
0.5
);
T
*
in_pos
=
&
in
[
out_id_h
*
input_w
+
channel_id
*
in_img_size
+
in_img_idy
*
in_img_w
+
in_img_idx
];
const
T
out_pos
=
out
[
out_id_h
*
output_w
+
out_id_w
];
atomicAdd
(
in_pos
,
out_pos
);
platform
::
CudaAtomicAdd
(
in_pos
,
out_pos
);
}
}
template
<
typename
T
>
__global__
void
KeBilinearInterpFw
(
const
T
*
in
,
const
size_t
in_img_h
,
const
size_t
in_img_w
,
const
size_t
input_h
,
const
size_t
input_w
,
T
*
out
,
const
size_t
out_img_h
,
const
size_t
out_img_w
,
const
size_t
output_h
,
const
size_t
output_w
,
const
size_t
num_channels
,
const
float
ratio_h
,
const
float
ratio_w
)
{
int
nthreads
=
output_h
*
output_w
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
nthreads
)
{
int
out_id_h
=
tid
/
output_w
;
int
out_id_w
=
tid
%
output_w
;
int
in_img_size
=
input_w
/
num_channels
;
int
out_img_size
=
output_w
/
num_channels
;
int
channel_id
=
out_id_w
/
out_img_size
;
int
out_img_idy
=
(
out_id_w
%
out_img_size
)
/
out_img_w
;
int
in_img_idy
=
ratio_h
*
out_img_idy
;
int
h_id
=
(
in_img_idy
<
in_img_h
-
1
)
?
1
:
0
;
T
h1lambda
=
ratio_h
*
out_img_idy
-
in_img_idy
;
T
h2lambda
=
1.
f
-
h1lambda
;
int
out_img_idx
=
tid
%
out_img_w
;
int
in_img_idx
=
ratio_w
*
out_img_idx
;
int
w_id
=
(
in_img_idx
<
in_img_w
-
1
)
?
1
:
0
;
T
w1lambda
=
ratio_w
*
out_img_idx
-
in_img_idx
;
T
w2lambda
=
1.
f
-
w1lambda
;
const
T
*
in_pos
=
&
in
[
out_id_h
*
input_w
+
channel_id
*
in_img_size
+
in_img_idy
*
in_img_w
+
in_img_idx
];
// bilinear interpolation
out
[
out_id_h
*
output_w
+
out_id_w
]
=
h2lambda
*
(
w2lambda
*
in_pos
[
0
]
+
w1lambda
*
in_pos
[
w_id
])
+
h1lambda
*
(
w2lambda
*
in_pos
[
h_id
*
in_img_w
]
+
w1lambda
*
in_pos
[
h_id
*
in_img_w
+
w_id
]);
}
}
template
<
typename
T
>
__global__
void
KeBilinearInterpBw
(
T
*
in
,
const
size_t
in_img_h
,
const
size_t
in_img_w
,
const
size_t
input_h
,
const
size_t
input_w
,
const
T
*
out
,
const
size_t
out_img_h
,
const
size_t
out_img_w
,
const
size_t
output_h
,
const
size_t
output_w
,
const
size_t
num_channels
,
const
T
ratio_h
,
const
T
ratio_w
)
{
int
nthreads
=
output_h
*
output_w
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
nthreads
)
{
int
out_id_h
=
tid
/
output_w
;
int
out_id_w
=
tid
%
output_w
;
int
in_img_size
=
input_w
/
num_channels
;
int
out_img_size
=
output_w
/
num_channels
;
int
channel_id
=
out_id_w
/
out_img_size
;
int
out_img_idy
=
(
out_id_w
%
out_img_size
)
/
out_img_w
;
int
in_img_idy
=
ratio_h
*
out_img_idy
;
int
h_id
=
(
in_img_idy
<
in_img_h
-
1
)
?
1
:
0
;
T
h1lambda
=
ratio_h
*
out_img_idy
-
in_img_idy
;
T
h2lambda
=
1.
f
-
h1lambda
;
int
out_img_idx
=
tid
%
out_img_w
;
int
in_img_idx
=
ratio_w
*
out_img_idx
;
int
w_id
=
(
in_img_idx
<
in_img_w
-
1
)
?
1
:
0
;
T
w1lambda
=
ratio_w
*
out_img_idx
-
in_img_idx
;
T
w2lambda
=
1.
f
-
w1lambda
;
T
*
in_pos
=
&
in
[
out_id_h
*
input_w
+
channel_id
*
in_img_size
+
in_img_idy
*
in_img_w
+
in_img_idx
];
const
T
*
out_pos
=
&
out
[
out_id_h
*
output_w
+
out_id_w
];
platform
::
CudaAtomicAdd
(
&
in_pos
[
0
],
h2lambda
*
w2lambda
*
out_pos
[
0
]);
platform
::
CudaAtomicAdd
(
&
in_pos
[
w_id
],
h2lambda
*
w1lambda
*
out_pos
[
0
]);
platform
::
CudaAtomicAdd
(
&
in_pos
[
h_id
*
in_img_w
],
h1lambda
*
w2lambda
*
out_pos
[
0
]);
platform
::
CudaAtomicAdd
(
&
in_pos
[
h_id
*
in_img_w
+
w_id
],
h1lambda
*
w1lambda
*
out_pos
[
0
]);
}
}
template
<
typename
T
>
class
NearestNeighborInterp
OpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
class
Interpolate
OpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"This kernel only runs on GPU device."
);
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
// float tensor
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
// float tensor
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
input_data
=
input
->
data
<
T
>
();
auto
interp_method
=
ctx
.
Attr
<
std
::
string
>
(
"interp_method"
);
int
out_h
=
ctx
.
Attr
<
int
>
(
"out_h"
);
int
out_w
=
ctx
.
Attr
<
int
>
(
"out_w"
);
auto
out_size
=
ctx
.
Input
<
Tensor
>
(
"OutSize"
);
...
...
@@ -105,26 +184,35 @@ class NearestNeighborInterpOpCUDAKernel : public framework::OpKernel<T> {
int
in_chw
=
c
*
in_hw
;
int
out_chw
=
c
*
out_hw
;
T
ratio_h
=
(
out_h
>
1
)
?
static_cast
<
T
>
(
in_h
-
1
)
/
(
out_h
-
1
)
:
0.
f
;
T
ratio_w
=
(
out_w
>
1
)
?
static_cast
<
T
>
(
in_w
-
1
)
/
(
out_w
-
1
)
:
0.
f
;
float
ratio_h
=
(
out_h
>
1
)
?
static_cast
<
float
>
(
in_h
-
1
)
/
(
out_h
-
1
)
:
0.
f
;
float
ratio_w
=
(
out_w
>
1
)
?
static_cast
<
float
>
(
in_w
-
1
)
/
(
out_w
-
1
)
:
0.
f
;
if
(
in_h
==
out_h
&&
in_w
==
out_w
)
{
memcpy
(
output_data
,
input_data
,
input
->
numel
()
*
sizeof
(
T
)
);
framework
::
TensorCopy
(
*
input
,
ctx
.
GetPlace
(),
output
);
return
;
}
int
threadNum
=
n
*
out_chw
;
int
blocks
=
(
threadNum
+
1024
-
1
)
/
1024
;
if
(
"nearest"
==
interp_method
)
{
KeNearestNeighborInterpFw
<
T
><<<
blocks
,
1024
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_data
,
in_h
,
in_w
,
n
,
in_chw
,
output_data
,
out_h
,
out_w
,
n
,
out_chw
,
c
,
ratio_h
,
ratio_w
);
}
else
if
(
"bilinear"
==
interp_method
)
{
KeBilinearInterpFw
<
T
><<<
blocks
,
1024
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_data
,
in_h
,
in_w
,
n
,
in_chw
,
output_data
,
out_h
,
out_w
,
n
,
out_chw
,
c
,
ratio_h
,
ratio_w
);
}
}
};
template
<
typename
T
>
class
NearestNeighborInterp
GradOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
class
Interpolate
GradOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
...
...
@@ -137,9 +225,9 @@ class NearestNeighborInterpGradOpCUDAKernel : public framework::OpKernel<T> {
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
zero
;
zero
(
device_ctx
,
input_grad
,
static_cast
<
T
>
(
0.0
));
auto
interp_method
=
ctx
.
Attr
<
std
::
string
>
(
"interp_method"
);
int
out_h
=
ctx
.
Attr
<
int
>
(
"out_h"
);
int
out_w
=
ctx
.
Attr
<
int
>
(
"out_w"
);
auto
out_size
=
ctx
.
Input
<
Tensor
>
(
"OutSize"
);
if
(
out_size
!=
nullptr
)
{
Tensor
sizes
;
...
...
@@ -159,21 +247,30 @@ class NearestNeighborInterpGradOpCUDAKernel : public framework::OpKernel<T> {
int
in_chw
=
c
*
in_hw
;
int
out_chw
=
c
*
out_hw
;
T
ratio_h
=
(
out_h
>
1
)
?
static_cast
<
T
>
(
in_h
-
1
)
/
(
out_h
-
1
)
:
0.
f
;
T
ratio_w
=
(
out_w
>
1
)
?
static_cast
<
T
>
(
in_w
-
1
)
/
(
out_w
-
1
)
:
0.
f
;
float
ratio_h
=
(
out_h
>
1
)
?
static_cast
<
float
>
(
in_h
-
1
)
/
(
out_h
-
1
)
:
0.
f
;
float
ratio_w
=
(
out_w
>
1
)
?
static_cast
<
float
>
(
in_w
-
1
)
/
(
out_w
-
1
)
:
0.
f
;
if
(
in_h
==
out_h
&&
in_w
==
out_w
)
{
memcpy
(
input_grad
,
output_grad
,
input_grad
->
numel
()
*
sizeof
(
T
)
);
framework
::
TensorCopy
(
*
output_grad
,
ctx
.
GetPlace
(),
input_grad
);
return
;
}
int
threadNum
=
n
*
out_chw
;
int
blocks
=
(
threadNum
+
1024
-
1
)
/
1024
;
if
(
"nearest"
==
interp_method
)
{
KeNearestNeighborInterpBw
<
T
><<<
blocks
,
1024
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_grad_data
,
in_h
,
in_w
,
n
,
in_chw
,
output_grad_data
,
out_h
,
out_w
,
n
,
out_chw
,
c
,
ratio_h
,
ratio_w
);
input_grad_data
,
in_h
,
in_w
,
n
,
in_chw
,
output_grad_data
,
out_h
,
out_w
,
n
,
out_chw
,
c
,
ratio_h
,
ratio_w
);
}
else
if
(
"bilinear"
==
interp_method
)
{
KeBilinearInterpBw
<
T
><<<
blocks
,
1024
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
input_grad_data
,
in_h
,
in_w
,
n
,
in_chw
,
output_grad_data
,
out_h
,
out_w
,
n
,
out_chw
,
c
,
ratio_h
,
ratio_w
);
}
}
};
...
...
@@ -181,7 +278,9 @@ class NearestNeighborInterpGradOpCUDAKernel : public framework::OpKernel<T> {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
nearest_neighbor_interp
,
ops
::
NearestNeighborInterpOpCUDAKernel
<
float
>
);
REGISTER_OP_CUDA_KERNEL
(
nearest_neighbor_interp_grad
,
ops
::
NearestNeighborInterpGradOpCUDAKernel
<
float
>
);
REGISTER_OP_CUDA_KERNEL
(
interpolate
,
ops
::
InterpolateOpCUDAKernel
<
float
>
,
ops
::
InterpolateOpCUDAKernel
<
double
>
,
ops
::
InterpolateOpCUDAKernel
<
int
>
);
REGISTER_OP_CUDA_KERNEL
(
interpolate_grad
,
ops
::
InterpolateGradOpCUDAKernel
<
float
>
,
ops
::
InterpolateGradOpCUDAKernel
<
double
>
);
paddle/fluid/operators/
nearest_neighbor_interp
_op.h
→
paddle/fluid/operators/
interpolate
_op.h
浏览文件 @
34bfae24
...
...
@@ -10,6 +10,7 @@
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
...
...
@@ -22,12 +23,126 @@ using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
class
NearestNeighborInterpKernel
:
public
framework
::
OpKernel
<
T
>
{
static
void
NearestNeighborInterpolate
(
const
Tensor
&
input
,
Tensor
*
output
,
const
float
ratio_h
,
const
float
ratio_w
,
const
int
n
,
const
int
c
,
const
int
out_h
,
const
int
out_w
)
{
auto
input_t
=
EigenTensor
<
T
,
4
>::
From
(
input
);
auto
output_t
=
EigenTensor
<
T
,
4
>::
From
(
*
output
);
for
(
int
k
=
0
;
k
<
out_h
;
k
++
)
{
// loop for images
int
in_k
=
static_cast
<
int
>
(
ratio_h
*
k
+
0.5
);
for
(
int
l
=
0
;
l
<
out_w
;
l
++
)
{
int
in_l
=
static_cast
<
int
>
(
ratio_w
*
l
+
0.5
);
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
// loop for batches
for
(
int
j
=
0
;
j
<
c
;
j
++
)
{
// loop for channels
output_t
(
i
,
j
,
k
,
l
)
=
input_t
(
i
,
j
,
in_k
,
in_l
);
}
}
}
}
}
template
<
typename
T
>
static
void
BilinearInterpolation
(
const
Tensor
&
input
,
Tensor
*
output
,
const
float
ratio_h
,
const
float
ratio_w
,
const
int
in_h
,
const
int
in_w
,
const
int
n
,
const
int
c
,
const
int
out_h
,
const
int
out_w
)
{
auto
input_t
=
EigenTensor
<
T
,
4
>::
From
(
input
);
auto
output_t
=
EigenTensor
<
T
,
4
>::
From
(
*
output
);
for
(
int
k
=
0
;
k
<
out_h
;
k
++
)
{
// loop for images
int
y_n
=
static_cast
<
int
>
(
ratio_h
*
k
);
int
y_s
=
(
y_n
+
1
)
<
(
in_h
-
1
)
?
(
y_n
+
1
)
:
(
in_h
-
1
);
float
d_n
=
ratio_h
*
k
-
y_n
;
float
d_s
=
1.
f
-
d_n
;
for
(
int
l
=
0
;
l
<
out_w
;
l
++
)
{
int
x_w
=
static_cast
<
int
>
(
ratio_w
*
l
);
int
x_e
=
(
x_w
+
1
)
<
(
in_w
-
1
)
?
(
x_w
+
1
)
:
(
in_w
-
1
);
float
d_w
=
ratio_w
*
l
-
x_w
;
float
d_e
=
1.
f
-
d_w
;
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
// loop for batches
for
(
int
j
=
0
;
j
<
c
;
j
++
)
{
// loop for channels
// bilinear interpolation
output_t
(
i
,
j
,
k
,
l
)
=
input_t
(
i
,
j
,
y_n
,
x_w
)
*
d_s
*
d_e
+
input_t
(
i
,
j
,
y_s
,
x_w
)
*
d_n
*
d_e
+
input_t
(
i
,
j
,
y_n
,
x_e
)
*
d_s
*
d_w
+
input_t
(
i
,
j
,
y_s
,
x_e
)
*
d_n
*
d_w
;
}
}
}
}
}
template
<
typename
T
>
static
void
NearestNeighborInterpolateGrad
(
const
Tensor
&
output_grad
,
Tensor
*
input_grad
,
const
float
ratio_h
,
const
float
ratio_w
,
const
int
n
,
const
int
c
,
const
int
out_h
,
const
int
out_w
)
{
auto
input_grad_t
=
EigenTensor
<
T
,
4
>::
From
(
*
input_grad
);
auto
output_grad_t
=
EigenTensor
<
T
,
4
>::
From
(
output_grad
);
for
(
int
k
=
0
;
k
<
out_h
;
k
++
)
{
// loop for images
int
in_k
=
static_cast
<
int
>
(
ratio_h
*
k
+
0.5
);
for
(
int
l
=
0
;
l
<
out_w
;
l
++
)
{
int
in_l
=
static_cast
<
int
>
(
ratio_w
*
l
+
0.5
);
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
// loop for batches
for
(
int
j
=
0
;
j
<
c
;
j
++
)
{
// loop for channels
input_grad_t
(
i
,
j
,
in_k
,
in_l
)
+=
output_grad_t
(
i
,
j
,
k
,
l
);
}
}
}
}
}
template
<
typename
T
>
static
void
BilinearInterpolationGrad
(
const
Tensor
&
output_grad
,
Tensor
*
input_grad
,
const
float
ratio_h
,
const
float
ratio_w
,
const
int
in_h
,
const
int
in_w
,
const
int
n
,
const
int
c
,
const
int
out_h
,
const
int
out_w
)
{
auto
input_grad_t
=
EigenTensor
<
T
,
4
>::
From
(
*
input_grad
);
auto
output_grad_t
=
EigenTensor
<
T
,
4
>::
From
(
output_grad
);
for
(
int
k
=
0
;
k
<
out_h
;
k
++
)
{
// loop for images
int
y_n
=
static_cast
<
int
>
(
ratio_h
*
k
);
int
y_s
=
(
y_n
+
1
)
<
(
in_h
-
1
)
?
(
y_n
+
1
)
:
(
in_h
-
1
);
float
d_n
=
ratio_h
*
k
-
y_n
;
float
d_s
=
1.
f
-
d_n
;
for
(
int
l
=
0
;
l
<
out_w
;
l
++
)
{
int
x_w
=
static_cast
<
int
>
(
ratio_w
*
l
);
int
x_e
=
(
x_w
+
1
)
<
(
in_w
-
1
)
?
(
x_w
+
1
)
:
(
in_w
-
1
);
float
d_w
=
ratio_w
*
l
-
x_w
;
float
d_e
=
1.
f
-
d_w
;
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
// loop for batches
for
(
int
j
=
0
;
j
<
c
;
j
++
)
{
// loop for channels
// bilinear interpolation grad
const
T
grad
=
output_grad_t
(
i
,
j
,
k
,
l
);
input_grad_t
(
i
,
j
,
y_n
,
x_w
)
+=
static_cast
<
T
>
(
grad
*
d_s
*
d_e
);
input_grad_t
(
i
,
j
,
y_s
,
x_w
)
+=
static_cast
<
T
>
(
grad
*
d_n
*
d_e
);
input_grad_t
(
i
,
j
,
y_n
,
x_e
)
+=
static_cast
<
T
>
(
grad
*
d_s
*
d_w
);
input_grad_t
(
i
,
j
,
y_s
,
x_e
)
+=
static_cast
<
T
>
(
grad
*
d_n
*
d_w
);
}
}
}
}
}
template
<
typename
T
>
class
InterpolateKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
std
::
string
interp_method
=
ctx
.
Attr
<
std
::
string
>
(
"interp_method"
);
int
out_h
=
ctx
.
Attr
<
int
>
(
"out_h"
);
int
out_w
=
ctx
.
Attr
<
int
>
(
"out_w"
);
auto
out_size
=
ctx
.
Input
<
Tensor
>
(
"OutSize"
);
...
...
@@ -58,30 +173,25 @@ class NearestNeighborInterpKernel : public framework::OpKernel<T> {
float
ratio_w
=
(
out_w
>
1
)
?
static_cast
<
float
>
(
in_w
-
1
)
/
(
out_w
-
1
)
:
0.
f
;
auto
input_t
=
EigenTensor
<
T
,
4
>::
From
(
*
input
);
auto
output_t
=
EigenTensor
<
T
,
4
>::
From
(
*
output
);
for
(
int
k
=
0
;
k
<
out_h
;
k
++
)
{
// loop for images
int
in_k
=
static_cast
<
int
>
(
round
(
ratio_h
*
k
));
for
(
int
l
=
0
;
l
<
out_w
;
l
++
)
{
int
in_l
=
static_cast
<
int
>
(
round
(
ratio_w
*
l
));
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
// loop for batches
for
(
int
j
=
0
;
j
<
c
;
j
++
)
{
// loop for channels
output_t
(
i
,
j
,
k
,
l
)
=
input_t
(
i
,
j
,
in_k
,
in_l
);
}
}
}
if
(
"bilinear"
==
interp_method
)
{
BilinearInterpolation
<
T
>
(
*
input
,
output
,
ratio_h
,
ratio_w
,
in_h
,
in_w
,
n
,
c
,
out_h
,
out_w
);
}
else
if
(
"nearest"
==
interp_method
)
{
NearestNeighborInterpolate
<
T
>
(
*
input
,
output
,
ratio_h
,
ratio_w
,
n
,
c
,
out_h
,
out_w
);
}
}
};
template
<
typename
T
>
class
NearestNeighborInterp
GradKernel
:
public
framework
::
OpKernel
<
T
>
{
class
Interpolate
GradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
input_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
output_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
std
::
string
interp_method
=
ctx
.
Attr
<
std
::
string
>
(
"interp_method"
);
int
out_h
=
ctx
.
Attr
<
int
>
(
"out_h"
);
int
out_w
=
ctx
.
Attr
<
int
>
(
"out_w"
);
auto
out_size
=
ctx
.
Input
<
Tensor
>
(
"OutSize"
);
...
...
@@ -112,18 +222,12 @@ class NearestNeighborInterpGradKernel : public framework::OpKernel<T> {
float
ratio_w
=
(
out_w
>
1
)
?
static_cast
<
float
>
(
in_w
-
1
)
/
(
out_w
-
1
)
:
0.
f
;
auto
input_grad_t
=
EigenTensor
<
T
,
4
>::
From
(
*
input_grad
);
auto
output_grad_t
=
EigenTensor
<
T
,
4
>::
From
(
*
output_grad
);
for
(
int
k
=
0
;
k
<
out_h
;
k
++
)
{
// loop for images
int
in_k
=
static_cast
<
int
>
(
round
(
ratio_h
*
k
));
for
(
int
l
=
0
;
l
<
out_w
;
l
++
)
{
int
in_l
=
static_cast
<
int
>
(
round
(
ratio_w
*
l
));
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
// loop for batches
for
(
int
j
=
0
;
j
<
c
;
j
++
)
{
// loop for channels
input_grad_t
(
i
,
j
,
in_k
,
in_l
)
+=
output_grad_t
(
i
,
j
,
k
,
l
);
}
}
}
if
(
"bilinear"
==
interp_method
)
{
BilinearInterpolationGrad
<
T
>
(
*
output_grad
,
input_grad
,
ratio_h
,
ratio_w
,
in_h
,
in_w
,
n
,
c
,
out_h
,
out_w
);
}
else
if
(
"nearest"
==
interp_method
)
{
NearestNeighborInterpolateGrad
<
T
>
(
*
output_grad
,
input_grad
,
ratio_h
,
ratio_w
,
n
,
c
,
out_h
,
out_w
);
}
}
};
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
34bfae24
...
...
@@ -5612,17 +5612,14 @@ def image_resize(input,
out = fluid.layers.image_resize(input, out_shape=[12, 12])
"""
resample_methods
=
{
'BILINEAR'
:
'bilinear_interp'
,
'NEAREST'
:
'nearest_neighbor_interp'
}
resample_methods
=
{
'BILINEAR'
:
'bilinear'
,
'NEAREST'
:
'nearest'
}
if
resample
not
in
resample_methods
:
raise
ValueError
(
"The 'resample' of image_resize can only be 'BILINEAR' and 'NEAREST' currently."
)
if
out_shape
is
None
and
scale
is
None
:
raise
ValueError
(
"One of out_shape and scale must not be None"
)
helper
=
LayerHelper
(
resample_methods
[
resample
]
,
**
locals
())
helper
=
LayerHelper
(
'interpolate'
,
**
locals
())
dtype
=
helper
.
input_dtype
()
def
_is_list_or_turple_
(
data
):
...
...
@@ -5647,15 +5644,18 @@ def image_resize(input,
out
=
helper
.
create_variable_for_type_inference
(
dtype
)
helper
.
append_op
(
type
=
resample_methods
[
resample
]
,
type
=
'interpolate'
,
inputs
=
inputs
,
outputs
=
{
"Out"
:
out
},
attrs
=
{
"out_h"
:
out_h
,
"out_w"
:
out_w
})
attrs
=
{
"out_h"
:
out_h
,
"out_w"
:
out_w
,
"interp_method"
:
resample_methods
[
resample
]
})
return
out
@
templatedoc
(
op_type
=
"
bilinear_interp
"
)
@
templatedoc
(
op_type
=
"
interpolate
"
)
def
resize_bilinear
(
input
,
out_shape
=
None
,
scale
=
None
,
name
=
None
):
"""
${comment}
...
...
@@ -5678,7 +5678,7 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
return
image_resize
(
input
,
out_shape
,
scale
,
name
,
'BILINEAR'
)
@
templatedoc
(
op_type
=
"
bilinear_interp
"
)
@
templatedoc
(
op_type
=
"
interpolate
"
)
def
resize_nearest
(
input
,
out_shape
=
None
,
scale
=
None
,
name
=
None
):
"""
${comment}
...
...
python/paddle/fluid/tests/unittests/test_
bilinear_interp
_op.py
→
python/paddle/fluid/tests/unittests/test_
interpolate
_op.py
浏览文件 @
34bfae24
...
...
@@ -20,7 +20,31 @@ from op_test import OpTest
import
paddle.fluid.core
as
core
def
nearest_neighbor_interp_np
(
X
,
out_h
,
out_w
,
out_size
=
None
):
"""nearest neighbor interpolation implement in shape [N, C, H, W]"""
if
out_size
is
not
None
:
out_h
=
out_size
[
0
]
out_w
=
out_size
[
1
]
n
,
c
,
in_h
,
in_w
=
X
.
shape
ratio_h
=
ratio_w
=
0.0
if
out_h
>
1
:
ratio_h
=
(
in_h
-
1.0
)
/
(
out_h
-
1.0
)
if
out_w
>
1
:
ratio_w
=
(
in_w
-
1.0
)
/
(
out_w
-
1.0
)
out
=
np
.
zeros
((
n
,
c
,
out_h
,
out_w
))
for
i
in
range
(
out_h
):
in_i
=
int
(
ratio_h
*
i
+
0.5
)
for
j
in
range
(
out_w
):
in_j
=
int
(
ratio_w
*
j
+
0.5
)
out
[:,
:,
i
,
j
]
=
X
[:,
:,
in_i
,
in_j
]
return
out
.
astype
(
X
.
dtype
)
def
bilinear_interp_np
(
input
,
out_h
,
out_w
,
out_size
):
"""bilinear interpolation implement in shape [N, C, H, W]"""
if
out_size
is
not
None
:
out_h
=
out_size
[
0
]
out_w
=
out_size
[
1
]
...
...
@@ -53,18 +77,29 @@ def bilinear_interp_np(input, out_h, out_w, out_size):
return
out
.
astype
(
input
.
dtype
)
class
TestBilinearInterpOp
(
OpTest
):
INTERPOLATE_FUNCS
=
{
'bilinear'
:
bilinear_interp_np
,
'nearest'
:
nearest_neighbor_interp_np
,
}
class
TestInterpolateOp
(
OpTest
):
def
setUp
(
self
):
self
.
out_size
=
None
self
.
init_test_case
()
self
.
op_type
=
"
bilinear_interp
"
self
.
op_type
=
"
interpolate
"
input_np
=
np
.
random
.
random
(
self
.
input_shape
).
astype
(
"float32"
)
output_np
=
bilinear_interp_np
(
input_np
,
self
.
out_h
,
self
.
out_w
,
self
.
out_size
)
output_np
=
INTERPOLATE_FUNCS
[
self
.
interp_method
](
input_np
,
self
.
out_h
,
self
.
out_w
,
self
.
out_size
)
self
.
inputs
=
{
'X'
:
input_np
}
if
self
.
out_size
is
not
None
:
self
.
inputs
[
'OutSize'
]
=
self
.
out_size
self
.
attrs
=
{
'out_h'
:
self
.
out_h
,
'out_w'
:
self
.
out_w
}
self
.
attrs
=
{
'out_h'
:
self
.
out_h
,
'out_w'
:
self
.
out_w
,
'interp_method'
:
self
.
interp_method
}
self
.
outputs
=
{
'Out'
:
output_np
}
def
test_check_output
(
self
):
...
...
@@ -74,90 +109,181 @@ class TestBilinearInterpOp(OpTest):
self
.
check_grad
([
'X'
],
'Out'
,
in_place
=
True
)
def
init_test_case
(
self
):
self
.
interp_method
=
'bilinear'
self
.
input_shape
=
[
2
,
3
,
4
,
4
]
self
.
out_h
=
2
self
.
out_w
=
2
self
.
out_size
=
np
.
array
([
3
,
3
]).
astype
(
"int32"
)
class
Test
Case1
(
TestBilinearInterp
Op
):
class
Test
BilinearInterpCase1
(
TestInterpolate
Op
):
def
init_test_case
(
self
):
self
.
interp_method
=
'bilinear'
self
.
input_shape
=
[
4
,
1
,
7
,
8
]
self
.
out_h
=
1
self
.
out_w
=
1
class
Test
Case2
(
TestBilinearInterp
Op
):
class
Test
BilinearInterpCase2
(
TestInterpolate
Op
):
def
init_test_case
(
self
):
self
.
interp_method
=
'bilinear'
self
.
input_shape
=
[
3
,
3
,
9
,
6
]
self
.
out_h
=
12
self
.
out_w
=
12
class
Test
Case3
(
TestBilinearInterp
Op
):
class
Test
BilinearInterpCase3
(
TestInterpolate
Op
):
def
init_test_case
(
self
):
self
.
interp_method
=
'bilinear'
self
.
input_shape
=
[
1
,
1
,
128
,
64
]
self
.
out_h
=
64
self
.
out_w
=
128
class
Test
Case4
(
TestBilinearInterp
Op
):
class
Test
BilinearInterpCase4
(
TestInterpolate
Op
):
def
init_test_case
(
self
):
self
.
interp_method
=
'bilinear'
self
.
input_shape
=
[
4
,
1
,
7
,
8
]
self
.
out_h
=
1
self
.
out_w
=
1
self
.
out_size
=
np
.
array
([
2
,
2
]).
astype
(
"int32"
)
class
Test
Case5
(
TestBilinearInterp
Op
):
class
Test
BilinearInterpCase5
(
TestInterpolate
Op
):
def
init_test_case
(
self
):
self
.
interp_method
=
'bilinear'
self
.
input_shape
=
[
3
,
3
,
9
,
6
]
self
.
out_h
=
12
self
.
out_w
=
12
self
.
out_size
=
np
.
array
([
11
,
11
]).
astype
(
"int32"
)
class
Test
Case6
(
TestBilinearInterp
Op
):
class
Test
BilinearInterpCase6
(
TestInterpolate
Op
):
def
init_test_case
(
self
):
self
.
interp_method
=
'bilinear'
self
.
input_shape
=
[
1
,
1
,
128
,
64
]
self
.
out_h
=
64
self
.
out_w
=
128
self
.
out_size
=
np
.
array
([
65
,
129
]).
astype
(
"int32"
)
class
TestBilinearInterpOpUint8
(
OpTest
):
# class TestBilinearInterpBigScale(TestInterpolateOp):
# def init_test_case(self):
# self.interp_method = 'bilinear'
# self.input_shape = [32, 16, 128, 64]
# self.out_h = 200
# self.out_w = 100
# self.out_size = np.array([201, 101]).astype('int32')
class
TestInterpolateOpUint8
(
OpTest
):
def
setUp
(
self
):
self
.
out_size
=
None
self
.
init_test_case
()
self
.
op_type
=
"
bilinear_interp
"
self
.
op_type
=
"
interpolate
"
input_np
=
np
.
random
.
randint
(
low
=
0
,
high
=
256
,
size
=
self
.
input_shape
).
astype
(
"uint8"
)
output_np
=
bilinear_interp_np
(
input_np
,
self
.
out_h
,
self
.
out_w
,
self
.
out_size
)
output_np
=
INTERPOLATE_FUNCS
[
self
.
interp_method
](
input_np
,
self
.
out_h
,
self
.
out_w
,
self
.
out_size
)
self
.
inputs
=
{
'X'
:
input_np
}
if
self
.
out_size
is
not
None
:
self
.
inputs
[
'OutSize'
]
=
self
.
out_size
self
.
attrs
=
{
'out_h'
:
self
.
out_h
,
'out_w'
:
self
.
out_w
}
self
.
attrs
=
{
'out_h'
:
self
.
out_h
,
'out_w'
:
self
.
out_w
,
'interp_method'
:
self
.
interp_method
}
self
.
outputs
=
{
'Out'
:
output_np
}
def
test_check_output
(
self
):
self
.
check_output_with_place
(
place
=
core
.
CPUPlace
(),
atol
=
1
)
def
init_test_case
(
self
):
self
.
interp_method
=
'bilinear'
self
.
input_shape
=
[
1
,
3
,
9
,
6
]
self
.
out_h
=
10
self
.
out_w
=
9
class
TestCase1Uint8
(
TestBilinearInterpOpUint8
):
class
TestBilinearInterpCase1Uint8
(
TestInterpolateOpUint8
):
def
init_test_case
(
self
):
self
.
interp_method
=
'bilinear'
self
.
input_shape
=
[
2
,
3
,
128
,
64
]
self
.
out_h
=
120
self
.
out_w
=
50
class
TestBilinearInterpCase2Uint8
(
TestInterpolateOpUint8
):
def
init_test_case
(
self
):
self
.
interp_method
=
'bilinear'
self
.
input_shape
=
[
4
,
1
,
7
,
8
]
self
.
out_h
=
5
self
.
out_w
=
13
self
.
out_size
=
np
.
array
([
6
,
15
]).
astype
(
"int32"
)
class
TestNearestNeighborInterpCase1
(
TestInterpolateOp
):
def
init_test_case
(
self
):
self
.
interp_method
=
'nearest'
self
.
input_shape
=
[
4
,
1
,
7
,
8
]
self
.
out_h
=
1
self
.
out_w
=
1
class
TestNearestNeighborInterpCase2
(
TestInterpolateOp
):
def
init_test_case
(
self
):
self
.
interp_method
=
'nearest'
self
.
input_shape
=
[
3
,
3
,
9
,
6
]
self
.
out_h
=
12
self
.
out_w
=
12
class
TestNearestNeighborInterpCase3
(
TestInterpolateOp
):
def
init_test_case
(
self
):
self
.
interp_method
=
'nearest'
self
.
input_shape
=
[
1
,
1
,
128
,
64
]
self
.
out_h
=
64
self
.
out_w
=
128
class
TestNearestNeighborInterpCase4
(
TestInterpolateOp
):
def
init_test_case
(
self
):
self
.
interp_method
=
'nearest'
self
.
input_shape
=
[
4
,
1
,
7
,
8
]
self
.
out_h
=
1
self
.
out_w
=
1
self
.
out_size
=
np
.
array
([
2
,
2
]).
astype
(
"int32"
)
class
TestNearestNeighborInterpCase5
(
TestInterpolateOp
):
def
init_test_case
(
self
):
self
.
interp_method
=
'nearest'
self
.
input_shape
=
[
3
,
3
,
9
,
6
]
self
.
out_h
=
12
self
.
out_w
=
12
self
.
out_size
=
np
.
array
([
11
,
11
]).
astype
(
"int32"
)
class
TestNearestNeighborInterpCase6
(
TestInterpolateOp
):
def
init_test_case
(
self
):
self
.
interp_method
=
'nearest'
self
.
input_shape
=
[
1
,
1
,
128
,
64
]
self
.
out_h
=
64
self
.
out_w
=
128
self
.
out_size
=
np
.
array
([
65
,
129
]).
astype
(
"int32"
)
class
TestNearestNeighborInterpCase1Uint8
(
TestInterpolateOpUint8
):
def
init_test_case
(
self
):
self
.
interp_method
=
'nearest'
self
.
input_shape
=
[
2
,
3
,
128
,
64
]
self
.
out_h
=
120
self
.
out_w
=
50
class
Test
Case2Uint8
(
TestBilinearInterp
OpUint8
):
class
Test
NearestNeighborInterpCase2Uint8
(
TestInterpolate
OpUint8
):
def
init_test_case
(
self
):
self
.
interp_method
=
'nearest'
self
.
input_shape
=
[
4
,
1
,
7
,
8
]
self
.
out_h
=
5
self
.
out_w
=
13
...
...
python/paddle/fluid/tests/unittests/test_layers.py
浏览文件 @
34bfae24
...
...
@@ -485,7 +485,7 @@ class TestBook(unittest.TestCase):
self
.
assertIsNotNone
(
output
)
print
(
str
(
program
))
def
test_resize_
bilinear
(
self
):
def
test_resize_
nearest
(
self
):
program
=
Program
()
with
program_guard
(
program
):
x
=
layers
.
data
(
name
=
'x'
,
shape
=
[
3
,
9
,
6
],
dtype
=
"float32"
)
...
...
python/paddle/fluid/tests/unittests/test_nearest_neighbor_interp_op.py
已删除
100644 → 0
浏览文件 @
df4a3544
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
import
paddle.fluid.core
as
core
def
nearest_neighbor_interp_np
(
X
,
out_h
,
out_w
,
out_size
=
None
):
"""nearest neighbor interpolation implement in shape [N, C, H, W]"""
if
out_size
is
not
None
:
out_h
=
out_size
[
0
]
out_w
=
out_size
[
1
]
n
,
c
,
in_h
,
in_w
=
X
.
shape
ratio_h
=
ratio_w
=
0.0
if
out_h
>
1
:
ratio_h
=
(
in_h
-
1.0
)
/
(
out_h
-
1.0
)
if
out_w
>
1
:
ratio_w
=
(
in_w
-
1.0
)
/
(
out_w
-
1.0
)
out
=
np
.
zeros
((
n
,
c
,
out_h
,
out_w
))
for
i
in
range
(
out_h
):
in_i
=
int
(
round
(
ratio_h
*
i
))
for
j
in
range
(
out_w
):
in_j
=
int
(
round
(
ratio_w
*
j
))
out
[:,
:,
i
,
j
]
=
X
[:,
:,
in_i
,
in_j
]
return
out
.
astype
(
X
.
dtype
)
class
TestBilinearInterpOp
(
OpTest
):
def
setUp
(
self
):
self
.
out_size
=
None
self
.
init_test_case
()
self
.
op_type
=
"nearest_neighbor_interp"
input_np
=
np
.
random
.
random
(
self
.
input_shape
).
astype
(
"float32"
)
output_np
=
nearest_neighbor_interp_np
(
input_np
,
self
.
out_h
,
self
.
out_w
,
self
.
out_size
)
self
.
inputs
=
{
'X'
:
input_np
}
if
self
.
out_size
is
not
None
:
self
.
inputs
[
'OutSize'
]
=
self
.
out_size
self
.
attrs
=
{
'out_h'
:
self
.
out_h
,
'out_w'
:
self
.
out_w
}
self
.
outputs
=
{
'Out'
:
output_np
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Out'
,
in_place
=
True
)
def
init_test_case
(
self
):
self
.
input_shape
=
[
2
,
3
,
4
,
4
]
self
.
out_h
=
2
self
.
out_w
=
2
self
.
out_size
=
np
.
array
([
3
,
3
]).
astype
(
"int32"
)
class
TestCase1
(
TestBilinearInterpOp
):
def
init_test_case
(
self
):
self
.
input_shape
=
[
4
,
1
,
7
,
8
]
self
.
out_h
=
1
self
.
out_w
=
1
class
TestCase2
(
TestBilinearInterpOp
):
def
init_test_case
(
self
):
self
.
input_shape
=
[
3
,
3
,
9
,
6
]
self
.
out_h
=
12
self
.
out_w
=
12
class
TestCase3
(
TestBilinearInterpOp
):
def
init_test_case
(
self
):
self
.
input_shape
=
[
1
,
1
,
128
,
64
]
self
.
out_h
=
64
self
.
out_w
=
128
class
TestCase4
(
TestBilinearInterpOp
):
def
init_test_case
(
self
):
self
.
input_shape
=
[
4
,
1
,
7
,
8
]
self
.
out_h
=
1
self
.
out_w
=
1
self
.
out_size
=
np
.
array
([
2
,
2
]).
astype
(
"int32"
)
class
TestCase5
(
TestBilinearInterpOp
):
def
init_test_case
(
self
):
self
.
input_shape
=
[
3
,
3
,
9
,
6
]
self
.
out_h
=
12
self
.
out_w
=
12
self
.
out_size
=
np
.
array
([
11
,
11
]).
astype
(
"int32"
)
class
TestCase6
(
TestBilinearInterpOp
):
def
init_test_case
(
self
):
self
.
input_shape
=
[
1
,
1
,
128
,
64
]
self
.
out_h
=
64
self
.
out_w
=
128
self
.
out_size
=
np
.
array
([
65
,
129
]).
astype
(
"int32"
)
class
TestBilinearInterpOpUint8
(
OpTest
):
def
setUp
(
self
):
self
.
out_size
=
None
self
.
init_test_case
()
self
.
op_type
=
"nearest_neighbor_interp"
input_np
=
np
.
random
.
randint
(
low
=
0
,
high
=
256
,
size
=
self
.
input_shape
).
astype
(
"uint8"
)
output_np
=
nearest_neighbor_interp_np
(
input_np
,
self
.
out_h
,
self
.
out_w
,
self
.
out_size
)
self
.
inputs
=
{
'X'
:
input_np
}
if
self
.
out_size
is
not
None
:
self
.
inputs
[
'OutSize'
]
=
self
.
out_size
self
.
attrs
=
{
'out_h'
:
self
.
out_h
,
'out_w'
:
self
.
out_w
}
self
.
outputs
=
{
'Out'
:
output_np
}
def
test_check_output
(
self
):
self
.
check_output_with_place
(
place
=
core
.
CPUPlace
(),
atol
=
1
)
def
init_test_case
(
self
):
self
.
input_shape
=
[
1
,
3
,
9
,
6
]
self
.
out_h
=
10
self
.
out_w
=
9
class
TestCase1Uint8
(
TestBilinearInterpOpUint8
):
def
init_test_case
(
self
):
self
.
input_shape
=
[
2
,
3
,
128
,
64
]
self
.
out_h
=
120
self
.
out_w
=
50
class
TestCase2Uint8
(
TestBilinearInterpOpUint8
):
def
init_test_case
(
self
):
self
.
input_shape
=
[
4
,
1
,
7
,
8
]
self
.
out_h
=
5
self
.
out_w
=
13
self
.
out_size
=
np
.
array
([
6
,
15
]).
astype
(
"int32"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录