Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
469a349a
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
469a349a
编写于
4月 23, 2018
作者:
W
wangyang59
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polishing after qingqing's comments
上级
7436b368
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
99 addition
and
108 deletion
+99
-108
paddle/fluid/operators/bilinear_interp_op.cc
paddle/fluid/operators/bilinear_interp_op.cc
+7
-4
paddle/fluid/operators/bilinear_interp_op.cu
paddle/fluid/operators/bilinear_interp_op.cu
+78
-3
paddle/fluid/operators/bilinear_interp_op.cu.h
paddle/fluid/operators/bilinear_interp_op.cu.h
+0
-101
python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py
...n/paddle/fluid/tests/unittests/test_bilinear_interp_op.py
+14
-0
未找到文件。
paddle/fluid/operators/bilinear_interp_op.cc
浏览文件 @
469a349a
...
...
@@ -44,10 +44,13 @@ class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker {
BilinearInterpOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input tensor of bilinear interpolation, 4-D with NCHW shape"
);
AddOutput
(
"Out"
,
"The output tensor with the same shape as X"
);
AddAttr
<
int
>
(
"out_h"
,
"output height of bilinear interpolation op."
);
AddAttr
<
int
>
(
"out_w"
,
"output weight of bilinear interpolation op."
);
"(Tensor) The input tensor of bilinear interpolation, "
"This is a 4-D tensor with shape of (N x C x h x w)"
);
AddOutput
(
"Out"
,
"(Tensor) The dimension of output is (N x C x out_h x out_w]"
);
AddAttr
<
int
>
(
"out_h"
,
"(int) output height of bilinear interpolation op."
);
AddAttr
<
int
>
(
"out_w"
,
"(int) output width of bilinear interpolation op."
);
AddComment
(
R"DOC(
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
...
...
paddle/fluid/operators/bilinear_interp_op.cu
浏览文件 @
469a349a
...
...
@@ -9,15 +9,90 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/bilinear_interp_op.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/bilinear_interp_op.h"
#include "paddle/fluid/platform/cuda_helper.h"
namespace
paddle
{
namespace
operators
{
using
framework
::
Tensor
;
template
<
typename
T
>
__global__
void
KeBilinearInterpFw
(
const
T
*
in
,
const
size_t
in_img_h
,
const
size_t
in_img_w
,
const
size_t
input_h
,
const
size_t
input_w
,
T
*
out
,
const
size_t
out_img_h
,
const
size_t
out_img_w
,
const
size_t
output_h
,
const
size_t
output_w
,
const
size_t
num_channels
,
const
T
ratio_h
,
const
T
ratioW
)
{
int
nthreads
=
output_h
*
output_w
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
nthreads
)
{
int
out_id_h
=
tid
/
output_w
;
int
out_id_w
=
tid
%
output_w
;
int
in_img_size
=
input_w
/
num_channels
;
int
out_img_size
=
output_w
/
num_channels
;
int
channel_id
=
out_id_w
/
out_img_size
;
int
out_img_idy
=
(
out_id_w
%
out_img_size
)
/
out_img_w
;
int
in_img_idy
=
ratio_h
*
out_img_idy
;
int
h_id
=
(
in_img_idy
<
in_img_h
-
1
)
?
1
:
0
;
T
h1lambda
=
ratio_h
*
out_img_idy
-
in_img_idy
;
T
h2lambda
=
1.
f
-
h1lambda
;
int
out_img_idx
=
tid
%
out_img_w
;
int
in_img_idx
=
ratioW
*
out_img_idx
;
int
w_id
=
(
in_img_idx
<
in_img_w
-
1
)
?
1
:
0
;
T
w1lambda
=
ratioW
*
out_img_idx
-
in_img_idx
;
T
w2lambda
=
1.
f
-
w1lambda
;
const
T
*
in_pos
=
&
in
[
out_id_h
*
input_w
+
channel_id
*
in_img_size
+
in_img_idy
*
in_img_w
+
in_img_idx
];
// bilinear interpolation
out
[
out_id_h
*
output_w
+
out_id_w
]
=
h2lambda
*
(
w2lambda
*
in_pos
[
0
]
+
w1lambda
*
in_pos
[
w_id
])
+
h1lambda
*
(
w2lambda
*
in_pos
[
h_id
*
in_img_w
]
+
w1lambda
*
in_pos
[
h_id
*
in_img_w
+
w_id
]);
}
}
template
<
typename
T
>
__global__
void
KeBilinearInterpBw
(
T
*
in
,
const
size_t
in_img_h
,
const
size_t
in_img_w
,
const
size_t
input_h
,
const
size_t
input_w
,
const
T
*
out
,
const
size_t
out_img_h
,
const
size_t
out_img_w
,
const
size_t
output_h
,
const
size_t
output_w
,
const
size_t
num_channels
,
const
T
ratio_h
,
const
T
ratioW
)
{
int
nthreads
=
output_h
*
output_w
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
nthreads
)
{
int
out_id_h
=
tid
/
output_w
;
int
out_id_w
=
tid
%
output_w
;
int
in_img_size
=
input_w
/
num_channels
;
int
out_img_size
=
output_w
/
num_channels
;
int
channel_id
=
out_id_w
/
out_img_size
;
int
out_img_idy
=
(
out_id_w
%
out_img_size
)
/
out_img_w
;
int
in_img_idy
=
ratio_h
*
out_img_idy
;
int
h_id
=
(
in_img_idy
<
in_img_h
-
1
)
?
1
:
0
;
T
h1lambda
=
ratio_h
*
out_img_idy
-
in_img_idy
;
T
h2lambda
=
1.
f
-
h1lambda
;
int
out_img_idx
=
tid
%
out_img_w
;
int
in_img_idx
=
ratioW
*
out_img_idx
;
int
w_id
=
(
in_img_idx
<
in_img_w
-
1
)
?
1
:
0
;
T
w1lambda
=
ratioW
*
out_img_idx
-
in_img_idx
;
T
w2lambda
=
1.
f
-
w1lambda
;
T
*
in_pos
=
&
in
[
out_id_h
*
input_w
+
channel_id
*
in_img_size
+
in_img_idy
*
in_img_w
+
in_img_idx
];
const
T
*
out_pos
=
&
out
[
out_id_h
*
output_w
+
out_id_w
];
atomicAdd
(
&
in_pos
[
0
],
h2lambda
*
w2lambda
*
out_pos
[
0
]);
atomicAdd
(
&
in_pos
[
w_id
],
h2lambda
*
w1lambda
*
out_pos
[
0
]);
atomicAdd
(
&
in_pos
[
h_id
*
in_img_w
],
h1lambda
*
w2lambda
*
out_pos
[
0
]);
atomicAdd
(
&
in_pos
[
h_id
*
in_img_w
+
w_id
],
h1lambda
*
w1lambda
*
out_pos
[
0
]);
}
}
template
<
typename
T
>
class
BilinearInterpOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
paddle/fluid/operators/bilinear_interp_op.cu.h
已删除
100644 → 0
浏览文件 @
7436b368
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/platform/cuda_helper.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
__global__
void
KeBilinearInterpFw
(
const
T
*
in
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
inputH
,
const
size_t
inputW
,
T
*
out
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
outputH
,
const
size_t
outputW
,
const
size_t
numChannels
,
const
T
ratioH
,
const
T
ratioW
)
{
int
nthreads
=
outputH
*
outputW
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
nthreads
)
{
int
outIdH
=
tid
/
outputW
;
int
outIdW
=
tid
%
outputW
;
int
inImgSize
=
inputW
/
numChannels
;
int
outImgSize
=
outputW
/
numChannels
;
int
channelId
=
outIdW
/
outImgSize
;
int
outImgIdy
=
(
outIdW
%
outImgSize
)
/
outImgW
;
int
inImgIdy
=
ratioH
*
outImgIdy
;
int
hId
=
(
inImgIdy
<
inImgH
-
1
)
?
1
:
0
;
T
h1lambda
=
ratioH
*
outImgIdy
-
inImgIdy
;
T
h2lambda
=
1.
f
-
h1lambda
;
int
outImgIdx
=
tid
%
outImgW
;
int
inImgIdx
=
ratioW
*
outImgIdx
;
int
wId
=
(
inImgIdx
<
inImgW
-
1
)
?
1
:
0
;
T
w1lambda
=
ratioW
*
outImgIdx
-
inImgIdx
;
T
w2lambda
=
1.
f
-
w1lambda
;
const
T
*
inPos
=
&
in
[
outIdH
*
inputW
+
channelId
*
inImgSize
+
inImgIdy
*
inImgW
+
inImgIdx
];
// bilinear interpolation
out
[
outIdH
*
outputW
+
outIdW
]
=
h2lambda
*
(
w2lambda
*
inPos
[
0
]
+
w1lambda
*
inPos
[
wId
])
+
h1lambda
*
(
w2lambda
*
inPos
[
hId
*
inImgW
]
+
w1lambda
*
inPos
[
hId
*
inImgW
+
wId
]);
}
}
template
<
typename
T
>
__global__
void
KeBilinearInterpBw
(
T
*
in
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
inputH
,
const
size_t
inputW
,
const
T
*
out
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
outputH
,
const
size_t
outputW
,
const
size_t
numChannels
,
const
T
ratioH
,
const
T
ratioW
)
{
int
nthreads
=
outputH
*
outputW
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
nthreads
)
{
int
outIdH
=
tid
/
outputW
;
int
outIdW
=
tid
%
outputW
;
int
inImgSize
=
inputW
/
numChannels
;
int
outImgSize
=
outputW
/
numChannels
;
int
channelId
=
outIdW
/
outImgSize
;
int
outImgIdy
=
(
outIdW
%
outImgSize
)
/
outImgW
;
int
inImgIdy
=
ratioH
*
outImgIdy
;
int
hId
=
(
inImgIdy
<
inImgH
-
1
)
?
1
:
0
;
T
h1lambda
=
ratioH
*
outImgIdy
-
inImgIdy
;
T
h2lambda
=
1.
f
-
h1lambda
;
int
outImgIdx
=
tid
%
outImgW
;
int
inImgIdx
=
ratioW
*
outImgIdx
;
int
wId
=
(
inImgIdx
<
inImgW
-
1
)
?
1
:
0
;
T
w1lambda
=
ratioW
*
outImgIdx
-
inImgIdx
;
T
w2lambda
=
1.
f
-
w1lambda
;
T
*
inPos
=
&
in
[
outIdH
*
inputW
+
channelId
*
inImgSize
+
inImgIdy
*
inImgW
+
inImgIdx
];
const
T
*
outPos
=
&
out
[
outIdH
*
outputW
+
outIdW
];
atomicAdd
(
&
inPos
[
0
],
h2lambda
*
w2lambda
*
outPos
[
0
]);
atomicAdd
(
&
inPos
[
wId
],
h2lambda
*
w1lambda
*
outPos
[
0
]);
atomicAdd
(
&
inPos
[
hId
*
inImgW
],
h1lambda
*
w2lambda
*
outPos
[
0
]);
atomicAdd
(
&
inPos
[
hId
*
inImgW
+
wId
],
h1lambda
*
w1lambda
*
outPos
[
0
]);
}
}
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/tests/unittests/test_bilinear_interp_op.py
浏览文件 @
469a349a
...
...
@@ -84,5 +84,19 @@ class TestCase2(TestBilinearInterpOp):
self
.
out_w
=
12
class
TestCase2
(
TestBilinearInterpOp
):
def
init_test_case
(
self
):
self
.
input_shape
=
[
16
,
3
,
512
,
1024
]
self
.
out_h
=
128
self
.
out_w
=
256
class
TestCase2
(
TestBilinearInterpOp
):
def
init_test_case
(
self
):
self
.
input_shape
=
[
8
,
1
,
256
,
128
]
self
.
out_h
=
1024
self
.
out_w
=
1024
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录