提交 bd773b9c 编写于 作者: W wanghaox

modify for maxoutop code review

上级 ab9c71d9
......@@ -8,24 +8,26 @@ if(WITH_GPU)
nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context)
nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context)
nv_library(sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context)
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context)
else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator)
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
cc_library(softmax SRCS softmax.cc DEPS operator)
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
cc_library(pooling SRCS pooling.cc DEPS device_context)
cc_library(maxouting SRCS maxouting.cc DEPS device_context)
cc_library(sequence_pooling SRCS sequence_pooling.cc DEPS device_context math_function)
cc_library(vol2col SRCS vol2col.cc DEPS device_context)
cc_library(context_project SRCS context_project.cc DEPS device_context)
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
cc_library(maxouting SRCS maxouting.cc DEPS device_context)
endif()
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
......
......@@ -20,25 +20,27 @@ namespace math {
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
* groups mustbe > 1
*/
template <typename MaxOutProcess, typename T>
class MaxOutFunctor<platform::CPUPlace, MaxOutProcess, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
int groups, int num_channels, MaxOutProcess maxout_process) {
const framework::Tensor& input,
framework::Tensor * output,
int groups,
MaxOutProcess maxout_process) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = num_channels/groups;
const int output_channels = output->dims()[1];
int fea_size = input_height * input_width;
// c_size mean output one batch size
int c_size = fea_size * output_channels;
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace());
T* output_data = output->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
int new_bindex = c_size * i;
......@@ -50,7 +52,6 @@ class MaxOutFunctor<platform::CPUPlace, MaxOutProcess, T> {
maxout_process.compute(ele,
input_data[(new_bindex+new_cindex) * groups+ph*fea_size+f]);
}
maxout_process.finalize(ele, (static_cast<T>(groups)));
output_data[(new_bindex+new_cindex+f)] = ele;
}
}
......@@ -68,11 +69,11 @@ public:
framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad,
int groups, int num_channels) {
int groups) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = num_channels / groups;
const int output_channels = output.dims()[1];
int fea_size = input_height * input_width;
......@@ -95,8 +96,6 @@ public:
if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx];
stop = true;
} else {
input_grad_data[input_idx] = 0;
}
}
}
......@@ -108,9 +107,9 @@ public:
template class MaxOutGradFunctor<platform::CPUPlace, float>;
template class MaxOutGradFunctor<platform::CPUPlace, double>;
template class MaxOutFunctor<platform::CPUPlace,
paddle::operators::math::MaxOut<float>, float>;
math::MaxOut<float>, float>;
template class MaxOutFunctor<platform::CPUPlace,
paddle::operators::math::MaxOut<double>, double>;
math::MaxOut<double>, double>;
} // namespace math
} // namespace operators
......
......@@ -24,21 +24,20 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data,
T* output_data, const int channels,
const int input_height, const int input_width,
int groups, MaxOutProcess maxout_process) {
int size = input_height * input_width * channels / groups;
int featLen = input_height * input_width;
const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width;
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int batch_idx = index / size;
int i = index % size;
int channel_idx = i / featLen;
int feat_idx = i % featLen;
int batch_offset = index % size;
int channel_idx = batch_offset / feat_len;
int feat_idx = batch_offset % feat_len;
int data_idx =
(batch_idx * size + channel_idx * featLen) * groups + feat_idx;
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
T ele = maxout_process.initial();
for (int g = 0; g < groups; g++) {
maxout_process.compute(ele, input_data[data_idx + g * featLen]);
for (int g = 0; g < groups; ++g) {
maxout_process.compute(ele, input_data[data_idx + g * feat_len]);
}
maxout_process.finalize(ele, (static_cast<T>(groups)));
output_data[index] = ele;
}
}
......@@ -47,21 +46,21 @@ __global__ void KernelMaxoutGrad(
const int nthreads, const T* input_data, const T* output_data,
const T* output_grad, T* input_grad, const int channels,
const int input_height, const int input_width, int groups) {
int size = input_height * input_width * channels / groups;
int featLen = input_height * input_width;
const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width;
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int batch_idx = index / size;
int i = index % size;
int channel_idx = i / featLen;
int feat_idx = i % featLen;
int batch_offset = index % size;
int channel_idx = batch_offset / feat_len;
int feat_idx = batch_offset % feat_len;
int data_idx =
(batch_idx * size + channel_idx * featLen) * groups + feat_idx;
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
int maxIndex = -1;
bool stop = false;
for (int g = 0; g < groups && !stop; g++) {
if (input_data[data_idx + g * featLen] == output_data[index]) {
maxIndex = data_idx + g * featLen;
if (input_data[data_idx + g * feat_len] == output_data[index]) {
maxIndex = data_idx + g * feat_len;
stop = true;
}
}
......@@ -73,28 +72,25 @@ __global__ void KernelMaxoutGrad(
}
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename MaxOutProcess, typename T>
class MaxOutFunctor<platform::GPUPlace, MaxOutProcess, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
int groups, int num_channels,
const framework::Tensor& input, framework::Tensor * output,
int groups,
MaxOutProcess maxout_process) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = num_channels / groups;
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const int output_channels = output->dims()[1];
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace());
int nthreads = batch_size * output_channels * output_height * output_width;
T* output_data = output->mutable_data<T>(context.GetPlace());
int nthreads = output->numel();
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
......@@ -110,8 +106,6 @@ class MaxOutFunctor<platform::GPUPlace, MaxOutProcess, T> {
};
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename T>
class MaxOutGradFunctor<platform::GPUPlace, T> {
......@@ -120,7 +114,7 @@ class MaxOutGradFunctor<platform::GPUPlace, T> {
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad,
int groups, int num_channels) {
int groups) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
......@@ -133,8 +127,7 @@ class MaxOutGradFunctor<platform::GPUPlace, T> {
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
int nthreads = batch_size * output_channels * output_height * output_width;
int nthreads = output.numel();
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
......@@ -152,9 +145,9 @@ template class MaxOutGradFunctor<platform::GPUPlace, float>;
template class MaxOutGradFunctor<platform::GPUPlace, double>;
template class MaxOutFunctor<platform::GPUPlace,
paddle::operators::math::MaxOut<float>, float>;
math::MaxOut<float>, float>;
template class MaxOutFunctor<platform::GPUPlace,
paddle::operators::math::MaxOut<double>, double>;
math::MaxOut<double>, double>;
} // namespace math
} // namespace operators
......
......@@ -22,26 +22,20 @@ namespace paddle {
namespace operators {
namespace math {
#define FLT_MAX \
__FLT_MAX__ // It might need to be placed in another file, but I'm still
// wondering where to put it.
__FLT_MAX__
/*
* \brief Extracting simple operations from pooling.
* Both MaxPool and AvgPool need "initial", "compute" and "finalize"
* \brief Extracting simple operations from maxout.
* need "initial", "compute"
* operation.
* MaxPool initializes temp variable to the negative maximum to find the
* maximum value in the pooling field.
* AvgPool initializes temp variable to the zero to accumulate all values
* in pool pooling, and finally takes the average.
* MaxPoolGrad and AvgPoolGrad are gradient operations respectively.
*/
template <class T>
class MaxOut {
public:
DEVICE inline T initial() { return static_cast<T>(-FLT_MAX); }
DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; }
DEVICE inline void finalize(T& y, const T& group) {}
};
template <class T>
......@@ -69,11 +63,12 @@ class MaxOutGrad {
* MaxPool2dGradFunctor, MaxPool3dGradFunctor.
*/
template <typename Place, typename MaxOutProcess, typename T>
class MaxOutFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
int groups, int num_channels, MaxOutProcess maxout_compute);
const framework::Tensor& input, framework::Tensor * output,
int groups, MaxOutProcess maxout_compute);
};
......@@ -84,8 +79,7 @@ class MaxOutGradFunctor {
const framework::Tensor& input,
framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, int groups,
int num_channels);
const framework::Tensor& output_grad, int groups);
};
......
......@@ -19,17 +19,16 @@ namespace operators {
using framework::Tensor;
/********first define ProtoMaker类 ***************/
class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MaxOutOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor) The input tensor of pooling operator. "
"(Tensor) The input tensor of maxout operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature.");
AddOutput("Out",
"(Tensor) The output tensor of pooling operator."
"(Tensor) The output tensor of maxout operator."
"The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
......@@ -38,23 +37,53 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>(
"groups",
R"DOC(The group number of input layer.
)DOC")
.SetDefault(2);
AddAttr<int>(
"num_channels",
R"DOC(The channel number of input layer.
)DOC")
.SetDefault(0);
AddComment(R"DOC(A layer to do max out on conv layer output.
- Input: output of a conv layer.
)DOC");
AddComment(R"DOC(
- Input: NCHW.
- Output: feature map size same as input. Channel is (input channel) / groups.
So groups should be larger than 1, and the num of channels should be able
to devided by groups.
.. math::
y_{si+j} = \max_k x_{gsi + sk + j}
g = groups
s = input.size / num_channels
0 \le i < num_channels / groups
0 \le j < s
0 \le k < groups
Please refer to Paper:
- Maxout Networks: http://www.jmlr.org/proceedings/papers/v28/goodfellow13.pdf
- Multi-digit Number Recognition from Street View \
Imagery using Deep Convolutional Neural Networks: \
https://arxiv.org/pdf/1312.6082v4.pdf
The simple usage is:
.. code-block:: python
maxout = maxout_layer(input,
num_channels=128,
groups=4)
:param input: The input of this layer.
:type input: LayerOutput
:param num_channels: The channel number of input layer. If None will be set
automatically from previous output.
:type num_channels: int | None
:param groups: The group number of input layer.
:type groups: int
:param name: The name of this layer. It is optional.
:type name: None | basestring.
:param layer_attr: Extra Layer attribute.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
)DOC");
}
};
/******************2nd **********************************/
class MaxOutOp : public framework::OperatorWithKernel {
public:
......@@ -67,20 +96,14 @@ class MaxOutOp : public framework::OperatorWithKernel {
"Output(Out) of maxoutOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
int groups = ctx->Attrs().Get<int>("groups");
int num_channels = ctx->Attrs().Get<int>("num_channels");
// check groups > 1
PADDLE_ENFORCE_GT(
groups, 1,
"in maxoutop groups should be larger than 1");
// check num_channels%groups=0
PADDLE_ENFORCE_EQ(num_channels % groups, 0,
"the num of channels should be able"
"to devided by groups");
int out_num_channels = num_channels / groups;
std::vector<int64_t> output_shape({in_x_dims[0], out_num_channels});
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1] / groups});
output_shape.push_back(in_x_dims[2]);
output_shape.push_back(in_x_dims[3]);
......
......@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/maxouting.h"
......@@ -32,14 +31,13 @@ class MaxOutKernel : public framework::OpKernel<T> {
Tensor* out = context.Output<Tensor>("Out");
int groups = context.template Attr<int>("groups");
int num_channels = context.template Attr<int>("num_channels");
paddle::operators::math::MaxOutFunctor<
Place, paddle::operators::math::MaxOut<T>, T>
maxout_forward;
paddle::operators::math::MaxOut<T> maxout_process;
maxout_forward(context.device_context(), *in_x, *out, groups, num_channels,
maxout_forward(context.device_context(), *in_x, out, groups,
maxout_process);
}
};
......@@ -55,7 +53,6 @@ class MaxOutGradKernel : public framework::OpKernel<T> {
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
int groups = context.template Attr<int>("groups");
int num_channels = context.template Attr<int>("num_channels");
......@@ -68,7 +65,7 @@ class MaxOutGradKernel : public framework::OpKernel<T> {
paddle::operators::math::MaxOutGradFunctor<Place, T>
maxout_backward;
maxout_backward(context.device_context(), *in_x, *in_x_grad, *out,
*out_grad, groups, num_channels);
*out_grad, groups);
}
}
};
......
......@@ -3,22 +3,13 @@ import numpy as np
from op_test import OpTest
def maxout_forward_naive_2sweetsky(input, groups, num_channels):
s0, s1, s2, s3 = input.shape
return np.ndarray([s0, s1 / groups, groups, s2, s3], \
buffer = input, dtype=input.dtype).max(axis=(2))
def maxout_forward_naive(input, groups,num_channels):
s0, s1, s2, s3 = input.shape
return np.ndarray([s0, s1 / groups, groups, s2, s3], \
buffer = input, dtype=input.dtype).max(axis=(2))
class TestMaxOut_Op(OpTest):
class TestMaxOutOp(OpTest):
def setUp(self):
self.op_type = "maxout"
self.init_test_case()
......@@ -37,7 +28,7 @@ class TestMaxOut_Op(OpTest):
def test_check_grad(self):
print self.inputs
print self.outputs
self.check_grad(['X'], 'Out', max_relative_error=0.5)
self.check_grad(['X'], 'Out')
def init_test_case(self):
self.MaxOut_forward_naive = maxout_forward_naive
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册