Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
905a462d
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
905a462d
编写于
9月 25, 2017
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'fix_maxpool_backward_functor_temp' into Add_pool_op_temp
上级
8c478b36
9b1431b8
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
664 addition
and
0 deletion
+664
-0
paddle/operators/math/CMakeLists.txt
paddle/operators/math/CMakeLists.txt
+2
-0
paddle/operators/math/pool_test_maxPool2d.cc
paddle/operators/math/pool_test_maxPool2d.cc
+150
-0
paddle/operators/math/pool_test_maxPool3d.cc
paddle/operators/math/pool_test_maxPool3d.cc
+154
-0
paddle/operators/math/pooling.cc
paddle/operators/math/pooling.cc
+142
-0
paddle/operators/math/pooling.cu
paddle/operators/math/pooling.cu
+196
-0
paddle/operators/math/pooling.h
paddle/operators/math/pooling.h
+20
-0
未找到文件。
paddle/operators/math/CMakeLists.txt
浏览文件 @
905a462d
...
...
@@ -7,3 +7,5 @@ endif()
nv_test
(
math_function_test SRCS math_function_test.cc DEPS math_function tensor
)
cc_test
(
im2col_test SRCS im2col_test.cc DEPS math_function tensor
)
cc_test
(
pool_test_maxPool2d_test SRCS pool_test_maxPool2d.cc DEPS math_function tensor
)
cc_test
(
pool_test_maxPool3d_test SRCS pool_test_maxPool3d.cc DEPS math_function tensor
)
paddle/operators/math/pool_test_maxPool2d.cc
0 → 100644
浏览文件 @
905a462d
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/operators/math/pooling.h"
#include "paddle/memory/memcpy.h"
#include "paddle/platform/enforce.h"
#include <stdlib.h>
#include <time.h>
#ifndef PADDLE_ONLY_CPU
template
<
typename
PooType
>
void
testPool2d
(
paddle
::
platform
::
DeviceContext
&
context
,
PooType
pool_process
,
paddle
::
framework
::
Tensor
&
input
,
paddle
::
framework
::
Tensor
&
input_grad
,
paddle
::
framework
::
Tensor
&
output
,
paddle
::
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
paddle
::
operators
::
math
::
Pool2dForwardFunctor
<
paddle
::
platform
::
GPUPlace
,
PooType
,
float
>
pool2d_forward
;
pool2d_forward
(
context
,
input
,
output
,
ksize
,
strides
,
paddings
,
pool_process
);
int
times
=
50
;
clock_t
start
,
finish
;
double
totaltime
;
// Pool2dBackwardFunctor
start
=
clock
();
for
(
int
i
=
0
;
i
<
times
;
++
i
)
{
paddle
::
operators
::
math
::
Pool2dBackwardFunctor
<
paddle
::
platform
::
GPUPlace
,
PooType
,
float
>
pool2d_backward
;
pool2d_backward
(
context
,
input
,
input_grad
,
output
,
output_grad
,
ksize
,
strides
,
paddings
,
pool_process
);
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
0
),
"cudaStreamSynchronize failed in pool2d_backward CopyFrom"
);
}
finish
=
clock
();
totaltime
=
(
double
)(
finish
-
start
)
/
CLOCKS_PER_SEC
;
totaltime
/=
times
;
std
::
cout
<<
"
\n
Pool3dBackwardFunctor: "
<<
totaltime
<<
"s"
<<
std
::
endl
;
// MaxPool3dBackwardFunctor
start
=
clock
();
for
(
int
j
=
0
;
j
<
times
;
++
j
)
{
paddle
::
operators
::
math
::
MaxPool2dBackwardFunctor
<
paddle
::
platform
::
GPUPlace
,
float
>
maxpool2d_backward
;
maxpool2d_backward
(
context
,
input
,
input_grad
,
output
,
output_grad
,
ksize
,
strides
,
paddings
);
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
0
),
"cudaStreamSynchronize failed in maxpool2d_backward CopyFrom"
);
}
finish
=
clock
();
totaltime
=
(
double
)(
finish
-
start
)
/
CLOCKS_PER_SEC
;
totaltime
/=
times
;
std
::
cout
<<
"
\n
MaxPool3dBackwardFunctor: "
<<
totaltime
<<
"s"
<<
std
::
endl
;
}
void
test2dPool
()
{
using
paddle
::
platform
::
DeviceContext
;
using
paddle
::
platform
::
CUDADeviceContext
;
using
paddle
::
platform
::
GPUPlace
;
paddle
::
framework
::
Tensor
input_tmp
;
paddle
::
framework
::
Tensor
output_tmp
;
paddle
::
framework
::
Tensor
input
;
paddle
::
framework
::
Tensor
input_grad
;
paddle
::
framework
::
Tensor
output
;
paddle
::
framework
::
Tensor
output_grad
;
int
batch
=
32
;
int
channel
=
32
;
int
input_height
=
128
;
int
input_width
=
128
;
int
in_len
=
batch
*
channel
*
input_height
*
input_width
;
std
::
vector
<
int
>
ksize
({
3
,
3
});
std
::
vector
<
int
>
strides
({
1
,
1
});
std
::
vector
<
int
>
paddings
({
0
,
0
});
int
output_height
=
(
input_height
-
ksize
[
0
]
+
2
*
paddings
[
0
])
/
strides
[
0
]
+
1
;
int
output_width
=
(
input_width
-
ksize
[
1
]
+
2
*
paddings
[
1
])
/
strides
[
1
]
+
1
;
int
output_len
=
output_height
*
output_width
;
input_tmp
.
mutable_data
<
float
>
({
batch
,
channel
,
input_height
,
input_width
},
paddle
::
platform
::
CPUPlace
());
output_tmp
.
mutable_data
<
float
>
({
batch
,
channel
,
output_height
,
output_width
},
paddle
::
platform
::
CPUPlace
());
float
*
arr
=
new
float
[
in_len
];
auto
*
place
=
new
paddle
::
platform
::
GPUPlace
();
float
*
input_ptr
=
input_tmp
.
data
<
float
>
();
for
(
int
i
=
0
;
i
<
in_len
;
++
i
)
arr
[
i
]
=
i
;
// rand() / double(RAND_MAX/2);
memcpy
(
input_ptr
,
arr
,
in_len
*
sizeof
(
float
));
input
.
CopyFrom
<
float
>
(
input_tmp
,
*
place
);
input_ptr
=
input_tmp
.
data
<
float
>
();
for
(
int
i
=
0
;
i
<
in_len
;
++
i
)
arr
[
i
]
=
0
;
memcpy
(
input_ptr
,
arr
,
in_len
*
sizeof
(
float
));
input_grad
.
CopyFrom
<
float
>
(
input_tmp
,
*
place
);
// output
input_ptr
=
output_tmp
.
data
<
float
>
();
for
(
int
i
=
0
;
i
<
output_len
;
++
i
)
arr
[
i
]
=
0
;
// rand() / double(RAND_MAX/2);
memcpy
(
input_ptr
,
arr
,
output_len
*
sizeof
(
float
));
output
.
CopyFrom
<
float
>
(
input_tmp
,
*
place
);
// output
input_ptr
=
output_tmp
.
data
<
float
>
();
for
(
int
i
=
0
;
i
<
output_len
;
++
i
)
arr
[
i
]
=
1
;
// rand() / double(RAND_MAX/2);
memcpy
(
input_ptr
,
arr
,
output_len
*
sizeof
(
float
));
output_grad
.
CopyFrom
<
float
>
(
input_tmp
,
*
place
);
paddle
::
platform
::
DeviceContext
*
context
=
new
paddle
::
platform
::
CUDADeviceContext
(
paddle
::
platform
::
GPUPlace
());
paddle
::
operators
::
math
::
pool
::
maxPool
<
float
>
pool_process
;
testPool2d
<
paddle
::
operators
::
math
::
pool
::
maxPool
<
float
>>
(
*
context
,
pool_process
,
input
,
input_grad
,
output
,
output_grad
,
ksize
,
strides
,
paddings
);
}
int
main
()
{
// testPool3d<paddle::platform::CPUPlace>();
test2dPool
();
// testPool3d<paddle::platform::GPUPlace>();
}
#endif
\ No newline at end of file
paddle/operators/math/pool_test_maxPool3d.cc
0 → 100644
浏览文件 @
905a462d
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/operators/math/pooling.h"
#include "paddle/memory/memcpy.h"
#include "paddle/platform/enforce.h"
#include <stdlib.h>
#include <time.h>
#ifndef PADDLE_ONLY_CPU
template
<
typename
PooType
>
void
testPool3d
(
paddle
::
platform
::
DeviceContext
&
context
,
PooType
pool_process
,
paddle
::
framework
::
Tensor
&
input
,
paddle
::
framework
::
Tensor
&
input_grad
,
paddle
::
framework
::
Tensor
&
output
,
paddle
::
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
paddle
::
operators
::
math
::
Pool3dForwardFunctor
<
paddle
::
platform
::
GPUPlace
,
PooType
,
float
>
pool3d_forward
;
pool3d_forward
(
context
,
input
,
output
,
ksize
,
strides
,
paddings
,
pool_process
);
int
times
=
50
;
clock_t
start
,
finish
;
double
totaltime
;
// Pool3dBackwardFunctor
start
=
clock
();
for
(
int
i
=
0
;
i
<
times
;
++
i
)
{
paddle
::
operators
::
math
::
Pool3dBackwardFunctor
<
paddle
::
platform
::
GPUPlace
,
PooType
,
float
>
pool3d_backward
;
pool3d_backward
(
context
,
input
,
input_grad
,
output
,
output_grad
,
ksize
,
strides
,
paddings
,
pool_process
);
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
0
),
"cudaStreamSynchronize failed in pool3d_backward CopyFrom"
);
}
finish
=
clock
();
totaltime
=
(
double
)(
finish
-
start
)
/
CLOCKS_PER_SEC
;
totaltime
/=
times
;
std
::
cout
<<
"
\n
Pool3dBackwardFunctor: "
<<
totaltime
<<
"s"
<<
std
::
endl
;
// MaxPool3dBackwardFunctor
start
=
clock
();
for
(
int
j
=
0
;
j
<
times
;
++
j
)
{
paddle
::
operators
::
math
::
MaxPool3dBackwardFunctor
<
paddle
::
platform
::
GPUPlace
,
float
>
maxpool3d_backward
;
maxpool3d_backward
(
context
,
input
,
input_grad
,
output
,
output_grad
,
ksize
,
strides
,
paddings
);
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
0
),
"cudaStreamSynchronize failed in maxpool3d_backward CopyFrom"
);
}
finish
=
clock
();
totaltime
=
(
double
)(
finish
-
start
)
/
CLOCKS_PER_SEC
;
totaltime
/=
times
;
std
::
cout
<<
"
\n
MaxPool3dBackwardFunctor: "
<<
totaltime
<<
"s"
<<
std
::
endl
;
}
void
test3dPool
()
{
using
paddle
::
platform
::
DeviceContext
;
using
paddle
::
platform
::
CUDADeviceContext
;
using
paddle
::
platform
::
GPUPlace
;
paddle
::
framework
::
Tensor
input_tmp
;
paddle
::
framework
::
Tensor
output_tmp
;
paddle
::
framework
::
Tensor
input
;
paddle
::
framework
::
Tensor
input_grad
;
paddle
::
framework
::
Tensor
output
;
paddle
::
framework
::
Tensor
output_grad
;
int
batch
=
32
;
int
channel
=
4
;
int
input_depth
=
4
;
int
input_height
=
128
;
int
input_width
=
128
;
int
in_len
=
batch
*
channel
*
input_depth
*
input_height
*
input_width
;
std
::
vector
<
int
>
ksize
({
3
,
3
,
3
});
std
::
vector
<
int
>
strides
({
2
,
2
,
2
});
std
::
vector
<
int
>
paddings
({
1
,
1
,
1
});
int
output_depth
=
(
input_depth
-
ksize
[
0
]
+
2
*
paddings
[
0
])
/
strides
[
0
]
+
1
;
int
output_height
=
(
input_height
-
ksize
[
1
]
+
2
*
paddings
[
1
])
/
strides
[
1
]
+
1
;
int
output_width
=
(
input_width
-
ksize
[
2
]
+
2
*
paddings
[
2
])
/
strides
[
2
]
+
1
;
int
output_len
=
output_depth
*
output_height
*
output_width
;
input_tmp
.
mutable_data
<
float
>
(
{
batch
,
channel
,
input_depth
,
input_height
,
input_width
},
paddle
::
platform
::
CPUPlace
());
output_tmp
.
mutable_data
<
float
>
(
{
batch
,
channel
,
output_depth
,
output_height
,
output_width
},
paddle
::
platform
::
CPUPlace
());
float
*
arr
=
new
float
[
in_len
];
auto
*
place
=
new
paddle
::
platform
::
GPUPlace
();
// input
float
*
input_ptr
=
input_tmp
.
data
<
float
>
();
for
(
int
i
=
0
;
i
<
in_len
;
++
i
)
arr
[
i
]
=
i
;
// rand() / double(RAND_MAX/2);
memcpy
(
input_ptr
,
arr
,
in_len
*
sizeof
(
float
));
input
.
CopyFrom
<
float
>
(
input_tmp
,
*
place
);
// input_grad
input_ptr
=
input_tmp
.
data
<
float
>
();
for
(
int
i
=
0
;
i
<
in_len
;
++
i
)
arr
[
i
]
=
0
;
memcpy
(
input_ptr
,
arr
,
in_len
*
sizeof
(
float
));
input_grad
.
CopyFrom
<
float
>
(
input_tmp
,
*
place
);
// output
input_ptr
=
output_tmp
.
data
<
float
>
();
for
(
int
i
=
0
;
i
<
output_len
;
++
i
)
arr
[
i
]
=
0
;
// rand() / double(RAND_MAX/2);
memcpy
(
input_ptr
,
arr
,
output_len
*
sizeof
(
float
));
output
.
CopyFrom
<
float
>
(
input_tmp
,
*
place
);
// output_grad
input_ptr
=
output_tmp
.
data
<
float
>
();
for
(
int
i
=
0
;
i
<
output_len
;
++
i
)
arr
[
i
]
=
1
;
// rand() / double(RAND_MAX/2);
memcpy
(
input_ptr
,
arr
,
output_len
*
sizeof
(
float
));
output_grad
.
CopyFrom
<
float
>
(
input_tmp
,
*
place
);
paddle
::
platform
::
DeviceContext
*
context
=
new
paddle
::
platform
::
CUDADeviceContext
(
paddle
::
platform
::
GPUPlace
());
paddle
::
operators
::
math
::
pool
::
maxPool
<
float
>
pool_process
;
testPool3d
<
paddle
::
operators
::
math
::
pool
::
maxPool
<
float
>>
(
*
context
,
pool_process
,
input
,
input_grad
,
output
,
output_grad
,
ksize
,
strides
,
paddings
);
}
int
main
()
{
test3dPool
();
}
#endif
\ No newline at end of file
paddle/operators/math/pooling.cc
浏览文件 @
905a462d
...
...
@@ -134,6 +134,70 @@ class Pool2dBackwardFunctor<platform::CPUPlace, PoolProcess, T> {
}
};
template
<
class
T
>
class
MaxPool2dBackwardFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
const
int
output_channels
=
output
.
dims
()[
1
];
const
int
output_height
=
output
.
dims
()[
2
];
const
int
output_width
=
output
.
dims
()[
3
];
const
int
ksize_height
=
ksize
[
0
];
const
int
ksize_width
=
ksize
[
1
];
const
int
stride_height
=
strides
[
0
];
const
int
stride_width
=
strides
[
1
];
const
int
padding_height
=
paddings
[
0
];
const
int
padding_width
=
paddings
[
1
];
const
int
input_stride
=
input_height
*
input_width
;
const
int
output_stride
=
output_height
*
output_width
;
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
output_data
=
output
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
.
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
for
(
int
ph
=
0
;
ph
<
output_height
;
++
ph
)
{
int
hstart
=
ph
*
stride_height
-
padding_height
;
int
hend
=
std
::
min
(
hstart
+
ksize_height
,
input_height
);
hstart
=
std
::
max
(
hstart
,
0
);
for
(
int
pw
=
0
;
pw
<
output_width
;
++
pw
)
{
int
wstart
=
pw
*
stride_width
-
padding_width
;
int
wend
=
std
::
min
(
wstart
+
ksize_width
,
input_width
);
wstart
=
std
::
max
(
wstart
,
0
);
bool
stop
=
false
;
for
(
int
h
=
hstart
;
h
<
hend
&&
!
stop
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
&&
!
stop
;
++
w
)
{
int
input_idx
=
h
*
input_width
+
w
;
int
output_idx
=
ph
*
output_width
+
pw
;
if
(
input_data
[
input_idx
]
==
output_data
[
output_idx
])
{
input_grad_data
[
input_idx
]
+=
output_grad_data
[
output_idx
];
stop
=
true
;
}
}
}
}
}
input_data
+=
input_stride
;
output_data
+=
output_stride
;
input_grad_data
+=
input_stride
;
output_grad_data
+=
output_stride
;
}
}
}
};
template
class
MaxPool2dBackwardFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
MaxPool2dBackwardFunctor
<
platform
::
CPUPlace
,
double
>;
template
class
Pool2dForwardFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
pool
::
maxPool
<
float
>,
float
>
;
template
class
Pool2dForwardFunctor
<
...
...
@@ -303,6 +367,84 @@ class Pool3dBackwardFunctor<platform::CPUPlace, PoolProcess, T> {
}
};
template
<
class
T
>
class
MaxPool3dBackwardFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
const
int
input_width
=
input
.
dims
()[
4
];
const
int
output_channels
=
output
.
dims
()[
1
];
const
int
output_depth
=
output
.
dims
()[
2
];
const
int
output_height
=
output
.
dims
()[
3
];
const
int
output_width
=
output
.
dims
()[
4
];
const
int
ksize_depth
=
ksize
[
0
];
const
int
ksize_height
=
ksize
[
1
];
const
int
ksize_width
=
ksize
[
2
];
const
int
stride_depth
=
strides
[
0
];
const
int
stride_height
=
strides
[
1
];
const
int
stride_width
=
strides
[
2
];
const
int
padding_depth
=
paddings
[
0
];
const
int
padding_height
=
paddings
[
1
];
const
int
padding_width
=
paddings
[
2
];
const
int
input_stride
=
input_depth
*
input_height
*
input_width
;
const
int
output_stride
=
output_depth
*
output_height
*
output_width
;
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
output_data
=
output
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
.
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
for
(
int
pd
=
0
;
pd
<
output_depth
;
++
pd
)
{
int
dstart
=
pd
*
stride_depth
-
padding_depth
;
int
dend
=
std
::
min
(
dstart
+
ksize_depth
,
input_depth
);
dstart
=
std
::
max
(
dstart
,
0
);
for
(
int
ph
=
0
;
ph
<
output_height
;
++
ph
)
{
int
hstart
=
ph
*
stride_height
-
padding_height
;
int
hend
=
std
::
min
(
hstart
+
ksize_height
,
input_height
);
hstart
=
std
::
max
(
hstart
,
0
);
for
(
int
pw
=
0
;
pw
<
output_width
;
++
pw
)
{
int
wstart
=
pw
*
stride_width
-
padding_width
;
int
wend
=
std
::
min
(
wstart
+
ksize_width
,
input_width
);
wstart
=
std
::
max
(
wstart
,
0
);
bool
stop
=
false
;
for
(
int
d
=
dstart
;
d
<
dend
&&
!
stop
;
++
d
)
{
for
(
int
h
=
hstart
;
h
<
hend
&&
!
stop
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
&&
!
stop
;
++
w
)
{
int
input_idx
=
(
d
*
input_height
+
h
)
*
input_width
+
w
;
int
output_idx
=
(
pd
*
output_height
+
ph
)
*
output_width
+
pw
;
if
(
input_data
[
input_idx
]
==
output_data
[
output_idx
])
{
input_grad_data
[
input_idx
]
+=
output_grad_data
[
output_idx
];
stop
=
true
;
}
}
}
}
}
}
}
input_data
+=
input_stride
;
output_data
+=
output_stride
;
input_grad_data
+=
input_stride
;
output_grad_data
+=
output_stride
;
}
}
}
};
template
class
MaxPool3dBackwardFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
MaxPool3dBackwardFunctor
<
platform
::
CPUPlace
,
double
>;
template
class
Pool3dForwardFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
pool
::
maxPool
<
float
>,
float
>
;
template
class
Pool3dForwardFunctor
<
...
...
paddle/operators/math/pooling.cu
浏览文件 @
905a462d
...
...
@@ -102,6 +102,51 @@ __global__ void KernelPool2dBackward(
}
}
template
<
typename
T
>
__global__
void
KernelMaxPool2dBackward
(
const
int
nthreads
,
const
T
*
input_data
,
const
T
*
output_data
,
const
T
*
output_grad
,
T
*
input_grad
,
const
int
channels
,
const
int
input_height
,
const
int
input_width
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
int
pw
=
index
%
output_width
;
int
ph
=
(
index
/
output_width
)
%
output_height
;
int
c
=
(
index
/
output_width
/
output_height
)
%
channels
;
int
batch_idx
=
index
/
output_width
/
output_height
/
channels
;
int
hstart
=
ph
*
stride_height
-
padding_height
;
int
hend
=
min
(
hstart
+
ksize_height
,
input_height
);
hstart
=
max
(
hstart
,
0
);
int
wstart
=
pw
*
stride_width
-
padding_width
;
int
wend
=
min
(
wstart
+
ksize_width
,
input_width
);
wstart
=
max
(
wstart
,
0
);
input_data
+=
(
batch_idx
*
channels
+
c
)
*
input_height
*
input_width
;
input_grad
+=
(
batch_idx
*
channels
+
c
)
*
input_height
*
input_width
;
T
ele
=
output_data
[
index
];
int
maxIndex
=
-
1
;
bool
stop
=
false
;
for
(
int
h
=
hstart
;
h
<
hend
&&
!
stop
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
&&
!
stop
;
++
w
)
{
if
(
ele
==
input_data
[
h
*
input_width
+
w
])
{
maxIndex
=
h
*
input_width
+
w
;
stop
=
true
;
}
}
}
if
(
maxIndex
!=
-
1
)
{
// atomic add
atomicAdd
(
input_grad
+
maxIndex
,
output_grad
[
index
]);
}
}
}
template
<
typename
PoolProcess
,
typename
T
>
class
Pool2dForwardFunctor
<
platform
::
GPUPlace
,
PoolProcess
,
T
>
{
public:
...
...
@@ -187,6 +232,52 @@ class Pool2dBackwardFunctor<platform::GPUPlace, PoolProcess, T> {
}
};
template
<
typename
T
>
class
MaxPool2dBackwardFunctor
<
platform
::
GPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
const
int
output_channels
=
output
.
dims
()[
1
];
const
int
output_height
=
output
.
dims
()[
2
];
const
int
output_width
=
output
.
dims
()[
3
];
const
int
ksize_height
=
ksize
[
0
];
const
int
ksize_width
=
ksize
[
1
];
const
int
stride_height
=
strides
[
0
];
const
int
stride_width
=
strides
[
1
];
const
int
padding_height
=
paddings
[
0
];
const
int
padding_width
=
paddings
[
1
];
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
output_data
=
output
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
.
mutable_data
<
T
>
(
context
.
GetPlace
());
int
nthreads
=
batch_size
*
output_channels
*
output_height
*
output_width
;
int
blocks
=
(
nthreads
+
1024
-
1
)
/
1024
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
KernelMaxPool2dBackward
<
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
nthreads
,
input_data
,
output_data
,
output_grad_data
,
input_grad_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
);
}
};
template
class
MaxPool2dBackwardFunctor
<
platform
::
GPUPlace
,
float
>;
// template class MaxPool2dBackwardFunctor<platform::GPUPlace, double>;
template
class
Pool2dForwardFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
pool
::
maxPool
<
float
>,
float
>
;
template
class
Pool2dForwardFunctor
<
...
...
@@ -315,6 +406,58 @@ __global__ void KernelPool3DBackward(
}
}
template
<
typename
T
>
__global__
void
KernelMaxPool3DBackward
(
const
int
nthreads
,
const
T
*
input_data
,
const
T
*
output_data
,
const
T
*
output_grad
,
T
*
input_grad
,
const
int
channels
,
const
int
input_depth
,
const
int
input_height
,
const
int
input_width
,
const
int
output_depth
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_depth
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
(
nthreads
);
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
output_width
;
int
ph
=
(
index
/
output_width
)
%
output_height
;
int
pd
=
(
index
/
output_width
/
output_height
)
%
output_depth
;
int
c
=
(
index
/
output_width
/
output_height
/
output_depth
)
%
channels
;
int
batch_idx
=
index
/
output_width
/
output_height
/
output_depth
/
channels
;
int
dstart
=
pd
*
stride_depth
-
padding_depth
;
int
hstart
=
ph
*
stride_height
-
padding_height
;
int
wstart
=
pw
*
stride_width
-
padding_width
;
int
dend
=
min
(
dstart
+
ksize_depth
,
input_depth
);
int
hend
=
min
(
hstart
+
ksize_height
,
input_height
);
int
wend
=
min
(
wstart
+
ksize_width
,
input_width
);
dstart
=
max
(
dstart
,
0
);
hstart
=
max
(
hstart
,
0
);
wstart
=
max
(
wstart
,
0
);
T
ele
=
output_data
[
index
];
bool
stop
=
false
;
int
maxIdx
=
-
1
;
input_data
+=
(
batch_idx
*
channels
+
c
)
*
input_depth
*
input_height
*
input_width
;
input_grad
+=
(
batch_idx
*
channels
+
c
)
*
input_depth
*
input_height
*
input_width
;
for
(
int
d
=
dstart
;
d
<
dend
&&
!
stop
;
++
d
)
{
for
(
int
h
=
hstart
;
h
<
hend
&&
!
stop
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
&&
!
stop
;
++
w
)
{
if
(
ele
==
input_data
[(
d
*
input_height
+
h
)
*
input_width
+
w
])
{
stop
=
true
;
maxIdx
=
(
d
*
input_height
+
h
)
*
input_width
+
w
;
}
}
}
}
if
(
maxIdx
!=
-
1
)
{
// atomic add
atomicAdd
(
input_grad
+
maxIdx
,
output_grad
[
index
]);
}
}
}
template
<
typename
PoolProcess
,
class
T
>
class
Pool3dForwardFunctor
<
platform
::
GPUPlace
,
PoolProcess
,
T
>
{
public:
...
...
@@ -415,6 +558,59 @@ class Pool3dBackwardFunctor<platform::GPUPlace, PoolProcess, T> {
}
};
template
<
class
T
>
class
MaxPool3dBackwardFunctor
<
platform
::
GPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
const
int
input_width
=
input
.
dims
()[
4
];
const
int
output_channels
=
output
.
dims
()[
1
];
const
int
output_depth
=
output
.
dims
()[
2
];
const
int
output_height
=
output
.
dims
()[
3
];
const
int
output_width
=
output
.
dims
()[
4
];
const
int
ksize_depth
=
ksize
[
0
];
const
int
ksize_height
=
ksize
[
1
];
const
int
ksize_width
=
ksize
[
2
];
const
int
stride_depth
=
strides
[
0
];
const
int
stride_height
=
strides
[
1
];
const
int
stride_width
=
strides
[
2
];
const
int
padding_depth
=
paddings
[
0
];
const
int
padding_height
=
paddings
[
1
];
const
int
padding_width
=
paddings
[
2
];
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
output_data
=
output
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
.
mutable_data
<
T
>
(
context
.
GetPlace
());
int
nthreads
=
batch_size
*
output_channels
*
output_depth
*
output_height
*
output_width
;
int
blocks
=
(
nthreads
+
1024
-
1
)
/
1024
;
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
KernelMaxPool3DBackward
<
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
nthreads
,
input_data
,
output_data
,
output_grad_data
,
input_grad_data
,
input_channels
,
input_depth
,
input_height
,
input_width
,
output_depth
,
output_height
,
output_width
,
ksize_depth
,
ksize_height
,
ksize_width
,
stride_depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_height
,
padding_width
);
}
};
template
class
MaxPool3dBackwardFunctor
<
platform
::
GPUPlace
,
float
>;
// template class MaxPool3dBackwardFunctor<platform::GPUPlace, double>;
template
class
Pool3dForwardFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
pool
::
maxPool
<
float
>,
float
>
;
template
class
Pool3dForwardFunctor
<
...
...
paddle/operators/math/pooling.h
浏览文件 @
905a462d
...
...
@@ -81,6 +81,16 @@ class Pool2dBackwardFunctor {
PoolProcess
pool_process
);
};
template
<
typename
Place
,
class
T
>
class
MaxPool2dBackwardFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
);
};
template
<
typename
Place
,
typename
PoolProcess
,
typename
T
>
class
Pool3dForwardFunctor
{
public:
...
...
@@ -101,6 +111,16 @@ class Pool3dBackwardFunctor {
PoolProcess
pool_process
);
};
template
<
typename
Place
,
class
T
>
class
MaxPool3dBackwardFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
);
};
}
// namespace math
}
// namespace operators
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录