Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a0c63f11
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a0c63f11
编写于
1月 27, 2019
作者:
T
tink2123
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add align_flag
test=develop
上级
b64cdaf6
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
36 addition
and
45 deletion
+36
-45
paddle/fluid/operators/interpolate_op.cc
paddle/fluid/operators/interpolate_op.cc
+1
-1
paddle/fluid/operators/interpolate_op.cu
paddle/fluid/operators/interpolate_op.cu
+16
-20
paddle/fluid/operators/interpolate_op.h
paddle/fluid/operators/interpolate_op.h
+16
-21
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+3
-3
未找到文件。
paddle/fluid/operators/interpolate_op.cc
浏览文件 @
a0c63f11
...
...
@@ -110,7 +110,7 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
to perform linear interpolation first in one direction, and then
again in the other direction.
Align_corners and align_mode are optinal parameters,
T
he calculation method
Align_corners and align_mode are optinal parameters,
t
he calculation method
of interpolation can be selected by them.
Example:
...
...
paddle/fluid/operators/interpolate_op.cu
浏览文件 @
a0c63f11
...
...
@@ -94,6 +94,7 @@ __global__ void KeBilinearInterpFw(
int
nthreads
=
output_h
*
output_w
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
bool
align_flag
=
(
align_mode
==
0
&&
!
align_corners
);
for
(;
tid
<
nthreads
;
tid
+=
stride
)
{
int
out_id_h
=
tid
/
output_w
;
int
out_id_w
=
tid
%
output_w
;
...
...
@@ -102,25 +103,23 @@ __global__ void KeBilinearInterpFw(
int
channel_id
=
out_id_w
/
out_img_size
;
int
out_img_idy
=
(
out_id_w
%
out_img_size
)
/
out_img_w
;
int
in_img_idy
=
(
align_mode
==
0
&&
!
align_corners
)
int
in_img_idy
=
align_flag
?
static_cast
<
int
>
(
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
)
:
static_cast
<
int
>
(
ratio_h
*
out_img_idy
);
in_img_idy
=
(
in_img_idy
>
0
)
?
in_img_idy
:
0
;
int
h_id
=
(
in_img_idy
<
in_img_h
-
1
)
?
1
:
0
;
T
h1lambda
=
(
align_mode
==
0
&&
!
align_corners
)
?
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
-
in_img_idy
:
ratio_h
*
out_img_idy
-
in_img_idy
;
T
h1lambda
=
align_flag
?
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
-
in_img_idy
:
ratio_h
*
out_img_idy
-
in_img_idy
;
T
h2lambda
=
1.
f
-
h1lambda
;
int
out_img_idx
=
tid
%
out_img_w
;
int
in_img_idx
=
(
align_mode
==
0
&&
!
align_corners
)
int
in_img_idx
=
align_flag
?
static_cast
<
int
>
(
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
)
:
static_cast
<
int
>
(
ratio_w
*
out_img_idx
);
in_img_idx
=
(
in_img_idx
>
0
)
?
in_img_idx
:
0
;
int
w_id
=
(
in_img_idx
<
in_img_w
-
1
)
?
1
:
0
;
T
w1lambda
=
(
align_mode
==
0
&&
!
align_corners
)
?
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
-
in_img_idx
:
ratio_w
*
out_img_idx
-
in_img_idx
;
T
w1lambda
=
align_flag
?
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
-
in_img_idx
:
ratio_w
*
out_img_idx
-
in_img_idx
;
T
w2lambda
=
1.
f
-
w1lambda
;
const
T
*
in_pos
=
&
in
[
out_id_h
*
input_w
+
channel_id
*
in_img_size
+
...
...
@@ -144,6 +143,7 @@ __global__ void KeBilinearInterpBw(
int
nthreads
=
output_h
*
output_w
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
bool
align_flag
=
(
align_mode
==
0
&&
!
align_corners
);
for
(;
tid
<
nthreads
;
tid
+=
stride
)
{
int
out_id_h
=
tid
/
output_w
;
int
out_id_w
=
tid
%
output_w
;
...
...
@@ -152,26 +152,22 @@ __global__ void KeBilinearInterpBw(
int
channel_id
=
out_id_w
/
out_img_size
;
int
out_img_idy
=
(
out_id_w
%
out_img_size
)
/
out_img_w
;
int
in_img_idy
=
(
align_mode
==
0
&&
!
align_corners
)
?
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
:
ratio_h
*
out_img_idy
;
int
in_img_idy
=
align_flag
?
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
:
ratio_h
*
out_img_idy
;
in_img_idy
=
(
in_img_idy
>
0
)
?
in_img_idy
:
0
;
int
h_id
=
(
in_img_idy
<
in_img_h
-
1
)
?
1
:
0
;
T
h1lambda
=
(
align_mode
==
0
&&
!
align_corners
)
?
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
-
in_img_idy
:
ratio_h
*
out_img_idy
-
in_img_idy
;
T
h1lambda
=
align_flag
?
ratio_h
*
(
out_img_idy
+
0.5
)
-
0.5
-
in_img_idy
:
ratio_h
*
out_img_idy
-
in_img_idy
;
T
h2lambda
=
1.
f
-
h1lambda
;
int
out_img_idx
=
tid
%
out_img_w
;
int
in_img_idx
=
(
align_mode
==
0
&&
!
align_corners
)
?
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
:
ratio_w
*
out_img_idx
;
int
in_img_idx
=
align_flag
?
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
:
ratio_w
*
out_img_idx
;
in_img_idx
=
(
in_img_idx
>
0
)
?
in_img_idx
:
0
;
int
w_id
=
(
in_img_idx
<
in_img_w
-
1
)
?
1
:
0
;
T
w1lambda
=
(
align_mode
==
0
&&
!
align_corners
)
?
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
-
in_img_idx
:
ratio_w
*
out_img_idx
-
in_img_idx
;
T
w1lambda
=
align_flag
?
ratio_w
*
(
out_img_idx
+
0.5
)
-
0.5
-
in_img_idx
:
ratio_w
*
out_img_idx
-
in_img_idx
;
T
w2lambda
=
1.
f
-
w1lambda
;
T
*
in_pos
=
&
in
[
out_id_h
*
input_w
+
channel_id
*
in_img_size
+
...
...
paddle/fluid/operators/interpolate_op.h
浏览文件 @
a0c63f11
...
...
@@ -56,15 +56,14 @@ static void BilinearInterpolation(const Tensor& input, Tensor* output,
const
bool
align_mode
)
{
auto
input_t
=
EigenTensor
<
T
,
4
>::
From
(
input
);
auto
output_t
=
EigenTensor
<
T
,
4
>::
From
(
*
output
);
bool
align_flag
=
(
align_mode
==
0
&&
!
align_corners
);
for
(
int
k
=
0
;
k
<
out_h
;
k
++
)
{
// loop for images
int
y_n
=
(
align_mode
==
0
&&
!
align_corners
)
?
static_cast
<
int
>
(
ratio_h
*
(
k
+
0.5
)
-
0.5
)
:
static_cast
<
int
>
(
ratio_h
*
k
);
int
y_n
=
align_flag
?
static_cast
<
int
>
(
ratio_h
*
(
k
+
0.5
)
-
0.5
)
:
static_cast
<
int
>
(
ratio_h
*
k
);
y_n
=
(
y_n
>
0
)
?
y_n
:
0
;
int
y_s
=
(
y_n
+
1
)
<
(
in_h
-
1
)
?
(
y_n
+
1
)
:
(
in_h
-
1
);
float
d_n
=
(
align_mode
==
0
&&
!
align_corners
)
?
ratio_h
*
(
k
+
0.5
)
-
0.5
-
y_n
:
ratio_h
*
k
-
y_n
;
float
d_n
=
align_flag
?
ratio_h
*
(
k
+
0.5
)
-
0.5
-
y_n
:
ratio_h
*
k
-
y_n
;
float
d_s
=
1.
f
-
d_n
;
for
(
int
l
=
0
;
l
<
out_w
;
l
++
)
{
...
...
@@ -73,9 +72,8 @@ static void BilinearInterpolation(const Tensor& input, Tensor* output,
:
static_cast
<
int
>
(
ratio_w
*
l
);
x_w
=
(
x_w
>
0
)
?
x_w
:
0
;
int
x_e
=
(
x_w
+
1
)
<
(
in_w
-
1
)
?
(
x_w
+
1
)
:
(
in_w
-
1
);
float
d_w
=
(
align_mode
==
0
&&
!
align_corners
)
?
ratio_w
*
(
l
+
0.5
)
-
0.5
-
x_w
:
ratio_w
*
l
-
x_w
;
float
d_w
=
align_flag
?
ratio_w
*
(
l
+
0.5
)
-
0.5
-
x_w
:
ratio_w
*
l
-
x_w
;
float
d_e
=
1.
f
-
d_w
;
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
// loop for batches
...
...
@@ -126,26 +124,23 @@ static void BilinearInterpolationGrad(const Tensor& output_grad,
const
int
align_mode
)
{
auto
input_grad_t
=
EigenTensor
<
T
,
4
>::
From
(
*
input_grad
);
auto
output_grad_t
=
EigenTensor
<
T
,
4
>::
From
(
output_grad
);
bool
align_flag
=
(
align_mode
==
0
&&
!
align_corners
);
for
(
int
k
=
0
;
k
<
out_h
;
k
++
)
{
// loop for images
int
y_n
=
(
align_mode
==
0
&&
!
align_corners
)
?
static_cast
<
int
>
(
ratio_h
*
(
k
+
0.5
)
-
0.5
)
:
static_cast
<
int
>
(
ratio_h
*
k
);
int
y_n
=
align_flag
?
static_cast
<
int
>
(
ratio_h
*
(
k
+
0.5
)
-
0.5
)
:
static_cast
<
int
>
(
ratio_h
*
k
);
y_n
=
(
y_n
>
0
)
?
y_n
:
0
;
int
y_s
=
(
y_n
+
1
)
<
(
in_h
-
1
)
?
(
y_n
+
1
)
:
(
in_h
-
1
);
float
d_n
=
(
align_mode
==
0
&&
!
align_corners
)
?
ratio_h
*
(
k
+
0.5
)
-
0.5
-
y_n
:
ratio_h
*
k
-
y_n
;
float
d_n
=
align_flag
?
ratio_h
*
(
k
+
0.5
)
-
0.5
-
y_n
:
ratio_h
*
k
-
y_n
;
float
d_s
=
1.
f
-
d_n
;
for
(
int
l
=
0
;
l
<
out_w
;
l
++
)
{
int
x_w
=
(
align_mode
==
0
&&
!
align_corners
)
?
static_cast
<
int
>
(
ratio_w
*
(
l
+
0.5
)
-
0.5
)
:
static_cast
<
int
>
(
ratio_w
*
l
);
int
x_w
=
align_flag
?
static_cast
<
int
>
(
ratio_w
*
(
l
+
0.5
)
-
0.5
)
:
static_cast
<
int
>
(
ratio_w
*
l
);
x_w
=
(
x_w
>
0
)
?
x_w
:
0
;
int
x_e
=
(
x_w
+
1
)
<
(
in_w
-
1
)
?
(
x_w
+
1
)
:
(
in_w
-
1
);
float
d_w
=
(
align_mode
==
0
&&
!
align_corners
)
?
ratio_w
*
(
l
+
0.5
)
-
0.5
-
x_w
:
ratio_w
*
l
-
x_w
;
float
d_w
=
align_flag
?
ratio_w
*
(
l
+
0.5
)
-
0.5
-
x_w
:
ratio_w
*
l
-
x_w
;
float
d_e
=
1.
f
-
d_w
;
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
// loop for batches
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
a0c63f11
...
...
@@ -6552,7 +6552,7 @@ def image_resize(input,
to perform linear interpolation first in one direction, and then
again in the other direction.
Align_corners and align_mode are optinal parameters,
T
he calculation method
Align_corners and align_mode are optinal parameters,
t
he calculation method
of interpolation can be selected by them.
Example:
...
...
@@ -6758,11 +6758,11 @@ def resize_bilinear(input,
For details of bilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation
Align_corners and align_mode are optinal parameters,
T
he calculation
Align_corners and align_mode are optinal parameters,
t
he calculation
method of interpolation can be selected by them.
Align_corners and align_mode are optinal parameters,
T
he calculation method
Align_corners and align_mode are optinal parameters,
t
he calculation method
of interpolation can be selected by them.
Example:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录