diff --git a/paddle/operators/math/pooling.cc b/paddle/operators/math/pooling.cc index 0e4d9007a63cd342668d5e6fa622deea06f90d9c..da0e8ff3d2776bd538323cb9d0e7f94457d07ed1 100644 --- a/paddle/operators/math/pooling.cc +++ b/paddle/operators/math/pooling.cc @@ -26,7 +26,6 @@ class MaxPool2dWithIndexFunctor { framework::Tensor& mask, std::vector& ksize, std::vector& strides, std::vector& paddings) { const int batch_size = input.dims()[0]; - const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; const int output_channels = output.dims()[1]; @@ -112,11 +111,11 @@ class MaxPool2dWithIndexGradFunctor { input_grad_data[input_idx] += output_grad_data[output_idx]; } } + // offset + input_grad_data += input_stride; + output_grad_data += output_stride; + mask_data += output_stride; } - // offset - input_grad_data += input_stride; - output_grad_data += output_stride; - mask_data += output_stride; } } }; @@ -152,6 +151,7 @@ class MaxPool3dWithIndexFunctor { const int padding_width = paddings[2]; const int input_stride = input_depth * input_height * input_width; const int output_stride = output_depth * output_height * output_width; + const T* input_data = input.data(); T* output_data = output.mutable_data(context.GetPlace()); T* mask_data = mask.mutable_data(context.GetPlace()); @@ -170,17 +170,17 @@ class MaxPool3dWithIndexFunctor { int wstart = pw * stride_width - padding_width; int wend = std::min(wstart + ksize_width, input_width); wstart = std::max(wstart, 0); + int output_idx = (pd * output_height + ph) * output_width + pw; T ele = static_cast(-FLT_MAX); int index = -1; for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - if (ele < - input_data[(d * input_height + h) * input_width + w]) { - index = (d * input_height + h) * input_width + w; - ele = - input_data[(d * input_height + h) * input_width + w]; + int input_idx = (d * input_height + h) * input_width + w; + if (ele < input_data[input_idx]) { + index = input_idx; + ele = input_data[input_idx]; } } } diff --git a/paddle/operators/math/pooling.cu b/paddle/operators/math/pooling.cu index f32e6a26d0db0bc2b2d493f5c5f5ba292f76b62c..5321ed21632367a26160f724ccd4f6d9db977f99 100644 --- a/paddle/operators/math/pooling.cu +++ b/paddle/operators/math/pooling.cu @@ -20,14 +20,14 @@ namespace operators { namespace math { template -__global__ void KernelMaxPool2dWithIdxForward( +__global__ void KernelMaxPool2dWithIdx( const int nthreads, const T* input_data, T* output_data, T* mask_data, const int channels, const int input_height, const int input_width, const int output_height, const int output_width, const int ksize_height, const int ksize_width, const int stride_height, const int stride_width, const int padding_height, const int padding_width) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < nthreads) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); + index += blockDim.x * gridDim.x) { int pw = index % output_width; int ph = (index / output_width) % output_height; int c = (index / output_width / output_height) % channels; @@ -43,51 +43,58 @@ __global__ void KernelMaxPool2dWithIdxForward( input_data += (batch_idx * channels + c) * input_height * input_width; T ele = -FLT_MAX; - int index = -1; + int max_index = -1; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { - if (ele < input_data[h * input_width + w]) { - index = h * input_width + w; - ele = input_data[h * input_width + w]; + int input_index = h * input_width + w; + if (ele < input_data[input_index]) { + max_index = input_index; + ele = input_data[input_index]; } } } output_data[index] = ele; - mask_data[index] = index; + mask_data[index] = max_index; } } template -__global__ void KernelMaxPool2DWithIdxBackward( +__global__ void KernelMaxPool2DWithIdxGrad( const int nthreads, T* input_grad, const T* output_grad, const T* mask_data, const int channels, const int input_height, const int input_width, const int output_height, const int output_width, const int ksize_height, const int ksize_width, const int stride_height, const int stride_width, const int padding_height, const int padding_width) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < nthreads) { - int offsetW = index % input_width + padding_width; - int offsetH = (index / input_width) % input_height + padding_height; - int offsetC = (index / input_width / input_height) % channels; + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); + index += blockDim.x * gridDim.x) { + int w_offset = index % input_width; + int h_offset = (index / input_width) % input_height; + int c_offset = (index / input_width / input_height) % channels; int batch_idx = index / input_width / input_height / channels; - int phstart = (offsetH < ksize_height) - ? 0 - : (offsetH - ksize_height) / stride_height + 1; - int pwstart = (offsetW < ksize_width) - ? 0 - : (offsetW - ksize_width) / stride_width + 1; - int phend = min(offsetH / stride_height + 1, output_height); - int pwend = min(offsetW / stride_width + 1, output_width); + int ph_start = + (h_offset + padding_height < ksize_height) + ? 0 + : (h_offset + padding_height - ksize_height) / stride_height + 1; + int pw_start = + (w_offset + padding_width < ksize_width) + ? 0 + : (w_offset + padding_width - ksize_width) / stride_width + 1; + int ph_end = + min((h_offset + padding_height) / stride_height + 1, output_height); + int pw_end = + min((w_offset + padding_width) / stride_width + 1, output_width); + T gradient = 0; + int input_current_featuremap_idx = h_offset * input_width + w_offset; int output_idx = - (batch_idx * channels + offsetC) * output_height * output_width; + (batch_idx * channels + c_offset) * output_height * output_width; + mask_data += output_idx; output_grad += output_idx; - for (int ph = phstart; ph < phend; ++ph) { - for (int pw = pwstart; pw < pwend; ++pw) { - if ((offsetH * input_width + offsetW) == - mask_data[ph * output_width + pw]) + for (int ph = ph_start; ph < ph_end; ++ph) { + for (int pw = pw_start; pw < pw_end; ++pw) { + if (mask_data[ph * output_width + pw] == input_current_featuremap_idx) gradient += output_grad[ph * output_width + pw]; } } @@ -125,7 +132,7 @@ class MaxPool2dWithIndexFunctor { dim3 threads(1024, 1); dim3 grid(blocks, 1); - KernelMaxPool2dWithIdxForward< + KernelMaxPool2dWithIdx< T><<(context) .stream()>>>(nthreads, input_data, output_data, mask_data, @@ -167,7 +174,7 @@ class MaxPool2dWithIndexGradFunctor { dim3 threads(1024, 1); dim3 grid(blocks, 1); - KernelMaxPool2DWithIdxBackward< + KernelMaxPool2DWithIdxGrad< T><<(context) .stream()>>>(nthreads, input_grad_data, output_grad_data, @@ -184,7 +191,7 @@ template class MaxPool2dWithIndexFunctor; template class MaxPool2dWithIndexGradFunctor; template -__global__ void KernelMaxPool3DWithIdxForward( +__global__ void KernelMaxPool3DWithIdx( const int nthreads, const T* input_data, T* output_data, T* mask_data, const int channels, const int input_depth, const int input_height, const int input_width, const int output_depth, const int output_height, @@ -200,6 +207,7 @@ __global__ void KernelMaxPool3DWithIdxForward( int c = (index / output_width / output_height / output_depth) % channels; int batch_idx = index / output_width / output_height / output_depth / channels; + int dstart = pd * stride_depth - padding_depth; int hstart = ph * stride_height - padding_height; int wstart = pw * stride_width - padding_width; @@ -209,8 +217,9 @@ __global__ void KernelMaxPool3DWithIdxForward( dstart = max(dstart, 0); hstart = max(hstart, 0); wstart = max(wstart, 0); + T ele = -FLT_MAX; - int index = -1; + int max_index = -1; input_data += (batch_idx * channels + c) * input_depth * input_height * input_width; @@ -218,19 +227,19 @@ __global__ void KernelMaxPool3DWithIdxForward( for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { if (ele < input_data[(d * input_height + h) * input_width + w]) { - index = (d * input_height + h) * input_width + w; - ele = input_data[(d * input_height + h) * input_width + w]; + max_index = (d * input_height + h) * input_width + w; + ele = input_data[max_index]; } } } } output_data[index] = ele; - mask_data[index] = index; + mask_data[index] = max_index; } } template -__global__ void KernelMaxPool3DWithIdxBackward( +__global__ void KernelMaxPool3DWithIdxGrad( const int nthreads, T* input_grad, const T* output_grad, const T* mask, const int channels, const int input_depth, const int input_height, const int input_width, const int output_depth, const int output_height, @@ -240,37 +249,45 @@ __global__ void KernelMaxPool3DWithIdxBackward( const int padding_width) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); index += blockDim.x * gridDim.x) { - int offsetW = index % input_width + padding_width; - int offsetH = (index / input_width) % input_height + padding_height; - int offsetD = - (index / input_width / input_height) % input_depth + padding_depth; - int offsetC = (index / input_width / input_height / input_depth) % channels; + int w_offset = index % input_width; + int h_offset = (index / input_width) % input_height; + int d_offset = (index / input_width / input_height) % input_depth; + int c_offset = + (index / input_width / input_height / input_depth) % channels; int batch_idx = index / input_width / input_height / input_depth / channels; - int pdstart = (offsetD < ksize_depth) - ? 0 - : (offsetD - ksize_depth) / stride_depth + 1; - int phstart = (offsetH < ksize_height) - ? 0 - : (offsetH - ksize_height) / stride_height + 1; - int pwstart = (offsetW < ksize_width) - ? 0 - : (offsetW - ksize_width) / stride_width + 1; - int pdend = min((offsetD) / stride_depth + 1, output_depth); - int phend = min((offsetH) / stride_height + 1, output_height); - int pwend = min((offsetW) / stride_width + 1, output_width); + int pd_start = + (d_offset + padding_depth < ksize_depth) + ? 0 + : (d_offset + padding_depth - ksize_depth) / stride_depth + 1; + int ph_start = + (h_offset + padding_height < ksize_height) + ? 0 + : (h_offset + padding_height - ksize_height) / stride_height + 1; + int pw_start = + (w_offset + padding_width < ksize_width) + ? 0 + : (w_offset + padding_width - ksize_width) / stride_width + 1; + int pd_end = + min((d_offset + padding_depth) / stride_depth + 1, output_depth); + int ph_end = + min((h_offset + padding_height) / stride_height + 1, output_height); + int pw_end = + min((w_offset + padding_width) / stride_width + 1, output_width); T gradient = 0; - int output_idx = (batch_idx * channels + offsetC) * output_depth * + int input_current_feature_map_idx = + (d_offset * input_height + h_offset) * input_width + w_offset; + int output_idx = (batch_idx * channels + c_offset) * output_depth * output_height * output_width; mask += output_idx; output_grad += output_idx; - for (int pd = pdstart; pd < pdend; ++pd) { - for (int ph = phstart; ph < phend; ++ph) { - for (int pw = pwstart; pw < pwend; ++pw) { - if (((offsetD * input_height + offsetH) * input_width + offsetW) == - mask[(pd * output_height + ph) * output_width + pw]) + for (int pd = pd_start; pd < pd_end; ++pd) { + for (int ph = ph_start; ph < ph_end; ++ph) { + for (int pw = pw_start; pw < pw_end; ++pw) { + if (mask[(pd * output_height + ph) * output_width + pw] == + input_current_feature_map_idx) gradient += output_grad[(pd * output_height + ph) * output_width + pw]; } @@ -308,7 +325,7 @@ class MaxPool3dWithIndexFunctor { const T* input_data = input.data(); T* output_data = output.mutable_data(context.GetPlace()); - T* mask_data = output.mutable_data(context.GetPlace()); + T* mask_data = mask.mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * output_depth * output_height * output_width; @@ -316,7 +333,7 @@ class MaxPool3dWithIndexFunctor { dim3 threads(1024, 1); dim3 grid(blocks, 1); - KernelMaxPool3DWithIdxForward< + KernelMaxPool3DWithIdx< T><<(context) .stream()>>>( @@ -341,10 +358,10 @@ class MaxPool3dWithIndexGradFunctor { const int input_depth = input_grad.dims()[2]; const int input_height = input_grad.dims()[3]; const int input_width = input_grad.dims()[4]; - const int output_channels = input_grad.dims()[1]; - const int output_depth = input_grad.dims()[2]; - const int output_height = input_grad.dims()[3]; - const int output_width = input_grad.dims()[4]; + const int output_channels = output_grad.dims()[1]; + const int output_depth = output_grad.dims()[2]; + const int output_height = output_grad.dims()[3]; + const int output_width = output_grad.dims()[4]; const int ksize_depth = ksize[0]; const int ksize_height = ksize[1]; const int ksize_width = ksize[2]; @@ -365,7 +382,7 @@ class MaxPool3dWithIndexGradFunctor { dim3 threads(1024, 1); dim3 grid(blocks, 1); - KernelMaxPool3DWithIdxBackward< + KernelMaxPool3DWithIdxGrad< T><<(context) .stream()>>>( diff --git a/paddle/operators/math/pooling.h b/paddle/operators/math/pooling.h index 3a05cd98fed3a4baa394dd50bf3a36c7d32e2e09..308a9341b6bdb5c6e0b639f9ed5dbb45ec4a1e3b 100644 --- a/paddle/operators/math/pooling.h +++ b/paddle/operators/math/pooling.h @@ -23,7 +23,6 @@ namespace operators { namespace math { ////////////////////// #define FLT_MAX __FLT_MAX__ -///////////////////// template class MaxPool2dWithIndexFunctor { diff --git a/paddle/operators/pool_with_index_op.cc b/paddle/operators/pool_with_index_op.cc index d7a07a403dbd4ff18d428f387e07adfeabb6b6cc..c51145b92310e8cd96e6c8914161742571e748b0 100644 --- a/paddle/operators/pool_with_index_op.cc +++ b/paddle/operators/pool_with_index_op.cc @@ -76,8 +76,8 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContextBase *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("X")), - "X(Input) of MaxPoolWithIndexOpGrad should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), + "X(Input) of Pooling should not be null."); PADDLE_ENFORCE( ctx->HasOutput(framework::GradVarName("X")), "X@GRAD(Input@GRAD) of MaxPoolWithIndexOpGrad should not be null."); @@ -97,28 +97,37 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { "number of channels, H and W is the height and width of image."); AddOutput("Out", "The output tensor of pooling operator." - "The format of output tensor is also NCHW."); + "The format of output tensor is also NCHW." + "Where N is batch size, C is " + "the number of channels, H and W is the height and " + "width of image."); AddOutput("Mask", "The Mask tensor of pooling operator." - "The format of output tensor is also NCHW."); + "The format of output tensor is also NCHW." + "Where N is batch size, C is the number of channels, H and W " + "is the height and width of image." + "The value in it is the index in current feature map"); AddAttr>( - "ksize", "pooling size(height, width) of pooling operator."); + "ksize", + "Pooling size(height, width) of pooling operator." + "If globalPooling = true, ksize is ignored and need not be " + "specified."); // TODO(Add checker) AddAttr( "globalPooling", - "whether to use the globalPooling." - "int constant equal to false or true" - "default false" + "Whether to use the globalPooling." + "Bool constant equal to false or true." + "Default false." "If globalPooling = true, ksize is ignored and need not be specified.") .SetDefault(false); AddAttr>("strides", - "strides(height, width) of pooling operator." - "default {1,1}") - .SetDefault({1, 1}); + "Strides(height, width) of pooling operator." + "Default {1,1}.") + .SetDefault({1, 1}); // TODO(Add checker) AddAttr>("paddings", - "paddings(height, width) of pooling operator." - "default {0,0}") - .SetDefault({0, 0}); + "Paddings(height, width) of pooling operator." + "Default {0,0}.") + .SetDefault({0, 0}); // TODO(Add checker) AddComment(R"DOC( The maxPooling2d with index operation calculates the output and the mask based on @@ -140,30 +149,40 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { "image."); AddOutput("Out", "The output tensor of pooling operator." - "The format of output tensor is also NCDHW."); + "The format of output tensor is also NCDHW." + "Where N is batch size, C is " + "the number of channels, D, H and W is the depth, height and " + "width of image."); AddOutput("Mask", "The Mask tensor of pooling operator." - "The format of output tensor is also NCDHW."); + "The format of output tensor is also NCDHW." + "Where N is batch size, C is the number of channels, D, H and W " + "is the depth, height and width of image." + "The value in it is the index in current feature map"); AddAttr>( - "ksize", "pooling size(depth, height, width) of pooling operator."); + "ksize", + "Pooling size(depth, height, width) of pooling operator." + "If globalPooling = true, ksize is ignored and need not be " + "specified."); // TODO(Add checker) AddAttr( "globalPooling", - "whether to use the globalPooling." - "int constant equal to false or true" - "default false" + "Whether to use the globalPooling." + "Bool constant equal to false or true." + "Default false." "If globalPooling = true, ksize is ignored and need not be specified.") .SetDefault(false); AddAttr>( "strides", - "strides(depth, height, width) of pooling operator." - "default {1,1,1}") - .SetDefault({1, 1, 1}); + "Strides(depth, height, width) of pooling operator." + "Default {1,1,1}.") + .SetDefault({1, 1, 1}); // TODO(Add checker) AddAttr>( "paddings", - "paddings(depth, height, width) of pooling operator." - "default {0,0,0}") - .SetDefault({0, 0, 0}); + "Paddings(depth, height, width) of pooling operator." + "Default {0,0,0}.") + .SetDefault({0, 0, 0}); // TODO(Add checker) + AddComment(R"DOC( The maxpooling3d with index operation calculates the output and the mask based on the input and ksize, strides, paddings parameters. diff --git a/paddle/operators/pool_with_index_op.h b/paddle/operators/pool_with_index_op.h index 91abeed016fe9d84bef321e9e7633ad658eb120c..5fe2f5df931c514e784e2734cce807aacf0f8736 100644 --- a/paddle/operators/pool_with_index_op.h +++ b/paddle/operators/pool_with_index_op.h @@ -32,11 +32,10 @@ class MaxPoolWithIndexKernel : public framework::OpKernel { Tensor* out = context.Output("Out"); Tensor* mask = context.Output("Mask"); - bool global_pooling = context.Attr("globalPooling"); std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - if (global_pooling) { + if (context.Attr("globalPooling")) { for (size_t i = 0; i < ksize.size(); ++i) { ksize[i] = static_cast(in_x->dims()[i + 2]); } @@ -63,7 +62,7 @@ template class MaxPoolWithIndexGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - const Tensor* mask = context.Input("Maks"); + const Tensor* mask = context.Input("Mask"); const Tensor* out_grad = context.Input(framework::GradVarName("Out")); Tensor* in_x_grad = context.Output(framework::GradVarName("X")); @@ -71,6 +70,11 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel { std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); + if (context.Attr("globalPooling")) { + for (size_t i = 0; i < ksize.size(); ++i) { + ksize[i] = static_cast(in_x_grad->dims()[i + 2]); + } + } if (in_x_grad) { in_x_grad->mutable_data(context.GetPlace()); diff --git a/python/paddle/v2/framework/tests/test_pool_max_op.py b/python/paddle/v2/framework/tests/test_pool_max_op.py index 2945c8b7a4eecf6cd04d69d25b0686c8add5f196..ffc345198daa98ce3a575dec8eea09b7b1dad4f8 100644 --- a/python/paddle/v2/framework/tests/test_pool_max_op.py +++ b/python/paddle/v2/framework/tests/test_pool_max_op.py @@ -3,7 +3,11 @@ import numpy as np from op_test import OpTest -def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): +def max_pool3D_forward_naive(x, + ksize, + strides, + paddings=[0, 0, 0], + global_pool=0): N, C, D, H, W = x.shape if global_pool == 1: @@ -25,8 +29,19 @@ def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end] out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4)) - # mask[:,:, k, i, j] = np.argmax(x_masked, axis=(2, 3, 4)) - return out + + for n in xrange(N): + for c in xrange(C): + arr = x_masked[n, c, :, :, :] + index = np.where(arr == np.max(arr)) + sub_deep = index[0][0] + sub_row = index[1][0] + sub_col = index[2][0] + index = ((d_start + sub_deep) * H + + (h_start + sub_row)) * W + w_start + sub_col + mask[n, c, k, i, j] = index + + return out, mask def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): @@ -47,19 +62,25 @@ def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): x_masked = x[:, :, r_start:r_end, c_start:c_end] out[:, :, i, j] = np.max(x_masked, axis=(2, 3)) - # mask[:,:, i, j] = np.argmax(x_masked, axis=(2, 3)) - return out + for n in xrange(N): + for c in xrange(C): + arr = x_masked[n, c, :, :] + index = np.where(arr == np.max(arr)) + sub_row = index[0][0] + sub_col = index[1][0] + index = (r_start + sub_row) * W + c_start + sub_col + mask[n, c, i, j] = index + + return out, mask class TestMaxPoolWithIndex_Op(OpTest): def setUp(self): self.initTestCase() - self.op_type = "maxPool3dWithIndex" input = np.random.random(self.shape).astype("float32") - output = self.pool_forward_naive(input, self.ksize, self.strides, - self.paddings, self.global_pool) - # mask = np.zeros(output.shape) + output, mask = self.pool_forward_naive(input, self.ksize, self.strides, + self.paddings, self.global_pool) self.attrs = { 'strides': self.strides, @@ -69,7 +90,7 @@ class TestMaxPoolWithIndex_Op(OpTest): } self.inputs = {'X': input} - self.outputs = {'Out': output} + self.outputs = {'Out': output, "Mask": mask} def test_check_output(self): self.check_output() @@ -78,7 +99,8 @@ class TestMaxPoolWithIndex_Op(OpTest): # self.check_grad(set(['X']), ['Out'], max_relative_error=0.07) def initTestCase(self): - self.global_pool = 0 + self.global_pool = False + self.op_type = "maxPool3dWithIndex" self.pool_forward_naive = max_pool3D_forward_naive self.shape = [2, 3, 7, 7, 7] self.ksize = [3, 3, 3] @@ -86,10 +108,9 @@ class TestMaxPoolWithIndex_Op(OpTest): self.paddings = [1, 1, 1] -"""" class TestCase1(TestMaxPoolWithIndex_Op): def initTestCase(self): - self.global_pool = 1 + self.global_pool = True self.op_type = "maxPool3dWithIndex" self.pool_forward_naive = max_pool3D_forward_naive self.shape = [2, 3, 5, 5, 5] @@ -100,7 +121,7 @@ class TestCase1(TestMaxPoolWithIndex_Op): class TestCase2(TestMaxPoolWithIndex_Op): def initTestCase(self): - self.global_pool = 0 + self.global_pool = False self.op_type = "maxPool2dWithIndex" self.pool_forward_naive = max_pool2D_forward_naive self.shape = [2, 3, 7, 7] @@ -111,7 +132,7 @@ class TestCase2(TestMaxPoolWithIndex_Op): class TestCase3(TestMaxPoolWithIndex_Op): def initTestCase(self): - self.global_pool = 1 + self.global_pool = True self.op_type = "maxPool2dWithIndex" self.pool_forward_naive = max_pool2D_forward_naive self.shape = [2, 3, 5, 5] @@ -122,4 +143,3 @@ class TestCase3(TestMaxPoolWithIndex_Op): if __name__ == '__main__': unittest.main() -"""