From bd773b9c8429a64287d840eb5bd297c882b1d9d7 Mon Sep 17 00:00:00 2001 From: wanghaox Date: Tue, 14 Nov 2017 14:20:50 +0800 Subject: [PATCH] modify for maxoutop code review --- paddle/operators/math/CMakeLists.txt | 6 +- paddle/operators/math/maxouting.cc | 25 ++++---- paddle/operators/math/maxouting.cu | 61 ++++++++---------- paddle/operators/math/maxouting.h | 22 +++---- paddle/operators/maxout_op.cc | 63 +++++++++++++------ paddle/operators/maxout_op.h | 7 +-- .../v2/framework/tests/test_maxout_op.py | 13 +--- 7 files changed, 98 insertions(+), 99 deletions(-) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index fb83b1478..3b4af8e43 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -8,24 +8,26 @@ if(WITH_GPU) nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) - nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context) + nv_library(sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context) nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context) nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function) + nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) cc_library(softmax SRCS softmax.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) cc_library(pooling SRCS pooling.cc DEPS device_context) - cc_library(maxouting SRCS maxouting.cc DEPS device_context) + cc_library(sequence_pooling SRCS sequence_pooling.cc DEPS device_context math_function) cc_library(vol2col SRCS vol2col.cc DEPS device_context) cc_library(context_project SRCS context_project.cc DEPS device_context) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) + cc_library(maxouting SRCS maxouting.cc DEPS device_context) endif() cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/maxouting.cc b/paddle/operators/math/maxouting.cc index f01fa1839..a634e49f4 100644 --- a/paddle/operators/math/maxouting.cc +++ b/paddle/operators/math/maxouting.cc @@ -20,25 +20,27 @@ namespace math { /* * All tensors are in NCHW format. - * Ksize, strides, paddings are two elements. These two elements represent - * height and width, respectively. + * groups mustbe > 1 */ template class MaxOutFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - int groups, int num_channels, MaxOutProcess maxout_process) { + const framework::Tensor& input, + framework::Tensor * output, + int groups, + MaxOutProcess maxout_process) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; - const int output_channels = num_channels/groups; + const int output_channels = output->dims()[1]; int fea_size = input_height * input_width; + // c_size mean output one batch size int c_size = fea_size * output_channels; const T* input_data = input.data(); - T* output_data = output.mutable_data(context.GetPlace()); + T* output_data = output->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; i++) { int new_bindex = c_size * i; @@ -50,7 +52,6 @@ class MaxOutFunctor { maxout_process.compute(ele, input_data[(new_bindex+new_cindex) * groups+ph*fea_size+f]); } - maxout_process.finalize(ele, (static_cast(groups))); output_data[(new_bindex+new_cindex+f)] = ele; } } @@ -68,11 +69,11 @@ public: framework::Tensor& input_grad, const framework::Tensor& output, const framework::Tensor& output_grad, - int groups, int num_channels) { + int groups) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; - const int output_channels = num_channels / groups; + const int output_channels = output.dims()[1]; int fea_size = input_height * input_width; @@ -95,8 +96,6 @@ public: if (input_data[input_idx] == output_data[output_idx]) { input_grad_data[input_idx] += output_grad_data[output_idx]; stop = true; - } else { - input_grad_data[input_idx] = 0; } } } @@ -108,9 +107,9 @@ public: template class MaxOutGradFunctor; template class MaxOutGradFunctor; template class MaxOutFunctor, float>; + math::MaxOut, float>; template class MaxOutFunctor, double>; + math::MaxOut, double>; } // namespace math } // namespace operators diff --git a/paddle/operators/math/maxouting.cu b/paddle/operators/math/maxouting.cu index b1c0dd8fd..42acaa2c7 100644 --- a/paddle/operators/math/maxouting.cu +++ b/paddle/operators/math/maxouting.cu @@ -24,21 +24,20 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data, T* output_data, const int channels, const int input_height, const int input_width, int groups, MaxOutProcess maxout_process) { - int size = input_height * input_width * channels / groups; - int featLen = input_height * input_width; + const int size = input_height * input_width * channels / groups; + const int feat_len = input_height * input_width; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int batch_idx = index / size; - int i = index % size; - int channel_idx = i / featLen; - int feat_idx = i % featLen; + int batch_offset = index % size; + int channel_idx = batch_offset / feat_len; + int feat_idx = batch_offset % feat_len; int data_idx = - (batch_idx * size + channel_idx * featLen) * groups + feat_idx; + (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; T ele = maxout_process.initial(); - for (int g = 0; g < groups; g++) { - maxout_process.compute(ele, input_data[data_idx + g * featLen]); + for (int g = 0; g < groups; ++g) { + maxout_process.compute(ele, input_data[data_idx + g * feat_len]); } - maxout_process.finalize(ele, (static_cast(groups))); output_data[index] = ele; } } @@ -47,21 +46,21 @@ __global__ void KernelMaxoutGrad( const int nthreads, const T* input_data, const T* output_data, const T* output_grad, T* input_grad, const int channels, const int input_height, const int input_width, int groups) { - int size = input_height * input_width * channels / groups; - int featLen = input_height * input_width; + const int size = input_height * input_width * channels / groups; + const int feat_len = input_height * input_width; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int batch_idx = index / size; - int i = index % size; - int channel_idx = i / featLen; - int feat_idx = i % featLen; + int batch_offset = index % size; + int channel_idx = batch_offset / feat_len; + int feat_idx = batch_offset % feat_len; int data_idx = - (batch_idx * size + channel_idx * featLen) * groups + feat_idx; + (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; int maxIndex = -1; bool stop = false; for (int g = 0; g < groups && !stop; g++) { - if (input_data[data_idx + g * featLen] == output_data[index]) { - maxIndex = data_idx + g * featLen; + if (input_data[data_idx + g * feat_len] == output_data[index]) { + maxIndex = data_idx + g * feat_len; stop = true; } } @@ -73,28 +72,25 @@ __global__ void KernelMaxoutGrad( } /* * All tensors are in NCHW format. - * Ksize, strides, paddings are two elements. These two elements represent - * height and width, respectively. */ template class MaxOutFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - int groups, int num_channels, + const framework::Tensor& input, framework::Tensor * output, + int groups, MaxOutProcess maxout_process) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; - const int output_channels = num_channels / groups; - const int output_height = output.dims()[2]; - const int output_width = output.dims()[3]; + const int output_channels = output->dims()[1]; + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; const T* input_data = input.data(); - T* output_data = output.mutable_data(context.GetPlace()); - - int nthreads = batch_size * output_channels * output_height * output_width; + T* output_data = output->mutable_data(context.GetPlace()); + int nthreads = output->numel(); int blocks = (nthreads + 1024 - 1) / 1024; dim3 threads(1024, 1); dim3 grid(blocks, 1); @@ -110,8 +106,6 @@ class MaxOutFunctor { }; /* * All tensors are in NCHW format. - * Ksize, strides, paddings are two elements. These two elements represent - * height and width, respectively. */ template class MaxOutGradFunctor { @@ -120,7 +114,7 @@ class MaxOutGradFunctor { const framework::Tensor& input, framework::Tensor& input_grad, const framework::Tensor& output, const framework::Tensor& output_grad, - int groups, int num_channels) { + int groups) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -133,8 +127,7 @@ class MaxOutGradFunctor { const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); T* input_grad_data = input_grad.mutable_data(context.GetPlace()); - - int nthreads = batch_size * output_channels * output_height * output_width; + int nthreads = output.numel(); int blocks = (nthreads + 1024 - 1) / 1024; dim3 threads(1024, 1); dim3 grid(blocks, 1); @@ -152,9 +145,9 @@ template class MaxOutGradFunctor; template class MaxOutGradFunctor; template class MaxOutFunctor, float>; + math::MaxOut, float>; template class MaxOutFunctor, double>; + math::MaxOut, double>; } // namespace math } // namespace operators diff --git a/paddle/operators/math/maxouting.h b/paddle/operators/math/maxouting.h index aeac08494..6aaa1656a 100644 --- a/paddle/operators/math/maxouting.h +++ b/paddle/operators/math/maxouting.h @@ -22,26 +22,20 @@ namespace paddle { namespace operators { namespace math { + #define FLT_MAX \ - __FLT_MAX__ // It might need to be placed in another file, but I'm still - // wondering where to put it. + __FLT_MAX__ /* - * \brief Extracting simple operations from pooling. - * Both MaxPool and AvgPool need "initial", "compute" and "finalize" + * \brief Extracting simple operations from maxout. + * need "initial", "compute" * operation. - * MaxPool initializes temp variable to the negative maximum to find the - * maximum value in the pooling field. - * AvgPool initializes temp variable to the zero to accumulate all values - * in pool pooling, and finally takes the average. - * MaxPoolGrad and AvgPoolGrad are gradient operations respectively. */ template class MaxOut { public: DEVICE inline T initial() { return static_cast(-FLT_MAX); } DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; } - DEVICE inline void finalize(T& y, const T& group) {} }; template @@ -69,11 +63,12 @@ class MaxOutGrad { * MaxPool2dGradFunctor, MaxPool3dGradFunctor. */ template + class MaxOutFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - int groups, int num_channels, MaxOutProcess maxout_compute); + const framework::Tensor& input, framework::Tensor * output, + int groups, MaxOutProcess maxout_compute); }; @@ -84,8 +79,7 @@ class MaxOutGradFunctor { const framework::Tensor& input, framework::Tensor& input_grad, const framework::Tensor& output, - const framework::Tensor& output_grad, int groups, - int num_channels); + const framework::Tensor& output_grad, int groups); }; diff --git a/paddle/operators/maxout_op.cc b/paddle/operators/maxout_op.cc index 41b3860a8..c54a70697 100644 --- a/paddle/operators/maxout_op.cc +++ b/paddle/operators/maxout_op.cc @@ -19,17 +19,16 @@ namespace operators { using framework::Tensor; -/********first define ProtoMakerē±» ***************/ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { public: MaxOutOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", - "(Tensor) The input tensor of pooling operator. " + "(Tensor) The input tensor of maxout operator. " "The format of input tensor is NCHW. Where N is batch size, C is the " "number of channels, H and W is the height and width of feature."); AddOutput("Out", - "(Tensor) The output tensor of pooling operator." + "(Tensor) The output tensor of maxout operator." "The format of output tensor is also NCHW." "Where N is batch size, C is " "the number of channels, H and W is the height and " @@ -38,23 +37,53 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr( "groups", R"DOC(The group number of input layer. - )DOC") - .SetDefault(2); - AddAttr( - "num_channels", - R"DOC(The channel number of input layer. - )DOC") - .SetDefault(0); - AddComment(R"DOC(A layer to do max out on conv layer output. - - Input: output of a conv layer. + )DOC"); + AddComment(R"DOC( + - Input: NCHW. - Output: feature map size same as input. Channel is (input channel) / groups. So groups should be larger than 1, and the num of channels should be able to devided by groups. + + .. math:: + y_{si+j} = \max_k x_{gsi + sk + j} + g = groups + s = input.size / num_channels + 0 \le i < num_channels / groups + 0 \le j < s + 0 \le k < groups + + Please refer to Paper: + - Maxout Networks: http://www.jmlr.org/proceedings/papers/v28/goodfellow13.pdf + - Multi-digit Number Recognition from Street View \ + Imagery using Deep Convolutional Neural Networks: \ + https://arxiv.org/pdf/1312.6082v4.pdf + + The simple usage is: + + .. code-block:: python + + maxout = maxout_layer(input, + num_channels=128, + groups=4) + + :param input: The input of this layer. + :type input: LayerOutput + :param num_channels: The channel number of input layer. If None will be set + automatically from previous output. + :type num_channels: int | None + :param groups: The group number of input layer. + :type groups: int + :param name: The name of this layer. It is optional. + :type name: None | basestring. + :param layer_attr: Extra Layer attribute. + :type layer_attr: ExtraLayerAttribute + :return: LayerOutput object. + :rtype: LayerOutput + )DOC"); } }; -/******************2nd **********************************/ class MaxOutOp : public framework::OperatorWithKernel { public: @@ -67,20 +96,14 @@ class MaxOutOp : public framework::OperatorWithKernel { "Output(Out) of maxoutOp should not be null."); auto in_x_dims = ctx->GetInputDim("X"); int groups = ctx->Attrs().Get("groups"); - int num_channels = ctx->Attrs().Get("num_channels"); // check groups > 1 PADDLE_ENFORCE_GT( groups, 1, "in maxoutop groups should be larger than 1"); - // check num_channels%groups=0 - PADDLE_ENFORCE_EQ(num_channels % groups, 0, - "the num of channels should be able" - "to devided by groups"); - int out_num_channels = num_channels / groups; - std::vector output_shape({in_x_dims[0], out_num_channels}); + std::vector output_shape({in_x_dims[0], in_x_dims[1] / groups}); output_shape.push_back(in_x_dims[2]); output_shape.push_back(in_x_dims[3]); diff --git a/paddle/operators/maxout_op.h b/paddle/operators/maxout_op.h index 232161351..3f5897abd 100644 --- a/paddle/operators/maxout_op.h +++ b/paddle/operators/maxout_op.h @@ -14,7 +14,6 @@ limitations under the License. */ #pragma once -#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/maxouting.h" @@ -32,14 +31,13 @@ class MaxOutKernel : public framework::OpKernel { Tensor* out = context.Output("Out"); int groups = context.template Attr("groups"); - int num_channels = context.template Attr("num_channels"); paddle::operators::math::MaxOutFunctor< Place, paddle::operators::math::MaxOut, T> maxout_forward; paddle::operators::math::MaxOut maxout_process; - maxout_forward(context.device_context(), *in_x, *out, groups, num_channels, + maxout_forward(context.device_context(), *in_x, out, groups, maxout_process); } }; @@ -55,7 +53,6 @@ class MaxOutGradKernel : public framework::OpKernel { Tensor* in_x_grad = context.Output(framework::GradVarName("X")); int groups = context.template Attr("groups"); - int num_channels = context.template Attr("num_channels"); @@ -68,7 +65,7 @@ class MaxOutGradKernel : public framework::OpKernel { paddle::operators::math::MaxOutGradFunctor maxout_backward; maxout_backward(context.device_context(), *in_x, *in_x_grad, *out, - *out_grad, groups, num_channels); + *out_grad, groups); } } }; diff --git a/python/paddle/v2/framework/tests/test_maxout_op.py b/python/paddle/v2/framework/tests/test_maxout_op.py index 4ea1e3c29..406147ef2 100644 --- a/python/paddle/v2/framework/tests/test_maxout_op.py +++ b/python/paddle/v2/framework/tests/test_maxout_op.py @@ -3,22 +3,13 @@ import numpy as np from op_test import OpTest - -def maxout_forward_naive_2sweetsky(input, groups, num_channels): - s0, s1, s2, s3 = input.shape - return np.ndarray([s0, s1 / groups, groups, s2, s3], \ - buffer = input, dtype=input.dtype).max(axis=(2)) - - def maxout_forward_naive(input, groups,num_channels): s0, s1, s2, s3 = input.shape return np.ndarray([s0, s1 / groups, groups, s2, s3], \ buffer = input, dtype=input.dtype).max(axis=(2)) - - -class TestMaxOut_Op(OpTest): +class TestMaxOutOp(OpTest): def setUp(self): self.op_type = "maxout" self.init_test_case() @@ -37,7 +28,7 @@ class TestMaxOut_Op(OpTest): def test_check_grad(self): print self.inputs print self.outputs - self.check_grad(['X'], 'Out', max_relative_error=0.5) + self.check_grad(['X'], 'Out') def init_test_case(self): self.MaxOut_forward_naive = maxout_forward_naive -- GitLab