提交 bd773b9c 编写于 作者: W wanghaox

modify for maxoutop code review

上级 ab9c71d9
...@@ -8,24 +8,26 @@ if(WITH_GPU) ...@@ -8,24 +8,26 @@ if(WITH_GPU)
nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context)
nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context) nv_library(sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context)
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context) nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context) nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function) nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context)
else() else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator)
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
cc_library(softmax SRCS softmax.cc DEPS operator) cc_library(softmax SRCS softmax.cc DEPS operator)
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
cc_library(pooling SRCS pooling.cc DEPS device_context) cc_library(pooling SRCS pooling.cc DEPS device_context)
cc_library(maxouting SRCS maxouting.cc DEPS device_context) cc_library(sequence_pooling SRCS sequence_pooling.cc DEPS device_context math_function)
cc_library(vol2col SRCS vol2col.cc DEPS device_context) cc_library(vol2col SRCS vol2col.cc DEPS device_context)
cc_library(context_project SRCS context_project.cc DEPS device_context) cc_library(context_project SRCS context_project.cc DEPS device_context)
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
cc_library(maxouting SRCS maxouting.cc DEPS device_context)
endif() endif()
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
......
...@@ -20,25 +20,27 @@ namespace math { ...@@ -20,25 +20,27 @@ namespace math {
/* /*
* All tensors are in NCHW format. * All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent * groups mustbe > 1
* height and width, respectively.
*/ */
template <typename MaxOutProcess, typename T> template <typename MaxOutProcess, typename T>
class MaxOutFunctor<platform::CPUPlace, MaxOutProcess, T> { class MaxOutFunctor<platform::CPUPlace, MaxOutProcess, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output, const framework::Tensor& input,
int groups, int num_channels, MaxOutProcess maxout_process) { framework::Tensor * output,
int groups,
MaxOutProcess maxout_process) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
const int output_channels = num_channels/groups; const int output_channels = output->dims()[1];
int fea_size = input_height * input_width; int fea_size = input_height * input_width;
// c_size mean output one batch size
int c_size = fea_size * output_channels; int c_size = fea_size * output_channels;
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace()); T* output_data = output->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
int new_bindex = c_size * i; int new_bindex = c_size * i;
...@@ -50,7 +52,6 @@ class MaxOutFunctor<platform::CPUPlace, MaxOutProcess, T> { ...@@ -50,7 +52,6 @@ class MaxOutFunctor<platform::CPUPlace, MaxOutProcess, T> {
maxout_process.compute(ele, maxout_process.compute(ele,
input_data[(new_bindex+new_cindex) * groups+ph*fea_size+f]); input_data[(new_bindex+new_cindex) * groups+ph*fea_size+f]);
} }
maxout_process.finalize(ele, (static_cast<T>(groups)));
output_data[(new_bindex+new_cindex+f)] = ele; output_data[(new_bindex+new_cindex+f)] = ele;
} }
} }
...@@ -68,11 +69,11 @@ public: ...@@ -68,11 +69,11 @@ public:
framework::Tensor& input_grad, framework::Tensor& input_grad,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, const framework::Tensor& output_grad,
int groups, int num_channels) { int groups) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
const int output_channels = num_channels / groups; const int output_channels = output.dims()[1];
int fea_size = input_height * input_width; int fea_size = input_height * input_width;
...@@ -95,8 +96,6 @@ public: ...@@ -95,8 +96,6 @@ public:
if (input_data[input_idx] == output_data[output_idx]) { if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx]; input_grad_data[input_idx] += output_grad_data[output_idx];
stop = true; stop = true;
} else {
input_grad_data[input_idx] = 0;
} }
} }
} }
...@@ -108,9 +107,9 @@ public: ...@@ -108,9 +107,9 @@ public:
template class MaxOutGradFunctor<platform::CPUPlace, float>; template class MaxOutGradFunctor<platform::CPUPlace, float>;
template class MaxOutGradFunctor<platform::CPUPlace, double>; template class MaxOutGradFunctor<platform::CPUPlace, double>;
template class MaxOutFunctor<platform::CPUPlace, template class MaxOutFunctor<platform::CPUPlace,
paddle::operators::math::MaxOut<float>, float>; math::MaxOut<float>, float>;
template class MaxOutFunctor<platform::CPUPlace, template class MaxOutFunctor<platform::CPUPlace,
paddle::operators::math::MaxOut<double>, double>; math::MaxOut<double>, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -24,21 +24,20 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data, ...@@ -24,21 +24,20 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data,
T* output_data, const int channels, T* output_data, const int channels,
const int input_height, const int input_width, const int input_height, const int input_width,
int groups, MaxOutProcess maxout_process) { int groups, MaxOutProcess maxout_process) {
int size = input_height * input_width * channels / groups; const int size = input_height * input_width * channels / groups;
int featLen = input_height * input_width; const int feat_len = input_height * input_width;
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) { index += blockDim.x * gridDim.x) {
int batch_idx = index / size; int batch_idx = index / size;
int i = index % size; int batch_offset = index % size;
int channel_idx = i / featLen; int channel_idx = batch_offset / feat_len;
int feat_idx = i % featLen; int feat_idx = batch_offset % feat_len;
int data_idx = int data_idx =
(batch_idx * size + channel_idx * featLen) * groups + feat_idx; (batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
T ele = maxout_process.initial(); T ele = maxout_process.initial();
for (int g = 0; g < groups; g++) { for (int g = 0; g < groups; ++g) {
maxout_process.compute(ele, input_data[data_idx + g * featLen]); maxout_process.compute(ele, input_data[data_idx + g * feat_len]);
} }
maxout_process.finalize(ele, (static_cast<T>(groups)));
output_data[index] = ele; output_data[index] = ele;
} }
} }
...@@ -47,21 +46,21 @@ __global__ void KernelMaxoutGrad( ...@@ -47,21 +46,21 @@ __global__ void KernelMaxoutGrad(
const int nthreads, const T* input_data, const T* output_data, const int nthreads, const T* input_data, const T* output_data,
const T* output_grad, T* input_grad, const int channels, const T* output_grad, T* input_grad, const int channels,
const int input_height, const int input_width, int groups) { const int input_height, const int input_width, int groups) {
int size = input_height * input_width * channels / groups; const int size = input_height * input_width * channels / groups;
int featLen = input_height * input_width; const int feat_len = input_height * input_width;
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) { index += blockDim.x * gridDim.x) {
int batch_idx = index / size; int batch_idx = index / size;
int i = index % size; int batch_offset = index % size;
int channel_idx = i / featLen; int channel_idx = batch_offset / feat_len;
int feat_idx = i % featLen; int feat_idx = batch_offset % feat_len;
int data_idx = int data_idx =
(batch_idx * size + channel_idx * featLen) * groups + feat_idx; (batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
int maxIndex = -1; int maxIndex = -1;
bool stop = false; bool stop = false;
for (int g = 0; g < groups && !stop; g++) { for (int g = 0; g < groups && !stop; g++) {
if (input_data[data_idx + g * featLen] == output_data[index]) { if (input_data[data_idx + g * feat_len] == output_data[index]) {
maxIndex = data_idx + g * featLen; maxIndex = data_idx + g * feat_len;
stop = true; stop = true;
} }
} }
...@@ -73,28 +72,25 @@ __global__ void KernelMaxoutGrad( ...@@ -73,28 +72,25 @@ __global__ void KernelMaxoutGrad(
} }
/* /*
* All tensors are in NCHW format. * All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/ */
template <typename MaxOutProcess, typename T> template <typename MaxOutProcess, typename T>
class MaxOutFunctor<platform::GPUPlace, MaxOutProcess, T> { class MaxOutFunctor<platform::GPUPlace, MaxOutProcess, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output, const framework::Tensor& input, framework::Tensor * output,
int groups, int num_channels, int groups,
MaxOutProcess maxout_process) { MaxOutProcess maxout_process) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1]; const int input_channels = input.dims()[1];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
const int output_channels = num_channels / groups; const int output_channels = output->dims()[1];
const int output_height = output.dims()[2]; const int output_height = output->dims()[2];
const int output_width = output.dims()[3]; const int output_width = output->dims()[3];
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace()); T* output_data = output->mutable_data<T>(context.GetPlace());
int nthreads = output->numel();
int nthreads = batch_size * output_channels * output_height * output_width;
int blocks = (nthreads + 1024 - 1) / 1024; int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1); dim3 threads(1024, 1);
dim3 grid(blocks, 1); dim3 grid(blocks, 1);
...@@ -110,8 +106,6 @@ class MaxOutFunctor<platform::GPUPlace, MaxOutProcess, T> { ...@@ -110,8 +106,6 @@ class MaxOutFunctor<platform::GPUPlace, MaxOutProcess, T> {
}; };
/* /*
* All tensors are in NCHW format. * All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/ */
template <typename T> template <typename T>
class MaxOutGradFunctor<platform::GPUPlace, T> { class MaxOutGradFunctor<platform::GPUPlace, T> {
...@@ -120,7 +114,7 @@ class MaxOutGradFunctor<platform::GPUPlace, T> { ...@@ -120,7 +114,7 @@ class MaxOutGradFunctor<platform::GPUPlace, T> {
const framework::Tensor& input, framework::Tensor& input_grad, const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, const framework::Tensor& output_grad,
int groups, int num_channels) { int groups) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1]; const int input_channels = input.dims()[1];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
...@@ -133,8 +127,7 @@ class MaxOutGradFunctor<platform::GPUPlace, T> { ...@@ -133,8 +127,7 @@ class MaxOutGradFunctor<platform::GPUPlace, T> {
const T* output_data = output.data<T>(); const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>(); const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace()); T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
int nthreads = output.numel();
int nthreads = batch_size * output_channels * output_height * output_width;
int blocks = (nthreads + 1024 - 1) / 1024; int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1); dim3 threads(1024, 1);
dim3 grid(blocks, 1); dim3 grid(blocks, 1);
...@@ -152,9 +145,9 @@ template class MaxOutGradFunctor<platform::GPUPlace, float>; ...@@ -152,9 +145,9 @@ template class MaxOutGradFunctor<platform::GPUPlace, float>;
template class MaxOutGradFunctor<platform::GPUPlace, double>; template class MaxOutGradFunctor<platform::GPUPlace, double>;
template class MaxOutFunctor<platform::GPUPlace, template class MaxOutFunctor<platform::GPUPlace,
paddle::operators::math::MaxOut<float>, float>; math::MaxOut<float>, float>;
template class MaxOutFunctor<platform::GPUPlace, template class MaxOutFunctor<platform::GPUPlace,
paddle::operators::math::MaxOut<double>, double>; math::MaxOut<double>, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -22,26 +22,20 @@ namespace paddle { ...@@ -22,26 +22,20 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
#define FLT_MAX \ #define FLT_MAX \
__FLT_MAX__ // It might need to be placed in another file, but I'm still __FLT_MAX__
// wondering where to put it.
/* /*
* \brief Extracting simple operations from pooling. * \brief Extracting simple operations from maxout.
* Both MaxPool and AvgPool need "initial", "compute" and "finalize" * need "initial", "compute"
* operation. * operation.
* MaxPool initializes temp variable to the negative maximum to find the
* maximum value in the pooling field.
* AvgPool initializes temp variable to the zero to accumulate all values
* in pool pooling, and finally takes the average.
* MaxPoolGrad and AvgPoolGrad are gradient operations respectively.
*/ */
template <class T> template <class T>
class MaxOut { class MaxOut {
public: public:
DEVICE inline T initial() { return static_cast<T>(-FLT_MAX); } DEVICE inline T initial() { return static_cast<T>(-FLT_MAX); }
DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; } DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; }
DEVICE inline void finalize(T& y, const T& group) {}
}; };
template <class T> template <class T>
...@@ -69,11 +63,12 @@ class MaxOutGrad { ...@@ -69,11 +63,12 @@ class MaxOutGrad {
* MaxPool2dGradFunctor, MaxPool3dGradFunctor. * MaxPool2dGradFunctor, MaxPool3dGradFunctor.
*/ */
template <typename Place, typename MaxOutProcess, typename T> template <typename Place, typename MaxOutProcess, typename T>
class MaxOutFunctor { class MaxOutFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output, const framework::Tensor& input, framework::Tensor * output,
int groups, int num_channels, MaxOutProcess maxout_compute); int groups, MaxOutProcess maxout_compute);
}; };
...@@ -84,8 +79,7 @@ class MaxOutGradFunctor { ...@@ -84,8 +79,7 @@ class MaxOutGradFunctor {
const framework::Tensor& input, const framework::Tensor& input,
framework::Tensor& input_grad, framework::Tensor& input_grad,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, int groups, const framework::Tensor& output_grad, int groups);
int num_channels);
}; };
......
...@@ -19,17 +19,16 @@ namespace operators { ...@@ -19,17 +19,16 @@ namespace operators {
using framework::Tensor; using framework::Tensor;
/********first define ProtoMaker类 ***************/
class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MaxOutOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) MaxOutOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor) The input tensor of pooling operator. " "(Tensor) The input tensor of maxout operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the " "The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature."); "number of channels, H and W is the height and width of feature.");
AddOutput("Out", AddOutput("Out",
"(Tensor) The output tensor of pooling operator." "(Tensor) The output tensor of maxout operator."
"The format of output tensor is also NCHW." "The format of output tensor is also NCHW."
"Where N is batch size, C is " "Where N is batch size, C is "
"the number of channels, H and W is the height and " "the number of channels, H and W is the height and "
...@@ -38,23 +37,53 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -38,23 +37,53 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>( AddAttr<int>(
"groups", "groups",
R"DOC(The group number of input layer. R"DOC(The group number of input layer.
)DOC") )DOC");
.SetDefault(2); AddComment(R"DOC(
AddAttr<int>( - Input: NCHW.
"num_channels",
R"DOC(The channel number of input layer.
)DOC")
.SetDefault(0);
AddComment(R"DOC(A layer to do max out on conv layer output.
- Input: output of a conv layer.
- Output: feature map size same as input. Channel is (input channel) / groups. - Output: feature map size same as input. Channel is (input channel) / groups.
So groups should be larger than 1, and the num of channels should be able So groups should be larger than 1, and the num of channels should be able
to devided by groups. to devided by groups.
.. math::
y_{si+j} = \max_k x_{gsi + sk + j}
g = groups
s = input.size / num_channels
0 \le i < num_channels / groups
0 \le j < s
0 \le k < groups
Please refer to Paper:
- Maxout Networks: http://www.jmlr.org/proceedings/papers/v28/goodfellow13.pdf
- Multi-digit Number Recognition from Street View \
Imagery using Deep Convolutional Neural Networks: \
https://arxiv.org/pdf/1312.6082v4.pdf
The simple usage is:
.. code-block:: python
maxout = maxout_layer(input,
num_channels=128,
groups=4)
:param input: The input of this layer.
:type input: LayerOutput
:param num_channels: The channel number of input layer. If None will be set
automatically from previous output.
:type num_channels: int | None
:param groups: The group number of input layer.
:type groups: int
:param name: The name of this layer. It is optional.
:type name: None | basestring.
:param layer_attr: Extra Layer attribute.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
)DOC"); )DOC");
} }
}; };
/******************2nd **********************************/
class MaxOutOp : public framework::OperatorWithKernel { class MaxOutOp : public framework::OperatorWithKernel {
public: public:
...@@ -67,20 +96,14 @@ class MaxOutOp : public framework::OperatorWithKernel { ...@@ -67,20 +96,14 @@ class MaxOutOp : public framework::OperatorWithKernel {
"Output(Out) of maxoutOp should not be null."); "Output(Out) of maxoutOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X"); auto in_x_dims = ctx->GetInputDim("X");
int groups = ctx->Attrs().Get<int>("groups"); int groups = ctx->Attrs().Get<int>("groups");
int num_channels = ctx->Attrs().Get<int>("num_channels");
// check groups > 1 // check groups > 1
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
groups, 1, groups, 1,
"in maxoutop groups should be larger than 1"); "in maxoutop groups should be larger than 1");
// check num_channels%groups=0
PADDLE_ENFORCE_EQ(num_channels % groups, 0,
"the num of channels should be able"
"to devided by groups");
int out_num_channels = num_channels / groups;
std::vector<int64_t> output_shape({in_x_dims[0], out_num_channels}); std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1] / groups});
output_shape.push_back(in_x_dims[2]); output_shape.push_back(in_x_dims[2]);
output_shape.push_back(in_x_dims[3]); output_shape.push_back(in_x_dims[3]);
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/maxouting.h" #include "paddle/operators/math/maxouting.h"
...@@ -32,14 +31,13 @@ class MaxOutKernel : public framework::OpKernel<T> { ...@@ -32,14 +31,13 @@ class MaxOutKernel : public framework::OpKernel<T> {
Tensor* out = context.Output<Tensor>("Out"); Tensor* out = context.Output<Tensor>("Out");
int groups = context.template Attr<int>("groups"); int groups = context.template Attr<int>("groups");
int num_channels = context.template Attr<int>("num_channels");
paddle::operators::math::MaxOutFunctor< paddle::operators::math::MaxOutFunctor<
Place, paddle::operators::math::MaxOut<T>, T> Place, paddle::operators::math::MaxOut<T>, T>
maxout_forward; maxout_forward;
paddle::operators::math::MaxOut<T> maxout_process; paddle::operators::math::MaxOut<T> maxout_process;
maxout_forward(context.device_context(), *in_x, *out, groups, num_channels, maxout_forward(context.device_context(), *in_x, out, groups,
maxout_process); maxout_process);
} }
}; };
...@@ -55,7 +53,6 @@ class MaxOutGradKernel : public framework::OpKernel<T> { ...@@ -55,7 +53,6 @@ class MaxOutGradKernel : public framework::OpKernel<T> {
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X")); Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
int groups = context.template Attr<int>("groups"); int groups = context.template Attr<int>("groups");
int num_channels = context.template Attr<int>("num_channels");
...@@ -68,7 +65,7 @@ class MaxOutGradKernel : public framework::OpKernel<T> { ...@@ -68,7 +65,7 @@ class MaxOutGradKernel : public framework::OpKernel<T> {
paddle::operators::math::MaxOutGradFunctor<Place, T> paddle::operators::math::MaxOutGradFunctor<Place, T>
maxout_backward; maxout_backward;
maxout_backward(context.device_context(), *in_x, *in_x_grad, *out, maxout_backward(context.device_context(), *in_x, *in_x_grad, *out,
*out_grad, groups, num_channels); *out_grad, groups);
} }
} }
}; };
......
...@@ -3,22 +3,13 @@ import numpy as np ...@@ -3,22 +3,13 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
def maxout_forward_naive_2sweetsky(input, groups, num_channels):
s0, s1, s2, s3 = input.shape
return np.ndarray([s0, s1 / groups, groups, s2, s3], \
buffer = input, dtype=input.dtype).max(axis=(2))
def maxout_forward_naive(input, groups,num_channels): def maxout_forward_naive(input, groups,num_channels):
s0, s1, s2, s3 = input.shape s0, s1, s2, s3 = input.shape
return np.ndarray([s0, s1 / groups, groups, s2, s3], \ return np.ndarray([s0, s1 / groups, groups, s2, s3], \
buffer = input, dtype=input.dtype).max(axis=(2)) buffer = input, dtype=input.dtype).max(axis=(2))
class TestMaxOutOp(OpTest):
class TestMaxOut_Op(OpTest):
def setUp(self): def setUp(self):
self.op_type = "maxout" self.op_type = "maxout"
self.init_test_case() self.init_test_case()
...@@ -37,7 +28,7 @@ class TestMaxOut_Op(OpTest): ...@@ -37,7 +28,7 @@ class TestMaxOut_Op(OpTest):
def test_check_grad(self): def test_check_grad(self):
print self.inputs print self.inputs
print self.outputs print self.outputs
self.check_grad(['X'], 'Out', max_relative_error=0.5) self.check_grad(['X'], 'Out')
def init_test_case(self): def init_test_case(self):
self.MaxOut_forward_naive = maxout_forward_naive self.MaxOut_forward_naive = maxout_forward_naive
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册