diff --git a/paddle/operators/math/maxouting.cc b/paddle/operators/math/maxouting.cc index b733af7410f5e40179a387cb056ffe41ab9e133d..baaa86ffced70743445f939a1d6a68d947dcd334 100644 --- a/paddle/operators/math/maxouting.cc +++ b/paddle/operators/math/maxouting.cc @@ -22,23 +22,20 @@ namespace math { * All tensors are in NCHW format. * groups mustbe > 1 */ -template -class MaxOutFunctor { +template +class MaxOutFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, framework::Tensor * output, - int groups, - MaxOutProcess maxout_process) { + int groups) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; const int output_channels = output->dims()[1]; - int fea_size = input_height * input_width; - // c_size mean output one batch size + // c_size means the output size of each sample int c_size = fea_size * output_channels; - const T* input_data = input.data(); T* output_data = output->mutable_data(context.GetPlace()); @@ -47,10 +44,11 @@ class MaxOutFunctor { for (int c = 0; c < output_channels; ++c) { int new_cindex = fea_size * c; for (int f = 0; f < fea_size; ++f) { - T ele = maxout_process.initial(); + // T ele = maxout_process.initial(); + T ele = static_cast(-FLT_MAX); for (int ph = 0; ph < groups; ++ph) { - maxout_process.compute(ele, - input_data[(new_bindex+new_cindex) * groups+ph*fea_size+f]); + T x=input_data[(new_bindex+new_cindex) * groups+ph*fea_size+f]; + ele = ele > x ? ele : x; } output_data[(new_bindex+new_cindex+f)] = ele; } @@ -74,9 +72,7 @@ public: const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; const int output_channels = output.dims()[1]; - int fea_size = input_height * input_width; - const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); @@ -87,15 +83,15 @@ public: for (int c = 0; c < output_channels; ++c) { int clen = fea_size * c; for (int f = 0; f < fea_size; ++f) { - int input_idx = 0; - bool stop = false; + int input_idx0 = (blen + clen) * groups + f; + bool continue_match = true; int output_idx = blen + clen + f; - for (int g = 0; g < groups && !stop; ++g) { - input_idx = (blen + clen) * groups + fea_size * g + f; + for (int g = 0; g < groups && continue_match; ++g) { + int input_idx = input_idx0 + fea_size * g; input_grad_data[input_idx] = 0; if (input_data[input_idx] == output_data[output_idx]) { input_grad_data[input_idx] += output_grad_data[output_idx]; - stop = true; + continue_match = false; } } } @@ -106,10 +102,8 @@ public: template class MaxOutGradFunctor; template class MaxOutGradFunctor; -template class MaxOutFunctor, float>; -template class MaxOutFunctor, double>; +template class MaxOutFunctor; +template class MaxOutFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/maxouting.cu b/paddle/operators/math/maxouting.cu index c2da29e35685f7049ca314d4dd03f2019bb5f409..1a8fc465cc3ffef6ebf204384536290faa80a1de 100644 --- a/paddle/operators/math/maxouting.cu +++ b/paddle/operators/math/maxouting.cu @@ -19,27 +19,28 @@ namespace paddle { namespace operators { namespace math { -template +template __global__ void KernelMaxOut(const int nthreads, const T* input_data, const int channels, const int input_height, const int input_width, - int groups, T* output_data, - MaxOutProcess maxout_process) { + int groups, T* output_data ) { const int size = input_height * input_width * channels / groups; const int feat_len = input_height * input_width; - for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; - index += blockDim.x * gridDim.x) { - int batch_idx = index / size; - int batch_offset = index % size; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (int i = index; i < nthreads; i += offset) { + int batch_idx = i / size; + int batch_offset = i % size; int channel_idx = batch_offset / feat_len; int feat_idx = batch_offset % feat_len; int data_idx = (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; - T ele = maxout_process.initial(); + T ele = static_cast(-FLT_MAX); for (int g = 0; g < groups; ++g) { - maxout_process.compute(ele, input_data[data_idx + g * feat_len]); + T x=input_data[data_idx + g * feat_len]; + ele = ele > x ? ele : x; } - output_data[index] = ele; + output_data[i] = ele; } } template @@ -49,38 +50,38 @@ __global__ void KernelMaxoutGrad( const int input_height, const int input_width, int groups) { const int size = input_height * input_width * channels / groups; const int feat_len = input_height * input_width; - for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; - index += blockDim.x * gridDim.x) { - int batch_idx = index / size; - int batch_offset = index % size; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (int i = index; i < nthreads; i += offset) { + int batch_idx = i / size; + int batch_offset = i % size; int channel_idx = batch_offset / feat_len; int feat_idx = batch_offset % feat_len; int data_idx = (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; - int maxIndex = -1; - bool stop = false; - for (int g = 0; g < groups && !stop; ++g) { - if (input_data[data_idx + g * feat_len] == output_data[index]) { - maxIndex = data_idx + g * feat_len; - stop = true; + int max_index = -1; + bool continue_match = true; + for (int g = 0; g < groups && continue_match; ++g) { + if (input_data[data_idx + g * feat_len] == output_data[i]) { + max_index = data_idx + g * feat_len; + continue_match = false; } } - if (maxIndex != -1) { + if (max_index != -1) { // atomic add - platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]); + platform::CudaAtomicAdd(input_grad + max_index, output_grad[index]); } } } /* * All tensors are in NCHW format. */ -template -class MaxOutFunctor { +template +class MaxOutFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, framework::Tensor * output, - int groups, - MaxOutProcess maxout_process) { + int groups) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -97,12 +98,11 @@ class MaxOutFunctor { dim3 grid(blocks, 1); KernelMaxOut< - MaxOutProcess, T><<(context) .stream()>>>(nthreads, input_data, input_channels, input_height, input_width, groups, - output_data, maxout_process); + output_data); } }; /* @@ -145,10 +145,8 @@ class MaxOutGradFunctor { template class MaxOutGradFunctor; template class MaxOutGradFunctor; -template class MaxOutFunctor, float>; -template class MaxOutFunctor, double>; +template class MaxOutFunctor; +template class MaxOutFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/maxouting.h b/paddle/operators/math/maxouting.h index a8e91a25b542cd55a94843983207966f70d457a4..72f40d96f79849d9592103556529b8935ad0d72d 100644 --- a/paddle/operators/math/maxouting.h +++ b/paddle/operators/math/maxouting.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/framework/eigen.h" #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" #include "paddle/platform/hostdevice.h" @@ -22,42 +21,18 @@ namespace paddle { namespace operators { namespace math { - #define FLT_MAX \ __FLT_MAX__ -/* - * \brief Extracting simple operations from maxout. - * need "initial", "compute" - * operation. - */ -template -class MaxOut { - public: - DEVICE inline T initial() { return static_cast(-FLT_MAX); } - DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; } -}; - -template -class MaxOutGrad { - public: - DEVICE inline void compute(const T& x, const T& y, const T& dy, T& dx, - T scale) { - dx += dy * (x == y); - } -}; - - -template +template class MaxOutFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, framework::Tensor * output, - int groups, MaxOutProcess maxout_compute); + int groups ); }; - template class MaxOutGradFunctor { public: @@ -67,13 +42,6 @@ class MaxOutGradFunctor { const framework::Tensor& output, const framework::Tensor& output_grad, int groups); }; - - - - - - - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/maxout_op.cc b/paddle/operators/maxout_op.cc index c54a7069799b9dfad2d2267cb256522843649aac..f9277518cc49c31861232b1eea80ae9b2fc80be3 100644 --- a/paddle/operators/maxout_op.cc +++ b/paddle/operators/maxout_op.cc @@ -12,7 +12,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include "paddle/operators/maxout_op.h" namespace paddle { namespace operators { @@ -33,18 +32,18 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { "Where N is batch size, C is " "the number of channels, H and W is the height and " "width of feature."); - AddAttr( "groups", R"DOC(The group number of input layer. )DOC"); AddComment(R"DOC( - Input: NCHW. - - Output: feature map size same as input. Channel is (input channel) / groups. + - Output: The feature map size of output is the same as the input. + The output_channel is (input channel) / groups So groups should be larger than 1, and the num of channels should be able - to devided by groups. + to be devided by groups. - .. math:: + math: y_{si+j} = \max_k x_{gsi + sk + j} g = groups s = input.size / num_channels @@ -57,29 +56,6 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { - Multi-digit Number Recognition from Street View \ Imagery using Deep Convolutional Neural Networks: \ https://arxiv.org/pdf/1312.6082v4.pdf - - The simple usage is: - - .. code-block:: python - - maxout = maxout_layer(input, - num_channels=128, - groups=4) - - :param input: The input of this layer. - :type input: LayerOutput - :param num_channels: The channel number of input layer. If None will be set - automatically from previous output. - :type num_channels: int | None - :param groups: The group number of input layer. - :type groups: int - :param name: The name of this layer. It is optional. - :type name: None | basestring. - :param layer_attr: Extra Layer attribute. - :type layer_attr: ExtraLayerAttribute - :return: LayerOutput object. - :rtype: LayerOutput - )DOC"); } }; @@ -88,7 +64,6 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { class MaxOutOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of maxoutOp" "should not be null."); @@ -96,26 +71,20 @@ class MaxOutOp : public framework::OperatorWithKernel { "Output(Out) of maxoutOp should not be null."); auto in_x_dims = ctx->GetInputDim("X"); int groups = ctx->Attrs().Get("groups"); - // check groups > 1 PADDLE_ENFORCE_GT( groups, 1, - "in maxoutop groups should be larger than 1"); - - + "groups should be larger than 1 in maxoutop"); std::vector output_shape({in_x_dims[0], in_x_dims[1] / groups}); output_shape.push_back(in_x_dims[2]); output_shape.push_back(in_x_dims[3]); - ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); } }; - class MaxOutOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), @@ -129,8 +98,6 @@ class MaxOutOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP(maxout, ops::MaxOutOp, ops::MaxOutOpMaker, maxout_grad, ops::MaxOutOpGrad); - - REGISTER_OP_CPU_KERNEL(maxout, ops::MaxOutKernel); REGISTER_OP_CPU_KERNEL(maxout_grad, diff --git a/paddle/operators/maxout_op.h b/paddle/operators/maxout_op.h index aab878af0facd474bd553b913b2378b51615601d..6c769838c35024c6d0a13c722c226dece5150897 100644 --- a/paddle/operators/maxout_op.h +++ b/paddle/operators/maxout_op.h @@ -29,16 +29,12 @@ class MaxOutKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { const Tensor* in_x = context.Input("X"); Tensor* out = context.Output("Out"); - int groups = context.template Attr("groups"); - paddle::operators::math::MaxOutFunctor< - Place, paddle::operators::math::MaxOut, T> + Place, T> maxout_forward; - paddle::operators::math::MaxOut maxout_process; - maxout_forward(context.device_context(), *in_x, out, groups, - maxout_process); + maxout_forward(context.device_context(), *in_x, out, groups); } }; @@ -51,15 +47,12 @@ class MaxOutGradKernel : public framework::OpKernel { const Tensor* out_grad = context.Input(framework::GradVarName("Out")); Tensor* in_x_grad = context.Output(framework::GradVarName("X")); - int groups = context.template Attr("groups"); - auto& device_ctx = context.device_context(); math::SetConstant zero; if (in_x_grad) { in_x_grad->mutable_data(context.GetPlace()); zero(device_ctx, in_x_grad, static_cast(0.0)); - paddle::operators::math::MaxOutGradFunctor maxout_backward; maxout_backward(context.device_context(), *in_x, *in_x_grad, *out,