From 8d1e9f0f7e593cf1a5a6ec7ef5c5d9d5ba5f160f Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Thu, 31 Oct 2019 14:34:24 +0800 Subject: [PATCH] maxout supports channel_last input (#20846) * maxout support channel_last input, test=develop * modified details of Input(X) and Attr(groups, axis) in doc, test=develop --- paddle/fluid/operators/math/maxouting.cc | 50 ++++++++---- paddle/fluid/operators/math/maxouting.cu | 79 +++++++++++-------- paddle/fluid/operators/math/maxouting.h | 6 +- paddle/fluid/operators/maxout_op.cc | 52 ++++++------ paddle/fluid/operators/maxout_op.h | 7 +- python/paddle/fluid/layers/nn.py | 20 +++-- .../fluid/tests/unittests/test_maxout_op.py | 54 ++++++++++++- 7 files changed, 183 insertions(+), 85 deletions(-) diff --git a/paddle/fluid/operators/math/maxouting.cc b/paddle/fluid/operators/math/maxouting.cc index 730f71e96b6..45556e97d1d 100644 --- a/paddle/fluid/operators/math/maxouting.cc +++ b/paddle/fluid/operators/math/maxouting.cc @@ -18,35 +18,45 @@ namespace paddle { namespace operators { namespace math { -// All tensors are in NCHW format, and the groups must be greater than 1 +// All tensors are in NCHW or NHWC format, and the groups must be greater than 1 template class MaxOutFunctor { public: void operator()(const platform::CPUDeviceContext& context, const framework::Tensor& input, framework::Tensor* output, - int groups) { + const int groups, const int axis) { const int batch_size = input.dims()[0]; - const int input_height = input.dims()[2]; - const int input_width = input.dims()[3]; - const int output_channels = output->dims()[1]; + const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]); + const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]); + const int output_channels = output->dims()[axis]; int fea_size = input_height * input_width; // c_size means the output size of each sample int c_size = fea_size * output_channels; const T* input_data = input.data(); T* output_data = output->mutable_data(context.GetPlace()); - for (int i = 0; i < batch_size; ++i) { int new_bindex = c_size * i; for (int c = 0; c < output_channels; ++c) { int new_cindex = fea_size * c; for (int f = 0; f < fea_size; ++f) { T ele = static_cast(-FLT_MAX); + int input_idx, output_idx; for (int ph = 0; ph < groups; ++ph) { - T x = input_data[(new_bindex + new_cindex) * groups + - ph * fea_size + f]; + if (axis == 1) { + input_idx = + (new_bindex + new_cindex) * groups + ph * fea_size + f; + } else { + input_idx = (new_bindex + f * output_channels + c) * groups + ph; + } + T x = input_data[input_idx]; ele = ele > x ? ele : x; } - output_data[(new_bindex + new_cindex + f)] = ele; + if (axis == 1) { + output_idx = new_bindex + new_cindex + f; + } else { + output_idx = new_bindex + f * output_channels + c; + } + output_data[output_idx] = ele; } } } @@ -59,11 +69,12 @@ class MaxOutGradFunctor { void operator()(const platform::CPUDeviceContext& context, const framework::Tensor& input, framework::Tensor* input_grad, const framework::Tensor& output, - const framework::Tensor& output_grad, int groups) { + const framework::Tensor& output_grad, const int groups, + const int axis) { const int batch_size = input.dims()[0]; - const int input_height = input.dims()[2]; - const int input_width = input.dims()[3]; - const int output_channels = output.dims()[1]; + const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]); + const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]); + const int output_channels = output.dims()[axis]; int fea_size = input_height * input_width; const T* input_data = input.data(); const T* output_data = output.data(); @@ -75,11 +86,18 @@ class MaxOutGradFunctor { for (int c = 0; c < output_channels; ++c) { int clen = fea_size * c; for (int f = 0; f < fea_size; ++f) { - int input_idx0 = (blen + clen) * groups + f; + int input_idx0, output_idx; bool continue_match = true; - int output_idx = blen + clen + f; + if (axis == 1) { + input_idx0 = (blen + clen) * groups + f; + output_idx = blen + clen + f; + } else { + input_idx0 = (blen + f * output_channels + c) * groups; + output_idx = blen + f * output_channels + c; + } for (int g = 0; g < groups && continue_match; ++g) { - int input_idx = input_idx0 + fea_size * g; + int idx_offset = (axis == 1 ? fea_size * g : g); + int input_idx = input_idx0 + idx_offset; if (input_data[input_idx] == output_data[output_idx]) { input_grad_data[input_idx] += output_grad_data[output_idx]; continue_match = false; diff --git a/paddle/fluid/operators/math/maxouting.cu b/paddle/fluid/operators/math/maxouting.cu index d9a23299a4d..8b134a29d81 100644 --- a/paddle/fluid/operators/math/maxouting.cu +++ b/paddle/fluid/operators/math/maxouting.cu @@ -22,8 +22,8 @@ namespace math { template __global__ void KernelMaxOut(const int nthreads, const T* input_data, const int channels, const int input_height, - const int input_width, int groups, - T* output_data) { + const int input_width, const int groups, + const int axis, T* output_data) { const int size = input_height * input_width * channels / groups; const int feat_len = input_height * input_width; int index = blockIdx.x * blockDim.x + threadIdx.x; @@ -31,13 +31,22 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data, for (int i = index; i < nthreads; i += offset) { int batch_idx = i / size; int batch_offset = i % size; - int channel_idx = batch_offset / feat_len; - int feat_idx = batch_offset % feat_len; - int data_idx = - (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; + int channel_idx, feat_idx, data_idx; + if (axis == 1) { + channel_idx = batch_offset / feat_len; + feat_idx = batch_offset % feat_len; + data_idx = + (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; + } else { + channel_idx = batch_offset % channels; + feat_idx = batch_offset / channels; + data_idx = + (batch_idx * size + feat_idx * channels + channel_idx) * groups; + } T ele = static_cast(-FLT_MAX); for (int g = 0; g < groups; ++g) { - T x = input_data[data_idx + g * feat_len]; + int idx_offset = (axis == 1 ? g * feat_len : g); + T x = input_data[data_idx + idx_offset]; ele = ele > x ? ele : x; } output_data[i] = ele; @@ -48,7 +57,7 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data, const T* output_data, const T* output_grad, T* input_grad, const int channels, const int input_height, const int input_width, - int groups) { + const int groups, const int axis) { const int size = input_height * input_width * channels / groups; const int feat_len = input_height * input_width; int index = blockIdx.x * blockDim.x + threadIdx.x; @@ -56,15 +65,24 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data, for (int i = index; i < nthreads; i += offset) { int batch_idx = i / size; int batch_offset = i % size; - int channel_idx = batch_offset / feat_len; - int feat_idx = batch_offset % feat_len; - int data_idx = - (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; + int channel_idx, feat_idx, data_idx; + if (axis == 1) { + channel_idx = batch_offset / feat_len; + feat_idx = batch_offset % feat_len; + data_idx = + (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; + } else { + channel_idx = batch_offset % channels; + feat_idx = batch_offset / channels; + data_idx = + (batch_idx * size + feat_idx * channels + channel_idx) * groups; + } int max_index = -1; bool continue_match = true; for (int g = 0; g < groups && continue_match; ++g) { - if (input_data[data_idx + g * feat_len] == output_data[i]) { - max_index = data_idx + g * feat_len; + int idx_offset = (axis == 1 ? g * feat_len : g); + if (input_data[data_idx + idx_offset] == output_data[i]) { + max_index = data_idx + idx_offset; continue_match = false; break; } @@ -75,21 +93,19 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data, } } /* - * All tensors are in NCHW format. + * All tensors are in NCHW or NHWC format. */ template class MaxOutFunctor { public: void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, framework::Tensor* output, - int groups) { + const int groups, const int axis) { const int batch_size = input.dims()[0]; - const int input_channels = input.dims()[1]; - const int input_height = input.dims()[2]; - const int input_width = input.dims()[3]; - const int output_channels = output->dims()[1]; - const int output_height = output->dims()[2]; - const int output_width = output->dims()[3]; + const int input_channels = input.dims()[axis]; + const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]); + const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]); + const int output_channels = output->dims()[axis]; const T* input_data = input.data(); T* output_data = output->mutable_data(context.GetPlace()); @@ -100,11 +116,11 @@ class MaxOutFunctor { KernelMaxOut<<>>( nthreads, input_data, input_channels, input_height, input_width, groups, - output_data); + axis, output_data); } }; /* - * All tensors are in NCHW format. + * All tensors are in NCHW or NHWC format. */ template class MaxOutGradFunctor { @@ -112,14 +128,13 @@ class MaxOutGradFunctor { void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, framework::Tensor* input_grad, const framework::Tensor& output, - const framework::Tensor& output_grad, int groups) { + const framework::Tensor& output_grad, const int groups, + const int axis) { const int batch_size = input.dims()[0]; - const int input_channels = input.dims()[1]; - const int input_height = input.dims()[2]; - const int input_width = input.dims()[3]; - const int output_channels = output.dims()[1]; - const int output_height = output.dims()[2]; - const int output_width = output.dims()[3]; + const int input_channels = input.dims()[axis]; + const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]); + const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]); + const int output_channels = output.dims()[axis]; const T* input_data = input.data(); const T* output_data = output.data(); @@ -132,7 +147,7 @@ class MaxOutGradFunctor { KernelMaxoutGrad<<>>( nthreads, input_data, output_data, output_grad_data, input_grad_data, - input_channels, input_height, input_width, groups); + input_channels, input_height, input_width, groups, axis); } }; diff --git a/paddle/fluid/operators/math/maxouting.h b/paddle/fluid/operators/math/maxouting.h index e4d378dc232..50bddf73bc1 100644 --- a/paddle/fluid/operators/math/maxouting.h +++ b/paddle/fluid/operators/math/maxouting.h @@ -26,7 +26,8 @@ template class MaxOutFunctor { public: void operator()(const DeviceContext& context, const framework::Tensor& input, - framework::Tensor* output, int groups); + framework::Tensor* output, const int groups, + const int axis = 1); }; template @@ -35,7 +36,8 @@ class MaxOutGradFunctor { void operator()(const DeviceContext& context, const framework::Tensor& input, framework::Tensor* input_grad, const framework::Tensor& output, - const framework::Tensor& output_grad, int groups); + const framework::Tensor& output_grad, const int groups, + const int axis = 1); }; } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/maxout_op.cc b/paddle/fluid/operators/maxout_op.cc index e051db8b89d..85323b69449 100644 --- a/paddle/fluid/operators/maxout_op.cc +++ b/paddle/fluid/operators/maxout_op.cc @@ -23,25 +23,27 @@ using framework::Tensor; class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput( - "X", - "(Tensor) The input tensor of maxout operator with data type of " - "float32. The format of input tensor is NCHW. Where N is batch size," - " C is the number of channels, H and W is the height and width of " - "feature."); + AddInput("X", + "A 4-D Tensor with data type of float32 or float64. " + "The data format is NCHW or NHWC. Where N is " + "batch size, C is the number of channels, " + "H and W is the height and width of " + "feature. "); AddOutput("Out", - "(Tensor) The output tensor of maxout operator." - "The data type is float32." - "The format of output tensor is also NCHW." - "Where N is batch size, C is " - "the number of channels, H and W is the height and " - "width of feature."); + "A 4-D Tensor with same data type and data format " + "with input Tensor. "); AddAttr( "groups", - "(int)," - "Specifies how many groups the input tensor will be split" - "in the channel dimension. And the number of output channel is " - "the number of channels divided by groups."); + "Specifies how many groups the input tensor will be split into " + "at the channel dimension. And the number of output channel is " + "the number of channels divided by groups. "); + AddAttr( + "axis", + "Specifies the index of channel dimension where maxout will " + "be performed. It should be 1 when data format is NCHW, -1 or 3 " + "when data format is NHWC. " + "Default: 1. ") + .SetDefault(1); AddComment(R"DOC( MaxOut Operator. @@ -70,17 +72,19 @@ class MaxOutOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of MaxoutOpshould not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of MaxoutOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + "Input(X) of MaxoutOpshould not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output(Out) of MaxoutOp should not be null."); auto in_x_dims = ctx->GetInputDim("X"); int groups = ctx->Attrs().Get("groups"); + int axis = ctx->Attrs().Get("axis"); // check groups > 1 - PADDLE_ENFORCE_GT(groups, 1, "groups should be larger than 1 in maxoutop"); - std::vector output_shape({in_x_dims[0], in_x_dims[1] / groups}); - output_shape.push_back(in_x_dims[2]); - output_shape.push_back(in_x_dims[3]); + PADDLE_ENFORCE_GT(groups, 1, + "Attr(groups) of Op(maxout) should be larger than 1."); + std::vector output_shape( + {in_x_dims[0], in_x_dims[1], in_x_dims[2], in_x_dims[3]}); + output_shape[axis] = in_x_dims[axis] / groups; ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); } }; diff --git a/paddle/fluid/operators/maxout_op.h b/paddle/fluid/operators/maxout_op.h index 5b9e003cb09..ec3897e4044 100644 --- a/paddle/fluid/operators/maxout_op.h +++ b/paddle/fluid/operators/maxout_op.h @@ -30,10 +30,11 @@ class MaxOutKernel : public framework::OpKernel { const Tensor* in_x = context.Input("X"); Tensor* out = context.Output("Out"); int groups = context.template Attr("groups"); + int axis = context.template Attr("axis"); math::MaxOutFunctor maxout_forward; maxout_forward(context.template device_context(), *in_x, out, - groups); + groups, axis); } }; @@ -47,13 +48,15 @@ class MaxOutGradKernel : public framework::OpKernel { context.Input(framework::GradVarName("Out")); Tensor* in_x_grad = context.Output(framework::GradVarName("X")); int groups = context.template Attr("groups"); + int axis = context.template Attr("axis"); auto& device_ctx = context.template device_context(); math::SetConstant zero; if (in_x_grad) { in_x_grad->mutable_data(context.GetPlace()); zero(device_ctx, in_x_grad, static_cast(0.0)); math::MaxOutGradFunctor maxout_backward; - maxout_backward(device_ctx, *in_x, in_x_grad, *out, *out_grad, groups); + maxout_backward(device_ctx, *in_x, in_x_grad, *out, *out_grad, groups, + axis); } } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 41f6486ac52..fbc8faade70 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -15106,22 +15106,23 @@ def sigmoid_cross_entropy_with_logits(x, @templatedoc() -def maxout(x, groups, name=None): +def maxout(x, groups, name=None, axis=1): """ ${comment} Args: x(${x_type}): ${x_comment} - groups(${groups_type}): ${groups_comment} + groups(int): ${groups_comment} + axis(int, optional): ${axis_comment} name(str, optional): For detailed information, please refer to :ref:`api_guide_Name`. Usually name is no need to set and None by default. Returns: - Variable: - - out(${out_type}): ${out_comment} + Variable: ${out_comment} + Raises: + ValueError: If `axis` is not 1, -1 or 3. Examples: .. code-block:: python @@ -15134,6 +15135,12 @@ def maxout(x, groups, name=None): out = fluid.layers.maxout(input, groups=2) """ helper = LayerHelper("maxout", **locals()) + if axis not in [1, -1, 3]: + raise ValueError( + "Attr(axis) should be 1 when data format is NCHW, -1 or 3 when data format is NHWC. Received " + "Attr(axis): %s." % str(axis)) + if axis == -1: + axis = 3 if name is None: out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -15144,7 +15151,8 @@ def maxout(x, groups, name=None): helper.append_op( type="maxout", inputs={"X": x}, - attrs={"groups": groups}, + attrs={"groups": groups, + "axis": axis}, outputs={"Out": out}) return out diff --git a/python/paddle/fluid/tests/unittests/test_maxout_op.py b/python/paddle/fluid/tests/unittests/test_maxout_op.py index d588b22fe26..19c517142fd 100644 --- a/python/paddle/fluid/tests/unittests/test_maxout_op.py +++ b/python/paddle/fluid/tests/unittests/test_maxout_op.py @@ -16,11 +16,16 @@ from __future__ import print_function import unittest import numpy as np +import paddle.fluid as fluid +import paddle.fluid.core as core from op_test import OpTest -def maxout_forward_naive(input, groups): +def maxout_forward_naive(input, groups, channel_axis): s0, s1, s2, s3 = input.shape + if channel_axis == 3: + return np.ndarray([s0, s1, s2, s3 // groups, groups], \ + buffer = input, dtype=input.dtype).max(axis=(4)) return np.ndarray([s0, s1 // groups, groups, s2, s3], \ buffer = input, dtype=input.dtype).max(axis=(2)) @@ -30,10 +35,11 @@ class TestMaxOutOp(OpTest): self.op_type = "maxout" self.init_test_case() input = np.random.random(self.shape).astype("float32") - output = self.MaxOut_forward_naive(input, self.groups).astype("float32") + output = self.MaxOut_forward_naive(input, self.groups, + self.axis).astype("float32") self.inputs = {'X': input} - self.attrs = {'groups': self.groups} + self.attrs = {'groups': self.groups, 'axis': self.axis} self.outputs = {'Out': output.astype('float32')} @@ -47,6 +53,48 @@ class TestMaxOutOp(OpTest): self.MaxOut_forward_naive = maxout_forward_naive self.shape = [100, 6, 2, 2] self.groups = 2 + self.axis = 1 + + +class TestMaxOutOpAxis(TestMaxOutOp): + def init_test_case(self): + self.MaxOut_forward_naive = maxout_forward_naive + self.shape = [100, 2, 2, 6] # NHWC format + self.groups = 2 + self.axis = 3 + + +class TestMaxOutOpAxisAPI(OpTest): + def test_axis(self): + data1 = fluid.data(name='data1', shape=[3, 6, 2, 2], dtype='float32') + data2 = fluid.data(name='data2', shape=[3, 2, 2, 6], dtype='float32') + out1 = fluid.layers.maxout(data1, groups=2, axis=1) + out2 = fluid.layers.maxout(data2, groups=2, axis=-1) + data1_np = np.random.random((3, 6, 2, 2)).astype("float32") + data2_np = np.transpose(data1_np, [0, 2, 3, 1]) + + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + results = exe.run(fluid.default_main_program(), + feed={"data1": data1_np, + "data2": data2_np}, + fetch_list=[out1, out2], + return_numpy=True) + + self.assertTrue( + np.allclose(results[0], np.transpose(results[1], (0, 3, 1, 2)))) + + def test_exception(self): + input = fluid.data(name="input", shape=[2, 4, 6, 6], dtype="float32") + + def _attr_axis(): + out = fluid.layers.maxout(input, groups=2, axis=2) + + self.assertRaises(ValueError, _attr_axis) if __name__ == '__main__': -- GitLab