提交 8d1e9f0f 编写于 作者: Z Zhang Ting 提交者: Aurelius84

maxout supports channel_last input (#20846)

* maxout support channel_last input, test=develop

* modified details of Input(X) and Attr(groups, axis) in doc, test=develop
上级 9d8ec423
......@@ -18,35 +18,45 @@ namespace paddle {
namespace operators {
namespace math {
// All tensors are in NCHW format, and the groups must be greater than 1
// All tensors are in NCHW or NHWC format, and the groups must be greater than 1
template <typename T>
class MaxOutFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, framework::Tensor* output,
int groups) {
const int groups, const int axis) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output->dims()[1];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output->dims()[axis];
int fea_size = input_height * input_width;
// c_size means the output size of each sample
int c_size = fea_size * output_channels;
const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; ++i) {
int new_bindex = c_size * i;
for (int c = 0; c < output_channels; ++c) {
int new_cindex = fea_size * c;
for (int f = 0; f < fea_size; ++f) {
T ele = static_cast<T>(-FLT_MAX);
int input_idx, output_idx;
for (int ph = 0; ph < groups; ++ph) {
T x = input_data[(new_bindex + new_cindex) * groups +
ph * fea_size + f];
if (axis == 1) {
input_idx =
(new_bindex + new_cindex) * groups + ph * fea_size + f;
} else {
input_idx = (new_bindex + f * output_channels + c) * groups + ph;
}
T x = input_data[input_idx];
ele = ele > x ? ele : x;
}
output_data[(new_bindex + new_cindex + f)] = ele;
if (axis == 1) {
output_idx = new_bindex + new_cindex + f;
} else {
output_idx = new_bindex + f * output_channels + c;
}
output_data[output_idx] = ele;
}
}
}
......@@ -59,11 +69,12 @@ class MaxOutGradFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, framework::Tensor* input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, int groups) {
const framework::Tensor& output_grad, const int groups,
const int axis) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output.dims()[axis];
int fea_size = input_height * input_width;
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
......@@ -75,11 +86,18 @@ class MaxOutGradFunctor<platform::CPUDeviceContext, T> {
for (int c = 0; c < output_channels; ++c) {
int clen = fea_size * c;
for (int f = 0; f < fea_size; ++f) {
int input_idx0 = (blen + clen) * groups + f;
int input_idx0, output_idx;
bool continue_match = true;
int output_idx = blen + clen + f;
if (axis == 1) {
input_idx0 = (blen + clen) * groups + f;
output_idx = blen + clen + f;
} else {
input_idx0 = (blen + f * output_channels + c) * groups;
output_idx = blen + f * output_channels + c;
}
for (int g = 0; g < groups && continue_match; ++g) {
int input_idx = input_idx0 + fea_size * g;
int idx_offset = (axis == 1 ? fea_size * g : g);
int input_idx = input_idx0 + idx_offset;
if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx];
continue_match = false;
......
......@@ -22,8 +22,8 @@ namespace math {
template <typename T>
__global__ void KernelMaxOut(const int nthreads, const T* input_data,
const int channels, const int input_height,
const int input_width, int groups,
T* output_data) {
const int input_width, const int groups,
const int axis, T* output_data) {
const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
......@@ -31,13 +31,22 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data,
for (int i = index; i < nthreads; i += offset) {
int batch_idx = i / size;
int batch_offset = i % size;
int channel_idx = batch_offset / feat_len;
int feat_idx = batch_offset % feat_len;
int data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
int channel_idx, feat_idx, data_idx;
if (axis == 1) {
channel_idx = batch_offset / feat_len;
feat_idx = batch_offset % feat_len;
data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
} else {
channel_idx = batch_offset % channels;
feat_idx = batch_offset / channels;
data_idx =
(batch_idx * size + feat_idx * channels + channel_idx) * groups;
}
T ele = static_cast<T>(-FLT_MAX);
for (int g = 0; g < groups; ++g) {
T x = input_data[data_idx + g * feat_len];
int idx_offset = (axis == 1 ? g * feat_len : g);
T x = input_data[data_idx + idx_offset];
ele = ele > x ? ele : x;
}
output_data[i] = ele;
......@@ -48,7 +57,7 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data,
const T* output_data, const T* output_grad,
T* input_grad, const int channels,
const int input_height, const int input_width,
int groups) {
const int groups, const int axis) {
const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
......@@ -56,15 +65,24 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data,
for (int i = index; i < nthreads; i += offset) {
int batch_idx = i / size;
int batch_offset = i % size;
int channel_idx = batch_offset / feat_len;
int feat_idx = batch_offset % feat_len;
int data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
int channel_idx, feat_idx, data_idx;
if (axis == 1) {
channel_idx = batch_offset / feat_len;
feat_idx = batch_offset % feat_len;
data_idx =
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
} else {
channel_idx = batch_offset % channels;
feat_idx = batch_offset / channels;
data_idx =
(batch_idx * size + feat_idx * channels + channel_idx) * groups;
}
int max_index = -1;
bool continue_match = true;
for (int g = 0; g < groups && continue_match; ++g) {
if (input_data[data_idx + g * feat_len] == output_data[i]) {
max_index = data_idx + g * feat_len;
int idx_offset = (axis == 1 ? g * feat_len : g);
if (input_data[data_idx + idx_offset] == output_data[i]) {
max_index = data_idx + idx_offset;
continue_match = false;
break;
}
......@@ -75,21 +93,19 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data,
}
}
/*
* All tensors are in NCHW format.
* All tensors are in NCHW or NHWC format.
*/
template <typename T>
class MaxOutFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, framework::Tensor* output,
int groups) {
const int groups, const int axis) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const int input_channels = input.dims()[axis];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output->dims()[axis];
const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
......@@ -100,11 +116,11 @@ class MaxOutFunctor<platform::CUDADeviceContext, T> {
KernelMaxOut<T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, input_channels, input_height, input_width, groups,
output_data);
axis, output_data);
}
};
/*
* All tensors are in NCHW format.
* All tensors are in NCHW or NHWC format.
*/
template <typename T>
class MaxOutGradFunctor<platform::CUDADeviceContext, T> {
......@@ -112,14 +128,13 @@ class MaxOutGradFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, framework::Tensor* input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, int groups) {
const framework::Tensor& output_grad, const int groups,
const int axis) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const int input_channels = input.dims()[axis];
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
const int output_channels = output.dims()[axis];
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
......@@ -132,7 +147,7 @@ class MaxOutGradFunctor<platform::CUDADeviceContext, T> {
KernelMaxoutGrad<T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_grad_data,
input_channels, input_height, input_width, groups);
input_channels, input_height, input_width, groups, axis);
}
};
......
......@@ -26,7 +26,8 @@ template <typename DeviceContext, typename T>
class MaxOutFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
framework::Tensor* output, int groups);
framework::Tensor* output, const int groups,
const int axis = 1);
};
template <typename DeviceContext, class T>
......@@ -35,7 +36,8 @@ class MaxOutGradFunctor {
void operator()(const DeviceContext& context, const framework::Tensor& input,
framework::Tensor* input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, int groups);
const framework::Tensor& output_grad, const int groups,
const int axis = 1);
};
} // namespace math
} // namespace operators
......
......@@ -23,25 +23,27 @@ using framework::Tensor;
class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"X",
"(Tensor) The input tensor of maxout operator with data type of "
"float32. The format of input tensor is NCHW. Where N is batch size,"
" C is the number of channels, H and W is the height and width of "
"feature.");
AddInput("X",
"A 4-D Tensor with data type of float32 or float64. "
"The data format is NCHW or NHWC. Where N is "
"batch size, C is the number of channels, "
"H and W is the height and width of "
"feature. ");
AddOutput("Out",
"(Tensor) The output tensor of maxout operator."
"The data type is float32."
"The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
"width of feature.");
"A 4-D Tensor with same data type and data format "
"with input Tensor. ");
AddAttr<int>(
"groups",
"(int),"
"Specifies how many groups the input tensor will be split"
"in the channel dimension. And the number of output channel is "
"the number of channels divided by groups.");
"Specifies how many groups the input tensor will be split into "
"at the channel dimension. And the number of output channel is "
"the number of channels divided by groups. ");
AddAttr<int>(
"axis",
"Specifies the index of channel dimension where maxout will "
"be performed. It should be 1 when data format is NCHW, -1 or 3 "
"when data format is NHWC. "
"Default: 1. ")
.SetDefault(1);
AddComment(R"DOC(
MaxOut Operator.
......@@ -70,17 +72,19 @@ class MaxOutOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of MaxoutOpshould not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of MaxoutOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of MaxoutOpshould not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of MaxoutOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
int groups = ctx->Attrs().Get<int>("groups");
int axis = ctx->Attrs().Get<int>("axis");
// check groups > 1
PADDLE_ENFORCE_GT(groups, 1, "groups should be larger than 1 in maxoutop");
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1] / groups});
output_shape.push_back(in_x_dims[2]);
output_shape.push_back(in_x_dims[3]);
PADDLE_ENFORCE_GT(groups, 1,
"Attr(groups) of Op(maxout) should be larger than 1.");
std::vector<int64_t> output_shape(
{in_x_dims[0], in_x_dims[1], in_x_dims[2], in_x_dims[3]});
output_shape[axis] = in_x_dims[axis] / groups;
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
};
......
......@@ -30,10 +30,11 @@ class MaxOutKernel : public framework::OpKernel<T> {
const Tensor* in_x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
int groups = context.template Attr<int>("groups");
int axis = context.template Attr<int>("axis");
math::MaxOutFunctor<DeviceContext, T> maxout_forward;
maxout_forward(context.template device_context<DeviceContext>(), *in_x, out,
groups);
groups, axis);
}
};
......@@ -47,13 +48,15 @@ class MaxOutGradKernel : public framework::OpKernel<T> {
context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
int groups = context.template Attr<int>("groups");
int axis = context.template Attr<int>("axis");
auto& device_ctx = context.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> zero;
if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, in_x_grad, static_cast<T>(0.0));
math::MaxOutGradFunctor<DeviceContext, T> maxout_backward;
maxout_backward(device_ctx, *in_x, in_x_grad, *out, *out_grad, groups);
maxout_backward(device_ctx, *in_x, in_x_grad, *out, *out_grad, groups,
axis);
}
}
};
......
......@@ -15106,22 +15106,23 @@ def sigmoid_cross_entropy_with_logits(x,
@templatedoc()
def maxout(x, groups, name=None):
def maxout(x, groups, name=None, axis=1):
"""
${comment}
Args:
x(${x_type}): ${x_comment}
groups(${groups_type}): ${groups_comment}
groups(int): ${groups_comment}
axis(int, optional): ${axis_comment}
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Variable:
out(${out_type}): ${out_comment}
Variable: ${out_comment}
Raises:
ValueError: If `axis` is not 1, -1 or 3.
Examples:
.. code-block:: python
......@@ -15134,6 +15135,12 @@ def maxout(x, groups, name=None):
out = fluid.layers.maxout(input, groups=2)
"""
helper = LayerHelper("maxout", **locals())
if axis not in [1, -1, 3]:
raise ValueError(
"Attr(axis) should be 1 when data format is NCHW, -1 or 3 when data format is NHWC. Received "
"Attr(axis): %s." % str(axis))
if axis == -1:
axis = 3
if name is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......@@ -15144,7 +15151,8 @@ def maxout(x, groups, name=None):
helper.append_op(
type="maxout",
inputs={"X": x},
attrs={"groups": groups},
attrs={"groups": groups,
"axis": axis},
outputs={"Out": out})
return out
......
......@@ -16,11 +16,16 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
from op_test import OpTest
def maxout_forward_naive(input, groups):
def maxout_forward_naive(input, groups, channel_axis):
s0, s1, s2, s3 = input.shape
if channel_axis == 3:
return np.ndarray([s0, s1, s2, s3 // groups, groups], \
buffer = input, dtype=input.dtype).max(axis=(4))
return np.ndarray([s0, s1 // groups, groups, s2, s3], \
buffer = input, dtype=input.dtype).max(axis=(2))
......@@ -30,10 +35,11 @@ class TestMaxOutOp(OpTest):
self.op_type = "maxout"
self.init_test_case()
input = np.random.random(self.shape).astype("float32")
output = self.MaxOut_forward_naive(input, self.groups).astype("float32")
output = self.MaxOut_forward_naive(input, self.groups,
self.axis).astype("float32")
self.inputs = {'X': input}
self.attrs = {'groups': self.groups}
self.attrs = {'groups': self.groups, 'axis': self.axis}
self.outputs = {'Out': output.astype('float32')}
......@@ -47,6 +53,48 @@ class TestMaxOutOp(OpTest):
self.MaxOut_forward_naive = maxout_forward_naive
self.shape = [100, 6, 2, 2]
self.groups = 2
self.axis = 1
class TestMaxOutOpAxis(TestMaxOutOp):
def init_test_case(self):
self.MaxOut_forward_naive = maxout_forward_naive
self.shape = [100, 2, 2, 6] # NHWC format
self.groups = 2
self.axis = 3
class TestMaxOutOpAxisAPI(OpTest):
def test_axis(self):
data1 = fluid.data(name='data1', shape=[3, 6, 2, 2], dtype='float32')
data2 = fluid.data(name='data2', shape=[3, 2, 2, 6], dtype='float32')
out1 = fluid.layers.maxout(data1, groups=2, axis=1)
out2 = fluid.layers.maxout(data2, groups=2, axis=-1)
data1_np = np.random.random((3, 6, 2, 2)).astype("float32")
data2_np = np.transpose(data1_np, [0, 2, 3, 1])
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
results = exe.run(fluid.default_main_program(),
feed={"data1": data1_np,
"data2": data2_np},
fetch_list=[out1, out2],
return_numpy=True)
self.assertTrue(
np.allclose(results[0], np.transpose(results[1], (0, 3, 1, 2))))
def test_exception(self):
input = fluid.data(name="input", shape=[2, 4, 6, 6], dtype="float32")
def _attr_axis():
out = fluid.layers.maxout(input, groups=2, axis=2)
self.assertRaises(ValueError, _attr_axis)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册