From 84a2512b90f39854386fb03f437ac8a92a486437 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 22 Sep 2017 10:06:03 +0800 Subject: [PATCH] fix parameter name and function define --- paddle/operators/math/pooling.cc | 24 ++--- paddle/operators/math/pooling.cu | 24 ++--- paddle/operators/math/pooling.h | 24 ++--- paddle/operators/pool_op.cc | 87 +++++++++++-------- paddle/operators/pool_op.h | 72 ++++++++------- .../v2/framework/tests/test_pool2d_op.py | 8 +- .../v2/framework/tests/test_pool3d_op.py | 8 +- 7 files changed, 129 insertions(+), 118 deletions(-) diff --git a/paddle/operators/math/pooling.cc b/paddle/operators/math/pooling.cc index 671bead1b49..5ce748ff08b 100644 --- a/paddle/operators/math/pooling.cc +++ b/paddle/operators/math/pooling.cc @@ -21,10 +21,10 @@ namespace math { template class Pool2dForwardFunctor { public: - void operator()(const framework::Tensor& input, framework::Tensor& output, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_process, - const platform::DeviceContext& context) { + std::vector& paddings, PoolProcess pool_process) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; @@ -75,12 +75,12 @@ class Pool2dForwardFunctor { template class Pool2dBackwardFunctor { public: - void operator()(const framework::Tensor& input, framework::Tensor& input_grad, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_process, - const platform::DeviceContext& context) { + PoolProcess pool_process) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; @@ -154,10 +154,10 @@ template class Pool2dBackwardFunctor< template class Pool3dForwardFunctor { public: - void operator()(const framework::Tensor& input, framework::Tensor& output, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_process, - const platform::DeviceContext& context) { + std::vector& paddings, PoolProcess pool_process) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; @@ -224,12 +224,12 @@ class Pool3dForwardFunctor { template class Pool3dBackwardFunctor { public: - void operator()(const framework::Tensor& input, framework::Tensor& input_grad, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_process, - const platform::DeviceContext& context) { + PoolProcess pool_process) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; diff --git a/paddle/operators/math/pooling.cu b/paddle/operators/math/pooling.cu index ce0a01776a9..124006942c0 100644 --- a/paddle/operators/math/pooling.cu +++ b/paddle/operators/math/pooling.cu @@ -105,10 +105,10 @@ __global__ void KernelPool2dBackward( template class Pool2dForwardFunctor { public: - void operator()(const framework::Tensor& input, framework::Tensor& output, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_process, - const platform::DeviceContext& context) { + std::vector& paddings, PoolProcess pool_process) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -148,12 +148,12 @@ class Pool2dForwardFunctor { template class Pool2dBackwardFunctor { public: - void operator()(const framework::Tensor& input, framework::Tensor& input_grad, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_process, - const platform::DeviceContext& context) { + PoolProcess pool_process) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -319,10 +319,10 @@ __global__ void KernelPool3DBackward( template class Pool3dForwardFunctor { public: - void operator()(const framework::Tensor& input, framework::Tensor& output, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_process, - const platform::DeviceContext& context) { + std::vector& paddings, PoolProcess pool_process) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; @@ -369,12 +369,12 @@ class Pool3dForwardFunctor { template class Pool3dBackwardFunctor { public: - void operator()(const framework::Tensor& input, framework::Tensor& input_grad, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_process, - const platform::DeviceContext& context) { + PoolProcess pool_process) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; diff --git a/paddle/operators/math/pooling.h b/paddle/operators/math/pooling.h index aad5a1837b8..3f01c9dacb0 100644 --- a/paddle/operators/math/pooling.h +++ b/paddle/operators/math/pooling.h @@ -59,41 +59,41 @@ class avePool { template class Pool2dForwardFunctor { public: - void operator()(const framework::Tensor& input, framework::Tensor& output, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_process, - const platform::DeviceContext& context); + std::vector& paddings, PoolProcess pool_process); }; template class Pool2dBackwardFunctor { public: - void operator()(const framework::Tensor& input, framework::Tensor& input_grad, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_process, - const platform::DeviceContext& context); + PoolProcess pool_process); }; template class Pool3dForwardFunctor { public: - void operator()(const framework::Tensor& input, framework::Tensor& output, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_process, - const platform::DeviceContext& context); + std::vector& paddings, PoolProcess pool_process); }; template class Pool3dBackwardFunctor { public: - void operator()(const framework::Tensor& input, framework::Tensor& input_grad, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_process, - const platform::DeviceContext& context); + PoolProcess pool_process); }; } // namespace math diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc index 1d79629d730..c23adf63bf8 100644 --- a/paddle/operators/pool_op.cc +++ b/paddle/operators/pool_op.cc @@ -28,18 +28,18 @@ class PoolOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"), - "Input(Input) of Pooling should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Output"), - "Output(Output) of Pooling should not be null."); - // PADDLE_ENFORCE_NOT_NULL(Attr("pooling_type"), + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "X(Input) of Pooling should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Out(Output) of Pooling should not be null."); + // PADDLE_ENFORCE_NOT_NULL(Attr("poolingType"), // "pooling_type should not be null."); // PADDLE_ENFORCE_NOT_NULL(Attr>("ksize"), "ksize should // not be null."); - auto input = ctx.Input("Input"); - auto output = ctx.Output("Output"); - int global_pooling = Attr("global_pooling"); - std::string pooling_type = Attr("pooling_type"); + auto in_X = ctx.Input("X"); + auto out = ctx.Output("Out"); + int global_pooling = Attr("globalPooling"); + std::string pooling_type = Attr("poolingType"); std::vector ksize = Attr>("ksize"); std::vector strides = Attr>("strides"); std::vector paddings = Attr>("paddings"); @@ -50,25 +50,25 @@ class PoolOp : public framework::OperatorWithKernel { "Pooling ksize should be 2-D or 3-D"); if (global_pooling == 1) { - for (size_t i = 0; i < ksize.size(); ++i) ksize[i] = input->dims()[i + 2]; + for (size_t i = 0; i < ksize.size(); ++i) ksize[i] = in_X->dims()[i + 2]; } if (ksize.size() == 2) { - PADDLE_ENFORCE_EQ(input->dims().size(), 4, + PADDLE_ENFORCE_EQ(in_X->dims().size(), 4, "Pool2DOp intput should be 4-D."); PADDLE_ENFORCE_EQ(strides.size(), 2, "Pool2DOp strides should be 2-D."); PADDLE_ENFORCE_EQ(paddings.size(), 2, "Pool2DOp paddings should be 2-D."); } else { - PADDLE_ENFORCE_EQ(input->dims().size(), 5, + PADDLE_ENFORCE_EQ(in_X->dims().size(), 5, "Pool3DOp intput should be 5-D."); PADDLE_ENFORCE_EQ(strides.size(), 3, "Pool3DOp strides should be 3-D."); PADDLE_ENFORCE_EQ(paddings.size(), 3, "Pool3DOp paddings should be 3-D."); } - std::vector output_shape({input->dims()[0], input->dims()[1]}); + std::vector output_shape({in_X->dims()[0], in_X->dims()[1]}); for (size_t i = 0; i < ksize.size(); ++i) { - output_shape.push_back(outputSize_pool(input->dims()[i + 2], ksize[i], + output_shape.push_back(outputSize_pool(in_X->dims()[i + 2], ksize[i], paddings[i], strides[i])); } - output->Resize(framework::make_ddim(output_shape)); + out->Resize(framework::make_ddim(output_shape)); } }; @@ -78,9 +78,8 @@ class PoolOpGrad : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto in = ctx.Input("Input"); - auto d_in = - ctx.Output(framework::GradVarName("Input")); + auto in = ctx.Input("X"); + auto d_in = ctx.Output(framework::GradVarName("X")); if (d_in) d_in->Resize(in->dims()); } }; @@ -90,27 +89,36 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { Pool3dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput( - "Input", + "X", "The input tensor of pooling operator. " "The format of input tensor is NCDHW. Where N is batch size, C is the " "number of channels, D, H and W is the depth, height and width of " "image."); - AddOutput("Output", + AddOutput("Out", "The output tensor of pooling operator." "The format of output tensor is also NCDHW."); - AddAttr("pooling_type", - "pooling_type of pooling operator.['max' or 'ave']"); - AddAttr>("ksize", "strides of pooling operator."); - AddAttr("global_pooling", "whether to use the global_pooling.") + AddAttr("poolingType", + "poolingType of pooling operator.['max' or 'ave']"); + AddAttr>( + "ksize", "pooling size(depth, height, width) of pooling operator."); + AddAttr("globalPooling", + "default 0" + "whether to use the globalPooling.") .SetDefault(0); - AddAttr>("strides", "strides of pooling operator.") + AddAttr>( + "strides", + "default {1,1,1}" + "strides(depth, height, width) of pooling operator.") .SetDefault({1, 1, 1}); - AddAttr>("paddings", "paddings of pooling operator.") + AddAttr>( + "paddings", + "default {0,0,0}" + "paddings(depth, height, width) of pooling operator.") .SetDefault({0, 0, 0}); AddComment(R"DOC( The pooling3d operation calculates the output based on -the input, pooling_type and ksize, strides, paddings parameters. +the input, poolingType and ksize, strides, paddings parameters. )DOC"); } }; @@ -120,26 +128,33 @@ class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker { Pool2dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput( - "Input", + "X", "The input tensor of pooling operator. " "The format of input tensor is NCHW. Where N is batch size, C is the " "number of channels, H and W is the height and width of image."); - AddOutput("Output", + AddOutput("Out", "The output tensor of pooling operator." "The format of output tensor is also NCHW."); - AddAttr("pooling_type", - "pooling_type of pooling operator.['max' or 'ave']"); - AddAttr>("ksize", "strides of pooling operator."); - AddAttr("global_pooling", "whether to use the global_pooling.") + AddAttr("poolingType", + "poolingType of pooling operator.['max' or 'ave']"); + AddAttr>( + "ksize", "pooling size(height, width) of pooling operator."); + AddAttr("globalPooling", + "default 0" + "whether to use the globalPooling.[0 or 1]") .SetDefault(0); - AddAttr>("strides", "strides of pooling operator.") + AddAttr>("strides", + "default {1, 1}" + "strides(height, width) of pooling operator.") .SetDefault({1, 1}); - AddAttr>("paddings", "paddings of pooling operator.") + AddAttr>("paddings", + "default {0, 0}" + "paddings(height, width) of pooling operator.") .SetDefault({0, 0}); AddComment(R"DOC( The pooling2d operation calculates the output based on -the input, pooling_type and ksize, strides, paddings parameters. +the input, poolingType and ksize, strides, paddings parameters. )DOC"); } }; diff --git a/paddle/operators/pool_op.h b/paddle/operators/pool_op.h index 16779cbb91b..2e737f0a4b5 100644 --- a/paddle/operators/pool_op.h +++ b/paddle/operators/pool_op.h @@ -28,17 +28,17 @@ template class PoolKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("Input"); - Tensor* output = context.Output("Output"); + const Tensor* in_X = context.Input("X"); + Tensor* out = context.Output("Out"); - int global_pooling = context.Attr("global_pooling"); - std::string pooling_type = context.Attr("pooling_type"); + int global_pooling = context.Attr("globalPooling"); + std::string pooling_type = context.Attr("poolingType"); std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); if (global_pooling == 1) { for (size_t i = 0; i < ksize.size(); ++i) { - ksize[i] = input->dims()[i + 2]; + ksize[i] = in_X->dims()[i + 2]; } } @@ -49,16 +49,16 @@ class PoolKernel : public framework::OpKernel { Place, paddle::operators::math::pool::maxPool, T> pool2d_forward; paddle::operators::math::pool::maxPool pool_process; - pool2d_forward(*input, *output, ksize, strides, paddings, - pool_process, context.device_context()); + pool2d_forward(context.device_context(), *in_X, *out, ksize, strides, + paddings, pool_process); } else if (pooling_type == "ave") { paddle::operators::math::Pool2dForwardFunctor< Place, paddle::operators::math::pool::avePool, T> pool2d_forward; paddle::operators::math::pool::avePool pool_process; - pool2d_forward(*input, *output, ksize, strides, paddings, - pool_process, (context.device_context())); + pool2d_forward(context.device_context(), *in_X, *out, ksize, strides, + paddings, pool_process); } } break; case 3: { @@ -67,15 +67,15 @@ class PoolKernel : public framework::OpKernel { Place, paddle::operators::math::pool::maxPool, T> pool3d_forward; paddle::operators::math::pool::maxPool pool_process; - pool3d_forward(*input, *output, ksize, strides, paddings, - pool_process, context.device_context()); + pool3d_forward(context.device_context(), *in_X, *out, ksize, strides, + paddings, pool_process); } else if (pooling_type == "ave") { paddle::operators::math::Pool3dForwardFunctor< Place, paddle::operators::math::pool::avePool, T> pool3d_forward; paddle::operators::math::pool::avePool pool_process; - pool3d_forward(*input, *output, ksize, strides, paddings, - pool_process, context.device_context()); + pool3d_forward(context.device_context(), *in_X, *out, ksize, strides, + paddings, pool_process); } } break; } @@ -86,26 +86,26 @@ template class PoolGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("Input"); - const Tensor* output = context.Input("Output"); - const Tensor* output_grad = - context.Input(framework::GradVarName("Output")); - Tensor* input_grad = - context.Output(framework::GradVarName("Input")); - - int global_pooling = context.Attr("global_pooling"); - std::string pooling_type = context.Attr("pooling_type"); + const Tensor* in_X = context.Input("X"); + const Tensor* out = context.Input("Out"); + const Tensor* out_grad = + context.Input(framework::GradVarName("Out")); + Tensor* in_X_grad = + context.Output(framework::GradVarName("X")); + + int global_pooling = context.Attr("globalPooling"); + std::string pooling_type = context.Attr("poolingType"); std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); if (global_pooling == 1) { - for (size_t i = 0; i < ksize.size(); ++i) ksize[i] = input->dims()[i + 2]; + for (size_t i = 0; i < ksize.size(); ++i) ksize[i] = in_X->dims()[i + 2]; } - if (input_grad) { - input_grad->mutable_data(context.GetPlace()); - auto temp = framework::EigenVector::Flatten(*input_grad); + if (in_X_grad) { + in_X_grad->mutable_data(context.GetPlace()); + auto temp = framework::EigenVector::Flatten(*in_X_grad); temp.device(context.GetEigenDevice()) = temp.constant(static_cast(0)); @@ -116,17 +116,15 @@ class PoolGradKernel : public framework::OpKernel { Place, paddle::operators::math::pool::maxPool, T> pool2d_backward; paddle::operators::math::pool::maxPool pool_process; - pool2d_backward(*input, *input_grad, *output, *output_grad, ksize, - strides, paddings, pool_process, - context.device_context()); + pool2d_backward(context.device_context(), *in_X, *in_X_grad, *out, + *out_grad, ksize, strides, paddings, pool_process); } else if (pooling_type == "ave") { paddle::operators::math::Pool2dBackwardFunctor< Place, paddle::operators::math::pool::avePool, T> pool2d_backward; paddle::operators::math::pool::avePool pool_process; - pool2d_backward(*input, *input_grad, *output, *output_grad, ksize, - strides, paddings, pool_process, - context.device_context()); + pool2d_backward(context.device_context(), *in_X, *in_X_grad, *out, + *out_grad, ksize, strides, paddings, pool_process); } } break; case 3: { @@ -135,17 +133,15 @@ class PoolGradKernel : public framework::OpKernel { Place, paddle::operators::math::pool::maxPool, T> pool3d_backward; paddle::operators::math::pool::maxPool pool_process; - pool3d_backward(*input, *input_grad, *output, *output_grad, ksize, - strides, paddings, pool_process, - context.device_context()); + pool3d_backward(context.device_context(), *in_X, *in_X_grad, *out, + *out_grad, ksize, strides, paddings, pool_process); } else if (pooling_type == "ave") { paddle::operators::math::Pool3dBackwardFunctor< Place, paddle::operators::math::pool::avePool, T> pool3d_backward; paddle::operators::math::pool::avePool pool_process; - pool3d_backward(*input, *input_grad, *output, *output_grad, ksize, - strides, paddings, pool_process, - context.device_context()); + pool3d_backward(context.device_context(), *in_X, *in_X_grad, *out, + *out_grad, ksize, strides, paddings, pool_process); } } break; } diff --git a/python/paddle/v2/framework/tests/test_pool2d_op.py b/python/paddle/v2/framework/tests/test_pool2d_op.py index cf327508f71..2a8fedc0379 100644 --- a/python/paddle/v2/framework/tests/test_pool2d_op.py +++ b/python/paddle/v2/framework/tests/test_pool2d_op.py @@ -47,23 +47,23 @@ class TestPool2d_Op(OpTest): input = np.random.random(self.shape).astype("float32") output = self.pool2D_forward_naive(input, self.ksize, self.strides, self.paddings) - self.inputs = {'Input': input} + self.inputs = {'X': input} self.attrs = { 'strides': self.strides, 'paddings': self.paddings, 'ksize': self.ksize, - 'pooling_type': self.pool_type, + 'poolingType': self.pool_type, } - self.outputs = {'Output': output} + self.outputs = {'Out': output} def test_check_output(self): self.check_output() def test_check_grad(self): if self.pool_type != "max": - self.check_grad(set(['Input']), 'Output', max_relative_error=0.07) + self.check_grad(set(['X']), 'Out', max_relative_error=0.07) def initTestCase(self): self.pool_type = "ave" diff --git a/python/paddle/v2/framework/tests/test_pool3d_op.py b/python/paddle/v2/framework/tests/test_pool3d_op.py index cfd0ced150b..907ee0c0fe6 100644 --- a/python/paddle/v2/framework/tests/test_pool3d_op.py +++ b/python/paddle/v2/framework/tests/test_pool3d_op.py @@ -57,23 +57,23 @@ class TestPool3d_Op(OpTest): input = np.random.random(self.shape).astype("float32") output = self.pool3D_forward_naive(input, self.ksize, self.strides, self.paddings) - self.inputs = {'Input': input} + self.inputs = {'X': input} self.attrs = { 'strides': self.strides, 'paddings': self.paddings, 'ksize': self.ksize, - 'pooling_type': self.pool_type, + 'poolingType': self.pool_type, } - self.outputs = {'Output': output} + self.outputs = {'Out': output} def test_check_output(self): self.check_output() def test_check_grad(self): if self.pool_type != "max": - self.check_grad(set(['Input']), 'Output', max_relative_error=0.07) + self.check_grad(set(['X']), 'Out', max_relative_error=0.07) def initTestCase(self): self.pool_type = "ave" -- GitLab