提交 f6e69d74 编写于 作者: C chengduoZH

fix maxpool backward functor

上级 905a462d
...@@ -23,16 +23,17 @@ ...@@ -23,16 +23,17 @@
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <typename PooType> template <typename PoolType, typename PoolGradType>
void testPool2d(paddle::platform::DeviceContext& context, PooType pool_process, void testPool2d(paddle::platform::DeviceContext& context, PoolType pool_process,
paddle::framework::Tensor& input, PoolGradType poolGrad_process, paddle::framework::Tensor& input,
paddle::framework::Tensor& input_grad, paddle::framework::Tensor& input_grad,
paddle::framework::Tensor& output, paddle::framework::Tensor& output,
paddle::framework::Tensor& output_grad, std::vector<int>& ksize, paddle::framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) { std::vector<int>& strides, std::vector<int>& paddings) {
paddle::operators::math::Pool2dForwardFunctor<paddle::platform::GPUPlace, paddle::operators::math::Pool2dForwardFunctor<paddle::platform::GPUPlace,
PooType, float> PoolType, float>
pool2d_forward; pool2d_forward;
pool2d_forward(context, input, output, ksize, strides, paddings, pool2d_forward(context, input, output, ksize, strides, paddings,
pool_process); pool_process);
...@@ -44,10 +45,11 @@ void testPool2d(paddle::platform::DeviceContext& context, PooType pool_process, ...@@ -44,10 +45,11 @@ void testPool2d(paddle::platform::DeviceContext& context, PooType pool_process,
start = clock(); start = clock();
for (int i = 0; i < times; ++i) { for (int i = 0; i < times; ++i) {
paddle::operators::math::Pool2dBackwardFunctor<paddle::platform::GPUPlace, paddle::operators::math::Pool2dBackwardFunctor<paddle::platform::GPUPlace,
PooType, float> PoolGradType, float>
pool2d_backward; pool2d_backward;
pool2d_backward(context, input, input_grad, output, output_grad, ksize, pool2d_backward(context, input, input_grad, output, output_grad, ksize,
strides, paddings, pool_process); strides, paddings, poolGrad_process);
PADDLE_ENFORCE(cudaStreamSynchronize(0), PADDLE_ENFORCE(cudaStreamSynchronize(0),
"cudaStreamSynchronize failed in pool2d_backward CopyFrom"); "cudaStreamSynchronize failed in pool2d_backward CopyFrom");
} }
...@@ -136,10 +138,12 @@ void test2dPool() { ...@@ -136,10 +138,12 @@ void test2dPool() {
paddle::platform::DeviceContext* context = paddle::platform::DeviceContext* context =
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace()); new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
paddle::operators::math::pool::maxPool<float> pool_process; paddle::operators::math::pool::maxPool<float> pool_process;
paddle::operators::math::pool::maxPoolGrad<float> poolGrad_process;
testPool2d<paddle::operators::math::pool::maxPool<float>>( testPool2d<paddle::operators::math::pool::maxPool<float>,
*context, pool_process, input, input_grad, output, output_grad, ksize, paddle::operators::math::pool::maxPoolGrad<float>>(
strides, paddings); *context, pool_process, poolGrad_process, input, input_grad, output,
output_grad, ksize, strides, paddings);
} }
int main() { int main() {
......
...@@ -23,15 +23,15 @@ ...@@ -23,15 +23,15 @@
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <typename PooType> template <typename PoolType, typename PoolGradType>
void testPool3d(paddle::platform::DeviceContext& context, PooType pool_process, void testPool3d(paddle::platform::DeviceContext& context, PoolType pool_process,
paddle::framework::Tensor& input, PoolGradType poolGrad_process, paddle::framework::Tensor& input,
paddle::framework::Tensor& input_grad, paddle::framework::Tensor& input_grad,
paddle::framework::Tensor& output, paddle::framework::Tensor& output,
paddle::framework::Tensor& output_grad, std::vector<int>& ksize, paddle::framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) { std::vector<int>& strides, std::vector<int>& paddings) {
paddle::operators::math::Pool3dForwardFunctor<paddle::platform::GPUPlace, paddle::operators::math::Pool3dForwardFunctor<paddle::platform::GPUPlace,
PooType, float> PoolType, float>
pool3d_forward; pool3d_forward;
pool3d_forward(context, input, output, ksize, strides, paddings, pool3d_forward(context, input, output, ksize, strides, paddings,
pool_process); pool_process);
...@@ -44,10 +44,10 @@ void testPool3d(paddle::platform::DeviceContext& context, PooType pool_process, ...@@ -44,10 +44,10 @@ void testPool3d(paddle::platform::DeviceContext& context, PooType pool_process,
start = clock(); start = clock();
for (int i = 0; i < times; ++i) { for (int i = 0; i < times; ++i) {
paddle::operators::math::Pool3dBackwardFunctor<paddle::platform::GPUPlace, paddle::operators::math::Pool3dBackwardFunctor<paddle::platform::GPUPlace,
PooType, float> PoolGradType, float>
pool3d_backward; pool3d_backward;
pool3d_backward(context, input, input_grad, output, output_grad, ksize, pool3d_backward(context, input, input_grad, output, output_grad, ksize,
strides, paddings, pool_process); strides, paddings, poolGrad_process);
PADDLE_ENFORCE(cudaStreamSynchronize(0), PADDLE_ENFORCE(cudaStreamSynchronize(0),
"cudaStreamSynchronize failed in pool3d_backward CopyFrom"); "cudaStreamSynchronize failed in pool3d_backward CopyFrom");
} }
...@@ -145,9 +145,12 @@ void test3dPool() { ...@@ -145,9 +145,12 @@ void test3dPool() {
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace()); new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
paddle::operators::math::pool::maxPool<float> pool_process; paddle::operators::math::pool::maxPool<float> pool_process;
testPool3d<paddle::operators::math::pool::maxPool<float>>( paddle::operators::math::pool::maxPoolGrad<float> poolGrad_process;
*context, pool_process, input, input_grad, output, output_grad, ksize,
strides, paddings); testPool3d<paddle::operators::math::pool::maxPool<float>,
paddle::operators::math::pool::maxPoolGrad<float>>(
*context, pool_process, poolGrad_process, input, input_grad, output,
output_grad, ksize, strides, paddings);
} }
int main() { test3dPool(); } int main() { test3dPool(); }
......
...@@ -196,7 +196,7 @@ class MaxPool2dBackwardFunctor<platform::CPUPlace, T> { ...@@ -196,7 +196,7 @@ class MaxPool2dBackwardFunctor<platform::CPUPlace, T> {
}; };
template class MaxPool2dBackwardFunctor<platform::CPUPlace, float>; template class MaxPool2dBackwardFunctor<platform::CPUPlace, float>;
template class MaxPool2dBackwardFunctor<platform::CPUPlace, double>; // template class MaxPool2dBackwardFunctor<platform::CPUPlace, double>;
template class Pool2dForwardFunctor< template class Pool2dForwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::maxPool<float>, float>; platform::CPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
...@@ -443,7 +443,7 @@ class MaxPool3dBackwardFunctor<platform::CPUPlace, T> { ...@@ -443,7 +443,7 @@ class MaxPool3dBackwardFunctor<platform::CPUPlace, T> {
}; };
template class MaxPool3dBackwardFunctor<platform::CPUPlace, float>; template class MaxPool3dBackwardFunctor<platform::CPUPlace, float>;
template class MaxPool3dBackwardFunctor<platform::CPUPlace, double>; // template class MaxPool3dBackwardFunctor<platform::CPUPlace, double>;
template class Pool3dForwardFunctor< template class Pool3dForwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::maxPool<float>, float>; platform::CPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
......
...@@ -111,12 +111,10 @@ class PoolGradKernel : public framework::OpKernel { ...@@ -111,12 +111,10 @@ class PoolGradKernel : public framework::OpKernel {
switch (ksize.size()) { switch (ksize.size()) {
case 2: { case 2: {
if (pooling_type == "max") { if (pooling_type == "max") {
paddle::operators::math::Pool2dBackwardFunctor< paddle::operators::math::MaxPool2dBackwardFunctor<Place, T>
Place, paddle::operators::math::pool::maxPoolGrad<T>, T>
pool2d_backward; pool2d_backward;
paddle::operators::math::pool::maxPoolGrad<T> pool_process;
pool2d_backward(context.device_context(), *in_X, *in_X_grad, *out, pool2d_backward(context.device_context(), *in_X, *in_X_grad, *out,
*out_grad, ksize, strides, paddings, pool_process); *out_grad, ksize, strides, paddings);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
paddle::operators::math::Pool2dBackwardFunctor< paddle::operators::math::Pool2dBackwardFunctor<
Place, paddle::operators::math::pool::avgPoolGrad<T>, T> Place, paddle::operators::math::pool::avgPoolGrad<T>, T>
...@@ -128,12 +126,10 @@ class PoolGradKernel : public framework::OpKernel { ...@@ -128,12 +126,10 @@ class PoolGradKernel : public framework::OpKernel {
} break; } break;
case 3: { case 3: {
if (pooling_type == "max") { if (pooling_type == "max") {
paddle::operators::math::Pool3dBackwardFunctor< paddle::operators::math::MaxPool3dBackwardFunctor<Place, T>
Place, paddle::operators::math::pool::maxPoolGrad<T>, T>
pool3d_backward; pool3d_backward;
paddle::operators::math::pool::maxPoolGrad<T> pool_process;
pool3d_backward(context.device_context(), *in_X, *in_X_grad, *out, pool3d_backward(context.device_context(), *in_X, *in_X_grad, *out,
*out_grad, ksize, strides, paddings, pool_process); *out_grad, ksize, strides, paddings);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
paddle::operators::math::Pool3dBackwardFunctor< paddle::operators::math::Pool3dBackwardFunctor<
Place, paddle::operators::math::pool::avgPoolGrad<T>, T> Place, paddle::operators::math::pool::avgPoolGrad<T>, T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册