diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index d0fe5b4635174fa0f74658509c4e8ca58a1d056a..059a6bba84cfb0c1f6cbbba3c88d589b52dc5592 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -184,6 +184,7 @@ set(DEPS_OPS sequence_softmax_op sum_op pool_op + maxout_op pool_with_index_op conv_op conv_transpose_op @@ -210,6 +211,7 @@ op_library(sgd_op DEPS selected_rows_functor) op_library(adagrad_op DEPS selected_rows_functor) op_library(conv_op DEPS vol2col) op_library(pool_op DEPS pooling) +op_library(maxout_op DEPS maxouting) op_library(pool_with_index_op DEPS pooling) op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table) op_library(lod_tensor_to_array_op SRCS lod_tensor_to_array_op.cc DEPS lod_rank_table_op) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 002b68fecf4f1e294387357f0346d9926a2b2b5a..3017f133afc5d4dcd484c78b44591a876ab4d667 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -14,6 +14,7 @@ if(WITH_GPU) nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context) nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function) + nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto) cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) @@ -26,6 +27,7 @@ else() cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) + cc_library(maxouting SRCS maxouting.cc DEPS device_context) endif() cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/maxouting.cc b/paddle/operators/math/maxouting.cc new file mode 100644 index 0000000000000000000000000000000000000000..e5168ce7afd4139475afa6edd5999b9974407c9b --- /dev/null +++ b/paddle/operators/math/maxouting.cc @@ -0,0 +1,106 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/maxouting.h" + +namespace paddle { +namespace operators { +namespace math { + +// All tensors are in NCHW format, and the groups must be greater than 1 +template +class MaxOutFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + framework::Tensor * output, + int groups) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output->dims()[1]; + int fea_size = input_height * input_width; + // c_size means the output size of each sample + int c_size = fea_size * output_channels; + const T* input_data = input.data(); + T* output_data = output->mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; ++i) { + int new_bindex = c_size * i; + for (int c = 0; c < output_channels; ++c) { + int new_cindex = fea_size * c; + for (int f = 0; f < fea_size; ++f) { + T ele = static_cast(-FLT_MAX); + for (int ph = 0; ph < groups; ++ph) { + T x = input_data[(new_bindex + new_cindex) * groups + + ph * fea_size + f]; + ele = ele > x ? ele : x; + } + output_data[(new_bindex+new_cindex+f)] = ele; + } + } + } + } +}; + + + +template +class MaxOutGradFunctor { +public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + framework::Tensor * input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, + int groups) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + int fea_size = input_height * input_width; + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; ++i) { + int blen = fea_size * output_channels * i; + for (int c = 0; c < output_channels; ++c) { + int clen = fea_size * c; + for (int f = 0; f < fea_size; ++f) { + int input_idx0 = (blen + clen) * groups + f; + bool continue_match = true; + int output_idx = blen + clen + f; + for (int g = 0; g < groups && continue_match; ++g) { + int input_idx = input_idx0 + fea_size * g; + if (input_data[input_idx] == output_data[output_idx]) { + input_grad_data[input_idx] += output_grad_data[output_idx]; + continue_match = false; + } + } + } + } + } + } +}; + +template class MaxOutGradFunctor; +template class MaxOutGradFunctor; +template class MaxOutFunctor; +template class MaxOutFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/maxouting.cu b/paddle/operators/math/maxouting.cu new file mode 100644 index 0000000000000000000000000000000000000000..7c698577b8a8258a58ba9a2b6c675457b2458a5b --- /dev/null +++ b/paddle/operators/math/maxouting.cu @@ -0,0 +1,154 @@ +/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/maxouting.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__global__ void KernelMaxOut(const int nthreads, const T* input_data, + const int channels, + const int input_height, const int input_width, + int groups, T* output_data ) { + const int size = input_height * input_width * channels / groups; + const int feat_len = input_height * input_width; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (int i = index; i < nthreads; i += offset) { + int batch_idx = i / size; + int batch_offset = i % size; + int channel_idx = batch_offset / feat_len; + int feat_idx = batch_offset % feat_len; + int data_idx = + (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; + T ele = static_cast(-FLT_MAX); + for (int g = 0; g < groups; ++g) { + T x = input_data[data_idx + g * feat_len]; + ele = ele > x ? ele : x; + } + output_data[i] = ele; + } +} +template +__global__ void KernelMaxoutGrad( + const int nthreads, const T* input_data, const T* output_data, + const T* output_grad, T* input_grad, const int channels, + const int input_height, const int input_width, int groups) { + const int size = input_height * input_width * channels / groups; + const int feat_len = input_height * input_width; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (int i = index; i < nthreads; i += offset) { + int batch_idx = i / size; + int batch_offset = i % size; + int channel_idx = batch_offset / feat_len; + int feat_idx = batch_offset % feat_len; + int data_idx = + (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; + int max_index = -1; + bool continue_match = true; + for (int g = 0; g < groups && continue_match; ++g) { + if (input_data[data_idx + g * feat_len] == output_data[i]) { + max_index = data_idx + g * feat_len; + continue_match = false; + break; + } + } + if (max_index != -1) { + input_grad[max_index] += output_grad[index]; + } + } +} +/* + * All tensors are in NCHW format. + */ +template +class MaxOutFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor * output, + int groups) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output->dims()[1]; + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; + + const T* input_data = input.data(); + T* output_data = output->mutable_data(context.GetPlace()); + int nthreads = output->numel(); + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxOut< + T><<(context) + .stream()>>>(nthreads, input_data, input_channels, + input_height, input_width, groups, + output_data); + } +}; +/* + * All tensors are in NCHW format. + */ +template +class MaxOutGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + framework::Tensor * input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, + int groups) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); + int nthreads = output.numel(); + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxoutGrad< + T><<(context) + .stream()>>>( + nthreads, input_data, output_data, output_grad_data, input_grad_data, + input_channels, input_height, input_width, groups); + } +}; + +template class MaxOutGradFunctor; +template class MaxOutGradFunctor; + +template class MaxOutFunctor; +template class MaxOutFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/maxouting.h b/paddle/operators/math/maxouting.h new file mode 100644 index 0000000000000000000000000000000000000000..d4c9da38ab8f8d88ed461d805ae64a015db968c4 --- /dev/null +++ b/paddle/operators/math/maxouting.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { +namespace math { + +#define FLT_MAX \ + __FLT_MAX__ + +template + +class MaxOutFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor * output, + int groups); +}; + +template +class MaxOutGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + framework::Tensor * input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, int groups); +}; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/maxout_op.cc b/paddle/operators/maxout_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..95467f2e69093906980d075b6a41c5d2934dd5a2 --- /dev/null +++ b/paddle/operators/maxout_op.cc @@ -0,0 +1,104 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/operators/maxout_op.h" +namespace paddle { +namespace operators { + +using framework::Tensor; + +class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MaxOutOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(Tensor) The input tensor of maxout operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of feature."); + AddOutput("Out", + "(Tensor) The output tensor of maxout operator." + "The format of output tensor is also NCHW." + "Where N is batch size, C is " + "the number of channels, H and W is the height and " + "width of feature."); + AddAttr( + "groups", + R"DOC("Specifies how many groups the input tensor will be split" + "in the channel dimension. And the number of output channel is " + "the number of channels divided by groups.." + )DOC"); + AddComment(R"DOC( + Assumed the input shape is (N, Ci, H, W). + The output shape is (N, Co, H, W). Then `Co = Ci / groups`. + + math: + y_{si+j} = \max_k x_{gsi + sk + j} + g = groups + s = input.size / num_channels + 0 \le i < num_channels / groups + 0 \le j < s + 0 \le k < groups + + Please refer to Paper: + - Maxout Networks: http://www.jmlr.org/proceedings/papers/v28/goodfellow13.pdf + - Multi-digit Number Recognition from Street View \ + Imagery using Deep Convolutional Neural Networks: \ + https://arxiv.org/pdf/1312.6082v4.pdf + )DOC"); + } +}; + + +class MaxOutOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MaxoutOp" + "should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of MaxoutOp should not be null."); + auto in_x_dims = ctx->GetInputDim("X"); + int groups = ctx->Attrs().Get("groups"); + // check groups > 1 + PADDLE_ENFORCE_GT( + groups, 1, + "groups should be larger than 1 in maxoutop"); + std::vector output_shape({in_x_dims[0], in_x_dims[1] / groups}); + output_shape.push_back(in_x_dims[2]); + output_shape.push_back(in_x_dims[3]); + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + } +}; + +class MaxOutOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Input(X@GRAD) should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(maxout, ops::MaxOutOp, ops::MaxOutOpMaker, maxout_grad, + ops::MaxOutOpGrad); +REGISTER_OP_CPU_KERNEL(maxout, ops::MaxOutKernel); +REGISTER_OP_CPU_KERNEL(maxout_grad, + ops::MaxOutGradKernel); diff --git a/paddle/operators/maxout_op.cu.cc b/paddle/operators/maxout_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..a5823fba6848a0d42a743c90d7d683e3e4ae4422 --- /dev/null +++ b/paddle/operators/maxout_op.cu.cc @@ -0,0 +1,25 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/maxout_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(maxout, + ops::MaxOutKernel, + ops::MaxOutKernel); +REGISTER_OP_GPU_KERNEL(maxout_grad, + ops::MaxOutGradKernel, + ops::MaxOutGradKernel); diff --git a/paddle/operators/maxout_op.h b/paddle/operators/maxout_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c404cd16a9b2372ea4c6a17eb5ac82cf8f3bf27c --- /dev/null +++ b/paddle/operators/maxout_op.h @@ -0,0 +1,62 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/maxouting.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class MaxOutKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* in_x = context.Input("X"); + Tensor* out = context.Output("Out"); + int groups = context.template Attr("groups"); + + math::MaxOutFunctor maxout_forward; + maxout_forward(context.device_context(), *in_x, out, groups); + } +}; + +template +class MaxOutGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* in_x = context.Input("X"); + const Tensor* out = context.Input("Out"); + const Tensor* out_grad = + context.Input(framework::GradVarName("Out")); + Tensor* in_x_grad = context.Output(framework::GradVarName("X")); + int groups = context.template Attr("groups"); + auto& device_ctx = context.device_context(); + math::SetConstant zero; + if (in_x_grad) { + in_x_grad->mutable_data(context.GetPlace()); + zero(device_ctx, in_x_grad, static_cast(0.0)); + math::MaxOutGradFunctor maxout_backward; + maxout_backward(context.device_context(), *in_x, in_x_grad, *out, + *out_grad, groups); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_maxout_op.py b/python/paddle/v2/fluid/tests/test_maxout_op.py new file mode 100644 index 0000000000000000000000000000000000000000..05e42f315833cab5bc5272cbd2173ea8012ff7f5 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_maxout_op.py @@ -0,0 +1,39 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def maxout_forward_naive(input, groups): + s0, s1, s2, s3 = input.shape + return np.ndarray([s0, s1 / groups, groups, s2, s3], \ + buffer = input, dtype=input.dtype).max(axis=(2)) + + +class TestMaxOutOp(OpTest): + def setUp(self): + self.op_type = "maxout" + self.init_test_case() + input = np.random.random(self.shape).astype("float32") + output = self.MaxOut_forward_naive(input, self.groups).astype("float32") + + self.inputs = {'X': input} + self.attrs = {'groups': self.groups} + + self.outputs = {'Out': output.astype('float32')} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def init_test_case(self): + self.MaxOut_forward_naive = maxout_forward_naive + self.shape = [100, 6, 2, 2] + self.groups=2 + + + + +if __name__ == '__main__': + unittest.main()