From 76fc1a82e109737d704b11d897b83b5f5138bc86 Mon Sep 17 00:00:00 2001 From: wanghaox Date: Mon, 20 Nov 2017 14:33:28 +0800 Subject: [PATCH] for code review 4 --- paddle/operators/math/maxouting.cc | 10 +++------- .../math/{maxouting.cu => maxouting.cu.cc} | 5 +++-- paddle/operators/math/maxouting.h | 2 +- paddle/operators/maxout_op.cc | 15 +++++++-------- .../operators/{maxout_op.cu => maxout_op.cu.cc} | 1 - paddle/operators/maxout_op.h | 11 ++++------- python/paddle/v2/fluid/tests/test_maxout_op.py | 5 ++--- 7 files changed, 20 insertions(+), 29 deletions(-) rename paddle/operators/math/{maxouting.cu => maxouting.cu.cc} (97%) rename paddle/operators/{maxout_op.cu => maxout_op.cu.cc} (97%) diff --git a/paddle/operators/math/maxouting.cc b/paddle/operators/math/maxouting.cc index a4d46ccc98..c8c1974f79 100644 --- a/paddle/operators/math/maxouting.cc +++ b/paddle/operators/math/maxouting.cc @@ -18,10 +18,7 @@ namespace paddle { namespace operators { namespace math { -/* - * All tensors are in NCHW format. - * groups mustbe > 1 - */ +// All tensors are in NCHW format, and the groups must be greater than 1 template class MaxOutFunctor { public: @@ -44,7 +41,6 @@ class MaxOutFunctor { for (int c = 0; c < output_channels; ++c) { int new_cindex = fea_size * c; for (int f = 0; f < fea_size; ++f) { - // T ele = maxout_process.initial(); T ele = static_cast(-FLT_MAX); for (int ph = 0; ph < groups; ++ph) { T x = input_data[(new_bindex + new_cindex) * groups @@ -65,7 +61,7 @@ class MaxOutGradFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, - framework::Tensor& input_grad, + framework::Tensor * input_grad, const framework::Tensor& output, const framework::Tensor& output_grad, int groups) { @@ -77,7 +73,7 @@ public: const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; ++i) { int blen = fea_size * output_channels * i; diff --git a/paddle/operators/math/maxouting.cu b/paddle/operators/math/maxouting.cu.cc similarity index 97% rename from paddle/operators/math/maxouting.cu rename to paddle/operators/math/maxouting.cu.cc index 336a1bd8b5..3a0600fd84 100644 --- a/paddle/operators/math/maxouting.cu +++ b/paddle/operators/math/maxouting.cu.cc @@ -112,7 +112,8 @@ template class MaxOutGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, + framework::Tensor * input_grad, const framework::Tensor& output, const framework::Tensor& output_grad, int groups) { @@ -127,7 +128,7 @@ class MaxOutGradFunctor { const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); int nthreads = output.numel(); int blocks = (nthreads + 1024 - 1) / 1024; dim3 threads(1024, 1); diff --git a/paddle/operators/math/maxouting.h b/paddle/operators/math/maxouting.h index 76a256add9..d4c9da38ab 100644 --- a/paddle/operators/math/maxouting.h +++ b/paddle/operators/math/maxouting.h @@ -38,7 +38,7 @@ class MaxOutGradFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, - framework::Tensor& input_grad, + framework::Tensor * input_grad, const framework::Tensor& output, const framework::Tensor& output_grad, int groups); }; diff --git a/paddle/operators/maxout_op.cc b/paddle/operators/maxout_op.cc index f9277518cc..95467f2e69 100644 --- a/paddle/operators/maxout_op.cc +++ b/paddle/operators/maxout_op.cc @@ -34,14 +34,13 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { "width of feature."); AddAttr( "groups", - R"DOC(The group number of input layer. + R"DOC("Specifies how many groups the input tensor will be split" + "in the channel dimension. And the number of output channel is " + "the number of channels divided by groups.." )DOC"); AddComment(R"DOC( - - Input: NCHW. - - Output: The feature map size of output is the same as the input. - The output_channel is (input channel) / groups - So groups should be larger than 1, and the num of channels should be able - to be devided by groups. + Assumed the input shape is (N, Ci, H, W). + The output shape is (N, Co, H, W). Then `Co = Ci / groups`. math: y_{si+j} = \max_k x_{gsi + sk + j} @@ -65,10 +64,10 @@ class MaxOutOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of maxoutOp" + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MaxoutOp" "should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of maxoutOp should not be null."); + "Output(Out) of MaxoutOp should not be null."); auto in_x_dims = ctx->GetInputDim("X"); int groups = ctx->Attrs().Get("groups"); // check groups > 1 diff --git a/paddle/operators/maxout_op.cu b/paddle/operators/maxout_op.cu.cc similarity index 97% rename from paddle/operators/maxout_op.cu rename to paddle/operators/maxout_op.cu.cc index 44a149b065..3e6debf699 100644 --- a/paddle/operators/maxout_op.cu +++ b/paddle/operators/maxout_op.cu.cc @@ -12,7 +12,6 @@ See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU #include "paddle/operators/maxout_op.h" namespace ops = paddle::operators; diff --git a/paddle/operators/maxout_op.h b/paddle/operators/maxout_op.h index 6c769838c3..c404cd16a9 100644 --- a/paddle/operators/maxout_op.h +++ b/paddle/operators/maxout_op.h @@ -31,9 +31,7 @@ class MaxOutKernel : public framework::OpKernel { Tensor* out = context.Output("Out"); int groups = context.template Attr("groups"); - paddle::operators::math::MaxOutFunctor< - Place, T> - maxout_forward; + math::MaxOutFunctor maxout_forward; maxout_forward(context.device_context(), *in_x, out, groups); } }; @@ -53,10 +51,9 @@ class MaxOutGradKernel : public framework::OpKernel { if (in_x_grad) { in_x_grad->mutable_data(context.GetPlace()); zero(device_ctx, in_x_grad, static_cast(0.0)); - paddle::operators::math::MaxOutGradFunctor - maxout_backward; - maxout_backward(context.device_context(), *in_x, *in_x_grad, *out, - *out_grad, groups); + math::MaxOutGradFunctor maxout_backward; + maxout_backward(context.device_context(), *in_x, in_x_grad, *out, + *out_grad, groups); } } }; diff --git a/python/paddle/v2/fluid/tests/test_maxout_op.py b/python/paddle/v2/fluid/tests/test_maxout_op.py index a7c47108f1..1416e13feb 100644 --- a/python/paddle/v2/fluid/tests/test_maxout_op.py +++ b/python/paddle/v2/fluid/tests/test_maxout_op.py @@ -3,7 +3,7 @@ import numpy as np from op_test import OpTest -def maxout_forward_naive(input, groups,num_channels): +def maxout_forward_naive(input, groups): s0, s1, s2, s3 = input.shape return np.ndarray([s0, s1 / groups, groups, s2, s3], \ buffer = input, dtype=input.dtype).max(axis=(2)) @@ -18,7 +18,7 @@ class TestMaxOutOp(OpTest): self.num_channels).astype("float32") self.inputs = {'X': input} - self.attrs = {'groups': self.groups, 'num_channels': self.num_channels} + self.attrs = {'groups': self.groups} self.outputs = {'Out': output.astype('float32')} @@ -32,7 +32,6 @@ class TestMaxOutOp(OpTest): self.MaxOut_forward_naive = maxout_forward_naive self.shape = [100, 6, 2, 2] self.groups=2 - self.num_channels=6 -- GitLab