未验证 提交 8efd0876 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #5187 from chengduoZH/fix_pool_op

fix pool op
...@@ -43,6 +43,7 @@ class PoolCudnnOpKernel : public framework::OpKernel<T> { ...@@ -43,6 +43,7 @@ class PoolCudnnOpKernel : public framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
if (ctx.Attr<bool>("globalPooling")) { if (ctx.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(input->dims()[i + 2]); ksize[i] = static_cast<int>(input->dims()[i + 2]);
} }
} }
...@@ -97,8 +98,10 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> { ...@@ -97,8 +98,10 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
if (ctx.Attr<bool>("globalPooling")) { if (ctx.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(input->dims()[i + 2]); ksize[i] = static_cast<int>(input->dims()[i + 2]);
}
} }
const T *input_data = input->data<T>(); const T *input_data = input->data<T>();
......
...@@ -39,8 +39,10 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -39,8 +39,10 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
if (ctx->Attrs().Get<bool>("globalPooling")) { if (ctx->Attrs().Get<bool>("globalPooling")) {
ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2); ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i) for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x_dims[i + 2]); ksize[i] = static_cast<int>(in_x_dims[i + 2]);
}
} }
PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U, PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U,
...@@ -84,15 +86,16 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto, ...@@ -84,15 +86,16 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto,
"(string), pooling type, can be \"max\" for max-pooling " "(string), pooling type, can be \"max\" for max-pooling "
"and \"avg\" for average-pooling.") "and \"avg\" for average-pooling.")
.InEnum({"max", "avg"}); .InEnum({"max", "avg"});
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("ksize",
"ksize", "(vector ), the pooling window size(height, width) "
"(vector ), the pooling window size(height, width) of pooling operator." "of pooling operator."
"If globalPooling = true, ksize is ignored and need not be " "If globalPooling = true, ksize and paddings will "
"specified."); // TODO(Chengduo): Add checker. (Currently, "be ignored."); // TODO(Chengduo): Add checker.
// (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<bool>("globalPooling", AddAttr<bool>("globalPooling",
"(bool default: false), whether to use the global pooling." "(bool default: false), whether to use the global pooling."
"If globalPooling = true, ksize is ignored.") "If globalPooling = true, ksize and paddings will be ignored.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"strides", "strides",
...@@ -101,7 +104,8 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto, ...@@ -101,7 +104,8 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"paddings", "paddings",
"(vector defalut:{0,0}), paddings(height, width) of pooling operator.") "(vector defalut:{0,0}), paddings(height, width) of pooling operator."
"If globalPooling = true, paddings and ksize will be ignored.")
.SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
...@@ -145,25 +149,28 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto, ...@@ -145,25 +149,28 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto,
"(string), pooling type, can be \"max\" for max-pooling " "(string), pooling type, can be \"max\" for max-pooling "
"and \"avg\" for average-pooling.") "and \"avg\" for average-pooling.")
.InEnum({"max", "avg"}); .InEnum({"max", "avg"});
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("ksize",
"ksize", "(vector ), the pooling window size(depth, height, "
"(vector ), the pooling window size(depth, height, width) of pooling " "width) of pooling "
"operator." "operator."
"If globalPooling = true, ksize is ignored and need not be " "If globalPooling = true, ksize and paddings wille "
"specified."); // TODO(Chengduo): Add checker. (Currently, "be ignored."); // TODO(Chengduo): Add checker.
// TypedAttrChecker don't support vector type.) // (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<bool>("globalPooling", AddAttr<bool>("globalPooling",
"(bool default: false), whether to use the global pooling." "(bool default: false), whether to use the global pooling."
"If globalPooling = true, ksize is ignored.") "If globalPooling = true, ksize and paddings wille be ignored.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"(vector, default:{1,1,1}), strides(depth, height, " "(vector, default:{1,1,1}), strides(depth, height, "
"width) of pooling operator.") "width) of pooling operator.")
.SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>("paddings", AddAttr<std::vector<int>>(
"(vector defalut:{0,0,0}), paddings(depth, height, " "paddings",
"width) of pooling operator.") "(vector defalut:{0,0,0}), paddings(depth, height, "
"width) of pooling operator."
"If globalPooling = true, ksize and paddings wille be ignored.")
.SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
......
...@@ -63,6 +63,7 @@ class PoolKernel : public framework::OpKernel<T> { ...@@ -63,6 +63,7 @@ class PoolKernel : public framework::OpKernel<T> {
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) { if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]); ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
} }
} }
...@@ -103,6 +104,7 @@ class PoolKernel : public framework::OpKernel<T> { ...@@ -103,6 +104,7 @@ class PoolKernel : public framework::OpKernel<T> {
paddings, pool_process); paddings, pool_process);
} }
} break; } break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
} }
} }
}; };
...@@ -123,8 +125,10 @@ class PoolGradKernel : public framework::OpKernel<T> { ...@@ -123,8 +125,10 @@ class PoolGradKernel : public framework::OpKernel<T> {
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) { if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]); ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
} }
if (in_x_grad) { if (in_x_grad) {
...@@ -164,6 +168,7 @@ class PoolGradKernel : public framework::OpKernel<T> { ...@@ -164,6 +168,7 @@ class PoolGradKernel : public framework::OpKernel<T> {
*out_grad, ksize, strides, paddings, pool_process); *out_grad, ksize, strides, paddings, pool_process);
} }
} break; } break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
} }
} }
} }
......
...@@ -46,8 +46,10 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { ...@@ -46,8 +46,10 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
if (ctx->Attrs().Get<bool>("globalPooling")) { if (ctx->Attrs().Get<bool>("globalPooling")) {
ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2); ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i) for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x_dims[i + 2]); ksize[i] = static_cast<int>(in_x_dims[i + 2]);
}
} }
PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U, PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U,
...@@ -87,31 +89,33 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -87,31 +89,33 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
"(Tensor) The input tensor of pooling operator. " "(Tensor), the input tensor of pooling operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the " "The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of image."); "number of channels, H and W is the height and width of image.");
AddOutput("Out", AddOutput("Out",
"(Tensor) The output tensor of pooling operator." "(Tensor), the output tensor of pooling operator."
"The format of output tensor is also NCHW." "The format of output tensor is also NCHW."
"Where N is batch size, C is " "Where N is batch size, C is "
"the number of channels, H and W is the height and " "the number of channels, H and W is the height and "
"width of image."); "width of image.");
AddOutput("Mask", AddOutput("Mask",
"(Tensor) The Mask tensor of pooling operator." "(Tensor), the Mask tensor of pooling operator."
"The format of output tensor is also NCHW." "The format of output tensor is also NCHW."
"Where N is batch size, C is the number of channels, H and W " "Where N is batch size, C is the number of channels, H and W "
"is the height and width of image." "is the height and width of image."
"The value in it is the index in current feature map"); "The value in it is the index in current feature map");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("ksize",
"ksize", "(vector ), the pooling window size(height, "
"(vector ), the pooling window size(height, width) of pooling operator." "width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be " "If globalPooling = true, ksize and paddings "
"specified."); // TODO(Chengduo): Add checker. (Currently, "will be ignored."); // TODO(Chengduo): Add
// checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<bool>("globalPooling", AddAttr<bool>(
"(bool default: false), whether to use the global pooling." "globalPooling",
"If globalPooling = true, ksize is ignored.") "(bool default: false), whether to use the global pooling."
"If globalPooling = true, ksize and paddings will be ignored.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"strides", "strides",
...@@ -120,7 +124,8 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -120,7 +124,8 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"paddings", "paddings",
"(vector defalut:{0,0}), paddings(height, width) of pooling operator.") "(vector defalut:{0, 0}), paddings(height, width) of pooling operator."
"If globalPooling = true, paddings and will be ignored.")
.SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
...@@ -153,42 +158,46 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -153,42 +158,46 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"X", "X",
"(Tensor) The input tensor of pooling operator. " "(Tensor), the input tensor of pooling operator. "
"The format of input tensor is NCDHW. Where N is batch size, C is " "The format of input tensor is NCDHW. Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and width of " "the number of channels, D, H and W is the depth, height and width of "
"image."); "image.");
AddOutput("Out", AddOutput("Out",
"(Tensor) The output tensor of pooling operator." "(Tensor), the output tensor of pooling operator."
"The format of output tensor is also NCDHW." "The format of output tensor is also NCDHW."
"Where N is batch size, C is " "Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and " "the number of channels, D, H and W is the depth, height and "
"width of image."); "width of image.");
AddOutput("Mask", AddOutput("Mask",
"(Tensor) The Mask tensor of pooling operator." "(Tensor), the Mask tensor of pooling operator."
"The format of output tensor is also NCDHW." "The format of output tensor is also NCDHW."
"Where N is batch size, C is the number of channels, D, H and W " "Where N is batch size, C is the number of channels, D, H and W "
"is the depth, height and width of image." "is the depth, height and width of image."
"The value in it is the index in current feature map"); "The value in it is the index in current feature map");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("ksize",
"ksize", "(vector), the pooling window size(depth, "
"(vector ), the pooling window size(depth, height, width) of pooling " "height, width) of pooling "
"operator." "operator."
"If globalPooling = true, ksize is ignored and need not be " "If globalPooling = true, ksize and paddings "
"specified."); // TODO(Chengduo): Add checker. (Currently, "will be ignored."); // TODO(Chengduo): Add
// checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<bool>("globalPooling", AddAttr<bool>(
"(bool default: false), whether to use the global pooling." "globalPooling",
"If globalPooling = true, ksize is ignored.") "(bool default: false), whether to use the global pooling."
"If globalPooling = true, ksize and paddings will be ignored.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"(vector, default:{1,1,1}), strides(depth, " "(vector, default:{1,1,1}), strides(depth, "
"height, width) of pooling operator.") "height, width) of pooling operator.")
.SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>("paddings", AddAttr<std::vector<int>>(
"(vector defalut:{0,0,0}), paddings(depth, " "paddings",
"height, width) of pooling operator.") "(vector defalut:{0,0,0}), paddings(depth, "
"height, width) of pooling operator."
"If globalPooling = true, paddings and ksize will be ignored.")
.SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
......
...@@ -37,6 +37,7 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> { ...@@ -37,6 +37,7 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> {
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) { if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]); ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
} }
} }
...@@ -54,6 +55,7 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> { ...@@ -54,6 +55,7 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> {
pool3d_forward(context.device_context(), *in_x, *out, *mask, ksize, pool3d_forward(context.device_context(), *in_x, *out, *mask, ksize,
strides, paddings); strides, paddings);
} break; } break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
} }
} }
}; };
...@@ -72,6 +74,7 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> { ...@@ -72,6 +74,7 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> {
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) { if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x_grad->dims()[i + 2]); ksize[i] = static_cast<int>(in_x_grad->dims()[i + 2]);
} }
} }
...@@ -95,6 +98,7 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> { ...@@ -95,6 +98,7 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> {
pool3d_backward(context.device_context(), *in_x_grad, *out_grad, pool3d_backward(context.device_context(), *in_x_grad, *out_grad,
*mask, ksize, strides, paddings); *mask, ksize, strides, paddings);
} break; } break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
} }
} }
} }
......
...@@ -49,9 +49,12 @@ class TestPool2d_Op(OpTest): ...@@ -49,9 +49,12 @@ class TestPool2d_Op(OpTest):
self.init_test_case() self.init_test_case()
self.init_op_type() self.init_op_type()
self.init_pool_type() self.init_pool_type()
if self.global_pool:
self.paddings = [0 for _ in range(len(self.paddings))]
input = np.random.random(self.shape).astype("float32") input = np.random.random(self.shape).astype("float32")
output = self.pool2D_forward_naive(input, self.ksize, self.strides, output = self.pool2D_forward_naive(input, self.ksize, self.strides,
self.paddings, self.global_pool) self.paddings,
self.global_pool).astype("float32")
self.inputs = {'X': input} self.inputs = {'X': input}
self.attrs = { self.attrs = {
......
...@@ -54,10 +54,13 @@ def avg_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): ...@@ -54,10 +54,13 @@ def avg_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
class TestPool3d_Op(OpTest): class TestPool3d_Op(OpTest):
def setUp(self): def setUp(self):
self.initTestCase() self.init_test_case()
if self.global_pool:
self.paddings = [0 for _ in range(len(self.paddings))]
input = np.random.random(self.shape).astype("float32") input = np.random.random(self.shape).astype("float32")
output = self.pool3D_forward_naive(input, self.ksize, self.strides, output = self.pool3D_forward_naive(input, self.ksize, self.strides,
self.paddings, self.global_pool) self.paddings,
self.global_pool).astype("float32")
self.inputs = {'X': input} self.inputs = {'X': input}
self.attrs = { self.attrs = {
...@@ -77,7 +80,7 @@ class TestPool3d_Op(OpTest): ...@@ -77,7 +80,7 @@ class TestPool3d_Op(OpTest):
if self.pool_type != "max": if self.pool_type != "max":
self.check_grad(set(['X']), 'Out', max_relative_error=0.07) self.check_grad(set(['X']), 'Out', max_relative_error=0.07)
def initTestCase(self): def init_test_case(self):
self.global_pool = True self.global_pool = True
self.op_type = "pool3d" self.op_type = "pool3d"
self.pool_type = "avg" self.pool_type = "avg"
...@@ -89,7 +92,7 @@ class TestPool3d_Op(OpTest): ...@@ -89,7 +92,7 @@ class TestPool3d_Op(OpTest):
class TestCase1(TestPool3d_Op): class TestCase1(TestPool3d_Op):
def initTestCase(self): def init_test_case(self):
self.global_pool = False self.global_pool = False
self.op_type = "pool3d" self.op_type = "pool3d"
self.pool_type = "avg" self.pool_type = "avg"
...@@ -101,7 +104,7 @@ class TestCase1(TestPool3d_Op): ...@@ -101,7 +104,7 @@ class TestCase1(TestPool3d_Op):
class TestCase2(TestPool3d_Op): class TestCase2(TestPool3d_Op):
def initTestCase(self): def init_test_case(self):
self.global_pool = False self.global_pool = False
self.op_type = "pool3d" self.op_type = "pool3d"
self.pool_type = "avg" self.pool_type = "avg"
...@@ -113,7 +116,7 @@ class TestCase2(TestPool3d_Op): ...@@ -113,7 +116,7 @@ class TestCase2(TestPool3d_Op):
class TestCase3(TestPool3d_Op): class TestCase3(TestPool3d_Op):
def initTestCase(self): def init_test_case(self):
self.global_pool = True self.global_pool = True
self.op_type = "pool3d" self.op_type = "pool3d"
self.pool_type = "max" self.pool_type = "max"
...@@ -125,7 +128,7 @@ class TestCase3(TestPool3d_Op): ...@@ -125,7 +128,7 @@ class TestCase3(TestPool3d_Op):
class TestCase4(TestPool3d_Op): class TestCase4(TestPool3d_Op):
def initTestCase(self): def init_test_case(self):
self.global_pool = False self.global_pool = False
self.op_type = "pool3d" self.op_type = "pool3d"
self.pool_type = "max" self.pool_type = "max"
...@@ -137,7 +140,7 @@ class TestCase4(TestPool3d_Op): ...@@ -137,7 +140,7 @@ class TestCase4(TestPool3d_Op):
class TestCase5(TestPool3d_Op): class TestCase5(TestPool3d_Op):
def initTestCase(self): def init_test_case(self):
self.global_pool = False self.global_pool = False
self.op_type = "pool3d" self.op_type = "pool3d"
self.pool_type = "max" self.pool_type = "max"
......
...@@ -3,11 +3,7 @@ import numpy as np ...@@ -3,11 +3,7 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
def max_pool3D_forward_naive(x, def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=0):
ksize,
strides,
paddings=[0, 0, 0],
global_pool=0):
N, C, D, H, W = x.shape N, C, D, H, W = x.shape
if global_pool == 1: if global_pool == 1:
...@@ -44,7 +40,7 @@ def max_pool3D_forward_naive(x, ...@@ -44,7 +40,7 @@ def max_pool3D_forward_naive(x,
return out, mask return out, mask
def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): def max_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=0):
N, C, H, W = x.shape N, C, H, W = x.shape
if global_pool == 1: if global_pool == 1:
...@@ -77,10 +73,14 @@ def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): ...@@ -77,10 +73,14 @@ def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
class TestMaxPoolWithIndex_Op(OpTest): class TestMaxPoolWithIndex_Op(OpTest):
def setUp(self): def setUp(self):
self.initTestCase() self.init_test_case()
if self.global_pool:
self.paddings = [0 for _ in range(len(self.paddings))]
input = np.random.random(self.shape).astype("float32") input = np.random.random(self.shape).astype("float32")
output, mask = self.pool_forward_naive(input, self.ksize, self.strides, output, mask = self.pool_forward_naive(input, self.ksize, self.strides,
self.paddings, self.global_pool) self.paddings, self.global_pool)
output = output.astype("float32")
mask = mask.astype("float32")
self.attrs = { self.attrs = {
'strides': self.strides, 'strides': self.strides,
...@@ -98,7 +98,7 @@ class TestMaxPoolWithIndex_Op(OpTest): ...@@ -98,7 +98,7 @@ class TestMaxPoolWithIndex_Op(OpTest):
# def test_check_grad(self): # def test_check_grad(self):
# self.check_grad(set(['X']), ['Out'], max_relative_error=0.07) # self.check_grad(set(['X']), ['Out'], max_relative_error=0.07)
def initTestCase(self): def init_test_case(self):
self.global_pool = True self.global_pool = True
self.index = "max_pool3d_with_index" self.index = "max_pool3d_with_index"
self.op_type = "%s" % self.index self.op_type = "%s" % self.index
...@@ -110,7 +110,7 @@ class TestMaxPoolWithIndex_Op(OpTest): ...@@ -110,7 +110,7 @@ class TestMaxPoolWithIndex_Op(OpTest):
class TestCase1(TestMaxPoolWithIndex_Op): class TestCase1(TestMaxPoolWithIndex_Op):
def initTestCase(self): def init_test_case(self):
self.global_pool = True self.global_pool = True
self.op_type = "max_pool3d_with_index" self.op_type = "max_pool3d_with_index"
self.pool_forward_naive = max_pool3D_forward_naive self.pool_forward_naive = max_pool3D_forward_naive
...@@ -121,7 +121,7 @@ class TestCase1(TestMaxPoolWithIndex_Op): ...@@ -121,7 +121,7 @@ class TestCase1(TestMaxPoolWithIndex_Op):
class TestCase2(TestMaxPoolWithIndex_Op): class TestCase2(TestMaxPoolWithIndex_Op):
def initTestCase(self): def init_test_case(self):
self.global_pool = False self.global_pool = False
self.op_type = "max_pool3d_with_index" self.op_type = "max_pool3d_with_index"
self.pool_forward_naive = max_pool3D_forward_naive self.pool_forward_naive = max_pool3D_forward_naive
...@@ -132,7 +132,7 @@ class TestCase2(TestMaxPoolWithIndex_Op): ...@@ -132,7 +132,7 @@ class TestCase2(TestMaxPoolWithIndex_Op):
class TestCase3(TestMaxPoolWithIndex_Op): class TestCase3(TestMaxPoolWithIndex_Op):
def initTestCase(self): def init_test_case(self):
self.global_pool = False self.global_pool = False
self.op_type = "max_pool3d_with_index" self.op_type = "max_pool3d_with_index"
self.pool_forward_naive = max_pool3D_forward_naive self.pool_forward_naive = max_pool3D_forward_naive
...@@ -143,7 +143,7 @@ class TestCase3(TestMaxPoolWithIndex_Op): ...@@ -143,7 +143,7 @@ class TestCase3(TestMaxPoolWithIndex_Op):
class TestCase4(TestMaxPoolWithIndex_Op): class TestCase4(TestMaxPoolWithIndex_Op):
def initTestCase(self): def init_test_case(self):
self.global_pool = True self.global_pool = True
self.op_type = "max_pool3d_with_index" self.op_type = "max_pool3d_with_index"
self.pool_forward_naive = max_pool3D_forward_naive self.pool_forward_naive = max_pool3D_forward_naive
...@@ -154,7 +154,7 @@ class TestCase4(TestMaxPoolWithIndex_Op): ...@@ -154,7 +154,7 @@ class TestCase4(TestMaxPoolWithIndex_Op):
class TestCase5(TestMaxPoolWithIndex_Op): class TestCase5(TestMaxPoolWithIndex_Op):
def initTestCase(self): def init_test_case(self):
self.global_pool = True self.global_pool = True
self.op_type = "max_pool3d_with_index" self.op_type = "max_pool3d_with_index"
self.pool_forward_naive = max_pool3D_forward_naive self.pool_forward_naive = max_pool3D_forward_naive
...@@ -165,7 +165,7 @@ class TestCase5(TestMaxPoolWithIndex_Op): ...@@ -165,7 +165,7 @@ class TestCase5(TestMaxPoolWithIndex_Op):
class TestCase6(TestMaxPoolWithIndex_Op): class TestCase6(TestMaxPoolWithIndex_Op):
def initTestCase(self): def init_test_case(self):
self.global_pool = False self.global_pool = False
self.op_type = "max_pool2d_with_index" self.op_type = "max_pool2d_with_index"
self.pool_forward_naive = max_pool2D_forward_naive self.pool_forward_naive = max_pool2D_forward_naive
...@@ -176,7 +176,7 @@ class TestCase6(TestMaxPoolWithIndex_Op): ...@@ -176,7 +176,7 @@ class TestCase6(TestMaxPoolWithIndex_Op):
class TestCase7(TestMaxPoolWithIndex_Op): class TestCase7(TestMaxPoolWithIndex_Op):
def initTestCase(self): def init_test_case(self):
self.global_pool = False self.global_pool = False
self.op_type = "max_pool2d_with_index" self.op_type = "max_pool2d_with_index"
self.pool_forward_naive = max_pool2D_forward_naive self.pool_forward_naive = max_pool2D_forward_naive
...@@ -187,7 +187,7 @@ class TestCase7(TestMaxPoolWithIndex_Op): ...@@ -187,7 +187,7 @@ class TestCase7(TestMaxPoolWithIndex_Op):
class TestCase8(TestMaxPoolWithIndex_Op): class TestCase8(TestMaxPoolWithIndex_Op):
def initTestCase(self): def init_test_case(self):
self.global_pool = True self.global_pool = True
self.op_type = "max_pool2d_with_index" self.op_type = "max_pool2d_with_index"
self.pool_forward_naive = max_pool2D_forward_naive self.pool_forward_naive = max_pool2D_forward_naive
...@@ -198,7 +198,7 @@ class TestCase8(TestMaxPoolWithIndex_Op): ...@@ -198,7 +198,7 @@ class TestCase8(TestMaxPoolWithIndex_Op):
class TestCase9(TestMaxPoolWithIndex_Op): class TestCase9(TestMaxPoolWithIndex_Op):
def initTestCase(self): def init_test_case(self):
self.global_pool = True self.global_pool = True
self.op_type = "max_pool2d_with_index" self.op_type = "max_pool2d_with_index"
self.pool_forward_naive = max_pool2D_forward_naive self.pool_forward_naive = max_pool2D_forward_naive
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册