提交 f6e69d74 编写于 作者: C chengduoZH

fix maxpool backward functor

上级 905a462d
......@@ -23,16 +23,17 @@
#ifndef PADDLE_ONLY_CPU
template <typename PooType>
void testPool2d(paddle::platform::DeviceContext& context, PooType pool_process,
paddle::framework::Tensor& input,
template <typename PoolType, typename PoolGradType>
void testPool2d(paddle::platform::DeviceContext& context, PoolType pool_process,
PoolGradType poolGrad_process, paddle::framework::Tensor& input,
paddle::framework::Tensor& input_grad,
paddle::framework::Tensor& output,
paddle::framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
paddle::operators::math::Pool2dForwardFunctor<paddle::platform::GPUPlace,
PooType, float>
PoolType, float>
pool2d_forward;
pool2d_forward(context, input, output, ksize, strides, paddings,
pool_process);
......@@ -44,10 +45,11 @@ void testPool2d(paddle::platform::DeviceContext& context, PooType pool_process,
start = clock();
for (int i = 0; i < times; ++i) {
paddle::operators::math::Pool2dBackwardFunctor<paddle::platform::GPUPlace,
PooType, float>
PoolGradType, float>
pool2d_backward;
pool2d_backward(context, input, input_grad, output, output_grad, ksize,
strides, paddings, pool_process);
strides, paddings, poolGrad_process);
PADDLE_ENFORCE(cudaStreamSynchronize(0),
"cudaStreamSynchronize failed in pool2d_backward CopyFrom");
}
......@@ -136,10 +138,12 @@ void test2dPool() {
paddle::platform::DeviceContext* context =
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
paddle::operators::math::pool::maxPool<float> pool_process;
paddle::operators::math::pool::maxPoolGrad<float> poolGrad_process;
testPool2d<paddle::operators::math::pool::maxPool<float>>(
*context, pool_process, input, input_grad, output, output_grad, ksize,
strides, paddings);
testPool2d<paddle::operators::math::pool::maxPool<float>,
paddle::operators::math::pool::maxPoolGrad<float>>(
*context, pool_process, poolGrad_process, input, input_grad, output,
output_grad, ksize, strides, paddings);
}
int main() {
......
......@@ -23,15 +23,15 @@
#ifndef PADDLE_ONLY_CPU
template <typename PooType>
void testPool3d(paddle::platform::DeviceContext& context, PooType pool_process,
paddle::framework::Tensor& input,
template <typename PoolType, typename PoolGradType>
void testPool3d(paddle::platform::DeviceContext& context, PoolType pool_process,
PoolGradType poolGrad_process, paddle::framework::Tensor& input,
paddle::framework::Tensor& input_grad,
paddle::framework::Tensor& output,
paddle::framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
paddle::operators::math::Pool3dForwardFunctor<paddle::platform::GPUPlace,
PooType, float>
PoolType, float>
pool3d_forward;
pool3d_forward(context, input, output, ksize, strides, paddings,
pool_process);
......@@ -44,10 +44,10 @@ void testPool3d(paddle::platform::DeviceContext& context, PooType pool_process,
start = clock();
for (int i = 0; i < times; ++i) {
paddle::operators::math::Pool3dBackwardFunctor<paddle::platform::GPUPlace,
PooType, float>
PoolGradType, float>
pool3d_backward;
pool3d_backward(context, input, input_grad, output, output_grad, ksize,
strides, paddings, pool_process);
strides, paddings, poolGrad_process);
PADDLE_ENFORCE(cudaStreamSynchronize(0),
"cudaStreamSynchronize failed in pool3d_backward CopyFrom");
}
......@@ -145,9 +145,12 @@ void test3dPool() {
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
paddle::operators::math::pool::maxPool<float> pool_process;
testPool3d<paddle::operators::math::pool::maxPool<float>>(
*context, pool_process, input, input_grad, output, output_grad, ksize,
strides, paddings);
paddle::operators::math::pool::maxPoolGrad<float> poolGrad_process;
testPool3d<paddle::operators::math::pool::maxPool<float>,
paddle::operators::math::pool::maxPoolGrad<float>>(
*context, pool_process, poolGrad_process, input, input_grad, output,
output_grad, ksize, strides, paddings);
}
int main() { test3dPool(); }
......
......@@ -196,7 +196,7 @@ class MaxPool2dBackwardFunctor<platform::CPUPlace, T> {
};
template class MaxPool2dBackwardFunctor<platform::CPUPlace, float>;
template class MaxPool2dBackwardFunctor<platform::CPUPlace, double>;
// template class MaxPool2dBackwardFunctor<platform::CPUPlace, double>;
template class Pool2dForwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
......@@ -443,7 +443,7 @@ class MaxPool3dBackwardFunctor<platform::CPUPlace, T> {
};
template class MaxPool3dBackwardFunctor<platform::CPUPlace, float>;
template class MaxPool3dBackwardFunctor<platform::CPUPlace, double>;
// template class MaxPool3dBackwardFunctor<platform::CPUPlace, double>;
template class Pool3dForwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
......
......@@ -111,12 +111,10 @@ class PoolGradKernel : public framework::OpKernel {
switch (ksize.size()) {
case 2: {
if (pooling_type == "max") {
paddle::operators::math::Pool2dBackwardFunctor<
Place, paddle::operators::math::pool::maxPoolGrad<T>, T>
paddle::operators::math::MaxPool2dBackwardFunctor<Place, T>
pool2d_backward;
paddle::operators::math::pool::maxPoolGrad<T> pool_process;
pool2d_backward(context.device_context(), *in_X, *in_X_grad, *out,
*out_grad, ksize, strides, paddings, pool_process);
*out_grad, ksize, strides, paddings);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool2dBackwardFunctor<
Place, paddle::operators::math::pool::avgPoolGrad<T>, T>
......@@ -128,12 +126,10 @@ class PoolGradKernel : public framework::OpKernel {
} break;
case 3: {
if (pooling_type == "max") {
paddle::operators::math::Pool3dBackwardFunctor<
Place, paddle::operators::math::pool::maxPoolGrad<T>, T>
paddle::operators::math::MaxPool3dBackwardFunctor<Place, T>
pool3d_backward;
paddle::operators::math::pool::maxPoolGrad<T> pool_process;
pool3d_backward(context.device_context(), *in_X, *in_X_grad, *out,
*out_grad, ksize, strides, paddings, pool_process);
*out_grad, ksize, strides, paddings);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool3dBackwardFunctor<
Place, paddle::operators::math::pool::avgPoolGrad<T>, T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册