提交 bee95fc8 编写于 作者: C chengduoZH

fix code format and some bug

上级 6326c40d
......@@ -26,7 +26,6 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
......@@ -112,11 +111,11 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
input_grad_data[input_idx] += output_grad_data[output_idx];
}
}
// offset
input_grad_data += input_stride;
output_grad_data += output_stride;
mask_data += output_stride;
}
// offset
input_grad_data += input_stride;
output_grad_data += output_stride;
mask_data += output_stride;
}
}
};
......@@ -152,6 +151,7 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
const int padding_width = paddings[2];
const int input_stride = input_depth * input_height * input_width;
const int output_stride = output_depth * output_height * output_width;
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace());
T* mask_data = mask.mutable_data<T>(context.GetPlace());
......@@ -170,17 +170,17 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
int output_idx = (pd * output_height + ph) * output_width + pw;
T ele = static_cast<T>(-FLT_MAX);
int index = -1;
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (ele <
input_data[(d * input_height + h) * input_width + w]) {
index = (d * input_height + h) * input_width + w;
ele =
input_data[(d * input_height + h) * input_width + w];
int input_idx = (d * input_height + h) * input_width + w;
if (ele < input_data[input_idx]) {
index = input_idx;
ele = input_data[input_idx];
}
}
}
......
......@@ -20,14 +20,14 @@ namespace operators {
namespace math {
template <typename T>
__global__ void KernelMaxPool2dWithIdxForward(
__global__ void KernelMaxPool2dWithIdx(
const int nthreads, const T* input_data, T* output_data, T* mask_data,
const int channels, const int input_height, const int input_width,
const int output_height, const int output_width, const int ksize_height,
const int ksize_width, const int stride_height, const int stride_width,
const int padding_height, const int padding_width) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
index += blockDim.x * gridDim.x) {
int pw = index % output_width;
int ph = (index / output_width) % output_height;
int c = (index / output_width / output_height) % channels;
......@@ -43,51 +43,58 @@ __global__ void KernelMaxPool2dWithIdxForward(
input_data += (batch_idx * channels + c) * input_height * input_width;
T ele = -FLT_MAX;
int index = -1;
int max_index = -1;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (ele < input_data[h * input_width + w]) {
index = h * input_width + w;
ele = input_data[h * input_width + w];
int input_index = h * input_width + w;
if (ele < input_data[input_index]) {
max_index = input_index;
ele = input_data[input_index];
}
}
}
output_data[index] = ele;
mask_data[index] = index;
mask_data[index] = max_index;
}
}
template <typename T>
__global__ void KernelMaxPool2DWithIdxBackward(
__global__ void KernelMaxPool2DWithIdxGrad(
const int nthreads, T* input_grad, const T* output_grad, const T* mask_data,
const int channels, const int input_height, const int input_width,
const int output_height, const int output_width, const int ksize_height,
const int ksize_width, const int stride_height, const int stride_width,
const int padding_height, const int padding_width) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
int offsetW = index % input_width + padding_width;
int offsetH = (index / input_width) % input_height + padding_height;
int offsetC = (index / input_width / input_height) % channels;
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
index += blockDim.x * gridDim.x) {
int w_offset = index % input_width;
int h_offset = (index / input_width) % input_height;
int c_offset = (index / input_width / input_height) % channels;
int batch_idx = index / input_width / input_height / channels;
int phstart = (offsetH < ksize_height)
? 0
: (offsetH - ksize_height) / stride_height + 1;
int pwstart = (offsetW < ksize_width)
? 0
: (offsetW - ksize_width) / stride_width + 1;
int phend = min(offsetH / stride_height + 1, output_height);
int pwend = min(offsetW / stride_width + 1, output_width);
int ph_start =
(h_offset + padding_height < ksize_height)
? 0
: (h_offset + padding_height - ksize_height) / stride_height + 1;
int pw_start =
(w_offset + padding_width < ksize_width)
? 0
: (w_offset + padding_width - ksize_width) / stride_width + 1;
int ph_end =
min((h_offset + padding_height) / stride_height + 1, output_height);
int pw_end =
min((w_offset + padding_width) / stride_width + 1, output_width);
T gradient = 0;
int input_current_featuremap_idx = h_offset * input_width + w_offset;
int output_idx =
(batch_idx * channels + offsetC) * output_height * output_width;
(batch_idx * channels + c_offset) * output_height * output_width;
mask_data += output_idx;
output_grad += output_idx;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
if ((offsetH * input_width + offsetW) ==
mask_data[ph * output_width + pw])
for (int ph = ph_start; ph < ph_end; ++ph) {
for (int pw = pw_start; pw < pw_end; ++pw) {
if (mask_data[ph * output_width + pw] == input_current_featuremap_idx)
gradient += output_grad[ph * output_width + pw];
}
}
......@@ -125,7 +132,7 @@ class MaxPool2dWithIndexFunctor<platform::GPUPlace, T> {
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelMaxPool2dWithIdxForward<
KernelMaxPool2dWithIdx<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(nthreads, input_data, output_data, mask_data,
......@@ -167,7 +174,7 @@ class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, T> {
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelMaxPool2DWithIdxBackward<
KernelMaxPool2DWithIdxGrad<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(nthreads, input_grad_data, output_grad_data,
......@@ -184,7 +191,7 @@ template class MaxPool2dWithIndexFunctor<platform::GPUPlace, double>;
template class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, double>;
template <typename T>
__global__ void KernelMaxPool3DWithIdxForward(
__global__ void KernelMaxPool3DWithIdx(
const int nthreads, const T* input_data, T* output_data, T* mask_data,
const int channels, const int input_depth, const int input_height,
const int input_width, const int output_depth, const int output_height,
......@@ -200,6 +207,7 @@ __global__ void KernelMaxPool3DWithIdxForward(
int c = (index / output_width / output_height / output_depth) % channels;
int batch_idx =
index / output_width / output_height / output_depth / channels;
int dstart = pd * stride_depth - padding_depth;
int hstart = ph * stride_height - padding_height;
int wstart = pw * stride_width - padding_width;
......@@ -209,8 +217,9 @@ __global__ void KernelMaxPool3DWithIdxForward(
dstart = max(dstart, 0);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
T ele = -FLT_MAX;
int index = -1;
int max_index = -1;
input_data +=
(batch_idx * channels + c) * input_depth * input_height * input_width;
......@@ -218,19 +227,19 @@ __global__ void KernelMaxPool3DWithIdxForward(
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (ele < input_data[(d * input_height + h) * input_width + w]) {
index = (d * input_height + h) * input_width + w;
ele = input_data[(d * input_height + h) * input_width + w];
max_index = (d * input_height + h) * input_width + w;
ele = input_data[max_index];
}
}
}
}
output_data[index] = ele;
mask_data[index] = index;
mask_data[index] = max_index;
}
}
template <typename T>
__global__ void KernelMaxPool3DWithIdxBackward(
__global__ void KernelMaxPool3DWithIdxGrad(
const int nthreads, T* input_grad, const T* output_grad, const T* mask,
const int channels, const int input_depth, const int input_height,
const int input_width, const int output_depth, const int output_height,
......@@ -240,37 +249,45 @@ __global__ void KernelMaxPool3DWithIdxBackward(
const int padding_width) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
index += blockDim.x * gridDim.x) {
int offsetW = index % input_width + padding_width;
int offsetH = (index / input_width) % input_height + padding_height;
int offsetD =
(index / input_width / input_height) % input_depth + padding_depth;
int offsetC = (index / input_width / input_height / input_depth) % channels;
int w_offset = index % input_width;
int h_offset = (index / input_width) % input_height;
int d_offset = (index / input_width / input_height) % input_depth;
int c_offset =
(index / input_width / input_height / input_depth) % channels;
int batch_idx = index / input_width / input_height / input_depth / channels;
int pdstart = (offsetD < ksize_depth)
? 0
: (offsetD - ksize_depth) / stride_depth + 1;
int phstart = (offsetH < ksize_height)
? 0
: (offsetH - ksize_height) / stride_height + 1;
int pwstart = (offsetW < ksize_width)
? 0
: (offsetW - ksize_width) / stride_width + 1;
int pdend = min((offsetD) / stride_depth + 1, output_depth);
int phend = min((offsetH) / stride_height + 1, output_height);
int pwend = min((offsetW) / stride_width + 1, output_width);
int pd_start =
(d_offset + padding_depth < ksize_depth)
? 0
: (d_offset + padding_depth - ksize_depth) / stride_depth + 1;
int ph_start =
(h_offset + padding_height < ksize_height)
? 0
: (h_offset + padding_height - ksize_height) / stride_height + 1;
int pw_start =
(w_offset + padding_width < ksize_width)
? 0
: (w_offset + padding_width - ksize_width) / stride_width + 1;
int pd_end =
min((d_offset + padding_depth) / stride_depth + 1, output_depth);
int ph_end =
min((h_offset + padding_height) / stride_height + 1, output_height);
int pw_end =
min((w_offset + padding_width) / stride_width + 1, output_width);
T gradient = 0;
int output_idx = (batch_idx * channels + offsetC) * output_depth *
int input_current_feature_map_idx =
(d_offset * input_height + h_offset) * input_width + w_offset;
int output_idx = (batch_idx * channels + c_offset) * output_depth *
output_height * output_width;
mask += output_idx;
output_grad += output_idx;
for (int pd = pdstart; pd < pdend; ++pd) {
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
if (((offsetD * input_height + offsetH) * input_width + offsetW) ==
mask[(pd * output_height + ph) * output_width + pw])
for (int pd = pd_start; pd < pd_end; ++pd) {
for (int ph = ph_start; ph < ph_end; ++ph) {
for (int pw = pw_start; pw < pw_end; ++pw) {
if (mask[(pd * output_height + ph) * output_width + pw] ==
input_current_feature_map_idx)
gradient +=
output_grad[(pd * output_height + ph) * output_width + pw];
}
......@@ -308,7 +325,7 @@ class MaxPool3dWithIndexFunctor<platform::GPUPlace, T> {
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace());
T* mask_data = output.mutable_data<T>(context.GetPlace());
T* mask_data = mask.mutable_data<T>(context.GetPlace());
int nthreads = batch_size * output_channels * output_depth * output_height *
output_width;
......@@ -316,7 +333,7 @@ class MaxPool3dWithIndexFunctor<platform::GPUPlace, T> {
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelMaxPool3DWithIdxForward<
KernelMaxPool3DWithIdx<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
......@@ -341,10 +358,10 @@ class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T> {
const int input_depth = input_grad.dims()[2];
const int input_height = input_grad.dims()[3];
const int input_width = input_grad.dims()[4];
const int output_channels = input_grad.dims()[1];
const int output_depth = input_grad.dims()[2];
const int output_height = input_grad.dims()[3];
const int output_width = input_grad.dims()[4];
const int output_channels = output_grad.dims()[1];
const int output_depth = output_grad.dims()[2];
const int output_height = output_grad.dims()[3];
const int output_width = output_grad.dims()[4];
const int ksize_depth = ksize[0];
const int ksize_height = ksize[1];
const int ksize_width = ksize[2];
......@@ -365,7 +382,7 @@ class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T> {
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelMaxPool3DWithIdxBackward<
KernelMaxPool3DWithIdxGrad<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
......
......@@ -23,7 +23,6 @@ namespace operators {
namespace math {
//////////////////////
#define FLT_MAX __FLT_MAX__
/////////////////////
template <typename Place, typename T>
class MaxPool2dWithIndexFunctor {
......
......@@ -76,8 +76,8 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("X")),
"X(Input) of MaxPoolWithIndexOpGrad should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"),
"X(Input) of Pooling should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput(framework::GradVarName("X")),
"X@GRAD(Input@GRAD) of MaxPoolWithIndexOpGrad should not be null.");
......@@ -97,28 +97,37 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
"number of channels, H and W is the height and width of image.");
AddOutput("Out",
"The output tensor of pooling operator."
"The format of output tensor is also NCHW.");
"The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
"width of image.");
AddOutput("Mask",
"The Mask tensor of pooling operator."
"The format of output tensor is also NCHW.");
"The format of output tensor is also NCHW."
"Where N is batch size, C is the number of channels, H and W "
"is the height and width of image."
"The value in it is the index in current feature map");
AddAttr<std::vector<int>>(
"ksize", "pooling size(height, width) of pooling operator.");
"ksize",
"Pooling size(height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Add checker)
AddAttr<bool>(
"globalPooling",
"whether to use the globalPooling."
"int constant equal to false or true"
"default false"
"Whether to use the globalPooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>("strides",
"strides(height, width) of pooling operator."
"default {1,1}")
.SetDefault({1, 1});
"Strides(height, width) of pooling operator."
"Default {1,1}.")
.SetDefault({1, 1}); // TODO(Add checker)
AddAttr<std::vector<int>>("paddings",
"paddings(height, width) of pooling operator."
"default {0,0}")
.SetDefault({0, 0});
"Paddings(height, width) of pooling operator."
"Default {0,0}.")
.SetDefault({0, 0}); // TODO(Add checker)
AddComment(R"DOC(
The maxPooling2d with index operation calculates the output and the mask based on
......@@ -140,30 +149,40 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
"image.");
AddOutput("Out",
"The output tensor of pooling operator."
"The format of output tensor is also NCDHW.");
"The format of output tensor is also NCDHW."
"Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and "
"width of image.");
AddOutput("Mask",
"The Mask tensor of pooling operator."
"The format of output tensor is also NCDHW.");
"The format of output tensor is also NCDHW."
"Where N is batch size, C is the number of channels, D, H and W "
"is the depth, height and width of image."
"The value in it is the index in current feature map");
AddAttr<std::vector<int>>(
"ksize", "pooling size(depth, height, width) of pooling operator.");
"ksize",
"Pooling size(depth, height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Add checker)
AddAttr<bool>(
"globalPooling",
"whether to use the globalPooling."
"int constant equal to false or true"
"default false"
"Whether to use the globalPooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>(
"strides",
"strides(depth, height, width) of pooling operator."
"default {1,1,1}")
.SetDefault({1, 1, 1});
"Strides(depth, height, width) of pooling operator."
"Default {1,1,1}.")
.SetDefault({1, 1, 1}); // TODO(Add checker)
AddAttr<std::vector<int>>(
"paddings",
"paddings(depth, height, width) of pooling operator."
"default {0,0,0}")
.SetDefault({0, 0, 0});
"Paddings(depth, height, width) of pooling operator."
"Default {0,0,0}.")
.SetDefault({0, 0, 0}); // TODO(Add checker)
AddComment(R"DOC(
The maxpooling3d with index operation calculates the output and the mask based on
the input and ksize, strides, paddings parameters.
......
......@@ -32,11 +32,10 @@ class MaxPoolWithIndexKernel : public framework::OpKernel {
Tensor* out = context.Output<Tensor>("Out");
Tensor* mask = context.Output<Tensor>("Mask");
bool global_pooling = context.Attr<bool>("globalPooling");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (global_pooling) {
if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
......@@ -63,7 +62,7 @@ template <typename Place, typename T>
class MaxPoolWithIndexGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* mask = context.Input<Tensor>("Maks");
const Tensor* mask = context.Input<Tensor>("Mask");
const Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
......@@ -71,6 +70,11 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel {
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x_grad->dims()[i + 2]);
}
}
if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace());
......
......@@ -3,7 +3,11 @@ import numpy as np
from op_test import OpTest
def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
def max_pool3D_forward_naive(x,
ksize,
strides,
paddings=[0, 0, 0],
global_pool=0):
N, C, D, H, W = x.shape
if global_pool == 1:
......@@ -25,8 +29,19 @@ def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end]
out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4))
# mask[:,:, k, i, j] = np.argmax(x_masked, axis=(2, 3, 4))
return out
for n in xrange(N):
for c in xrange(C):
arr = x_masked[n, c, :, :, :]
index = np.where(arr == np.max(arr))
sub_deep = index[0][0]
sub_row = index[1][0]
sub_col = index[2][0]
index = ((d_start + sub_deep) * H +
(h_start + sub_row)) * W + w_start + sub_col
mask[n, c, k, i, j] = index
return out, mask
def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
......@@ -47,19 +62,25 @@ def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
x_masked = x[:, :, r_start:r_end, c_start:c_end]
out[:, :, i, j] = np.max(x_masked, axis=(2, 3))
# mask[:,:, i, j] = np.argmax(x_masked, axis=(2, 3))
return out
for n in xrange(N):
for c in xrange(C):
arr = x_masked[n, c, :, :]
index = np.where(arr == np.max(arr))
sub_row = index[0][0]
sub_col = index[1][0]
index = (r_start + sub_row) * W + c_start + sub_col
mask[n, c, i, j] = index
return out, mask
class TestMaxPoolWithIndex_Op(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = "maxPool3dWithIndex"
input = np.random.random(self.shape).astype("float32")
output = self.pool_forward_naive(input, self.ksize, self.strides,
self.paddings, self.global_pool)
# mask = np.zeros(output.shape)
output, mask = self.pool_forward_naive(input, self.ksize, self.strides,
self.paddings, self.global_pool)
self.attrs = {
'strides': self.strides,
......@@ -69,7 +90,7 @@ class TestMaxPoolWithIndex_Op(OpTest):
}
self.inputs = {'X': input}
self.outputs = {'Out': output}
self.outputs = {'Out': output, "Mask": mask}
def test_check_output(self):
self.check_output()
......@@ -78,7 +99,8 @@ class TestMaxPoolWithIndex_Op(OpTest):
# self.check_grad(set(['X']), ['Out'], max_relative_error=0.07)
def initTestCase(self):
self.global_pool = 0
self.global_pool = False
self.op_type = "maxPool3dWithIndex"
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 7, 7, 7]
self.ksize = [3, 3, 3]
......@@ -86,10 +108,9 @@ class TestMaxPoolWithIndex_Op(OpTest):
self.paddings = [1, 1, 1]
""""
class TestCase1(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = 1
self.global_pool = True
self.op_type = "maxPool3dWithIndex"
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 5, 5, 5]
......@@ -100,7 +121,7 @@ class TestCase1(TestMaxPoolWithIndex_Op):
class TestCase2(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = 0
self.global_pool = False
self.op_type = "maxPool2dWithIndex"
self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
......@@ -111,7 +132,7 @@ class TestCase2(TestMaxPoolWithIndex_Op):
class TestCase3(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = 1
self.global_pool = True
self.op_type = "maxPool2dWithIndex"
self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 5, 5]
......@@ -122,4 +143,3 @@ class TestCase3(TestMaxPoolWithIndex_Op):
if __name__ == '__main__':
unittest.main()
"""
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册