diff --git a/paddle/operators/math/maxouting.cc b/paddle/operators/math/maxouting.cc new file mode 100644 index 0000000000000000000000000000000000000000..f01fa183913ef19fab9076a0d66330e17399de2c --- /dev/null +++ b/paddle/operators/math/maxouting.cc @@ -0,0 +1,117 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/maxouting.h" + +namespace paddle { +namespace operators { +namespace math { + +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ +template +class MaxOutFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + int groups, int num_channels, MaxOutProcess maxout_process) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = num_channels/groups; + + int fea_size = input_height * input_width; + int c_size = fea_size * output_channels; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + int new_bindex = c_size * i; + for (int c = 0; c < output_channels; ++c) { + int new_cindex = fea_size * c; + for (int f = 0; f < fea_size; f++) { + T ele = maxout_process.initial(); + for (int ph = 0; ph < groups; ++ph) { + maxout_process.compute(ele, + input_data[(new_bindex+new_cindex) * groups+ph*fea_size+f]); + } + maxout_process.finalize(ele, (static_cast(groups))); + output_data[(new_bindex+new_cindex+f)] = ele; + } + } + } + } +}; + + + +template +class MaxOutGradFunctor { +public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, + int groups, int num_channels) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = num_channels / groups; + + int fea_size = input_height * input_width; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + int blen = fea_size * output_channels * i; + for (int c = 0; c < output_channels; ++c) { + int clen = fea_size * c; + for (int f = 0; f < fea_size; f++) { + int input_idx = 0; + bool stop = false; + int output_idx = blen + clen + f; + for (int g = 0; g < groups && !stop; g++) { + input_idx = (blen + clen) * groups + fea_size * g + f; + input_grad_data[input_idx] = 0; + if (input_data[input_idx] == output_data[output_idx]) { + input_grad_data[input_idx] += output_grad_data[output_idx]; + stop = true; + } else { + input_grad_data[input_idx] = 0; + } + } + } + } + } + } +}; + +template class MaxOutGradFunctor; +template class MaxOutGradFunctor; +template class MaxOutFunctor, float>; +template class MaxOutFunctor, double>; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/maxouting.cu b/paddle/operators/math/maxouting.cu new file mode 100644 index 0000000000000000000000000000000000000000..b1c0dd8fd4b4664c96ad7fe0ae9b44b7609bf0a8 --- /dev/null +++ b/paddle/operators/math/maxouting.cu @@ -0,0 +1,161 @@ +/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/maxouting.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__global__ void KernelMaxOut(const int nthreads, const T* input_data, + T* output_data, const int channels, + const int input_height, const int input_width, + int groups, MaxOutProcess maxout_process) { + int size = input_height * input_width * channels / groups; + int featLen = input_height * input_width; + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int batch_idx = index / size; + int i = index % size; + int channel_idx = i / featLen; + int feat_idx = i % featLen; + int data_idx = + (batch_idx * size + channel_idx * featLen) * groups + feat_idx; + T ele = maxout_process.initial(); + for (int g = 0; g < groups; g++) { + maxout_process.compute(ele, input_data[data_idx + g * featLen]); + } + maxout_process.finalize(ele, (static_cast(groups))); + output_data[index] = ele; + } +} +template +__global__ void KernelMaxoutGrad( + const int nthreads, const T* input_data, const T* output_data, + const T* output_grad, T* input_grad, const int channels, + const int input_height, const int input_width, int groups) { + int size = input_height * input_width * channels / groups; + int featLen = input_height * input_width; + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int batch_idx = index / size; + int i = index % size; + int channel_idx = i / featLen; + int feat_idx = i % featLen; + int data_idx = + (batch_idx * size + channel_idx * featLen) * groups + feat_idx; + int maxIndex = -1; + bool stop = false; + for (int g = 0; g < groups && !stop; g++) { + if (input_data[data_idx + g * featLen] == output_data[index]) { + maxIndex = data_idx + g * featLen; + stop = true; + } + } + if (maxIndex != -1) { + // atomic add + platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]); + } + } +} +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ +template +class MaxOutFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + int groups, int num_channels, + MaxOutProcess maxout_process) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = num_channels / groups; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_height * output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxOut< + MaxOutProcess, + T><<(context) + .stream()>>>(nthreads, input_data, output_data, input_channels, + input_height, input_width, groups, + maxout_process); + } +}; +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ +template +class MaxOutGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, + int groups, int num_channels) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_height * output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxoutGrad< + T><<(context) + .stream()>>>( + nthreads, input_data, output_data, output_grad_data, input_grad_data, + input_channels, input_height, input_width, groups); + } +}; + +template class MaxOutGradFunctor; +template class MaxOutGradFunctor; + +template class MaxOutFunctor, float>; +template class MaxOutFunctor, double>; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/maxouting.h b/paddle/operators/math/maxouting.h new file mode 100644 index 0000000000000000000000000000000000000000..aeac084944d6afa49b3186902dc0271430e438e9 --- /dev/null +++ b/paddle/operators/math/maxouting.h @@ -0,0 +1,99 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { +namespace math { + +#define FLT_MAX \ + __FLT_MAX__ // It might need to be placed in another file, but I'm still + // wondering where to put it. + +/* + * \brief Extracting simple operations from pooling. + * Both MaxPool and AvgPool need "initial", "compute" and "finalize" + * operation. + * MaxPool initializes temp variable to the negative maximum to find the + * maximum value in the pooling field. + * AvgPool initializes temp variable to the zero to accumulate all values + * in pool pooling, and finally takes the average. + * MaxPoolGrad and AvgPoolGrad are gradient operations respectively. + */ +template +class MaxOut { + public: + DEVICE inline T initial() { return static_cast(-FLT_MAX); } + DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; } + DEVICE inline void finalize(T& y, const T& group) {} +}; + +template +class MaxOutGrad { + public: + DEVICE inline void compute(const T& x, const T& y, const T& dy, T& dx, + T scale) { + dx += dy * (x == y); + } +}; + + +/* + * \brief Getting pooling results, and calculating gradient. + * + * In pool2d, all tensors are in NCHW format. Where N is batch size, C is the + * number of channels, H and W is the height and width of feature. + * In pool3d, all tensors are in NCDHW format. Where N is batch size, C is the + * number of channels, D, H and W is the depth, height and width of feature. + * + * In max pooling, it is possible that the pooling region has multiple maximum + * elements. In this case, we should compute the gradient of the first maximum + * element. + * This is different from average pooling. So we rewrite the max_pool_grad: + * MaxPool2dGradFunctor, MaxPool3dGradFunctor. + */ +template +class MaxOutFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + int groups, int num_channels, MaxOutProcess maxout_compute); +}; + + +template +class MaxOutGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, int groups, + int num_channels); +}; + + + + + + + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/maxout_op.cc b/paddle/operators/maxout_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..41b3860a861de248fbd4224d308ea175b5913649 --- /dev/null +++ b/paddle/operators/maxout_op.cc @@ -0,0 +1,115 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + + +#include "paddle/operators/maxout_op.h" +namespace paddle { +namespace operators { + +using framework::Tensor; + +/********first define ProtoMakerē±» ***************/ +class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MaxOutOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(Tensor) The input tensor of pooling operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of feature."); + AddOutput("Out", + "(Tensor) The output tensor of pooling operator." + "The format of output tensor is also NCHW." + "Where N is batch size, C is " + "the number of channels, H and W is the height and " + "width of feature."); + + AddAttr( + "groups", + R"DOC(The group number of input layer. + )DOC") + .SetDefault(2); + AddAttr( + "num_channels", + R"DOC(The channel number of input layer. + )DOC") + .SetDefault(0); + AddComment(R"DOC(A layer to do max out on conv layer output. + - Input: output of a conv layer. + - Output: feature map size same as input. Channel is (input channel) / groups. + So groups should be larger than 1, and the num of channels should be able + to devided by groups. + )DOC"); + } +}; + +/******************2nd **********************************/ + +class MaxOutOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of maxoutOp" + "should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of maxoutOp should not be null."); + auto in_x_dims = ctx->GetInputDim("X"); + int groups = ctx->Attrs().Get("groups"); + int num_channels = ctx->Attrs().Get("num_channels"); + + // check groups > 1 + PADDLE_ENFORCE_GT( + groups, 1, + "in maxoutop groups should be larger than 1"); + // check num_channels%groups=0 + PADDLE_ENFORCE_EQ(num_channels % groups, 0, + "the num of channels should be able" + "to devided by groups"); + + int out_num_channels = num_channels / groups; + + std::vector output_shape({in_x_dims[0], out_num_channels}); + output_shape.push_back(in_x_dims[2]); + output_shape.push_back(in_x_dims[3]); + + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + } +}; + + +class MaxOutOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Input(X@GRAD) should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(maxout, ops::MaxOutOp, ops::MaxOutOpMaker, maxout_grad, + ops::MaxOutOpGrad); + + +REGISTER_OP_CPU_KERNEL(maxout, ops::MaxOutKernel); +REGISTER_OP_CPU_KERNEL(maxout_grad, + ops::MaxOutGradKernel); diff --git a/paddle/operators/maxout_op.cu b/paddle/operators/maxout_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..44a149b065d890af75997138fe602b1496b6527a --- /dev/null +++ b/paddle/operators/maxout_op.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/maxout_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(maxout, ops::MaxOutKernel); +REGISTER_OP_GPU_KERNEL(maxout_grad, + ops::MaxOutGradKernel); diff --git a/paddle/operators/maxout_op.h b/paddle/operators/maxout_op.h new file mode 100644 index 0000000000000000000000000000000000000000..23216135129e9f42014d97a855030f927ead3716 --- /dev/null +++ b/paddle/operators/maxout_op.h @@ -0,0 +1,77 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/maxouting.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class MaxOutKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* in_x = context.Input("X"); + Tensor* out = context.Output("Out"); + + int groups = context.template Attr("groups"); + int num_channels = context.template Attr("num_channels"); + + + paddle::operators::math::MaxOutFunctor< + Place, paddle::operators::math::MaxOut, T> + maxout_forward; + paddle::operators::math::MaxOut maxout_process; + maxout_forward(context.device_context(), *in_x, *out, groups, num_channels, + maxout_process); + } +}; + +template +class MaxOutGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* in_x = context.Input("X"); + const Tensor* out = context.Input("Out"); + const Tensor* out_grad = + context.Input(framework::GradVarName("Out")); + Tensor* in_x_grad = context.Output(framework::GradVarName("X")); + + int groups = context.template Attr("groups"); + int num_channels = context.template Attr("num_channels"); + + + + if (in_x_grad) { + in_x_grad->mutable_data(context.GetPlace()); + auto temp = framework::EigenVector::Flatten(*in_x_grad); + temp.device(context.GetEigenDevice()) = + temp.constant(static_cast(0)); + + paddle::operators::math::MaxOutGradFunctor + maxout_backward; + maxout_backward(context.device_context(), *in_x, *in_x_grad, *out, + *out_grad, groups, num_channels); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_maxout_op.py b/python/paddle/v2/framework/tests/test_maxout_op.py new file mode 100644 index 0000000000000000000000000000000000000000..4ea1e3c29c643ed0e6bdc00c8e41839d19794581 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_maxout_op.py @@ -0,0 +1,52 @@ +import unittest +import numpy as np +from op_test import OpTest + + + +def maxout_forward_naive_2sweetsky(input, groups, num_channels): + s0, s1, s2, s3 = input.shape + return np.ndarray([s0, s1 / groups, groups, s2, s3], \ + buffer = input, dtype=input.dtype).max(axis=(2)) + + +def maxout_forward_naive(input, groups,num_channels): + s0, s1, s2, s3 = input.shape + return np.ndarray([s0, s1 / groups, groups, s2, s3], \ + buffer = input, dtype=input.dtype).max(axis=(2)) + + + + +class TestMaxOut_Op(OpTest): + def setUp(self): + self.op_type = "maxout" + self.init_test_case() + input = np.random.random(self.shape).astype("float32") + output = self.MaxOut_forward_naive(input, self.groups, + self.num_channels).astype("float32") + + self.inputs = {'X': input} + self.attrs = {'groups': self.groups, 'num_channels': self.num_channels} + + self.outputs = {'Out': output.astype('float32')} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + print self.inputs + print self.outputs + self.check_grad(['X'], 'Out', max_relative_error=0.5) + + def init_test_case(self): + self.MaxOut_forward_naive = maxout_forward_naive + self.shape = [100, 6, 2, 2] + self.groups=2 + self.num_channels=6 + + + + +if __name__ == '__main__': + unittest.main()